etesync-server/etebase_fastapi/msgpack.py
Tom Hacohen 5b8f667e55 Cleanup django db connections before every request and every dependency.
This is instead of the commit we reverted in the previous commit.
The problem is that django keeps the connection per thread and it relies
on django itself to clean them up before/after connections.
We can't do this, because django is unaware of fastapi, so we have to
manage this ourselves.

The easiest way is to call it at the beginning of evenry route and every dep.
We need to do it for each because unfortunately fastapi may send them to
different worker threads.
2020-12-30 17:16:58 +02:00

90 lines
3.3 KiB
Python

import typing as t
from fastapi import params
from fastapi.routing import APIRoute, get_request_handler
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import Response
from .utils import msgpack_encode, msgpack_decode
from .db_hack import django_db_cleanup_decorator
class MsgpackRequest(Request):
media_type = "application/msgpack"
async def json(self) -> bytes:
if not hasattr(self, "_json"):
body = await super().body()
self._json = msgpack_decode(body)
return self._json
class MsgpackResponse(Response):
media_type = "application/msgpack"
def render(self, content: t.Optional[t.Any]) -> bytes:
if content is None:
return b""
if isinstance(content, BaseModel):
content = content.dict()
return msgpack_encode(content)
class MsgpackRoute(APIRoute):
# keep track of content-type -> request classes
REQUESTS_CLASSES = {MsgpackRequest.media_type: MsgpackRequest}
# keep track of content-type -> response classes
ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse}
def __init__(
self,
path: str,
endpoint: t.Callable[..., t.Any],
*args,
dependencies: t.Optional[t.Sequence[params.Depends]] = None,
**kwargs
):
if dependencies is not None:
dependencies = [
params.Depends(django_db_cleanup_decorator(dep.dependency), use_cache=dep.use_cache)
for dep in dependencies
]
endpoint = django_db_cleanup_decorator(endpoint)
super().__init__(path, endpoint, *args, dependencies=dependencies, **kwargs)
def _get_media_type_route_handler(self, media_type):
return get_request_handler(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
# use custom response class or fallback on default self.response_class
response_class=self.ROUTES_HANDLERS_CLASSES.get(media_type, self.response_class),
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
)
def get_route_handler(self) -> t.Callable:
async def custom_route_handler(request: Request) -> Response:
content_type = request.headers.get("Content-Type")
try:
request_cls = self.REQUESTS_CLASSES[content_type]
request = request_cls(request.scope, request.receive)
except KeyError:
# nothing registered to handle content_type, process given requests as-is
pass
accept = request.headers.get("Accept")
route_handler = self._get_media_type_route_handler(accept)
return await route_handler(request)
return custom_route_handler