Pass generic context to callbacks instead of the whole view
This commit is contained in:
parent
5a6c8a1d05
commit
c2eb4fd30c
@ -20,7 +20,7 @@ from django.contrib.auth import get_user_model
|
||||
from django.db import IntegrityError, transaction
|
||||
from rest_framework import serializers, status
|
||||
from . import models
|
||||
from .utils import get_user_queryset, create_user
|
||||
from .utils import get_user_queryset, create_user, CallbackContext
|
||||
|
||||
from .exceptions import EtebaseValidationError
|
||||
|
||||
@ -102,7 +102,7 @@ class CollectionTypeField(BinaryBase64Field):
|
||||
class UserSlugRelatedField(serializers.SlugRelatedField):
|
||||
def get_queryset(self):
|
||||
view = self.context.get("view", None)
|
||||
return get_user_queryset(super().get_queryset(), view)
|
||||
return get_user_queryset(super().get_queryset(), context=CallbackContext(view.kwargs))
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(slug_field=User.USERNAME_FIELD, **kwargs)
|
||||
@ -515,12 +515,17 @@ class AuthenticationSignupSerializer(BetterErrorsMixin, serializers.Serializer):
|
||||
with transaction.atomic():
|
||||
try:
|
||||
view = self.context.get("view", None)
|
||||
user_queryset = get_user_queryset(User.objects.all(), view)
|
||||
user_queryset = get_user_queryset(User.objects.all(), context=CallbackContext(view.kwargs))
|
||||
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data["username"].lower()})
|
||||
except User.DoesNotExist:
|
||||
# Create the user and save the casing the user chose as the first name
|
||||
try:
|
||||
instance = create_user(**user_data, password=None, first_name=user_data["username"], view=view)
|
||||
instance = create_user(
|
||||
**user_data,
|
||||
password=None,
|
||||
first_name=user_data["username"],
|
||||
context=CallbackContext(view.kwargs)
|
||||
)
|
||||
instance.full_clean()
|
||||
except EtebaseValidationError as e:
|
||||
raise e
|
||||
|
@ -1,3 +1,6 @@
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import PermissionDenied
|
||||
|
||||
@ -7,18 +10,24 @@ from . import app_settings
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def get_user_queryset(queryset, view):
|
||||
@dataclass
|
||||
class CallbackContext:
|
||||
"""Class for passing extra context to callbacks"""
|
||||
|
||||
url_kwargs: t.Dict[str, t.Any]
|
||||
|
||||
|
||||
def get_user_queryset(queryset, context: CallbackContext):
|
||||
custom_func = app_settings.GET_USER_QUERYSET_FUNC
|
||||
if custom_func is not None:
|
||||
return custom_func(queryset, view)
|
||||
return custom_func(queryset, context)
|
||||
return queryset
|
||||
|
||||
|
||||
def create_user(*args, **kwargs):
|
||||
def create_user(context: CallbackContext, *args, **kwargs):
|
||||
custom_func = app_settings.CREATE_USER_FUNC
|
||||
if custom_func is not None:
|
||||
return custom_func(*args, **kwargs)
|
||||
_ = kwargs.pop("view")
|
||||
return User.objects.create_user(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -73,7 +73,7 @@ from .serializers import (
|
||||
UserInfoPubkeySerializer,
|
||||
UserSerializer,
|
||||
)
|
||||
from .utils import get_user_queryset
|
||||
from .utils import get_user_queryset, CallbackContext
|
||||
from .exceptions import EtebaseValidationError
|
||||
from .parsers import ChunkUploadParser
|
||||
from .signals import user_signed_up
|
||||
@ -598,7 +598,7 @@ class InvitationOutgoingViewSet(InvitationBaseViewSet):
|
||||
def fetch_user_profile(self, request, *args, **kwargs):
|
||||
username = request.GET.get("username")
|
||||
kwargs = {User.USERNAME_FIELD: username.lower()}
|
||||
user = get_object_or_404(get_user_queryset(User.objects.all(), self), **kwargs)
|
||||
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(self.kwargs)), **kwargs)
|
||||
user_info = get_object_or_404(UserInfo.objects.all(), owner=user)
|
||||
serializer = UserInfoPubkeySerializer(user_info)
|
||||
return Response(serializer.data)
|
||||
@ -642,7 +642,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
|
||||
)
|
||||
|
||||
def get_queryset(self):
|
||||
return get_user_queryset(User.objects.all(), self)
|
||||
return get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {"request": self.request, "format": self.format_kwarg, "view": self}
|
||||
@ -837,7 +837,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
|
||||
return HttpResponseBadRequest("Only allowed in debug mode.")
|
||||
|
||||
with transaction.atomic():
|
||||
user_queryset = get_user_queryset(User.objects.all(), self)
|
||||
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(self.kwargs))
|
||||
user = get_object_or_404(user_queryset, username=request.data.get("user").get("username"))
|
||||
|
||||
# Only allow test users for extra safety
|
||||
|
Loading…
Reference in New Issue
Block a user