Prepare for fully typechecked codebase

This commit is contained in:
Jonas Schäfer
2021-01-16 16:02:25 +01:00
parent d3777d3b07
commit 9e3fcbaf67
7 changed files with 249 additions and 90 deletions

View File

@@ -1 +1,2 @@
pyscss~=1.3
mypy

31
mypy.ini Normal file
View File

@@ -0,0 +1,31 @@
[mypy]
python_version = 3.7
#warn_return_any = True
warn_unused_configs = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
#check_untyped_defs = True
disallow_untyped_decorators = False
#disallow_any_unimported = True
#disallow_any_expr = True
#disallow_any_decorated = True
disallow_any_explicit = False
#disallow_any_generics = True
disallow_subclassing_any = False
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_unreachable = True
[mypy-hsluv.*]
ignore_missing_imports = True
[mypy-flask_wtf.*]
ignore_missing_imports = True
[mypy-flask_babel.*]
ignore_missing_imports = True
[mypy-wtforms.*]
ignore_missing_imports = True

View File

@@ -3,6 +3,7 @@ import binascii
import itertools
import logging
import pathlib
import typing
from datetime import datetime, timedelta
@@ -46,14 +47,14 @@ class LoginForm(FlaskForm):
@babel.localeselector
def selected_locale():
def selected_locale() -> str:
return request.accept_languages.best_match(
current_app.config['LANGUAGES']
)
@app.route("/login", methods=["GET", "POST"])
async def login():
async def login() -> typing.Union[str, quart.Response]:
if client.has_session and (await client.test_session()):
return redirect(url_for('user.index'))
@@ -78,7 +79,7 @@ async def login():
@app.route("/")
async def home():
async def home() -> quart.Response:
if client.has_session:
return redirect(url_for('user.index'))
@@ -86,21 +87,21 @@ async def home():
@app.route("/meta/about.html")
async def about():
async def about() -> str:
return await render_template("about.html", version=version)
@app.route("/meta/demo.html")
async def demo():
async def demo() -> str:
return await render_template("demo.html")
def repad(s):
def repad(s: str) -> str:
return s + "=" * (4 - len(s) % 4)
@app.route("/avatar/<from_>/<code>")
async def avatar(from_, code):
async def avatar(from_: str, code: str) -> quart.Response:
try:
etag = request.headers["if-none-match"]
except KeyError:
@@ -116,7 +117,7 @@ async def avatar(from_, code):
300,
))
response = Response(None, mimetype=info["type"])
response = Response("", mimetype=info["type"])
response.headers["etag"] = new_etag
# XXX: It seems to me that quart expects localtime(?!) in this field...
response.expires = datetime.now() + cache_ttl
@@ -125,15 +126,17 @@ async def avatar(from_, code):
if etag is not None and new_etag == etag:
response.status_code = 304
response.set_data("")
return response
data = await client.get_avatar_data(address, info["sha1"])
if data is None:
response.status_code = 404
return response
response.status_code = 200
if request.method == "HEAD":
response.content_length = len(data)
response.set_data("")
return response
response.set_data(data)
@@ -141,8 +144,9 @@ async def avatar(from_, code):
@app.context_processor
def proc():
def url_for_avatar(entity, hash_, **kwargs):
def proc() -> typing.Dict[str, typing.Any]:
def url_for_avatar(entity: str, hash_: str,
**kwargs: typing.Any) -> str:
return url_for(
"avatar",
from_=base64.urlsafe_b64encode(
@@ -165,7 +169,7 @@ app.template_filter("repr")(repr)
@app.template_filter("flatten")
def flatten(a, levels=1):
def flatten(a: typing.Iterable, levels: int = 1) -> typing.Iterable:
for i in range(levels):
a = itertools.chain(*a)
return a

View File

@@ -1,12 +1,16 @@
import functools
import hashlib
import typing
import hsluv
# This is essentially an implementation of XEP-0392.
def clip_rgb(r, g, b):
RGBf = typing.Tuple[float, float, float]
def clip_rgb(r: float, g: float, b: float) -> RGBf:
return (
min(max(r, 0), 1),
min(max(g, 0), 1),
@@ -15,7 +19,7 @@ def clip_rgb(r, g, b):
@functools.lru_cache(128)
def text_to_colour(text):
def text_to_colour(text: str) -> RGBf:
MASK = 0xffff
h = hashlib.sha1()
h.update(text.encode("utf-8"))
@@ -26,7 +30,7 @@ def text_to_colour(text):
return r, g, b
def text_to_css(text):
def text_to_css(text: str) -> str:
return "#{:02x}{:02x}{:02x}".format(
*(round(v * 255) for v in text_to_colour(text))
)

View File

@@ -2,6 +2,8 @@ import functools
import hashlib
import logging
import secrets
import types
import typing
import aiohttp
@@ -17,8 +19,11 @@ from . import xmpputil
from .xmpputil import split_jid
T = typing.TypeVar("T")
class HTTPSessionManager:
def __init__(self, app_context_attribute):
def __init__(self, app_context_attribute: str):
self._app_context_attribute = app_context_attribute
async def _create(self) -> aiohttp.ClientSession:
@@ -26,13 +31,14 @@ class HTTPSessionManager:
"Accept": "application/json",
})
async def teardown(self, exc):
async def teardown(self, exc: typing.Optional[BaseException]) -> None:
app_ctx = _app_ctx_stack.top
try:
session = getattr(app_ctx, self._app_context_attribute)
except AttributeError:
return
exc_type: typing.Optional[typing.Type[BaseException]]
if exc is not None:
exc_type = type(exc)
traceback = getattr(exc, "__traceback__", None)
@@ -54,14 +60,22 @@ class HTTPSessionManager:
setattr(app_ctx, self._app_context_attribute, session)
return session
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
# we do nothing on aexit, since the session will be kept alive until
# the application context is torn down in teardown.
pass
class HTTPAuthSessionManager(HTTPSessionManager):
def __init__(self, app_context_attribute, session_token_key):
def __init__(
self,
app_context_attribute: str,
session_token_key: str):
super().__init__(app_context_attribute)
self._session_token_key = session_token_key
@@ -79,9 +93,20 @@ class HTTPAuthSessionManager(HTTPSessionManager):
)
def autosession(f):
class AuthSessionProvider(typing.Protocol):
_auth_session: HTTPAuthSessionManager
def autosession(
f: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, T]]
) -> typing.Callable[...,
typing.Coroutine[typing.Any, typing.Any, T]]:
@functools.wraps(f)
async def wrapper(self, *args, session=None, **kwargs):
async def wrapper(
self: AuthSessionProvider,
*args: typing.Any,
session: typing.Optional[aiohttp.ClientSession] = None,
**kwargs: typing.Any) -> T:
if session is None:
async with self._auth_session as session:
return (await f(self, *args, session=session, **kwargs))
@@ -96,8 +121,8 @@ class ProsodyClient:
SESSION_TOKEN = "prosody_access_token"
SESSION_ADDRESS = "prosody_jid"
def __init__(self, app=None):
self._default_login_redirect = None
def __init__(self, app: typing.Optional[quart.Quart] = None):
self._default_login_redirect: typing.Optional[str] = None
self._plain_session = HTTPSessionManager(self.CTX_PLAIN_SESSION)
self._auth_session = HTTPAuthSessionManager(self.CTX_AUTH_SESSION,
self.SESSION_TOKEN)
@@ -109,34 +134,34 @@ class ProsodyClient:
self.init_app(app)
@property
def default_login_redirect(self):
def default_login_redirect(self) -> typing.Optional[str]:
return self._default_login_redirect
@default_login_redirect.setter
def default_login_redirect(self, v):
def default_login_redirect(self, v: str) -> None:
self._default_login_redirect = v
def init_app(self, app):
def init_app(self, app: quart.Quart) -> None:
app.config[self.CONFIG_ENDPOINT]
app.teardown_appcontext(self._plain_session.teardown)
app.teardown_appcontext(self._auth_session.teardown)
@property
def _endpoint_base(self):
def _endpoint_base(self) -> str:
return current_app.config[self.CONFIG_ENDPOINT]
@property
def _login_endpoint(self):
def _login_endpoint(self) -> str:
return "{}/oauth2/token".format(self._endpoint_base)
@property
def _rest_endpoint(self):
def _rest_endpoint(self) -> str:
return "{}/rest".format(self._endpoint_base)
async def _oauth2_bearer_token(self,
session: aiohttp.ClientSession,
jid: str,
password: str):
password: str) -> None:
request = aiohttp.FormData()
request.add_field("grant_type", "password")
request.add_field("username", jid)
@@ -172,7 +197,7 @@ class ProsodyClient:
)
)
async def login(self, jid: str, password: str):
async def login(self, jid: str, password: str) -> bool:
async with self._plain_session as session:
token = await self._oauth2_bearer_token(session, jid, password)
@@ -181,27 +206,39 @@ class ProsodyClient:
return True
@property
def session_token(self):
def session_token(self) -> str:
try:
return http_session[self.SESSION_TOKEN]
except KeyError:
raise abort(401, "no session")
@property
def session_address(self):
def session_address(self) -> str:
try:
return http_session[self.SESSION_ADDRESS]
except KeyError:
raise abort(401, "no session")
@property
def has_session(self):
def has_session(self) -> bool:
return self.SESSION_TOKEN in http_session
def require_session(self, redirect_to: str = None):
def decorator(f):
def require_session(
self,
redirect_to: typing.Optional[str] = None,
) -> typing.Callable[
[typing.Callable[..., typing.Awaitable[T]]],
typing.Callable[..., typing.Awaitable[
typing.Union[T, quart.Response]]]]:
def decorator(
f: typing.Callable[..., typing.Awaitable[T]],
) -> typing.Callable[..., typing.Awaitable[
typing.Union[T, quart.Response]]]:
@functools.wraps(f)
async def wrapped(*args, **kwargs):
async def wrapped(
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Union[T, quart.Response]:
if not self.has_session or not (await self.test_session()):
nonlocal redirect_to
if redirect_to is not False:
@@ -215,10 +252,18 @@ class ProsodyClient:
return wrapped
return decorator
async def _xml_iq_call(self, session, payload, *, headers=None,
sensitive=False):
headers = headers or {}
headers.update({
async def _xml_iq_call(
self,
session: aiohttp.ClientSession,
payload: ET.Element,
*,
headers: typing.Optional[typing.Mapping[str, str]] = None,
sensitive: bool = False,
) -> ET.Element:
final_headers: typing.MutableMapping[str, str] = {}
if headers is not None:
final_headers.update(headers)
final_headers.update({
"Content-Type": "application/xmpp+xml",
"Accept": "application/xmpp+xml",
})
@@ -232,7 +277,7 @@ class ProsodyClient:
id_, "(sensitive)" if sensitive else serialised,
)
async with session.post(self._rest_endpoint,
headers=headers,
headers=final_headers,
data=serialised) as resp:
if resp.status != 200:
abort(resp.status)
@@ -243,7 +288,7 @@ class ProsodyClient:
)
return ET.fromstring(reply_payload)
async def get_user_info(self):
async def get_user_info(self) -> typing.Mapping:
localpart, domain, _ = split_jid(self.session_address)
async with self._auth_session as session:
@@ -267,7 +312,7 @@ class ProsodyClient:
}
@autosession
async def test_session(self, session):
async def test_session(self, session: aiohttp.ClientSession) -> bool:
req = {
"kind": "iq",
"type": "get",
@@ -279,7 +324,11 @@ class ProsodyClient:
return resp.status == 200
@autosession
async def get_user_nickname(self, session):
async def get_user_nickname(
self,
*,
session: aiohttp.ClientSession,
) -> typing.Optional[str]:
iq_resp = await self._xml_iq_call(
session,
xmpputil.make_nickname_get_request(self.session_address)
@@ -287,7 +336,12 @@ class ProsodyClient:
return xmpputil.extract_nickname_get_reply(iq_resp)
@autosession
async def set_user_nickname(self, new_nickname, session):
async def set_user_nickname(
self,
new_nickname: str,
*,
session: aiohttp.ClientSession,
) -> None:
iq_resp = await self._xml_iq_call(
session,
xmpputil.make_nickname_set_request(self.session_address,
@@ -297,8 +351,13 @@ class ProsodyClient:
xmpputil.extract_iq_reply(iq_resp)
@autosession
async def get_avatar(self, from_, session,
metadata_only=False):
async def get_avatar(
self,
from_: str,
metadata_only: bool = False,
*,
session: aiohttp.ClientSession,
) -> typing.Mapping:
metadata_resp = await self._xml_iq_call(
session,
xmpputil.make_avatar_metadata_request(from_)
@@ -316,7 +375,13 @@ class ProsodyClient:
return info
@autosession
async def get_avatar_data(self, from_, id_, session):
async def get_avatar_data(
self,
from_: str,
id_: str,
*,
session: aiohttp.ClientSession,
) -> typing.Optional[bytes]:
data_resp = await self._xml_iq_call(
session,
xmpputil.make_avatar_data_request(from_, id_)
@@ -324,7 +389,13 @@ class ProsodyClient:
return xmpputil.extract_avatar_data_get_reply(data_resp)
@autosession
async def set_user_avatar(self, data, mimetype, session):
async def set_user_avatar(
self,
data: bytes,
mimetype: str,
*,
session: aiohttp.ClientSession,
) -> None:
id_ = hashlib.sha1(data).hexdigest()
data_resp = await self._xml_iq_call(
@@ -348,7 +419,11 @@ class ProsodyClient:
)
xmpputil.extract_iq_reply(metadata_resp)
async def change_password(self, current_password, new_password):
async def change_password(
self,
current_password: str,
new_password: str,
) -> None:
# we play it safe here and do not use the existing auth session;
# instead, we do a login on the plain session and use the token we
# got there, replacing the current session token on the way.
@@ -375,7 +450,7 @@ class ProsodyClient:
# server to expire/revoke all tokens on password change.
http_session[self.SESSION_TOKEN] = token
async def logout(self):
async def logout(self) -> None:
# this currently only kills the cookie stuff, we may want to invalidate
# the token on the server side, toos
# See-Also: https://issues.prosody.im/1503

View File

@@ -1,3 +1,5 @@
import typing
import quart.flask_patch
from quart import Blueprint, render_template, request, redirect, url_for
@@ -12,7 +14,7 @@ user_bp = Blueprint('user', __name__, url_prefix="/user")
@user_bp.context_processor
async def proc():
async def proc() -> typing.Mapping[str, typing.Any]:
return {"user_info": await client.get_user_info()}
@@ -53,14 +55,14 @@ class ProfileForm(FlaskForm):
@user_bp.route("/")
@client.require_session()
async def index():
async def index() -> str:
user_info = await client.get_user_info()
return await render_template("user_home.html", user_info=user_info)
@user_bp.route('/passwd', methods=["GET", "POST"])
@client.require_session()
async def change_pw():
async def change_pw() -> typing.Union[str, quart.Response]:
form = ChangePasswordForm()
if form.validate_on_submit():
try:
@@ -81,7 +83,7 @@ async def change_pw():
@user_bp.route("/profile", methods=["GET", "POST"])
@client.require_session()
async def profile():
async def profile() -> typing.Union[str, quart.Response]:
form = ProfileForm()
if request.method != "POST":
user_info = await client.get_user_info()
@@ -106,7 +108,7 @@ async def profile():
@user_bp.route("/logout", methods=["GET", "POST"])
@client.require_session()
async def logout():
async def logout() -> typing.Union[quart.Response, str]:
form = LogoutForm()
if form.validate_on_submit():
await client.logout()

View File

@@ -37,7 +37,13 @@ NS_USER_AVATAR_DATA = "urn:xmpp:avatar:data"
TAG_USER_AVATAR_DATA = "{{{}}}data".format(NS_USER_AVATAR_DATA)
def split_jid(s):
SimpleJID = typing.Tuple[typing.Optional[str], str, typing.Optional[str]]
T = typing.TypeVar("T")
def split_jid(s: str) -> SimpleJID:
resource: typing.Optional[str]
localpart: typing.Optional[str]
bare, sep, resource = s.partition("/")
if not sep:
resource = None
@@ -48,32 +54,38 @@ def split_jid(s):
return localpart, domain, resource
def raise_iq_error(err: ET.Element):
def raise_iq_error(err: ET.Element) -> None:
err_condition_el = None
err_text_el = None
err_app_def_condition_el = None
# err_text_el = None
# err_app_def_condition_el = None
for el in err:
if el.tag == TAG_XMPP_ERROR_TEXT:
err_text_el = el
# err_text_el = el
continue
elif el.tag.startswith("{{{}}}".format(NS_XMPP_ERROR_CONDITION)):
err_condition_el = el
else:
err_app_def_condition_el = el
# else:
# err_app_def_condition_el = el
print(err_text_el, err_condition_el, err_app_def_condition_el)
if err_condition_el is None:
condition_tag = "undefined-condition"
else:
condition_tag = err_condition_el.tag
abort(ERROR_CODE_MAP.get(err_condition_el.tag, 500),
err_condition_el.tag)
# print(err_text_el, err_condition_el, err_app_def_condition_el)
abort(ERROR_CODE_MAP.get(condition_tag, 500), condition_tag)
def extract_iq_reply(tree: ET.Element,
require_tag: str = None) -> typing.Optional[ET.Element]:
def extract_iq_reply(
tree: ET.Element,
require_tag: typing.Optional[str] = None,
) -> typing.Optional[ET.Element]:
iq_type = tree.get("type")
if iq_type == "error":
error = tree.find(TAG_XMPP_ERROR)
if error is not None:
raise raise_iq_error(error)
raise_iq_error(error)
raise abort(500, "malformed reply")
elif iq_type == "result":
if len(tree) > 0:
@@ -88,7 +100,7 @@ def extract_iq_reply(tree: ET.Element,
raise abort(500, "unsupported reply")
def make_password_change_request(jid, password):
def make_password_change_request(jid: str, password: str) -> ET.Element:
username, domain, _ = split_jid(jid)
# XXX: this is due to a problem with mod_rest / mod_register in prosody:
# it doesnt recognize the password change stanza unless we send it to
@@ -100,7 +112,10 @@ def make_password_change_request(jid, password):
return req
def make_pubsub_item_put_request(to, node, id_=None):
def make_pubsub_item_put_request(
to: str, node: str,
id_: typing.Optional[str] = None,
) -> typing.Tuple[ET.Element, ET.Element]:
req = ET.Element("iq", type="set", to=to)
q = ET.SubElement(req, "pubsub", xmlns=NS_PUBSUB)
publish = ET.SubElement(q, "publish", node=node)
@@ -110,7 +125,7 @@ def make_pubsub_item_put_request(to, node, id_=None):
return req, item
def make_nickname_set_request(to, nickname):
def make_nickname_set_request(to: str, nickname: str) -> ET.Element:
req, item = make_pubsub_item_put_request(
to,
NODE_USER_NICKNAME,
@@ -119,7 +134,11 @@ def make_nickname_set_request(to, nickname):
return req
def make_pubsub_item_request(to, node, id_=None):
def make_pubsub_item_request(
to: str,
node: str,
id_: typing.Optional[str] = None,
) -> ET.Element:
req = ET.Element("iq", type="get", to=to)
q = ET.SubElement(req, "pubsub", xmlns=NS_PUBSUB)
items = ET.SubElement(q, "items", node=node)
@@ -131,19 +150,23 @@ def make_pubsub_item_request(to, node, id_=None):
return req
def make_nickname_get_request(to):
def make_nickname_get_request(to: str) -> ET.Element:
return make_pubsub_item_request(to, NODE_USER_NICKNAME)
def make_avatar_metadata_request(to):
def make_avatar_metadata_request(to: str) -> ET.Element:
return make_pubsub_item_request(to, NODE_USER_AVATAR_METADATA)
def make_avatar_data_request(to, sha1):
def make_avatar_data_request(to: str, sha1: str) -> ET.Element:
return make_pubsub_item_request(to, NODE_USER_AVATAR_DATA, id_=sha1)
def make_avatar_data_set_request(to, data, id_):
def make_avatar_data_set_request(
to: str,
data: bytes,
id_: str,
) -> ET.Element:
req, item = make_pubsub_item_put_request(
to,
NODE_USER_AVATAR_DATA,
@@ -154,9 +177,14 @@ def make_avatar_data_set_request(to, data, id_):
return req
def make_avatar_metadata_set_request(to, mimetype: str, id_: str, size: int,
width: int = None,
height: int = None):
def make_avatar_metadata_set_request(
to: str,
mimetype: str,
id_: str,
size: int,
width: typing.Optional[int] = None,
height: typing.Optional[int] = None,
) -> ET.Element:
req, item = make_pubsub_item_put_request(
to,
NODE_USER_AVATAR_METADATA,
@@ -166,7 +194,7 @@ def make_avatar_metadata_set_request(to, mimetype: str, id_: str, size: int,
item,
"metadata", xmlns=NS_USER_AVATAR_METADATA)
attr = {
attr: typing.MutableMapping[str, str] = {
"id": id_,
"bytes": str(size),
"type": mimetype,
@@ -187,12 +215,19 @@ def _require_child(t: ET.Element, tag: str) -> ET.Element:
return el
def extract_pubsub_item_get_reply(iq_tree, payload_tag):
def extract_pubsub_item_get_reply(
iq_tree: ET.Element,
payload_tag: str,
) -> typing.Optional[ET.Element]:
try:
pubsub = extract_iq_reply(iq_tree, TAG_PUBSUB)
except quart.exceptions.NotFound:
return None
if pubsub is None:
# no payload in IQ reply
raise abort(500, "malformed reply")
items = _require_child(pubsub, TAG_PUBSUB_ITEMS)
if len(items) == 0:
return None
@@ -200,14 +235,16 @@ def extract_pubsub_item_get_reply(iq_tree, payload_tag):
return _require_child(_require_child(items, TAG_PUBSUB_ITEM), payload_tag)
def extract_nickname_get_reply(iq_tree):
def extract_nickname_get_reply(iq_tree: ET.Element) -> typing.Optional[str]:
nick = extract_pubsub_item_get_reply(iq_tree, TAG_USER_NICKNAME_NICK)
if nick is None:
return None
return nick.text
def extract_avatar_metadata_get_reply(iq_tree):
def extract_avatar_metadata_get_reply(
iq_tree: ET.Element,
) -> typing.Optional[typing.MutableMapping[str, typing.Any]]:
metadata = extract_pubsub_item_get_reply(iq_tree, TAG_USER_AVATAR_METADATA)
if metadata is None:
return None
@@ -218,12 +255,15 @@ def extract_avatar_metadata_get_reply(iq_tree):
info = metadata[0]
attrs = info.attrib
result = {
result: typing.MutableMapping[str, typing.Optional[str]] = {
"sha1": attrs["id"],
"type": attrs.get("type", "image/png"),
}
def extract_optional(key, type_=int):
def extract_optional(
key: str,
type_: typing.Callable[[str], typing.Any] = lambda x: int(x),
) -> None:
try:
result[key] = type_(attrs[key])
except (KeyError, ValueError, TypeError):
@@ -236,8 +276,10 @@ def extract_avatar_metadata_get_reply(iq_tree):
return result
def extract_avatar_data_get_reply(iq_tree):
def extract_avatar_data_get_reply(
iq_tree: ET.Element,
) -> typing.Optional[bytes]:
data = extract_pubsub_item_get_reply(iq_tree, TAG_USER_AVATAR_DATA)
if data.text is None:
if data is None or data.text is None:
return None
return base64.b64decode(data.text)