diff --git a/snikket_web/prosodyclient.py b/snikket_web/prosodyclient.py index bee17fe..57eac0c 100644 --- a/snikket_web/prosodyclient.py +++ b/snikket_web/prosodyclient.py @@ -359,7 +359,7 @@ class ProsodyClient: if app is not None: self.init_app(app) - self._client_info = None + self._client_info: typing.Optional[typing.Mapping[str, str]] = None @property def default_login_redirect(self) -> typing.Optional[str]: @@ -418,13 +418,10 @@ class ProsodyClient: "scope", " ".join([SCOPE_RESTRICTED, SCOPE_DEFAULT, SCOPE_ADMIN]) ) - auth = BasicAuth( - login=self._client_info["client_id"], - password=self._client_info["client_secret"], - ) - self.logger.debug("sending OAuth2 request (payload omitted)") - async with session.post(self._login_endpoint, auth=auth, data=request) as resp: + async with session.post( + self._login_endpoint, auth=self.get_client_auth(), data=request + ) as resp: auth_status = resp.status auth_info: typing.Mapping[str, str] = await resp.json() @@ -458,10 +455,10 @@ class ProsodyClient: http_session[self.SESSION_TOKEN] = token_info.token http_session[self.SESSION_CACHED_SCOPE] = " ".join(token_info.scopes) - def is_client_registered(self): + def is_client_registered(self) -> bool: return self._client_info is not None - async def register_client(self): + async def register_client(self) -> None: self.logger.debug( "sending OAuth2 client registration request (payload omitted)" ) @@ -493,6 +490,15 @@ class ProsodyClient: self._client_info = await resp.json() + def get_client_auth(self) -> BasicAuth: + if self._client_info is None: + raise RuntimeError("Client is not registered") + + return BasicAuth( + login=self._client_info["client_id"], + password=self._client_info["client_secret"], + ) + async def login(self, jid: str, password: str) -> bool: async with self._plain_session as session: token_info = await self._oauth2_bearer_token(