change response content to pydantic models and error handling
This commit is contained in:
		
							parent
							
								
									a0d1d23d2d
								
							
						
					
					
						commit
						31e0e0b832
					
				@ -2,7 +2,6 @@ import dataclasses
 | 
				
			|||||||
import typing as t
 | 
					import typing as t
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
from functools import cached_property
 | 
					from functools import cached_property
 | 
				
			||||||
from django.core import exceptions as django_exceptions
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import nacl
 | 
					import nacl
 | 
				
			||||||
import nacl.encoding
 | 
					import nacl.encoding
 | 
				
			||||||
@ -12,6 +11,7 @@ import nacl.signing
 | 
				
			|||||||
from asgiref.sync import sync_to_async
 | 
					from asgiref.sync import sync_to_async
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
 | 
					from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
 | 
				
			||||||
 | 
					from django.core import exceptions as django_exceptions
 | 
				
			||||||
from django.db import transaction
 | 
					from django.db import transaction
 | 
				
			||||||
from django.utils import timezone
 | 
					from django.utils import timezone
 | 
				
			||||||
from fastapi import APIRouter, Depends, status, Request, Response
 | 
					from fastapi import APIRouter, Depends, status, Request, Response
 | 
				
			||||||
@ -21,7 +21,6 @@ from pydantic import BaseModel
 | 
				
			|||||||
from django_etebase import app_settings, models
 | 
					from django_etebase import app_settings, models
 | 
				
			||||||
from django_etebase.exceptions import EtebaseValidationError
 | 
					from django_etebase.exceptions import EtebaseValidationError
 | 
				
			||||||
from django_etebase.models import UserInfo
 | 
					from django_etebase.models import UserInfo
 | 
				
			||||||
from django_etebase.serializers import UserSerializer
 | 
					 | 
				
			||||||
from django_etebase.signals import user_signed_up
 | 
					from django_etebase.signals import user_signed_up
 | 
				
			||||||
from django_etebase.token_auth.models import AuthToken
 | 
					from django_etebase.token_auth.models import AuthToken
 | 
				
			||||||
from django_etebase.token_auth.models import get_default_expiry
 | 
					from django_etebase.token_auth.models import get_default_expiry
 | 
				
			||||||
@ -43,10 +42,16 @@ class AuthData:
 | 
				
			|||||||
    token: AuthToken
 | 
					    token: AuthToken
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LoginChallengeData(BaseModel):
 | 
					class LoginChallengeIn(BaseModel):
 | 
				
			||||||
    username: str
 | 
					    username: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LoginChallengeOut(BaseModel):
 | 
				
			||||||
 | 
					    salt: bytes
 | 
				
			||||||
 | 
					    challenge: bytes
 | 
				
			||||||
 | 
					    version: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LoginResponse(BaseModel):
 | 
					class LoginResponse(BaseModel):
 | 
				
			||||||
    username: str
 | 
					    username: str
 | 
				
			||||||
    challenge: bytes
 | 
					    challenge: bytes
 | 
				
			||||||
@ -54,6 +59,26 @@ class LoginResponse(BaseModel):
 | 
				
			|||||||
    action: t.Literal["login", "changePassword"]
 | 
					    action: t.Literal["login", "changePassword"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UserOut(BaseModel):
 | 
				
			||||||
 | 
					    pubkey: bytes
 | 
				
			||||||
 | 
					    encryptedContent: bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_orm(cls: t.Type["UserOut"], obj: User) -> "UserOut":
 | 
				
			||||||
 | 
					        return cls(pubkey=obj.userinfo.pubkey, encryptedContent=obj.userinfo.encryptedContent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LoginOut(BaseModel):
 | 
				
			||||||
 | 
					    token: str
 | 
				
			||||||
 | 
					    user: UserOut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_orm(cls: t.Type["LoginOut"], obj: User) -> "LoginOut":
 | 
				
			||||||
 | 
					        token = AuthToken.objects.create(user=obj).key
 | 
				
			||||||
 | 
					        user = UserOut.from_orm(obj)
 | 
				
			||||||
 | 
					        return cls(token=token, user=user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Authentication(BaseModel):
 | 
					class Authentication(BaseModel):
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
        keep_untouched = (cached_property,)
 | 
					        keep_untouched = (cached_property,)
 | 
				
			||||||
@ -145,7 +170,7 @@ def __get_login_user(username: str) -> User:
 | 
				
			|||||||
        raise AuthenticationFailed(code="user_not_found", detail="User not found")
 | 
					        raise AuthenticationFailed(code="user_not_found", detail="User not found")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_login_user(challenge: LoginChallengeData) -> User:
 | 
					async def get_login_user(challenge: LoginChallengeIn) -> User:
 | 
				
			||||||
    user = await __get_login_user(challenge.username)
 | 
					    user = await __get_login_user(challenge.username)
 | 
				
			||||||
    return user
 | 
					    return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -161,7 +186,6 @@ def get_encryption_key(salt):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@sync_to_async
 | 
					 | 
				
			||||||
def save_changed_password(data: ChangePassword, user: User):
 | 
					def save_changed_password(data: ChangePassword, user: User):
 | 
				
			||||||
    response_data = data.response_data
 | 
					    response_data = data.response_data
 | 
				
			||||||
    user_info: UserInfo = user.userinfo
 | 
					    user_info: UserInfo = user.userinfo
 | 
				
			||||||
@ -170,24 +194,6 @@ def save_changed_password(data: ChangePassword, user: User):
 | 
				
			|||||||
    user_info.save()
 | 
					    user_info.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@sync_to_async
 | 
					 | 
				
			||||||
def login_response_data(user: User):
 | 
					 | 
				
			||||||
    return {
 | 
					 | 
				
			||||||
        "token": AuthToken.objects.create(user=user).key,
 | 
					 | 
				
			||||||
        "user": UserSerializer(user).data,
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@sync_to_async
 | 
					 | 
				
			||||||
def send_user_logged_in_async(user: User, request: Request):
 | 
					 | 
				
			||||||
    user_logged_in.send(sender=user.__class__, request=request, user=user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@sync_to_async
 | 
					 | 
				
			||||||
def send_user_logged_out_async(user: User, request: Request):
 | 
					 | 
				
			||||||
    user_logged_out.send(sender=user.__class__, request=request, user=user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@sync_to_async
 | 
					@sync_to_async
 | 
				
			||||||
def validate_login_request(
 | 
					def validate_login_request(
 | 
				
			||||||
    validated_data: LoginResponse,
 | 
					    validated_data: LoginResponse,
 | 
				
			||||||
@ -195,39 +201,26 @@ def validate_login_request(
 | 
				
			|||||||
    user: User,
 | 
					    user: User,
 | 
				
			||||||
    expected_action: str,
 | 
					    expected_action: str,
 | 
				
			||||||
    host_from_request: str,
 | 
					    host_from_request: str,
 | 
				
			||||||
) -> t.Optional[MsgpackResponse]:
 | 
					):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    enc_key = get_encryption_key(bytes(user.userinfo.salt))
 | 
					    enc_key = get_encryption_key(bytes(user.userinfo.salt))
 | 
				
			||||||
    box = nacl.secret.SecretBox(enc_key)
 | 
					    box = nacl.secret.SecretBox(enc_key)
 | 
				
			||||||
    challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
 | 
					    challenge_data = msgpack_decode(box.decrypt(validated_data.challenge))
 | 
				
			||||||
    now = int(datetime.now().timestamp())
 | 
					    now = int(datetime.now().timestamp())
 | 
				
			||||||
    if validated_data.action != expected_action:
 | 
					    if validated_data.action != expected_action:
 | 
				
			||||||
        content = {
 | 
					        raise ValidationError("wrong_action", f'Expected "{challenge_sent_to_user.response}" but got something else')
 | 
				
			||||||
            "code": "wrong_action",
 | 
					 | 
				
			||||||
            "detail": 'Expected "{}" but got something else'.format(challenge_sent_to_user.response),
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
 | 
					 | 
				
			||||||
    elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
 | 
					    elif now - challenge_data["timestamp"] > app_settings.CHALLENGE_VALID_SECONDS:
 | 
				
			||||||
        content = {"code": "challenge_expired", "detail": "Login challenge has expired"}
 | 
					        raise ValidationError("challenge_expired", "Login challenge has expired")
 | 
				
			||||||
        return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
 | 
					 | 
				
			||||||
    elif challenge_data["userId"] != user.id:
 | 
					    elif challenge_data["userId"] != user.id:
 | 
				
			||||||
        content = {"code": "wrong_user", "detail": "This challenge is for the wrong user"}
 | 
					        raise ValidationError("wrong_user", "This challenge is for the wrong user")
 | 
				
			||||||
        return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
 | 
					 | 
				
			||||||
    elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
 | 
					    elif not settings.DEBUG and validated_data.host.split(":", 1)[0] != host_from_request:
 | 
				
			||||||
        detail = 'Found wrong host name. Got: "{}" expected: "{}"'.format(validated_data.host, host_from_request)
 | 
					        raise ValidationError(
 | 
				
			||||||
        content = {"code": "wrong_host", "detail": detail}
 | 
					            "wrong_host", f'Found wrong host name. Got: "{validated_data.host}" expected: "{host_from_request}"'
 | 
				
			||||||
        return MsgpackResponse(content, status_code=status.HTTP_400_BAD_REQUEST)
 | 
					        )
 | 
				
			||||||
    verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
 | 
					    verify_key = nacl.signing.VerifyKey(bytes(user.userinfo.loginPubkey), encoder=nacl.encoding.RawEncoder)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
 | 
					        verify_key.verify(challenge_sent_to_user.response, challenge_sent_to_user.signature)
 | 
				
			||||||
    except nacl.exceptions.BadSignatureError:
 | 
					    except nacl.exceptions.BadSignatureError:
 | 
				
			||||||
        return MsgpackResponse(
 | 
					        raise ValidationError("login_bad_signature", "Wrong password for user.", status.HTTP_401_UNAUTHORIZED)
 | 
				
			||||||
            {"code": "login_bad_signature", "detail": "Wrong password for user."},
 | 
					 | 
				
			||||||
            status_code=status.HTTP_401_UNAUTHORIZED,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return None
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@authentication_router.post("/login_challenge/")
 | 
					@authentication_router.post("/login_challenge/")
 | 
				
			||||||
@ -239,35 +232,34 @@ async def login_challenge(user: User = Depends(get_login_user)):
 | 
				
			|||||||
        "userId": user.id,
 | 
					        "userId": user.id,
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
 | 
					    challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
 | 
				
			||||||
    return MsgpackResponse({"salt": user.userinfo.salt, "version": user.userinfo.version, "challenge": challenge})
 | 
					    return MsgpackResponse(
 | 
				
			||||||
 | 
					        LoginChallengeOut(salt=user.userinfo.salt, challenge=challenge, version=user.userinfo.version)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@authentication_router.post("/login/")
 | 
					@authentication_router.post("/login/")
 | 
				
			||||||
async def login(data: Login, request: Request):
 | 
					async def login(data: Login, request: Request):
 | 
				
			||||||
    user = await get_login_user(LoginChallengeData(username=data.response_data.username))
 | 
					    user = await get_login_user(LoginChallengeIn(username=data.response_data.username))
 | 
				
			||||||
    host = request.headers.get("Host")
 | 
					    host = request.headers.get("Host")
 | 
				
			||||||
    bad_login_response = await validate_login_request(data.response_data, data, user, "login", host)
 | 
					    await validate_login_request(data.response_data, data, user, "login", host)
 | 
				
			||||||
    if bad_login_response is not None:
 | 
					    data = await sync_to_async(LoginOut.from_orm)(user)
 | 
				
			||||||
        return bad_login_response
 | 
					    await sync_to_async(user_logged_in.send)(sender=user.__class__, request=None, user=user)
 | 
				
			||||||
    data = await login_response_data(user)
 | 
					    return MsgpackResponse(content=data, status_code=status.HTTP_200_OK)
 | 
				
			||||||
    await send_user_logged_in_async(user, request)
 | 
					 | 
				
			||||||
    return MsgpackResponse(data, status_code=status.HTTP_200_OK)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@authentication_router.post("/logout/")
 | 
					@authentication_router.post("/logout/")
 | 
				
			||||||
async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
 | 
					async def logout(request: Request, auth_data: AuthData = Depends(get_auth_data)):
 | 
				
			||||||
    await sync_to_async(auth_data.token.delete)()
 | 
					    await sync_to_async(auth_data.token.delete)()
 | 
				
			||||||
    await send_user_logged_out_async(auth_data.user, request)
 | 
					    # XXX-TOM
 | 
				
			||||||
 | 
					    await sync_to_async(user_logged_out.send)(sender=auth_data.user.__class__, request=None, user=auth_data.user)
 | 
				
			||||||
    return Response(status_code=status.HTTP_204_NO_CONTENT)
 | 
					    return Response(status_code=status.HTTP_204_NO_CONTENT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@authentication_router.post("/change_password/")
 | 
					@authentication_router.post("/change_password/")
 | 
				
			||||||
async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
 | 
					async def change_password(data: ChangePassword, request: Request, user: User = Depends(get_authenticated_user)):
 | 
				
			||||||
    host = request.headers.get("Host")
 | 
					    host = request.headers.get("Host")
 | 
				
			||||||
    bad_login_response = await validate_login_request(data.response_data, data, user, "changePassword", host)
 | 
					    await validate_login_request(data.response_data, data, user, "changePassword", host)
 | 
				
			||||||
    if bad_login_response is not None:
 | 
					    await sync_to_async(save_changed_password)(data, user)
 | 
				
			||||||
        return bad_login_response
 | 
					 | 
				
			||||||
    await save_changed_password(data, user)
 | 
					 | 
				
			||||||
    return Response(status_code=status.HTTP_204_NO_CONTENT)
 | 
					    return Response(status_code=status.HTTP_204_NO_CONTENT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -300,15 +292,10 @@ def signup_save(data: SignupIn) -> User:
 | 
				
			|||||||
    return instance
 | 
					    return instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@sync_to_async
 | 
					 | 
				
			||||||
def send_user_signed_up_async(user: User, request):
 | 
					 | 
				
			||||||
    user_signed_up.send(sender=user.__class__, request=request, user=user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@authentication_router.post("/signup/")
 | 
					@authentication_router.post("/signup/")
 | 
				
			||||||
async def signup(data: SignupIn):
 | 
					async def signup(data: SignupIn):
 | 
				
			||||||
    user = await sync_to_async(signup_save)(data)
 | 
					    user = await sync_to_async(signup_save)(data)
 | 
				
			||||||
    # XXX-TOM
 | 
					    # XXX-TOM
 | 
				
			||||||
    data = await login_response_data(user)
 | 
					    data = await sync_to_async(LoginOut.from_orm)(user)
 | 
				
			||||||
    await send_user_signed_up_async(user, None)
 | 
					    await sync_to_async(user_signed_up.send)(sender=user.__class__, request=None, user=user)
 | 
				
			||||||
    return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED)
 | 
					    return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user