Files
xmpp-api/src/xmpp_api/db/__init__.py

148 lines
4.1 KiB
Python

from enum import auto, Enum
from typing import Annotated, Generator, Sequence, cast
from fastapi import Depends, Request, HTTPException
from sqlmodel import SQLModel, Session, create_engine, select
from sqlalchemy import Engine
from xmpp_api.config.config import ConfigDep
from xmpp_api.db.user import User
from xmpp_api.db.bot import Bot, AllowedJid
def get_engine(config: ConfigDep) -> Engine:
return create_engine(
config.database.uri,
connect_args=config.database.connect_args,
)
EngineDep = Annotated[Engine, Depends(get_engine)]
def get_session(engine: EngineDep) -> Generator[Session]:
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
class TokenType(Enum):
"""
Types of authentication tokens.
"""
# Token is put into the X-Token header
X_TOKEN = auto()
# Token is put into the "Authorization" header with type "Bearer".
BEARER = auto()
def get_by_token(
cls: type[SQLModel],
request: Request,
session: SessionDep,
token_type: TokenType = TokenType.X_TOKEN,
) -> SQLModel:
token: str | None = None
match token_type:
case TokenType.X_TOKEN:
token = request.headers.get("X-Token")
case TokenType.BEARER:
token = request.headers.get("Authorization")
if token is not None:
parts = token.split(" ")
if len(parts) == 2 and parts[0] == "Bearer":
token = parts[1]
if token is None:
raise HTTPException(
detail="No token provided",
status_code=400,
)
obj = session.exec(select(cls).where(cls.token == token)).first() # type: ignore
if obj is None:
raise HTTPException(
detail="Unauthorized",
status_code=403,
)
return obj
def get_user(request: Request, session: SessionDep) -> User:
return cast(
User, get_by_token(User, request, session, token_type=TokenType.X_TOKEN)
)
UserDep = Annotated[User, Depends(get_user)]
def get_bot(request: Request, session: SessionDep) -> Bot:
return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.X_TOKEN))
BotDep = Annotated[Bot, Depends(get_bot)]
def get_authorization_bot(request: Request, session: SessionDep) -> Bot:
return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.BEARER))
AuthorizationBotDep = Annotated[Bot, Depends(get_authorization_bot)]
def get_bot_by_id(bot_id: str, user_id: str, session: SessionDep) -> Bot | None:
"""
Fetches the specified bot from the database.
Args
:bot_id The ID of the bot.
:user_id The ID of the authenticated user.
:session_dep The database session
Returns
Bot | None: The bot object, if found, or None.
"""
return session.exec(
select(Bot).where(Bot.id == bot_id, Bot.owner_id == user_id)
).first()
def get_jids_by_bot_id(bot_id: str, session: SessionDep) -> Sequence[AllowedJid]:
"""
Retrieve all AllowedJid objects associated with a given bot_id from the database.
Args:
bot_id (str): The ID of the bot for which to retrieve AllowedJids.
session (SessionDep): A FastAPI dependency that provides access to the database session.
Returns:
Sequence[AllowedJid]: A sequence of AllowedJid objects associated with the given bot_id.
"""
return session.exec(
select(AllowedJid).where(AllowedJid.bot_id == bot_id),
).all()
def get_jid_by_jid_token(jid_token: str, session: SessionDep) -> AllowedJid | None:
"""
Retrieve an AllowedJid object from the database based on the provided jid_token.
Args:
jid_token (str): The token associated with the AllowedJid to retrieve.
session (SessionDep): A FastAPI dependency that provides access to the database session.
Returns:
AllowedJid | None: The AllowedJid object associated with the given jid_token, or None if no match is found.
"""
return session.exec(
select(AllowedJid).where(AllowedJid.token == jid_token),
).first()