Fix many type errors.
This commit is contained in:
parent
e13f26ec56
commit
794b5f3983
@ -1,9 +1,9 @@
|
|||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.crypto import get_random_string
|
from django.utils.crypto import get_random_string
|
||||||
|
from myauth.models import get_typed_user_model
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
|
|
||||||
|
|
||||||
def generate_key():
|
def generate_key():
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import typing as t
|
import typing as t
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
|
|
||||||
from . import app_settings
|
from . import app_settings
|
||||||
|
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -15,7 +15,7 @@ class CallbackContext:
|
|||||||
"""Class for passing extra context to callbacks"""
|
"""Class for passing extra context to callbacks"""
|
||||||
|
|
||||||
url_kwargs: t.Dict[str, t.Any]
|
url_kwargs: t.Dict[str, t.Any]
|
||||||
user: t.Optional[User] = None
|
user: t.Optional[UserType] = None
|
||||||
|
|
||||||
|
|
||||||
def get_user_queryset(queryset, context: CallbackContext):
|
def get_user_queryset(queryset, context: CallbackContext):
|
||||||
|
@ -9,7 +9,7 @@ import nacl.secret
|
|||||||
import nacl.signing
|
import nacl.signing
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
from django.contrib.auth import user_logged_out, user_logged_in
|
||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from fastapi import APIRouter, Depends, status, Request
|
from fastapi import APIRouter, Depends, status, Request
|
||||||
@ -19,12 +19,13 @@ from django_etebase.token_auth.models import AuthToken
|
|||||||
from django_etebase.models import UserInfo
|
from django_etebase.models import UserInfo
|
||||||
from django_etebase.signals import user_signed_up
|
from django_etebase.signals import user_signed_up
|
||||||
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
|
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
|
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
|
||||||
from .msgpack import MsgpackRoute
|
from .msgpack import MsgpackRoute
|
||||||
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
|
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
|
||||||
from .dependencies import AuthData, get_auth_data, get_authenticated_user
|
from .dependencies import AuthData, get_auth_data, get_authenticated_user
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
authentication_router = APIRouter(route_class=MsgpackRoute)
|
authentication_router = APIRouter(route_class=MsgpackRoute)
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ class UserOut(BaseModel):
|
|||||||
encryptedContent: bytes
|
encryptedContent: bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut":
|
def from_orm(cls: t.Type["UserOut"], obj: UserType) -> "UserOut":
|
||||||
return cls(
|
return cls(
|
||||||
username=obj.username,
|
username=obj.username,
|
||||||
email=obj.email,
|
email=obj.email,
|
||||||
@ -66,7 +67,7 @@ class LoginOut(BaseModel):
|
|||||||
user: UserOut
|
user: UserOut
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut":
|
def from_orm(cls: t.Type["LoginOut"], obj: UserType) -> "LoginOut":
|
||||||
token = AuthToken.objects.create(user=obj).key
|
token = AuthToken.objects.create(user=obj).key
|
||||||
user = UserOut.from_orm(obj)
|
user = UserOut.from_orm(obj)
|
||||||
return cls(token=token, user=user)
|
return cls(token=token, user=user)
|
||||||
@ -111,7 +112,7 @@ class SignupIn(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@sync_to_async
|
@sync_to_async
|
||||||
def __get_login_user(username: str) -> User:
|
def __get_login_user(username: str) -> UserType:
|
||||||
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
||||||
try:
|
try:
|
||||||
user = User.objects.get(**kwargs)
|
user = User.objects.get(**kwargs)
|
||||||
@ -122,7 +123,7 @@ def __get_login_user(username: str) -> User:
|
|||||||
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
raise AuthenticationFailed(code="user_not_found", detail="User not found")
|
||||||
|
|
||||||
|
|
||||||
async def get_login_user(challenge: LoginChallengeIn) -> User:
|
async def get_login_user(challenge: LoginChallengeIn) -> UserType:
|
||||||
user = await __get_login_user(challenge.username)
|
user = await __get_login_user(challenge.username)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -138,7 +139,7 @@ def get_encryption_key(salt):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_changed_password(data: ChangePassword, user: User):
|
def save_changed_password(data: ChangePassword, user: UserType):
|
||||||
response_data = data.response_data
|
response_data = data.response_data
|
||||||
user_info: UserInfo = user.userinfo
|
user_info: UserInfo = user.userinfo
|
||||||
user_info.loginPubkey = response_data.loginPubkey
|
user_info.loginPubkey = response_data.loginPubkey
|
||||||
@ -150,7 +151,7 @@ def save_changed_password(data: ChangePassword, user: User):
|
|||||||
def validate_login_request(
|
def validate_login_request(
|
||||||
validated_data: LoginResponse,
|
validated_data: LoginResponse,
|
||||||
challenge_sent_to_user: Authentication,
|
challenge_sent_to_user: Authentication,
|
||||||
user: User,
|
user: UserType,
|
||||||
expected_action: str,
|
expected_action: str,
|
||||||
host_from_request: str,
|
host_from_request: str,
|
||||||
):
|
):
|
||||||
@ -159,7 +160,7 @@ def validate_login_request(
|
|||||||
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
|
||||||
now = int(datetime.now().timestamp())
|
now = int(datetime.now().timestamp())
|
||||||
if validated_data.action != expected_action:
|
if validated_data.action != expected_action:
|
||||||
raise HttpError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else')
|
raise HttpError("wrong_action", f'Expected "{expected_action}" but got something else')
|
||||||
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
|
||||||
raise HttpError("challenge_expired", "Login challenge has expired")
|
raise HttpError("challenge_expired", "Login challenge has expired")
|
||||||
elif challenge_data["userId"] != user.id:
|
elif challenge_data["userId"] != user.id:
|
||||||
@ -181,7 +182,7 @@ async def is_etebase():
|
|||||||
|
|
||||||
|
|
||||||
@authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
|
@authentication_router.post("/login_challenge/", response_model=LoginChallengeOut)
|
||||||
def login_challenge(user: User = Depends(get_login_user)):
|
def login_challenge(user: UserType = Depends(get_login_user)):
|
||||||
salt = bytes(user.userinfo.salt)
|
salt = bytes(user.userinfo.salt)
|
||||||
enc_key = get_encryption_key(salt)
|
enc_key = get_encryption_key(salt)
|
||||||
box = nacl.secret.SecretBox(enc_key)
|
box = nacl.secret.SecretBox(enc_key)
|
||||||
@ -210,14 +211,14 @@ def logout(auth_data: AuthData = Depends(get_auth_data)):
|
|||||||
|
|
||||||
|
|
||||||
@authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
|
@authentication_router.post("/change_password/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
|
||||||
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
|
async def change_password(data: ChangePassword, request: Request, user: UserType = Depends(get_authenticated_user)):
|
||||||
host = request.headers.get("Host")
|
host = request.headers.get("Host")
|
||||||
await validate_login_request(data.response_data, data, user, "changePassword", host)
|
await validate_login_request(data.response_data, data, user, "changePassword", host)
|
||||||
await sync_to_async(save_changed_password)(data, user)
|
await sync_to_async(save_changed_password)(data, user)
|
||||||
|
|
||||||
|
|
||||||
@authentication_router.post("/dashboard_url/", responses=permission_responses)
|
@authentication_router.post("/dashboard_url/", responses=permission_responses)
|
||||||
def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)):
|
def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_user)):
|
||||||
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
|
get_dashboard_url = app_settings.DASHBOARD_URL_FUNC
|
||||||
if get_dashboard_url is None:
|
if get_dashboard_url is None:
|
||||||
raise HttpError("not_supported", "This server doesn't have a user dashboard.")
|
raise HttpError("not_supported", "This server doesn't have a user dashboard.")
|
||||||
@ -228,7 +229,7 @@ def dashboard_url(request: Request, user: User = Depends(get_authenticated_user)
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def signup_save(data: SignupIn, request: Request) -> User:
|
def signup_save(data: SignupIn, request: Request) -> UserType:
|
||||||
user_data = data.user
|
user_data = data.user
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
try:
|
try:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from django.core.files.base import ContentFile
|
from django.core.files.base import ContentFile
|
||||||
from django.db import transaction, IntegrityError
|
from django.db import transaction, IntegrityError
|
||||||
@ -9,6 +8,7 @@ from django.db.models import Q, QuerySet
|
|||||||
from fastapi import APIRouter, Depends, status, Request
|
from fastapi import APIRouter, Depends, status, Request
|
||||||
|
|
||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
from .authentication import get_authenticated_user
|
from .authentication import get_authenticated_user
|
||||||
from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
|
from .exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
|
||||||
from .msgpack import MsgpackRoute
|
from .msgpack import MsgpackRoute
|
||||||
@ -27,7 +27,7 @@ from .utils import (
|
|||||||
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
|
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
|
||||||
from .sendfile import sendfile
|
from .sendfile import sendfile
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model
|
||||||
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
|
|
||||||
@ -36,11 +36,14 @@ class ListMulti(BaseModel):
|
|||||||
collectionTypes: t.List[bytes]
|
collectionTypes: t.List[bytes]
|
||||||
|
|
||||||
|
|
||||||
|
ChunkType = t.Tuple[str, t.Optional[bytes]]
|
||||||
|
|
||||||
|
|
||||||
class CollectionItemRevisionInOut(BaseModel):
|
class CollectionItemRevisionInOut(BaseModel):
|
||||||
uid: str
|
uid: str
|
||||||
meta: bytes
|
meta: bytes
|
||||||
deleted: bool
|
deleted: bool
|
||||||
chunks: t.List[t.Tuple[str, t.Optional[bytes]]]
|
chunks: t.List[ChunkType]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
@ -49,7 +52,7 @@ class CollectionItemRevisionInOut(BaseModel):
|
|||||||
def from_orm_context(
|
def from_orm_context(
|
||||||
cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context
|
cls: t.Type["CollectionItemRevisionInOut"], obj: models.CollectionItemRevision, context: Context
|
||||||
) -> "CollectionItemRevisionInOut":
|
) -> "CollectionItemRevisionInOut":
|
||||||
chunks = []
|
chunks: t.List[ChunkType] = []
|
||||||
for chunk_relation in obj.chunks_relation.all():
|
for chunk_relation in obj.chunks_relation.all():
|
||||||
chunk_obj = chunk_relation.chunk
|
chunk_obj = chunk_relation.chunk
|
||||||
if context.prefetch == "auto":
|
if context.prefetch == "auto":
|
||||||
@ -185,7 +188,7 @@ class ItemBatchIn(BaseModel):
|
|||||||
@sync_to_async
|
@sync_to_async
|
||||||
def collection_list_common(
|
def collection_list_common(
|
||||||
queryset: QuerySet,
|
queryset: QuerySet,
|
||||||
user: User,
|
user: UserType,
|
||||||
stoken: t.Optional[str],
|
stoken: t.Optional[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
prefetch: Prefetch,
|
prefetch: Prefetch,
|
||||||
@ -210,7 +213,7 @@ def collection_list_common(
|
|||||||
|
|
||||||
remed = remed_qs.values_list("collection__uid", flat=True)
|
remed = remed_qs.values_list("collection__uid", flat=True)
|
||||||
if len(remed) > 0:
|
if len(remed) > 0:
|
||||||
ret.removedMemberships = [{"uid": x} for x in remed]
|
ret.removedMemberships = [RemovedMembershipOut(uid=x) for x in remed]
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -219,14 +222,14 @@ def collection_list_common(
|
|||||||
|
|
||||||
|
|
||||||
def verify_collection_admin(
|
def verify_collection_admin(
|
||||||
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user)
|
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
|
||||||
):
|
):
|
||||||
if not is_collection_admin(collection, user):
|
if not is_collection_admin(collection, user):
|
||||||
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
|
raise PermissionDenied("admin_access_required", "Only collection admins can perform this operation.")
|
||||||
|
|
||||||
|
|
||||||
def has_write_access(
|
def has_write_access(
|
||||||
collection: models.Collection = Depends(get_collection), user: User = Depends(get_authenticated_user)
|
collection: models.Collection = Depends(get_collection), user: UserType = Depends(get_authenticated_user)
|
||||||
):
|
):
|
||||||
member = collection.members.get(user=user)
|
member = collection.members.get(user=user)
|
||||||
if member.accessLevel == models.AccessLevels.READ_ONLY:
|
if member.accessLevel == models.AccessLevels.READ_ONLY:
|
||||||
@ -247,7 +250,7 @@ async def list_multi(
|
|||||||
stoken: t.Optional[str] = None,
|
stoken: t.Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
queryset: QuerySet = Depends(get_collection_queryset),
|
queryset: QuerySet = Depends(get_collection_queryset),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
):
|
):
|
||||||
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
|
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
|
||||||
@ -263,7 +266,7 @@ async def collection_list(
|
|||||||
stoken: t.Optional[str] = None,
|
stoken: t.Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
queryset: QuerySet = Depends(get_collection_queryset),
|
queryset: QuerySet = Depends(get_collection_queryset),
|
||||||
):
|
):
|
||||||
return await collection_list_common(queryset, user, stoken, limit, prefetch)
|
return await collection_list_common(queryset, user, stoken, limit, prefetch)
|
||||||
@ -299,7 +302,7 @@ def process_revisions_for_item(item: models.CollectionItem, revision_data: Colle
|
|||||||
return revision
|
return revision
|
||||||
|
|
||||||
|
|
||||||
def _create(data: CollectionIn, user: User):
|
def _create(data: CollectionIn, user: UserType):
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
if data.item.etag is not None:
|
if data.item.etag is not None:
|
||||||
raise ValidationError("bad_etag", "etag is not null")
|
raise ValidationError("bad_etag", "etag is not null")
|
||||||
@ -335,14 +338,14 @@ def _create(data: CollectionIn, user: User):
|
|||||||
|
|
||||||
|
|
||||||
@collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
|
@collection_router.post("/", status_code=status.HTTP_201_CREATED, dependencies=PERMISSIONS_READWRITE)
|
||||||
async def create(data: CollectionIn, user: User = Depends(get_authenticated_user)):
|
async def create(data: CollectionIn, user: UserType = Depends(get_authenticated_user)):
|
||||||
await sync_to_async(_create)(data, user)
|
await sync_to_async(_create)(data, user)
|
||||||
|
|
||||||
|
|
||||||
@collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ)
|
@collection_router.get("/{collection_uid}/", response_model=CollectionOut, dependencies=PERMISSIONS_READ)
|
||||||
def collection_get(
|
def collection_get(
|
||||||
obj: models.Collection = Depends(get_collection),
|
obj: models.Collection = Depends(get_collection),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
):
|
):
|
||||||
return CollectionOut.from_orm_context(obj, Context(user, prefetch))
|
return CollectionOut.from_orm_context(obj, Context(user, prefetch))
|
||||||
@ -393,7 +396,7 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
|
|||||||
def item_get(
|
def item_get(
|
||||||
item_uid: str,
|
item_uid: str,
|
||||||
queryset: QuerySet = Depends(get_item_queryset),
|
queryset: QuerySet = Depends(get_item_queryset),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
):
|
):
|
||||||
obj = queryset.get(uid=item_uid)
|
obj = queryset.get(uid=item_uid)
|
||||||
@ -403,7 +406,7 @@ def item_get(
|
|||||||
@sync_to_async
|
@sync_to_async
|
||||||
def item_list_common(
|
def item_list_common(
|
||||||
queryset: QuerySet,
|
queryset: QuerySet,
|
||||||
user: User,
|
user: UserType,
|
||||||
stoken: t.Optional[str],
|
stoken: t.Optional[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
prefetch: Prefetch,
|
prefetch: Prefetch,
|
||||||
@ -424,7 +427,7 @@ async def item_list(
|
|||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
withCollection: bool = False,
|
withCollection: bool = False,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
if not withCollection:
|
if not withCollection:
|
||||||
queryset = queryset.filter(parent__isnull=True)
|
queryset = queryset.filter(parent__isnull=True)
|
||||||
@ -433,7 +436,7 @@ async def item_list(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid: str, validate_etag: bool):
|
def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
|
||||||
queryset = get_collection_queryset(user)
|
queryset = get_collection_queryset(user)
|
||||||
with transaction.atomic(): # We need this for locking the collection object
|
with transaction.atomic(): # We need this for locking the collection object
|
||||||
collection_object = queryset.select_for_update().get(uid=uid)
|
collection_object = queryset.select_for_update().get(uid=uid)
|
||||||
@ -467,7 +470,7 @@ def item_revisions(
|
|||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
iterator: t.Optional[str] = None,
|
iterator: t.Optional[str] = None,
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
items: QuerySet = Depends(get_item_queryset),
|
items: QuerySet = Depends(get_item_queryset),
|
||||||
):
|
):
|
||||||
item = get_object_or_404(items, uid=item_uid)
|
item = get_object_or_404(items, uid=item_uid)
|
||||||
@ -501,7 +504,7 @@ def fetch_updates(
|
|||||||
data: t.List[CollectionItemBulkGetIn],
|
data: t.List[CollectionItemBulkGetIn],
|
||||||
stoken: t.Optional[str] = None,
|
stoken: t.Optional[str] = None,
|
||||||
prefetch: Prefetch = PrefetchQuery,
|
prefetch: Prefetch = PrefetchQuery,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
queryset: QuerySet = Depends(get_item_queryset),
|
queryset: QuerySet = Depends(get_item_queryset),
|
||||||
):
|
):
|
||||||
# FIXME: make configurable?
|
# FIXME: make configurable?
|
||||||
@ -531,14 +534,14 @@ def fetch_updates(
|
|||||||
|
|
||||||
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
@item_router.post("/item/transaction/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
||||||
def item_transaction(
|
def item_transaction(
|
||||||
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user)
|
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
|
||||||
):
|
):
|
||||||
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True)
|
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True)
|
||||||
|
|
||||||
|
|
||||||
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
|
||||||
def item_batch(
|
def item_batch(
|
||||||
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: User = Depends(get_authenticated_user)
|
collection_uid: str, data: ItemBatchIn, stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user)
|
||||||
):
|
):
|
||||||
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False)
|
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False)
|
||||||
|
|
||||||
|
@ -3,17 +3,17 @@ import dataclasses
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
from django_etebase.token_auth.models import AuthToken, get_default_expiry
|
from django_etebase.token_auth.models import AuthToken, get_default_expiry
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
from .exceptions import AuthenticationFailed
|
from .exceptions import AuthenticationFailed
|
||||||
from .utils import get_object_or_404
|
from .utils import get_object_or_404
|
||||||
|
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
token_scheme = APIKeyHeader(name="Authorization")
|
token_scheme = APIKeyHeader(name="Authorization")
|
||||||
AUTO_REFRESH = True
|
AUTO_REFRESH = True
|
||||||
MIN_REFRESH_INTERVAL = 60
|
MIN_REFRESH_INTERVAL = 60
|
||||||
@ -21,7 +21,7 @@ MIN_REFRESH_INTERVAL = 60
|
|||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class AuthData:
|
class AuthData:
|
||||||
user: User
|
user: UserType
|
||||||
token: AuthToken
|
token: AuthToken
|
||||||
|
|
||||||
|
|
||||||
@ -60,12 +60,12 @@ def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
|
|||||||
return AuthData(user, token)
|
return AuthData(user, token)
|
||||||
|
|
||||||
|
|
||||||
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
|
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> UserType:
|
||||||
user, _ = __get_authenticated_user(api_token)
|
user, _ = __get_authenticated_user(api_token)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
|
def get_collection_queryset(user: UserType = Depends(get_authenticated_user)) -> QuerySet:
|
||||||
default_queryset: QuerySet = models.Collection.objects.all()
|
default_queryset: QuerySet = models.Collection.objects.all()
|
||||||
return default_queryset.filter(members__user=user)
|
return default_queryset.filter(members__user=user)
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.db import transaction, IntegrityError
|
from django.db import transaction, IntegrityError
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from fastapi import APIRouter, Depends, status, Request
|
from fastapi import APIRouter, Depends, status, Request
|
||||||
|
|
||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
from django_etebase.utils import get_user_queryset, CallbackContext
|
from django_etebase.utils import get_user_queryset, CallbackContext
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
from .authentication import get_authenticated_user
|
from .authentication import get_authenticated_user
|
||||||
from .exceptions import HttpError, PermissionDenied
|
from .exceptions import HttpError, PermissionDenied
|
||||||
from .msgpack import MsgpackRoute
|
from .msgpack import MsgpackRoute
|
||||||
@ -20,7 +20,7 @@ from .utils import (
|
|||||||
PERMISSIONS_READWRITE,
|
PERMISSIONS_READWRITE,
|
||||||
)
|
)
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
invitation_incoming_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
invitation_outgoing_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
default_queryset: QuerySet = models.CollectionInvitation.objects.all()
|
default_queryset: QuerySet = models.CollectionInvitation.objects.all()
|
||||||
@ -53,7 +53,8 @@ class CollectionInvitationCommon(BaseModel):
|
|||||||
|
|
||||||
class CollectionInvitationIn(CollectionInvitationCommon):
|
class CollectionInvitationIn(CollectionInvitationCommon):
|
||||||
def validate_db(self, context: Context):
|
def validate_db(self, context: Context):
|
||||||
if context.user.username == self.username.lower():
|
user = context.user
|
||||||
|
if user is not None and (user.username == self.username.lower()):
|
||||||
raise HttpError("no_self_invite", "Inviting yourself is not allowed")
|
raise HttpError("no_self_invite", "Inviting yourself is not allowed")
|
||||||
|
|
||||||
|
|
||||||
@ -84,11 +85,11 @@ class InvitationListResponse(BaseModel):
|
|||||||
done: bool
|
done: bool
|
||||||
|
|
||||||
|
|
||||||
def get_incoming_queryset(user: User = Depends(get_authenticated_user)):
|
def get_incoming_queryset(user: UserType = Depends(get_authenticated_user)):
|
||||||
return default_queryset.filter(user=user)
|
return default_queryset.filter(user=user)
|
||||||
|
|
||||||
|
|
||||||
def get_outgoing_queryset(user: User = Depends(get_authenticated_user)):
|
def get_outgoing_queryset(user: UserType = Depends(get_authenticated_user)):
|
||||||
return default_queryset.filter(fromMember__user=user)
|
return default_queryset.filter(fromMember__user=user)
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +184,7 @@ def incoming_accept(
|
|||||||
def outgoing_create(
|
def outgoing_create(
|
||||||
data: CollectionInvitationIn,
|
data: CollectionInvitationIn,
|
||||||
request: Request,
|
request: Request,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
collection = get_object_or_404(models.Collection.objects, uid=data.collection)
|
collection = get_object_or_404(models.Collection.objects, uid=data.collection)
|
||||||
to_user = get_object_or_404(
|
to_user = get_object_or_404(
|
||||||
@ -231,7 +232,7 @@ def outgoing_delete(
|
|||||||
def outgoing_fetch_user_profile(
|
def outgoing_fetch_user_profile(
|
||||||
username: str,
|
username: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
user: User = Depends(get_authenticated_user),
|
user: UserType = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
kwargs = {User.USERNAME_FIELD: username.lower()}
|
kwargs = {User.USERNAME_FIELD: username.lower()}
|
||||||
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
|
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
|
||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
from .authentication import get_authenticated_user
|
from .authentication import get_authenticated_user
|
||||||
from .msgpack import MsgpackRoute
|
from .msgpack import MsgpackRoute
|
||||||
from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
|
from .utils import get_object_or_404, BaseModel, permission_responses, PERMISSIONS_READ, PERMISSIONS_READWRITE
|
||||||
@ -13,7 +13,7 @@ from .stoken_handler import filter_by_stoken_and_limit
|
|||||||
|
|
||||||
from .collection import get_collection, verify_collection_admin
|
from .collection import get_collection, verify_collection_admin
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
member_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
default_queryset: QuerySet = models.CollectionMember.objects.all()
|
default_queryset: QuerySet = models.CollectionMember.objects.all()
|
||||||
|
|
||||||
@ -98,6 +98,8 @@ def member_patch(
|
|||||||
|
|
||||||
|
|
||||||
@member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ)
|
@member_router.post("/member/leave/", status_code=status.HTTP_204_NO_CONTENT, dependencies=PERMISSIONS_READ)
|
||||||
def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)):
|
def member_leave(
|
||||||
|
user: UserType = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)
|
||||||
|
):
|
||||||
obj = get_object_or_404(collection.members, user=user)
|
obj = get_object_or_404(collection.members, user=user)
|
||||||
obj.revoke()
|
obj.revoke()
|
||||||
|
@ -19,13 +19,15 @@ class MsgpackRequest(Request):
|
|||||||
class MsgpackResponse(Response):
|
class MsgpackResponse(Response):
|
||||||
media_type = "application/msgpack"
|
media_type = "application/msgpack"
|
||||||
|
|
||||||
def render(self, content: t.Optional[t.Any]) -> t.Optional[bytes]:
|
def render(self, content: t.Optional[t.Any]) -> bytes:
|
||||||
if content is None:
|
if content is None:
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
if isinstance(content, BaseModel):
|
if isinstance(content, BaseModel):
|
||||||
content = content.dict()
|
content = content.dict()
|
||||||
return msgpack.packb(content, use_bin_type=True)
|
ret = msgpack.packb(content, use_bin_type=True)
|
||||||
|
assert ret is not None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class MsgpackRoute(APIRoute):
|
class MsgpackRoute(APIRoute):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from fastapi import APIRouter, Request, status
|
from fastapi import APIRouter, Request, status
|
||||||
@ -8,9 +7,10 @@ from django_etebase.utils import get_user_queryset, CallbackContext
|
|||||||
from etebase_fastapi.authentication import SignupIn, signup_save
|
from etebase_fastapi.authentication import SignupIn, signup_save
|
||||||
from etebase_fastapi.msgpack import MsgpackRoute
|
from etebase_fastapi.msgpack import MsgpackRoute
|
||||||
from etebase_fastapi.exceptions import HttpError
|
from etebase_fastapi.exceptions import HttpError
|
||||||
|
from myauth.models import get_typed_user_model
|
||||||
|
|
||||||
test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"])
|
test_reset_view_router = APIRouter(route_class=MsgpackRoute, tags=["test helpers"])
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
|
|
||||||
|
|
||||||
@test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT)
|
@test_reset_view_router.post("/reset/", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
@ -8,14 +8,14 @@ from pydantic import BaseModel as PyBaseModel
|
|||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
|
|
||||||
from django_etebase import app_settings
|
from django_etebase import app_settings
|
||||||
from django_etebase.models import AccessLevels
|
from django_etebase.models import AccessLevels
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
|
|
||||||
from .exceptions import HttpError, HttpErrorOut
|
from .exceptions import HttpError, HttpErrorOut
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
|
|
||||||
Prefetch = t.Literal["auto", "medium"]
|
Prefetch = t.Literal["auto", "medium"]
|
||||||
PrefetchQuery = Query(default="auto")
|
PrefetchQuery = Query(default="auto")
|
||||||
@ -30,7 +30,7 @@ class BaseModel(PyBaseModel):
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Context:
|
class Context:
|
||||||
user: t.Optional[User]
|
user: t.Optional[UserType]
|
||||||
prefetch: t.Optional[Prefetch]
|
prefetch: t.Optional[Prefetch]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from django import forms
|
from django import forms
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.contrib.auth.forms import UsernameField
|
from django.contrib.auth.forms import UsernameField
|
||||||
|
from myauth.models import get_typed_user_model
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_typed_user_model()
|
||||||
|
|
||||||
|
|
||||||
class AdminUserCreationForm(forms.ModelForm):
|
class AdminUserCreationForm(forms.ModelForm):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import typing as t
|
||||||
|
|
||||||
from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager
|
from django.contrib.auth.models import AbstractUser, UserManager as DjangoUserManager
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import models
|
from django.db import models
|
||||||
@ -28,9 +30,21 @@ class User(AbstractUser):
|
|||||||
unique=True,
|
unique=True,
|
||||||
help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."),
|
help_text=_("Required. 150 characters or fewer. Letters, digits and ./-/_ only."),
|
||||||
validators=[username_validator],
|
validators=[username_validator],
|
||||||
error_messages={"unique": _("A user with that username already exists."),},
|
error_messages={
|
||||||
|
"unique": _("A user with that username already exists."),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def normalize_username(cls, username):
|
def normalize_username(cls, username):
|
||||||
return super().normalize_username(username).lower()
|
return super().normalize_username(username).lower()
|
||||||
|
|
||||||
|
|
||||||
|
UserType = t.Type[User]
|
||||||
|
|
||||||
|
|
||||||
|
def get_typed_user_model() -> UserType:
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
|
||||||
|
ret: t.Any = get_user_model()
|
||||||
|
return ret
|
||||||
|
Loading…
Reference in New Issue
Block a user