Move common dependencies to their own file.
This commit is contained in:
parent
3e39aa88a1
commit
c2a2e710c9
@ -1,4 +1,3 @@
|
|||||||
import dataclasses
|
|
||||||
import typing as t
|
import typing as t
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -13,33 +12,22 @@ from django.conf import settings
|
|||||||
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
|
||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.utils import timezone
|
|
||||||
from fastapi import APIRouter, Depends, status, Request
|
from fastapi import APIRouter, Depends, status, Request
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
|
|
||||||
from django_etebase import app_settings, models
|
from django_etebase import app_settings, models
|
||||||
|
from django_etebase.token_auth.models import AuthToken
|
||||||
from django_etebase.models import UserInfo
|
from django_etebase.models import UserInfo
|
||||||
from django_etebase.signals import user_signed_up
|
from django_etebase.signals import user_signed_up
|
||||||
from django_etebase.token_auth.models import AuthToken
|
|
||||||
from django_etebase.token_auth.models import get_default_expiry
|
|
||||||
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
|
from django_etebase.utils import create_user, get_user_queryset, CallbackContext
|
||||||
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
|
from .exceptions import AuthenticationFailed, transform_validation_error, HttpError
|
||||||
from .msgpack import MsgpackRoute
|
from .msgpack import MsgpackRoute
|
||||||
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
|
from .utils import BaseModel, permission_responses, msgpack_encode, msgpack_decode
|
||||||
|
from .dependencies import AuthData, get_auth_data, get_authenticated_user
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
token_scheme = APIKeyHeader(name="Authorization")
|
|
||||||
AUTO_REFRESH = True
|
|
||||||
MIN_REFRESH_INTERVAL = 60
|
|
||||||
authentication_router = APIRouter(route_class=MsgpackRoute)
|
authentication_router = APIRouter(route_class=MsgpackRoute)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class AuthData:
|
|
||||||
user: User
|
|
||||||
token: AuthToken
|
|
||||||
|
|
||||||
|
|
||||||
class LoginChallengeIn(BaseModel):
|
class LoginChallengeIn(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
@ -115,47 +103,6 @@ class SignupIn(BaseModel):
|
|||||||
encryptedContent: bytes
|
encryptedContent: bytes
|
||||||
|
|
||||||
|
|
||||||
def __renew_token(auth_token: AuthToken):
|
|
||||||
current_expiry = auth_token.expiry
|
|
||||||
new_expiry = get_default_expiry()
|
|
||||||
# Throttle refreshing of token to avoid db writes
|
|
||||||
delta = (new_expiry - current_expiry).total_seconds()
|
|
||||||
if delta > MIN_REFRESH_INTERVAL:
|
|
||||||
auth_token.expiry = new_expiry
|
|
||||||
auth_token.save(update_fields=("expiry",))
|
|
||||||
|
|
||||||
|
|
||||||
@sync_to_async
|
|
||||||
def __get_authenticated_user(api_token: str):
|
|
||||||
api_token = api_token.split()[1]
|
|
||||||
try:
|
|
||||||
token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token)
|
|
||||||
except AuthToken.DoesNotExist:
|
|
||||||
raise AuthenticationFailed(detail="Invalid token.")
|
|
||||||
if not token.user.is_active:
|
|
||||||
raise AuthenticationFailed(detail="User inactive or deleted.")
|
|
||||||
|
|
||||||
if token.expiry is not None:
|
|
||||||
if token.expiry < timezone.now():
|
|
||||||
token.delete()
|
|
||||||
raise AuthenticationFailed(detail="Invalid token.")
|
|
||||||
|
|
||||||
if AUTO_REFRESH:
|
|
||||||
__renew_token(token)
|
|
||||||
|
|
||||||
return token.user, token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
|
|
||||||
user, token = await __get_authenticated_user(api_token)
|
|
||||||
return AuthData(user, token)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
|
|
||||||
user, token = await __get_authenticated_user(api_token)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@sync_to_async
|
@sync_to_async
|
||||||
def __get_login_user(username: str) -> User:
|
def __get_login_user(username: str) -> User:
|
||||||
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
kwargs = {User.USERNAME_FIELD + "__iexact": username.lower()}
|
||||||
|
@ -5,8 +5,7 @@ from django.contrib.auth import get_user_model
|
|||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from django.core.files.base import ContentFile
|
from django.core.files.base import ContentFile
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Q
|
from django.db.models import Q, QuerySet
|
||||||
from django.db.models import QuerySet
|
|
||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
|
||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
@ -25,12 +24,11 @@ from .utils import (
|
|||||||
PERMISSIONS_READ,
|
PERMISSIONS_READ,
|
||||||
PERMISSIONS_READWRITE,
|
PERMISSIONS_READWRITE,
|
||||||
)
|
)
|
||||||
|
from .dependencies import get_collection_queryset, get_item_queryset, get_collection
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
item_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
default_queryset: QuerySet = models.Collection.objects.all()
|
|
||||||
default_item_queryset: QuerySet = models.CollectionItem.objects.all()
|
|
||||||
|
|
||||||
|
|
||||||
class ListMulti(BaseModel):
|
class ListMulti(BaseModel):
|
||||||
@ -203,21 +201,6 @@ def collection_list_common(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
|
|
||||||
return default_queryset.filter(members__user=user)
|
|
||||||
|
|
||||||
|
|
||||||
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
|
|
||||||
return get_object_or_404(queryset, uid=collection_uid)
|
|
||||||
|
|
||||||
|
|
||||||
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
|
|
||||||
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
|
|
||||||
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
|
|
||||||
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
|
|
||||||
# permissions
|
# permissions
|
||||||
|
|
||||||
|
|
||||||
|
82
etebase_fastapi/dependencies.py
Normal file
82
etebase_fastapi/dependencies.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.utils import timezone
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from django_etebase import models
|
||||||
|
from django_etebase.token_auth.models import AuthToken, get_default_expiry
|
||||||
|
from .exceptions import AuthenticationFailed
|
||||||
|
from .utils import get_object_or_404
|
||||||
|
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
token_scheme = APIKeyHeader(name="Authorization")
|
||||||
|
AUTO_REFRESH = True
|
||||||
|
MIN_REFRESH_INTERVAL = 60
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class AuthData:
|
||||||
|
user: User
|
||||||
|
token: AuthToken
|
||||||
|
|
||||||
|
|
||||||
|
def __renew_token(auth_token: AuthToken):
|
||||||
|
current_expiry = auth_token.expiry
|
||||||
|
new_expiry = get_default_expiry()
|
||||||
|
# Throttle refreshing of token to avoid db writes
|
||||||
|
delta = (new_expiry - current_expiry).total_seconds()
|
||||||
|
if delta > MIN_REFRESH_INTERVAL:
|
||||||
|
auth_token.expiry = new_expiry
|
||||||
|
auth_token.save(update_fields=("expiry",))
|
||||||
|
|
||||||
|
|
||||||
|
def __get_authenticated_user(api_token: str):
|
||||||
|
api_token = api_token.split()[1]
|
||||||
|
try:
|
||||||
|
token: AuthToken = AuthToken.objects.select_related("user").get(key=api_token)
|
||||||
|
except AuthToken.DoesNotExist:
|
||||||
|
raise AuthenticationFailed(detail="Invalid token.")
|
||||||
|
if not token.user.is_active:
|
||||||
|
raise AuthenticationFailed(detail="User inactive or deleted.")
|
||||||
|
|
||||||
|
if token.expiry is not None:
|
||||||
|
if token.expiry < timezone.now():
|
||||||
|
token.delete()
|
||||||
|
raise AuthenticationFailed(detail="Invalid token.")
|
||||||
|
|
||||||
|
if AUTO_REFRESH:
|
||||||
|
__renew_token(token)
|
||||||
|
|
||||||
|
return token.user, token
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_data(api_token: str = Depends(token_scheme)) -> AuthData:
|
||||||
|
user, token = __get_authenticated_user(api_token)
|
||||||
|
return AuthData(user, token)
|
||||||
|
|
||||||
|
|
||||||
|
def get_authenticated_user(api_token: str = Depends(token_scheme)) -> User:
|
||||||
|
user, _ = __get_authenticated_user(api_token)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
|
||||||
|
default_queryset: QuerySet = models.Collection.objects.all()
|
||||||
|
return default_queryset.filter(members__user=user)
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
|
||||||
|
return get_object_or_404(queryset, uid=collection_uid)
|
||||||
|
|
||||||
|
|
||||||
|
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
|
||||||
|
default_item_queryset: QuerySet = models.CollectionItem.objects.all()
|
||||||
|
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
|
||||||
|
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
|
||||||
|
|
||||||
|
return queryset
|
Loading…
Reference in New Issue
Block a user