2020-12-23 21:29:08 +00:00
|
|
|
import typing as t
|
2020-12-30 15:09:16 +00:00
|
|
|
|
2020-12-23 21:29:08 +00:00
|
|
|
from fastapi.routing import APIRoute, get_request_handler
|
2020-12-25 09:10:43 +00:00
|
|
|
from pydantic import BaseModel
|
2020-12-23 21:29:08 +00:00
|
|
|
from starlette.requests import Request
|
|
|
|
from starlette.responses import Response
|
|
|
|
|
2020-12-29 13:44:52 +00:00
|
|
|
from .utils import msgpack_encode, msgpack_decode
|
2020-12-30 15:09:16 +00:00
|
|
|
from .db_hack import django_db_cleanup_decorator
|
2020-12-29 13:44:52 +00:00
|
|
|
|
2020-12-23 21:29:08 +00:00
|
|
|
|
|
|
|
class MsgpackRequest(Request):
|
|
|
|
media_type = "application/msgpack"
|
|
|
|
|
|
|
|
async def json(self) -> bytes:
|
|
|
|
if not hasattr(self, "_json"):
|
|
|
|
body = await super().body()
|
2020-12-29 13:44:52 +00:00
|
|
|
self._json = msgpack_decode(body)
|
2020-12-23 21:29:08 +00:00
|
|
|
return self._json
|
|
|
|
|
|
|
|
|
|
|
|
class MsgpackResponse(Response):
|
|
|
|
media_type = "application/msgpack"
|
|
|
|
|
2020-12-29 11:22:36 +00:00
|
|
|
def render(self, content: t.Optional[t.Any]) -> bytes:
|
2020-12-27 19:01:14 +00:00
|
|
|
if content is None:
|
|
|
|
return b""
|
|
|
|
|
2020-12-25 09:10:43 +00:00
|
|
|
if isinstance(content, BaseModel):
|
2020-12-28 07:25:28 +00:00
|
|
|
content = content.dict()
|
2020-12-29 13:44:52 +00:00
|
|
|
return msgpack_encode(content)
|
2020-12-23 21:29:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MsgpackRoute(APIRoute):
|
|
|
|
# keep track of content-type -> request classes
|
|
|
|
REQUESTS_CLASSES = {MsgpackRequest.media_type: MsgpackRequest}
|
|
|
|
# keep track of content-type -> response classes
|
|
|
|
ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse}
|
|
|
|
|
2020-12-31 08:03:16 +00:00
|
|
|
def __init__(self, path: str, endpoint: t.Callable[..., t.Any], *args, **kwargs):
|
2020-12-30 15:09:16 +00:00
|
|
|
endpoint = django_db_cleanup_decorator(endpoint)
|
2020-12-31 08:03:16 +00:00
|
|
|
super().__init__(path, endpoint, *args, **kwargs)
|
2020-12-30 15:09:16 +00:00
|
|
|
|
2020-12-23 21:29:08 +00:00
|
|
|
def _get_media_type_route_handler(self, media_type):
|
|
|
|
return get_request_handler(
|
|
|
|
dependant=self.dependant,
|
|
|
|
body_field=self.body_field,
|
|
|
|
status_code=self.status_code,
|
|
|
|
# use custom response class or fallback on default self.response_class
|
|
|
|
response_class=self.ROUTES_HANDLERS_CLASSES.get(media_type, self.response_class),
|
|
|
|
response_field=self.secure_cloned_response_field,
|
|
|
|
response_model_include=self.response_model_include,
|
|
|
|
response_model_exclude=self.response_model_exclude,
|
|
|
|
response_model_by_alias=self.response_model_by_alias,
|
|
|
|
response_model_exclude_unset=self.response_model_exclude_unset,
|
|
|
|
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
|
|
|
response_model_exclude_none=self.response_model_exclude_none,
|
|
|
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_route_handler(self) -> t.Callable:
|
|
|
|
async def custom_route_handler(request: Request) -> Response:
|
|
|
|
|
|
|
|
content_type = request.headers.get("Content-Type")
|
|
|
|
try:
|
|
|
|
request_cls = self.REQUESTS_CLASSES[content_type]
|
|
|
|
request = request_cls(request.scope, request.receive)
|
|
|
|
except KeyError:
|
|
|
|
# nothing registered to handle content_type, process given requests as-is
|
|
|
|
pass
|
|
|
|
|
|
|
|
accept = request.headers.get("Accept")
|
|
|
|
route_handler = self._get_media_type_route_handler(accept)
|
|
|
|
return await route_handler(request)
|
|
|
|
|
|
|
|
return custom_route_handler
|