Signup: use the get_user_queryset function when checking if user exists.

This commit is contained in:
Tom Hacohen 2020-07-13 16:03:34 +03:00
parent af86d877f2
commit 46b4f08afa
2 changed files with 19 additions and 3 deletions

View File

@ -394,7 +394,9 @@ class AuthenticationSignupSerializer(serializers.Serializer):
with transaction.atomic(): with transaction.atomic():
try: try:
instance = User.objects.get_by_natural_key(user_data['username']) view = self.context.get('view', None)
user_queryset = get_user_queryset(User.objects.all(), view)
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data['username'].lower()})
except User.DoesNotExist: except User.DoesNotExist:
# Create the user and save the casing the user chose as the first name # Create the user and save the casing the user chose as the first name
instance = User.objects.create_user(**user_data, password=None, first_name=user_data['username']) instance = User.objects.create_user(**user_data, password=None, first_name=user_data['username'])

View File

@ -601,6 +601,13 @@ class AuthenticationViewSet(viewsets.ViewSet):
def get_queryset(self): def get_queryset(self):
return get_user_queryset(User.objects.all(), self) return get_user_queryset(User.objects.all(), self)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def login_response_data(self, user): def login_response_data(self, user):
return { return {
'token': AuthToken.objects.create(user=user).key, 'token': AuthToken.objects.create(user=user).key,
@ -612,7 +619,7 @@ class AuthenticationViewSet(viewsets.ViewSet):
@action_decorator(detail=False, methods=['POST']) @action_decorator(detail=False, methods=['POST'])
def signup(self, request, *args, **kwargs): def signup(self, request, *args, **kwargs):
serializer = AuthenticationSignupSerializer(data=request.data) serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context())
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
user = serializer.save() user = serializer.save()
@ -748,6 +755,13 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
renderer_classes = BaseViewSet.renderer_classes renderer_classes = BaseViewSet.renderer_classes
parser_classes = BaseViewSet.parser_classes parser_classes = BaseViewSet.parser_classes
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
@ -768,7 +782,7 @@ class TestAuthenticationViewSet(viewsets.ViewSet):
if hasattr(user, 'userinfo'): if hasattr(user, 'userinfo'):
user.userinfo.delete() user.userinfo.delete()
serializer = AuthenticationSignupSerializer(data=request.data) serializer = AuthenticationSignupSerializer(data=request.data, context=self.get_serializer_context())
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save() serializer.save()