Use dependency injection for getting collection/item queryset.

This commit is contained in:
Tom Hacohen 2020-12-27 22:27:33 +02:00
parent 8160a33384
commit df19887af7
3 changed files with 46 additions and 62 deletions

View File

@ -194,18 +194,19 @@ def collection_list_common(
return MsgpackResponse(content=ret) return MsgpackResponse(content=ret)
def get_collection_queryset(user: User) -> QuerySet: def get_collection_queryset(user: User = Depends(get_authenticated_user)) -> QuerySet:
return default_queryset.filter(members__user=user) return default_queryset.filter(members__user=user)
def get_item_queryset( def get_collection(collection_uid: str, queryset: QuerySet = Depends(get_collection_queryset)) -> models.Collection:
user: User, collection_uid: str, queryset: QuerySet = default_item_queryset return get_object_or_404(queryset, uid=collection_uid)
) -> t.Tuple[models.Collection, QuerySet]:
collection = get_object_or_404(get_collection_queryset(user), uid=collection_uid)
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
queryset = queryset.filter(collection__pk=collection.pk, revisions__current=True)
return collection, queryset
def get_item_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
# XXX Potentially add this for performance: .prefetch_related('revisions__chunks')
queryset = default_item_queryset.filter(collection__pk=collection.pk, revisions__current=True)
return queryset
@collection_router.post("/list_multi/") @collection_router.post("/list_multi/")
@ -213,11 +214,10 @@ async def list_multi(
data: ListMulti, data: ListMulti,
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
queryset: QuerySet = Depends(get_collection_queryset),
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
queryset = get_collection_queryset(user)
# FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration") # FIXME: Remove the isnull part once we attach collection types to all objects ("collection-type-migration")
queryset = queryset.filter( queryset = queryset.filter(
Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True) Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True)
@ -228,13 +228,12 @@ async def list_multi(
@collection_router.post("/list/") @collection_router.post("/list/")
async def collection_list( async def collection_list(
req: Request,
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_collection_queryset),
): ):
queryset = get_collection_queryset(user)
return await collection_list_common(queryset, user, stoken, limit, prefetch) return await collection_list_common(queryset, user, stoken, limit, prefetch)
@ -309,9 +308,12 @@ async def create(data: CollectionIn, user: User = Depends(get_authenticated_user
return MsgpackResponse({}, status_code=status.HTTP_201_CREATED) return MsgpackResponse({}, status_code=status.HTTP_201_CREATED)
@collection_router.get("/{uid}/") @collection_router.get("/{collection_uid}/")
def collection_get(uid: str, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery): def collection_get(
obj = get_collection_queryset(user).get(uid=uid) obj: models.Collection = Depends(get_collection),
user: User = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery
):
ret = CollectionOut.from_orm_context(obj, Context(user, prefetch)) ret = CollectionOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(ret) return MsgpackResponse(ret)
@ -358,9 +360,10 @@ def item_create(item_model: CollectionItemIn, collection: models.Collection, val
@collection_router.get("/{collection_uid}/item/{uid}/") @collection_router.get("/{collection_uid}/item/{uid}/")
def item_get( def item_get(
collection_uid: str, uid: str, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery uid: str,
queryset: QuerySet = Depends(get_item_queryset),
user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery,
): ):
_, queryset = get_item_queryset(user, collection_uid)
obj = queryset.get(uid=uid) obj = queryset.get(uid=uid)
ret = CollectionItemOut.from_orm_context(obj, Context(user, prefetch)) ret = CollectionItemOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(ret) return MsgpackResponse(ret)
@ -386,14 +389,13 @@ def item_list_common(
@collection_router.get("/{collection_uid}/item/") @collection_router.get("/{collection_uid}/item/")
async def item_list( async def item_list(
collection_uid: str, queryset: QuerySet = Depends(get_item_queryset),
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
withCollection: bool = False, withCollection: bool = False,
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
): ):
_, queryset = await sync_to_async(get_item_queryset)(user, collection_uid)
if not withCollection: if not withCollection:
queryset = queryset.filter(parent__isnull=True) queryset = queryset.filter(parent__isnull=True)
@ -419,14 +421,13 @@ def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid
@collection_router.get("/{collection_uid}/item/{uid}/revision/") @collection_router.get("/{collection_uid}/item/{uid}/revision/")
def item_revisions( def item_revisions(
collection_uid: str,
uid: str, uid: str,
limit: int = 50, limit: int = 50,
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
items: QuerySet = Depends(get_item_queryset),
): ):
_, items = get_item_queryset(user, collection_uid)
item = get_object_or_404(items, uid=uid) item = get_object_or_404(items, uid=uid)
queryset = item.revisions.order_by("-id") queryset = item.revisions.order_by("-id")
@ -456,13 +457,12 @@ def item_revisions(
@collection_router.post("/{collection_uid}/item/fetch_updates/") @collection_router.post("/{collection_uid}/item/fetch_updates/")
def fetch_updates( def fetch_updates(
collection_uid: str,
data: t.List[CollectionItemBulkGetIn], data: t.List[CollectionItemBulkGetIn],
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
user: User = Depends(get_authenticated_user), user: User = Depends(get_authenticated_user),
queryset: QuerySet = Depends(get_item_queryset),
): ):
_, queryset = get_item_queryset(user, collection_uid)
# FIXME: make configurable? # FIXME: make configurable?
item_limit = 200 item_limit = 200

View File

@ -73,12 +73,12 @@ class InvitationListResponse(BaseModel):
done: bool done: bool
def get_incoming_queryset(user: User, queryset=default_queryset): def get_incoming_queryset(user: User = Depends(get_authenticated_user)):
return queryset.filter(user=user) return default_queryset.filter(user=user)
def get_outgoing_queryset(user: User, queryset=default_queryset): def get_outgoing_queryset(user: User = Depends(get_authenticated_user)):
return queryset.filter(fromMember__user=user) return default_queryset.filter(fromMember__user=user)
def list_common( def list_common(
@ -114,17 +114,16 @@ def list_common(
def incoming_list( def incoming_list(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_incoming_queryset),
): ):
return list_common(get_incoming_queryset(user), iterator, limit) return list_common(queryset, iterator, limit)
@invitation_incoming_router.get("/{invitation_uid}/", response_model=CollectionInvitationOut) @invitation_incoming_router.get("/{invitation_uid}/", response_model=CollectionInvitationOut)
def incoming_get( def incoming_get(
invitation_uid: str, invitation_uid: str,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_incoming_queryset),
): ):
queryset = get_incoming_queryset(user)
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
ret = CollectionInvitationOut.from_orm(obj) ret = CollectionInvitationOut.from_orm(obj)
return MsgpackResponse(ret) return MsgpackResponse(ret)
@ -133,9 +132,8 @@ def incoming_get(
@invitation_incoming_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT) @invitation_incoming_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT)
def incoming_delete( def incoming_delete(
invitation_uid: str, invitation_uid: str,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_incoming_queryset),
): ):
queryset = get_incoming_queryset(user)
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
obj.delete() obj.delete()
@ -144,9 +142,8 @@ def incoming_delete(
def incoming_accept( def incoming_accept(
invitation_uid: str, invitation_uid: str,
data: CollectionInvitationAcceptIn, data: CollectionInvitationAcceptIn,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_incoming_queryset),
): ):
queryset = get_incoming_queryset(user)
invitation = get_object_or_404(queryset, uid=invitation_uid) invitation = get_object_or_404(queryset, uid=invitation_uid)
with transaction.atomic(): with transaction.atomic():
@ -201,17 +198,16 @@ def outgoing_create(
def outgoing_list( def outgoing_list(
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_outgoing_queryset),
): ):
return list_common(get_outgoing_queryset(user), iterator, limit) return list_common(queryset, iterator, limit)
@invitation_outgoing_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT) @invitation_outgoing_router.delete("/{invitation_uid}/", status_code=status.HTTP_204_NO_CONTENT)
def outgoing_delete( def outgoing_delete(
invitation_uid: str, invitation_uid: str,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_outgoing_queryset),
): ):
queryset = get_outgoing_queryset(user)
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
obj.delete() obj.delete()

View File

@ -12,15 +12,18 @@ from .msgpack import MsgpackResponse
from .utils import get_object_or_404 from .utils import get_object_or_404
from .stoken_handler import filter_by_stoken_and_limit from .stoken_handler import filter_by_stoken_and_limit
from .collection import collection_router, get_collection_queryset from .collection import collection_router, get_collection
User = get_user_model() User = get_user_model()
default_queryset: QuerySet = models.CollectionMember.objects.all() default_queryset: QuerySet = models.CollectionMember.objects.all()
def get_queryset(user: User, collection_uid: str, queryset=default_queryset) -> t.Tuple[models.Collection, QuerySet]: def get_queryset(collection: models.Collection = Depends(get_collection)) -> QuerySet:
collection = get_object_or_404(get_collection_queryset(user), uid=collection_uid) return default_queryset.filter(collection=collection)
return collection, queryset.filter(collection=collection)
def get_member(username: str, queryset: QuerySet = Depends(get_queryset)) -> QuerySet:
return get_object_or_404(queryset, user__username__iexact=username)
class CollectionMemberModifyAccessLevelIn(BaseModel): class CollectionMemberModifyAccessLevelIn(BaseModel):
@ -47,12 +50,10 @@ class MemberListResponse(BaseModel):
@collection_router.get("/{collection_uid}/member/", response_model=MemberListResponse) @collection_router.get("/{collection_uid}/member/", response_model=MemberListResponse)
def member_list( def member_list(
collection_uid: str,
iterator: t.Optional[str] = None, iterator: t.Optional[str] = None,
limit: int = 50, limit: int = 50,
user: User = Depends(get_authenticated_user), queryset: QuerySet = Depends(get_queryset),
): ):
_, queryset = get_queryset(user, collection_uid)
queryset = queryset.order_by("id") queryset = queryset.order_by("id")
result, new_stoken_obj, done = filter_by_stoken_and_limit( result, new_stoken_obj, done = filter_by_stoken_and_limit(
iterator, limit, queryset, models.CollectionMember.stoken_annotation iterator, limit, queryset, models.CollectionMember.stoken_annotation
@ -69,25 +70,16 @@ def member_list(
@collection_router.delete("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT) @collection_router.delete("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT)
def member_delete( def member_delete(
collection_uid: str, obj: models.CollectionMember = Depends(get_member),
username: str,
user: User = Depends(get_authenticated_user),
): ):
_, queryset = get_queryset(user, collection_uid)
obj = get_object_or_404(queryset, user__username__iexact=username)
obj.revoke() obj.revoke()
@collection_router.patch("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT) @collection_router.patch("/{collection_uid}/member/{username}/", status_code=status.HTTP_204_NO_CONTENT)
def member_patch( def member_patch(
collection_uid: str,
username: str,
data: CollectionMemberModifyAccessLevelIn, data: CollectionMemberModifyAccessLevelIn,
user: User = Depends(get_authenticated_user), instance: models.CollectionMember = Depends(get_member),
): ):
_, queryset = get_queryset(user, collection_uid)
instance = get_object_or_404(queryset, user__username__iexact=username)
with transaction.atomic(): with transaction.atomic():
# We only allow updating accessLevel # We only allow updating accessLevel
if instance.accessLevel != data.accessLevel: if instance.accessLevel != data.accessLevel:
@ -97,10 +89,6 @@ def member_patch(
@collection_router.post("/{collection_uid}/member/leave/", status_code=status.HTTP_204_NO_CONTENT) @collection_router.post("/{collection_uid}/member/leave/", status_code=status.HTTP_204_NO_CONTENT)
def member_leave( def member_leave(user: User = Depends(get_authenticated_user), collection: models.Collection = Depends(get_collection)):
collection_uid: str,
user: User = Depends(get_authenticated_user),
):
collection, _ = get_queryset(user, collection_uid)
obj = get_object_or_404(collection.members, user=user) obj = get_object_or_404(collection.members, user=user)
obj.revoke() obj.revoke()