Pass generic context to callbacks instead of the whole view
This commit is contained in:
parent
5a6c8a1d05
commit
c2eb4fd30c
@ -20,7 +20,7 @@ from django.contrib.auth import get_user_model
|
|||||||
from django.db import IntegrityError, transaction
|
from django.db import IntegrityError, transaction
|
||||||
from rest_framework import serializers, status
|
from rest_framework import serializers, status
|
||||||
from . import models
|
from . import models
|
||||||
from .utils import get_user_queryset, create_user
|
from .utils import get_user_queryset, create_user, CallbackContext
|
||||||
|
|
||||||
from .exceptions import EtebaseValidationError
|
from .exceptions import EtebaseValidationError
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class CollectionTypeField(BinaryBase64Field):
|
|||||||
class UserSlugRelatedField(serializers.SlugRelatedField):
|
class UserSlugRelatedField(serializers.SlugRelatedField):
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
view = self.context.get("view", None)
|
view = self.context.get("view", None)
|
||||||
return get_user_queryset(super().get_queryset(), view)
|
return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs))
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
|
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
|
||||||
@ -515,12 +515,17 @@ class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer):
|
|||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
try:
|
try:
|
||||||
view = self.context.get("view", None)
|
view = self.context.get("view", None)
|
||||||
user_queryset = get_user_queryset(User.objects.all(), view)
|
user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs))
|
||||||
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
|
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
# Create the user and save the casing the user chose as the first name
|
# Create the user and save the casing the user chose as the first name
|
||||||
try:
|
try:
|
||||||
instance = create_user(**user_data, password=None, first_name=user_data["username"], view=view)
|
instance = create_user(
|
||||||
|
**user_data,
|
||||||
|
password=None,
|
||||||
|
first_name=user_data["username"],
|
||||||
|
context=CallbackContext(view.kwargs)
|
||||||
|
)
|
||||||
instance.full_clean()
|
instance.full_clean()
|
||||||
except EtebaseValidationError as e:
|
except EtebaseValidationError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
import typing as t
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
|
|
||||||
@ -7,18 +10,24 @@ from . import app_settings
|
|||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
def get_user_queryset(queryset, view):
|
@dataclass
|
||||||
|
class CallbackContext:
|
||||||
|
"""Class for passing extra context to callbacks"""
|
||||||
|
|
||||||
|
url_kwargs: t.Dict[str, t.Any]
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_queryset(queryset, context: CallbackContext):
|
||||||
custom_func = app_settings.GET_USER_QUERYSET_FUNC
|
custom_func = app_settings.GET_USER_QUERYSET_FUNC
|
||||||
if custom_func is not None:
|
if custom_func is not None:
|
||||||
return custom_func(queryset, view)
|
return custom_func(queryset, context)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
|
||||||
def create_user(*args, **kwargs):
|
def create_user(context: CallbackContext, *args, **kwargs):
|
||||||
custom_func = app_settings.CREATE_USER_FUNC
|
custom_func = app_settings.CREATE_USER_FUNC
|
||||||
if custom_func is not None:
|
if custom_func is not None:
|
||||||
return custom_func(*args, **kwargs)
|
return custom_func(*args, **kwargs)
|
||||||
_ = kwargs.pop("view")
|
|
||||||
return User.objects.create_user(*args, **kwargs)
|
return User.objects.create_user(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ from .serializers import (
|
|||||||
UserInfoPubkeySerializer,
|
UserInfoPubkeySerializer,
|
||||||
UserSerializer,
|
UserSerializer,
|
||||||
)
|
)
|
||||||
from .utils import get_user_queryset
|
from .utils import get_user_queryset, CallbackContext
|
||||||
from .exceptions import EtebaseValidationError
|
from .exceptions import EtebaseValidationError
|
||||||
from .parsers import ChunkUploadParser
|
from .parsers import ChunkUploadParser
|
||||||
from .signals import user_signed_up
|
from .signals import user_signed_up
|
||||||
@ -598,7 +598,7 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet):
|
|||||||
def fetch_user_profile(self, request, *args, **kwargs):
|
def fetch_user_profile(self, request, *args, **kwargs):
|
||||||
username = request.GET.get("username")
|
username = request.GET.get("username")
|
||||||
kwargs = {User.USERNAME_FIELD: username.lower()}
|
kwargs = {User.USERNAME_FIELD: username.lower()}
|
||||||
user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs)
|
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs)
|
||||||
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
|
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
|
||||||
serializer = UserInfoPubkeySerializer(user_info)
|
serializer = UserInfoPubkeySerializer(user_info)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
@ -642,7 +642,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
return get_user_queryset(User.objects.all(), self)
|
return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||||
|
|
||||||
def get_serializer_context(self):
|
def get_serializer_context(self):
|
||||||
return {"request": self.request, "format": self.format_kwarg, "view": self}
|
return {"request": self.request, "format": self.format_kwarg, "view": self}
|
||||||
@ -837,7 +837,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
|
|||||||
return HttpResponseBadRequest("Only allowed in debug mode.")
|
return HttpResponseBadRequest("Only allowed in debug mode.")
|
||||||
|
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
user_queryset = get_user_queryset(User.objects.all(), self)
|
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||||
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
|
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
|
||||||
|
|
||||||
# Only allow test users for extra safety
|
# Only allow test users for extra safety
|
||||||
|
Loading…
Reference in New Issue
Block a user