diff --git a/snikket_web/prosodyclient.py b/snikket_web/prosodyclient.py index f084cd2..56d5f86 100644 --- a/snikket_web/prosodyclient.py +++ b/snikket_web/prosodyclient.py @@ -12,6 +12,7 @@ import typing_extensions from datetime import datetime, timezone import aiohttp +from aiohttp import BasicAuth import xml.etree.ElementTree as ET @@ -361,6 +362,8 @@ class ProsodyClient: if app is not None: self.init_app(app) + self._client_info = None + @property def default_login_redirect(self) -> typing.Optional[str]: return self._default_login_redirect @@ -390,6 +393,10 @@ class ProsodyClient: def _rest_endpoint(self) -> str: return "{}/rest".format(self._endpoint_base) + @property + def _register_client_endpoint(self) -> str: + return "{}/oauth2/register".format(self._endpoint_base) + def _admin_v1_endpoint(self, subpath: str) -> str: return "{}/admin_api{}".format(self._endpoint_base, subpath) @@ -403,17 +410,26 @@ class ProsodyClient: session: aiohttp.ClientSession, jid: str, password: str) -> TokenInfo: + if not self.is_client_registered(): + self.logger.debug("registering oauth client...") + await self.register_client() + self.logger.debug("registered client!") request = aiohttp.FormData() request.add_field("grant_type", "password") - request.add_field("username", jid) + request.add_field("username", jid.split("@")[0]) request.add_field("password", password) request.add_field( "scope", " ".join([SCOPE_RESTRICTED, SCOPE_DEFAULT, SCOPE_ADMIN]) ) + auth = BasicAuth( + login=self._client_info["client_id"], + password=self._client_info["client_secret"], + ) + self.logger.debug("sending OAuth2 request (payload omitted)") - async with session.post(self._login_endpoint, data=request) as resp: + async with session.post(self._login_endpoint, auth=auth, data=request) as resp: auth_status = resp.status auth_info: typing.Mapping[str, str] = (await resp.json()) @@ -449,6 +465,37 @@ class ProsodyClient: http_session[self.SESSION_TOKEN] = token_info.token http_session[self.SESSION_CACHED_SCOPE] = " ".join(token_info.scopes) + def is_client_registered(self): + return self._client_info is not None + + async def register_client(self): + self.logger.debug("sending OAuth2 client registration request (payload omitted)") + registration_data = { + "client_name": "Snikket web portal", + "client_uri": "https://{}".format(current_app.config["SNIKKET_DOMAIN"]), + # This redirect URI is not used, because we use the password grant type. + # However, we're registering it with a sensible value because 1) Prosody + # requires us to provide at least one redirect_uri, and 2) if we ever + # need it in the future, we won't have to re-register. + "redirect_uris": ["https://{}/login_result".format(current_app.config["SNIKKET_DOMAIN"])], + "grant_types": ["password"], + "response_types": ["code"], + } + async with self._plain_session as session: + async with session.post(self._register_client_endpoint, json=registration_data) as resp: + reg_status = resp.status + auth_info: typing.Mapping[str, str] = (await resp.json()) + + if reg_status != 201: + raise RuntimeError( + "Failed to register with backend server: ({}): {}", + reg_status, + await resp.text() + ) + + self._client_info = await resp.json() + + async def login(self, jid: str, password: str) -> bool: async with self._plain_session as session: token_info = await self._oauth2_bearer_token(