from enum import auto, Enum from typing import Annotated, Generator, Sequence, cast from fastapi import Depends, Request, HTTPException from sqlmodel import SQLModel, Session, create_engine, select from sqlalchemy import Engine from xmpp_api.config.config import ConfigDep from xmpp_api.db.user import User from xmpp_api.db.bot import Bot, AllowedJid def get_engine(config: ConfigDep) -> Engine: return create_engine( config.database.uri, connect_args=config.database.connect_args, ) EngineDep = Annotated[Engine, Depends(get_engine)] def get_session(engine: EngineDep) -> Generator[Session]: with Session(engine) as session: yield session SessionDep = Annotated[Session, Depends(get_session)] class TokenType(Enum): """ Types of authentication tokens. """ # Token is put into the X-Token header X_TOKEN = auto() # Token is put into the "Authorization" header with type "Bearer". BEARER = auto() def get_by_token( cls: type[SQLModel], request: Request, session: SessionDep, token_type: TokenType = TokenType.X_TOKEN, ) -> SQLModel: token: str | None = None match token_type: case TokenType.X_TOKEN: token = request.headers.get("X-Token") case TokenType.BEARER: token = request.headers.get("Authorization") if token is not None: parts = token.split(" ") if len(parts) == 2 and parts[0] == "Bearer": token = parts[1] if token is None: raise HTTPException( detail="No token provided", status_code=400, ) obj = session.exec(select(cls).where(cls.token == token)).first() # type: ignore if obj is None: raise HTTPException( detail="Unauthorized", status_code=403, ) return obj def get_user(request: Request, session: SessionDep) -> User: return cast( User, get_by_token(User, request, session, token_type=TokenType.X_TOKEN) ) UserDep = Annotated[User, Depends(get_user)] def get_bot(request: Request, session: SessionDep) -> Bot: return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.X_TOKEN)) BotDep = Annotated[Bot, Depends(get_bot)] def get_authorization_bot(request: Request, session: SessionDep) -> Bot: return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.BEARER)) AuthorizationBotDep = Annotated[Bot, Depends(get_authorization_bot)] def get_bot_by_id(bot_id: str, user_id: str, session: SessionDep) -> Bot | None: """ Fetches the specified bot from the database. Args :bot_id The ID of the bot. :user_id The ID of the authenticated user. :session_dep The database session Returns Bot | None: The bot object, if found, or None. """ return session.exec( select(Bot).where(Bot.id == bot_id, Bot.owner_id == user_id) ).first() def get_jids_by_bot_id(bot_id: str, session: SessionDep) -> Sequence[AllowedJid]: """ Retrieve all AllowedJid objects associated with a given bot_id from the database. Args: bot_id (str): The ID of the bot for which to retrieve AllowedJids. session (SessionDep): A FastAPI dependency that provides access to the database session. Returns: Sequence[AllowedJid]: A sequence of AllowedJid objects associated with the given bot_id. """ return session.exec( select(AllowedJid).where(AllowedJid.bot_id == bot_id), ).all() def get_jid_by_jid_token(jid_token: str, session: SessionDep) -> AllowedJid | None: """ Retrieve an AllowedJid object from the database based on the provided jid_token. Args: jid_token (str): The token associated with the AllowedJid to retrieve. session (SessionDep): A FastAPI dependency that provides access to the database session. Returns: AllowedJid | None: The AllowedJid object associated with the given jid_token, or None if no match is found. """ return session.exec( select(AllowedJid).where(AllowedJid.token == jid_token), ).first()