Fix many type errors.

This commit is contained in:
Tom Hacohen 2020-12-29 13:22:36 +02:00
parent e13f26ec56
commit 794b5f3983
12 changed files with 87 additions and 64 deletions

View File

@ -1,9 +1,9 @@
from django.contrib.auth import get_user_model
from django.db import models
from django.utils import timezone
from django.utils.crypto import get_random_string
from myauth.models import get_typed_user_model
User = get_user_model()
User = get_typed_user_model()
def generate_key():

View File

@ -1,13 +1,13 @@
import typing as t
from dataclasses import dataclass
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied
from myauth.models import UserType, get_typed_user_model
from . import app_settings
User = get_user_model()
User = get_typed_user_model()
@dataclass
@ -15,7 +15,7 @@ class CallbackContext:
"""Class for passing extra context to callbacks"""
url_kwargs: t.Dict[str, t.Any]
user: t.Optional[User] = None
user: t.Optional[UserType] = None
def get_user_queryset(queryset, context: CallbackContext):

View File

@ -9,7 +9,7 @@ import nacl.secret
import nacl.signing
from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
from django.contrib.auth import user_logged_out, user_logged_in
from django.core import exceptions as django_exceptions
from django.db import transaction
from fastapi import APIRouter, Depends, status, Request
@ -19,12 +19,13 @@ from django_etebase.token_auth.models import AuthToken
from django_etebase.models import UserInfo
from django_etebase.signals import user_signed_up
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
from myauth.models import UserType, get_typed_user_model
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
from .msgpack import MsgpackRoute
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
from .dependencies import AuthData, get_auth_data, get_authenticated_user
User = get_user_model()
User = get_typed_user_model()
authentication_router = APIRouter(route_class=MsgpackRoute)
@ -52,7 +53,7 @@ class UserOut(BaseModel):
encryptedContent: bytes
@classmethod
def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut":
def from_orm(cls: t.Type["UserOut"], obj: UserType) -> "UserOut":
return cls(
username=obj.username,
email=obj.email,
@ -66,7 +67,7 @@ class LoginOut(BaseModel):
user: UserOut
@classmethod
def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut":
def from_orm(cls: t.Type["LoginOut"], obj: UserType) -> "LoginOut":
token = AuthToken.objects.create(user=obj).key
user = UserOut.from_orm(obj)
return cls(token=token, user=user)
@ -111,7 +112,7 @@ class SignupIn(BaseModel):
@sync_to_async
def __get_login_user(username: str) -> User:
def __get_login_user(username: str) -> UserType:
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
try:
user = User.objects.get(**kwargs)
@ -122,7 +123,7 @@ def __get_login_user(username: str) -> User:
raise AuthenticationFailed(code="user_not_found", detail="User not found")
async def get_login_user(challenge: LoginChallengeIn) -> User:
async def get_login_user(challenge: LoginChallengeIn) -> UserType:
user = await __get_login_user(challenge.username)
return user
@ -138,7 +139,7 @@ def get_encryption_key(salt):
)
def save_changed_password(data: ChangePassword, user: User):
def save_changed_password(data: ChangePassword, user: UserType):
response_data = data.response_data
user_info: UserInfo = user.userinfo
user_info.loginPubkey = response_data.loginPubkey
@ -150,7 +151,7 @@ def save_changed_password(data: ChangePassword, user: User):
def validate_login_request(
validated_data: LoginResponse,
challenge_sent_to_user: Authentication,
user: User,
user: UserType,
expected_action: str,
host_from_request: str,
):
@ -159,7 +160,7 @@ def validate_login_request(
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
now = int(datetime.now().timestamp())
if validated_data.action != expected_action:
raise HttpError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else')
raise HttpError("wrong_action", f'Expected "{expected_action}" but got something else')
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
raise HttpError("challenge_expired", "Login challenge has expired")
elif challenge_data["userId"] != user.id:
@ -181,7 +182,7 @@ async def is_etebase():
@authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
def login_challenge(user: User = Depends(get_login_user)):
def login_challenge(user: UserType = Depends(get_login_user)):
salt = bytes(user.userinfo.salt)
enc_key = get_encryption_key(salt)
box = nacl.secret.SecretBox(enc_key)
@ -210,14 +211,14 @@ def logout(auth_data: AuthData = Depends(get_auth_data)):
@authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
async def change_password(data: ChangePassword, request: Request, user: UserType = Depends(get_authenticated_user)):
host = request.headers.get("Host")
await validate_login_request(data.response_data, data, user, "changePassword", host)
await sync_to_async(save_changed_password)(data, user)
@authentication_router.post("/dashboard_url/", responses=permission_responses)
def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)):
def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_user)):
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
if get_dashboard_url is None:
raise HttpError("not_supported", "This server doesn't have a user dashboard.")
@ -228,7 +229,7 @@ def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)
return ret
def signup_save(data: SignupIn, request: Request) -> User:
def signup_save(data: SignupIn, request: Request) -> UserType:
user_data = data.user
with transaction.atomic():
try:

View File

@ -1,7 +1,6 @@
import typing as t
from asgiref.sync import sync_to_async
from django.contrib.auth import get_user_model
from django.core import exceptions as django_exceptions
from django.core.files.base import ContentFile
from django.db import transaction, IntegrityError
@ -9,6 +8,7 @@ from django.db.models import Q, QuerySet
from fastapi import APIRouter, Depends, status, Request
from django_etebase import models
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user
from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
from .msgpack import MsgpackRoute
@ -27,7 +27,7 @@ from .utils import (
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
from .sendfile import sendfile
User = get_user_model()
User = get_typed_user_model
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
@ -36,11 +36,14 @@ class ListMulti(BaseModel):
collectionTypes: t.List[bytes]
ChunkType = t.Tuple[str, t.Optional[bytes]]
class CollectionItemRevisionInOut(BaseModel):
uid: str
meta: bytes
deleted: bool
chunks: t.List[t.Tuple[str, t.Optional[bytes]]]
chunks: t.List[ChunkType]
class Config:
orm_mode = True
@ -49,7 +52,7 @@ class CollectionItemRevisionInOut(BaseModel):
def from_orm_context(
cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context
) -> "CollectionItemRevisionInOut":
chunks = []
chunks: t.List[ChunkType] = []
for chunk_relation in obj.chunks_relation.all():
chunk_obj = chunk_relation.chunk
if context.prefetch == "auto":
@ -185,7 +188,7 @@ class ItemBatchIn(BaseModel):
@sync_to_async
def collection_list_common(
queryset: QuerySet,
user: User,
user: UserType,
stoken: t.Optional[str],
limit: int,
prefetch: Prefetch,
@ -210,7 +213,7 @@ def collection_list_common(
remed = remed_qs.values_list("collection__uid", flat=True)
if len(remed) > 0:
ret.removedMemberships = [{"uid": x} for x in remed]
ret.removedMemberships = [RemovedMembershipOut(uid=x) for x in remed]
return ret
@ -219,14 +222,14 @@ def collection_list_common(
def verify_collection_admin(
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user)
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
):
if not is_collection_admin(collection, user):
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
def has_write_access(
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user)
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
):
member = collection.members.get(user=user)
if member.accessLevel == models.AccessLevels.READ_ONLY:
@ -247,7 +250,7 @@ async def list_multi(
stoken: t.Optional[str] = None,
limit: int = 50,
queryset: QuerySet = Depends(get_collection_queryset),
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery,
):
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
@ -263,7 +266,7 @@ async def collection_list(
stoken: t.Optional[str] = None,
limit: int = 50,
prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_collection_queryset),
):
return await collection_list_common(queryset, user, stoken, limit, prefetch)
@ -299,7 +302,7 @@ def process_revisions_for_item(item: models.CollectionItem, revision_data: Colle
return revision
def _create(data: CollectionIn, user: User):
def _create(data: CollectionIn, user: UserType):
with transaction.atomic():
if data.item.etag is not None:
raise ValidationError("bad_etag", "etag is not null")
@ -335,14 +338,14 @@ def _create(data: CollectionIn, user: User):
@collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)):
async def create(data: CollectionIn, user: UserType = Depends(get_authenticated_user)):
await sync_to_async(_create)(data, user)
@collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ)
def collection_get(
obj: models.Collection = Depends(get_collection),
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery,
):
return CollectionOut.from_orm_context(obj, Context(user, prefetch))
@ -393,7 +396,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
def item_get(
item_uid: str,
queryset: QuerySet = Depends(get_item_queryset),
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery,
):
obj = queryset.get(uid=item_uid)
@ -403,7 +406,7 @@ def item_get(
@sync_to_async
def item_list_common(
queryset: QuerySet,
user: User,
user: UserType,
stoken: t.Optional[str],
limit: int,
prefetch: Prefetch,
@ -424,7 +427,7 @@ async def item_list(
limit: int = 50,
prefetch: Prefetch = PrefetchQuery,
withCollection: bool = False,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
):
if not withCollection:
queryset = queryset.filter(parent__isnull=True)
@ -433,7 +436,7 @@ async def item_list(
return response
def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid: str, validate_etag: bool):
def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
queryset = get_collection_queryset(user)
with transaction.atomic(): # We need this for locking the collection object
collection_object = queryset.select_for_update().get(uid=uid)
@ -467,7 +470,7 @@ def item_revisions(
limit: int = 50,
iterator: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
items: QuerySet = Depends(get_item_queryset),
):
item = get_object_or_404(items, uid=item_uid)
@ -501,7 +504,7 @@ def fetch_updates(
data: t.List[CollectionItemBulkGetIn],
stoken: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_item_queryset),
):
# FIXME: make configurable?
@ -531,14 +534,14 @@ def fetch_updates(
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
def item_transaction(
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user)
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True)
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
def item_batch(
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user)
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False)

View File

@ -3,17 +3,17 @@ import dataclasses
from fastapi import Depends
from fastapi.security import APIKeyHeader
from django.contrib.auth import get_user_model
from django.utils import timezone
from django.db.models import QuerySet
from django_etebase import models
from django_etebase.token_auth.models import AuthToken, get_default_expiry
from myauth.models import UserType, get_typed_user_model
from .exceptions import AuthenticationFailed
from .utils import get_object_or_404
User = get_user_model()
User = get_typed_user_model()
token_scheme = APIKeyHeader(name="Authorization")
AUTO_REFRESH = True
MIN_REFRESH_INTERVAL = 60
@ -21,7 +21,7 @@ MIN_REFRESH_INTERVAL = 60
@dataclasses.dataclass(frozen=True)
class AuthData:
user: User
user: UserType
token: AuthToken
@ -60,12 +60,12 @@ def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
return AuthData(user, token)
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType:
user, _ = __get_authenticated_user(api_token)
return user
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet:
default_queryset: QuerySet = models.Collection.objects.all()
return default_queryset.filter(members__user=user)

View File

@ -1,12 +1,12 @@
import typing as t
from django.contrib.auth import get_user_model
from django.db import transaction, IntegrityError
from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status, Request
from django_etebase import models
from django_etebase.utils import get_user_queryset, CallbackContext
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user
from .exceptions import HttpError, PermissionDenied
from .msgpack import MsgpackRoute
@ -20,7 +20,7 @@ from .utils import (
PERMISSIONS_READWRITE,
)
User = get_user_model()
User = get_typed_user_model()
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionInvitation.objects.all()
@ -53,7 +53,8 @@ class CollectionInvitationCommon(BaseModel):
class CollectionInvitationIn(CollectionInvitationCommon):
def validate_db(self, context: Context):
if context.user.username == self.username.lower():
user = context.user
if user is not None and (user.username == self.username.lower()):
raise HttpError("no_self_invite", "Inviting yourself is not allowed")
@ -84,11 +85,11 @@ class InvitationListResponse(BaseModel):
done: bool
def get_incoming_queryset(user: User = Depends(get_authenticated_user)):
def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(user=user)
def get_outgoing_queryset(user: User = Depends(get_authenticated_user)):
def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
return default_queryset.filter(fromMember__user=user)
@ -183,7 +184,7 @@ def incoming_accept(
def outgoing_create(
data: CollectionInvitationIn,
request: Request,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
):
collection = get_object_or_404(models.Collection.objects, uid=data.collection)
to_user = get_object_or_404(
@ -231,7 +232,7 @@ def outgoing_delete(
def outgoing_fetch_user_profile(
username: str,
request: Request,
user: User = Depends(get_authenticated_user),
user: UserType = Depends(get_authenticated_user),
):
kwargs = {User.USERNAME_FIELD: username.lower()}
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)

View File

@ -1,11 +1,11 @@
import typing as t
from django.contrib.auth import get_user_model
from django.db import transaction
from django.db.models import QuerySet
from fastapi import APIRouter, Depends, status
from django_etebase import models
from myauth.models import UserType, get_typed_user_model
from .authentication import get_authenticated_user
from .msgpack import MsgpackRoute
from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
@ -13,7 +13,7 @@ from .stoken_handler import filter_by_stoken_and_limit
from .collection import get_collection, verify_collection_admin
User = get_user_model()
User = get_typed_user_model()
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
default_queryset: QuerySet = models.CollectionMember.objects.all()
@ -98,6 +98,8 @@ def member_patch(
@member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ)
def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)):
def member_leave(
user: UserType = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)
):
obj = get_object_or_404(collection.members, user=user)
obj.revoke()

View File

@ -19,13 +19,15 @@ class MsgpackRequest(Request):
class MsgpackResponse(Response):
media_type = "application/msgpack"
def render(self, content: t.Optional[t.Any]) -> t.Optional[bytes]:
def render(self, content: t.Optional[t.Any]) -> bytes:
if content is None:
return b""
if isinstance(content, BaseModel):
content = content.dict()
return msgpack.packb(content, use_bin_type=True)
ret = msgpack.packb(content, use_bin_type=True)
assert ret is not None
return ret
class MsgpackRoute(APIRoute):

View File

@ -1,5 +1,4 @@
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import transaction
from django.shortcuts import get_object_or_404
from fastapi import APIRouter, Request, status
@ -8,9 +7,10 @@ from django_etebase.utils import get_user_queryset, CallbackContext
from etebase_fastapi.authentication import SignupIn, signup_save
from etebase_fastapi.msgpack import MsgpackRoute
from etebase_fastapi.exceptions import HttpError
from myauth.models import get_typed_user_model
test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"])
User = get_user_model()
User = get_typed_user_model()
@test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT)

View File

@ -8,14 +8,14 @@ from pydantic import BaseModel as PyBaseModel
from django.db.models import QuerySet
from django.core.exceptions import ObjectDoesNotExist
from django.contrib.auth import get_user_model
from django_etebase import app_settings
from django_etebase.models import AccessLevels
from myauth.models import UserType, get_typed_user_model
from .exceptions import HttpError, HttpErrorOut
User = get_user_model()
User = get_typed_user_model()
Prefetch = t.Literal["auto", "medium"]
PrefetchQuery = Query(default="auto")
@ -30,7 +30,7 @@ class BaseModel(PyBaseModel):
@dataclasses.dataclass
class Context:
user: t.Optional[User]
user: t.Optional[UserType]
prefetch: t.Optional[Prefetch]

View File

@ -1,8 +1,8 @@
from django import forms
from django.contrib.auth import get_user_model
from django.contrib.auth.forms import UsernameField
from myauth.models import get_typed_user_model
User = get_user_model()
User = get_typed_user_model()
class AdminUserCreationForm(forms.ModelForm):

View File

@ -1,3 +1,5 @@
import typing as t
from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager
from django.core import validators
from django.db import models
@ -28,9 +30,21 @@ class User(AbstractUser):
unique=True,
help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."),
validators=[username_validator],
error_messages={"unique": _("A user with that username already exists."),},
error_messages={
"unique": _("A user with that username already exists."),
},
)
@classmethod
def normalize_username(cls, username):
return super().normalize_username(username).lower()
UserType = t.Type[User]
def get_typed_user_model() -> UserType:
from django.contrib.auth import get_user_model
ret: t.Any = get_user_model()
return ret