Prepare for fully typechecked codebase

This commit is contained in:
Jonas Schäfer
2021-01-16 16:02:25 +01:00
parent d3777d3b07
commit 9e3fcbaf67
7 changed files with 249 additions and 90 deletions

View File

@@ -2,6 +2,8 @@ import functools
import hashlib
import logging
import secrets
import types
import typing
import aiohttp
@@ -17,8 +19,11 @@ from . import xmpputil
from .xmpputil import split_jid
T = typing.TypeVar("T")
class HTTPSessionManager:
def __init__(self, app_context_attribute):
def __init__(self, app_context_attribute: str):
self._app_context_attribute = app_context_attribute
async def _create(self) -> aiohttp.ClientSession:
@@ -26,13 +31,14 @@ class HTTPSessionManager:
"Accept": "application/json",
})
async def teardown(self, exc):
async def teardown(self, exc: typing.Optional[BaseException]) -> None:
app_ctx = _app_ctx_stack.top
try:
session = getattr(app_ctx, self._app_context_attribute)
except AttributeError:
return
exc_type: typing.Optional[typing.Type[BaseException]]
if exc is not None:
exc_type = type(exc)
traceback = getattr(exc, "__traceback__", None)
@@ -54,14 +60,22 @@ class HTTPSessionManager:
setattr(app_ctx, self._app_context_attribute, session)
return session
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
# we do nothing on aexit, since the session will be kept alive until
# the application context is torn down in teardown.
pass
class HTTPAuthSessionManager(HTTPSessionManager):
def __init__(self, app_context_attribute, session_token_key):
def __init__(
self,
app_context_attribute: str,
session_token_key: str):
super().__init__(app_context_attribute)
self._session_token_key = session_token_key
@@ -79,9 +93,20 @@ class HTTPAuthSessionManager(HTTPSessionManager):
)
def autosession(f):
class AuthSessionProvider(typing.Protocol):
_auth_session: HTTPAuthSessionManager
def autosession(
f: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, T]]
) -> typing.Callable[...,
typing.Coroutine[typing.Any, typing.Any, T]]:
@functools.wraps(f)
async def wrapper(self, *args, session=None, **kwargs):
async def wrapper(
self: AuthSessionProvider,
*args: typing.Any,
session: typing.Optional[aiohttp.ClientSession] = None,
**kwargs: typing.Any) -> T:
if session is None:
async with self._auth_session as session:
return (await f(self, *args, session=session, **kwargs))
@@ -96,8 +121,8 @@ class ProsodyClient:
SESSION_TOKEN = "prosody_access_token"
SESSION_ADDRESS = "prosody_jid"
def __init__(self, app=None):
self._default_login_redirect = None
def __init__(self, app: typing.Optional[quart.Quart] = None):
self._default_login_redirect: typing.Optional[str] = None
self._plain_session = HTTPSessionManager(self.CTX_PLAIN_SESSION)
self._auth_session = HTTPAuthSessionManager(self.CTX_AUTH_SESSION,
self.SESSION_TOKEN)
@@ -109,34 +134,34 @@ class ProsodyClient:
self.init_app(app)
@property
def default_login_redirect(self):
def default_login_redirect(self) -> typing.Optional[str]:
return self._default_login_redirect
@default_login_redirect.setter
def default_login_redirect(self, v):
def default_login_redirect(self, v: str) -> None:
self._default_login_redirect = v
def init_app(self, app):
def init_app(self, app: quart.Quart) -> None:
app.config[self.CONFIG_ENDPOINT]
app.teardown_appcontext(self._plain_session.teardown)
app.teardown_appcontext(self._auth_session.teardown)
@property
def _endpoint_base(self):
def _endpoint_base(self) -> str:
return current_app.config[self.CONFIG_ENDPOINT]
@property
def _login_endpoint(self):
def _login_endpoint(self) -> str:
return "{}/oauth2/token".format(self._endpoint_base)
@property
def _rest_endpoint(self):
def _rest_endpoint(self) -> str:
return "{}/rest".format(self._endpoint_base)
async def _oauth2_bearer_token(self,
session: aiohttp.ClientSession,
jid: str,
password: str):
password: str) -> None:
request = aiohttp.FormData()
request.add_field("grant_type", "password")
request.add_field("username", jid)
@@ -172,7 +197,7 @@ class ProsodyClient:
)
)
async def login(self, jid: str, password: str):
async def login(self, jid: str, password: str) -> bool:
async with self._plain_session as session:
token = await self._oauth2_bearer_token(session, jid, password)
@@ -181,27 +206,39 @@ class ProsodyClient:
return True
@property
def session_token(self):
def session_token(self) -> str:
try:
return http_session[self.SESSION_TOKEN]
except KeyError:
raise abort(401, "no session")
@property
def session_address(self):
def session_address(self) -> str:
try:
return http_session[self.SESSION_ADDRESS]
except KeyError:
raise abort(401, "no session")
@property
def has_session(self):
def has_session(self) -> bool:
return self.SESSION_TOKEN in http_session
def require_session(self, redirect_to: str = None):
def decorator(f):
def require_session(
self,
redirect_to: typing.Optional[str] = None,
) -> typing.Callable[
[typing.Callable[..., typing.Awaitable[T]]],
typing.Callable[..., typing.Awaitable[
typing.Union[T, quart.Response]]]]:
def decorator(
f: typing.Callable[..., typing.Awaitable[T]],
) -> typing.Callable[..., typing.Awaitable[
typing.Union[T, quart.Response]]]:
@functools.wraps(f)
async def wrapped(*args, **kwargs):
async def wrapped(
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[T, quart.Response]:
if not self.has_session or not (await self.test_session()):
nonlocal redirect_to
if redirect_to is not False:
@@ -215,10 +252,18 @@ class ProsodyClient:
return wrapped
return decorator
async def _xml_iq_call(self, session, payload, *, headers=None,
sensitive=False):
headers = headers or {}
headers.update({
async def _xml_iq_call(
self,
session: aiohttp.ClientSession,
payload: ET.Element,
*,
headers: typing.Optional[typing.Mapping[str, str]] = None,
sensitive: bool = False,
) -> ET.Element:
final_headers: typing.MutableMapping[str, str] = {}
if headers is not None:
final_headers.update(headers)
final_headers.update({
"Content-Type": "application/xmpp+xml",
"Accept": "application/xmpp+xml",
})
@@ -232,7 +277,7 @@ class ProsodyClient:
id_, "(sensitive)" if sensitive else serialised,
)
async with session.post(self._rest_endpoint,
headers=headers,
headers=final_headers,
data=serialised) as resp:
if resp.status != 200:
abort(resp.status)
@@ -243,7 +288,7 @@ class ProsodyClient:
)
return ET.fromstring(reply_payload)
async def get_user_info(self):
async def get_user_info(self) -> typing.Mapping:
localpart, domain, _ = split_jid(self.session_address)
async with self._auth_session as session:
@@ -267,7 +312,7 @@ class ProsodyClient:
}
@autosession
async def test_session(self, session):
async def test_session(self, session: aiohttp.ClientSession) -> bool:
req = {
"kind": "iq",
"type": "get",
@@ -279,7 +324,11 @@ class ProsodyClient:
return resp.status == 200
@autosession
async def get_user_nickname(self, session):
async def get_user_nickname(
self,
*,
session: aiohttp.ClientSession,
) -> typing.Optional[str]:
iq_resp = await self._xml_iq_call(
session,
xmpputil.make_nickname_get_request(self.session_address)
@@ -287,7 +336,12 @@ class ProsodyClient:
return xmpputil.extract_nickname_get_reply(iq_resp)
@autosession
async def set_user_nickname(self, new_nickname, session):
async def set_user_nickname(
self,
new_nickname: str,
*,
session: aiohttp.ClientSession,
) -> None:
iq_resp = await self._xml_iq_call(
session,
xmpputil.make_nickname_set_request(self.session_address,
@@ -297,8 +351,13 @@ class ProsodyClient:
xmpputil.extract_iq_reply(iq_resp)
@autosession
async def get_avatar(self, from_, session,
metadata_only=False):
async def get_avatar(
self,
from_: str,
metadata_only: bool = False,
*,
session: aiohttp.ClientSession,
) -> typing.Mapping:
metadata_resp = await self._xml_iq_call(
session,
xmpputil.make_avatar_metadata_request(from_)
@@ -316,7 +375,13 @@ class ProsodyClient:
return info
@autosession
async def get_avatar_data(self, from_, id_, session):
async def get_avatar_data(
self,
from_: str,
id_: str,
*,
session: aiohttp.ClientSession,
) -> typing.Optional[bytes]:
data_resp = await self._xml_iq_call(
session,
xmpputil.make_avatar_data_request(from_, id_)
@@ -324,7 +389,13 @@ class ProsodyClient:
return xmpputil.extract_avatar_data_get_reply(data_resp)
@autosession
async def set_user_avatar(self, data, mimetype, session):
async def set_user_avatar(
self,
data: bytes,
mimetype: str,
*,
session: aiohttp.ClientSession,
) -> None:
id_ = hashlib.sha1(data).hexdigest()
data_resp = await self._xml_iq_call(
@@ -348,7 +419,11 @@ class ProsodyClient:
)
xmpputil.extract_iq_reply(metadata_resp)
async def change_password(self, current_password, new_password):
async def change_password(
self,
current_password: str,
new_password: str,
) -> None:
# we play it safe here and do not use the existing auth session;
# instead, we do a login on the plain session and use the token we
# got there, replacing the current session token on the way.
@@ -375,7 +450,7 @@ class ProsodyClient:
# server to expire/revoke all tokens on password change.
http_session[self.SESSION_TOKEN] = token
async def logout(self):
async def logout(self) -> None:
# this currently only kills the cookie stuff, we may want to invalidate
# the token on the server side, toos
# See-Also: https://issues.prosody.im/1503