Change to standalone stoken objects (+ small optimisation).

Makes it possible to now generate Stokens as we need so we can add them to
non-revision objects, for example, membership changes.

We also slightly improved how we filter by revs.
This commit is contained in:
Tom Hacohen
2020-05-26 18:52:44 +03:00
parent 3cdb7783fe
commit 2a39f3538e
3 changed files with 33 additions and 12 deletions

View File

@@ -35,7 +35,15 @@ import nacl.secret
import nacl.hash
from . import app_settings, permissions
from .models import Collection, CollectionItem, CollectionItemRevision, CollectionMember, CollectionInvitation, UserInfo
from .models import (
Collection,
CollectionItem,
CollectionItemRevision,
CollectionMember,
CollectionInvitation,
Stoken,
UserInfo,
)
from .serializers import (
b64encode,
AuthenticationSignupSerializer,
@@ -94,18 +102,18 @@ class BaseViewSet(viewsets.ModelViewSet):
user = self.request.user
return queryset.filter(members__user=user)
def get_cstoken_rev(self, request):
def get_cstoken_obj(self, request):
cstoken = request.GET.get('cstoken', None)
if cstoken is not None:
return get_object_or_404(CollectionItemRevision.objects.all(), uid=cstoken)
return get_object_or_404(Stoken.objects.all(), uid=cstoken)
return None
def filter_by_cstoken(self, request, queryset):
cstoken_id_field = self.cstoken_id_field + '__id'
cstoken_rev = self.get_cstoken_rev(request)
cstoken_rev = self.get_cstoken_obj(request)
if cstoken_rev is not None:
filter_by = {cstoken_id_field + '__gt': cstoken_rev.id}
queryset = queryset.filter(**filter_by)
@@ -116,7 +124,7 @@ class BaseViewSet(viewsets.ModelViewSet):
cstoken_id_field = self.cstoken_id_field + '__id'
new_cstoken_id = queryset.aggregate(cstoken_id=Max(cstoken_id_field))['cstoken_id']
new_cstoken = new_cstoken_id and CollectionItemRevision.objects.get(id=new_cstoken_id).uid
new_cstoken = new_cstoken_id and Stoken.objects.get(id=new_cstoken_id).uid
return queryset, new_cstoken
@@ -139,7 +147,7 @@ class CollectionViewSet(BaseViewSet):
queryset = Collection.objects.all()
serializer_class = CollectionSerializer
lookup_field = 'uid'
cstoken_id_field = 'items__revisions'
cstoken_id_field = 'items__revisions__stoken'
def get_queryset(self, queryset=None):
if queryset is None:
@@ -199,7 +207,7 @@ class CollectionItemViewSet(BaseViewSet):
queryset = CollectionItem.objects.all()
serializer_class = CollectionItemSerializer
lookup_field = 'uid'
cstoken_id_field = 'revisions'
cstoken_id_field = 'revisions__stoken'
def get_queryset(self):
collection_uid = self.kwargs['collection_uid']
@@ -290,8 +298,8 @@ class CollectionItemViewSet(BaseViewSet):
queryset, cstoken_rev = self.filter_by_cstoken(request, queryset)
uids, stokens = zip(*[(item['uid'], item.get('stoken')) for item in serializer.validated_data])
rev_ids = CollectionItemRevision.objects.filter(uid__in=stokens, current=True).values_list('id', flat=True)
queryset = queryset.filter(uid__in=uids).exclude(revisions__id__in=rev_ids)
revs = CollectionItemRevision.objects.filter(stoken__uid__in=stokens, current=True)
queryset = queryset.filter(uid__in=uids).exclude(revisions__in=revs)
queryset, new_cstoken = self.get_queryset_cstoken(queryset)
cstoken = cstoken_rev and cstoken_rev.uid