148 lines
4.1 KiB
Python
148 lines
4.1 KiB
Python
from enum import auto, Enum
|
|
from typing import Annotated, Generator, Sequence, cast
|
|
|
|
from fastapi import Depends, Request, HTTPException
|
|
from sqlmodel import SQLModel, Session, create_engine, select
|
|
from sqlalchemy import Engine
|
|
|
|
from xmpp_api.config.config import ConfigDep
|
|
from xmpp_api.db.user import User
|
|
from xmpp_api.db.bot import Bot, AllowedJid
|
|
|
|
|
|
def get_engine(config: ConfigDep) -> Engine:
|
|
return create_engine(
|
|
config.database.uri,
|
|
connect_args=config.database.connect_args,
|
|
)
|
|
|
|
|
|
EngineDep = Annotated[Engine, Depends(get_engine)]
|
|
|
|
|
|
def get_session(engine: EngineDep) -> Generator[Session]:
|
|
with Session(engine) as session:
|
|
yield session
|
|
|
|
|
|
SessionDep = Annotated[Session, Depends(get_session)]
|
|
|
|
|
|
class TokenType(Enum):
|
|
"""
|
|
Types of authentication tokens.
|
|
"""
|
|
|
|
# Token is put into the X-Token header
|
|
X_TOKEN = auto()
|
|
|
|
# Token is put into the "Authorization" header with type "Bearer".
|
|
BEARER = auto()
|
|
|
|
|
|
def get_by_token(
|
|
cls: type[SQLModel],
|
|
request: Request,
|
|
session: SessionDep,
|
|
token_type: TokenType = TokenType.X_TOKEN,
|
|
) -> SQLModel:
|
|
token: str | None = None
|
|
match token_type:
|
|
case TokenType.X_TOKEN:
|
|
token = request.headers.get("X-Token")
|
|
case TokenType.BEARER:
|
|
token = request.headers.get("Authorization")
|
|
if token is not None:
|
|
parts = token.split(" ")
|
|
if len(parts) == 2 and parts[0] == "Bearer":
|
|
token = parts[1]
|
|
|
|
if token is None:
|
|
raise HTTPException(
|
|
detail="No token provided",
|
|
status_code=400,
|
|
)
|
|
|
|
obj = session.exec(select(cls).where(cls.token == token)).first() # type: ignore
|
|
if obj is None:
|
|
raise HTTPException(
|
|
detail="Unauthorized",
|
|
status_code=403,
|
|
)
|
|
return obj
|
|
|
|
|
|
def get_user(request: Request, session: SessionDep) -> User:
|
|
return cast(
|
|
User, get_by_token(User, request, session, token_type=TokenType.X_TOKEN)
|
|
)
|
|
|
|
|
|
UserDep = Annotated[User, Depends(get_user)]
|
|
|
|
|
|
def get_bot(request: Request, session: SessionDep) -> Bot:
|
|
return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.X_TOKEN))
|
|
|
|
|
|
BotDep = Annotated[Bot, Depends(get_bot)]
|
|
|
|
|
|
def get_authorization_bot(request: Request, session: SessionDep) -> Bot:
|
|
return cast(Bot, get_by_token(Bot, request, session, token_type=TokenType.BEARER))
|
|
|
|
|
|
AuthorizationBotDep = Annotated[Bot, Depends(get_authorization_bot)]
|
|
|
|
|
|
def get_bot_by_id(bot_id: str, user_id: str, session: SessionDep) -> Bot | None:
|
|
"""
|
|
Fetches the specified bot from the database.
|
|
|
|
Args
|
|
:bot_id The ID of the bot.
|
|
:user_id The ID of the authenticated user.
|
|
:session_dep The database session
|
|
|
|
Returns
|
|
Bot | None: The bot object, if found, or None.
|
|
"""
|
|
|
|
return session.exec(
|
|
select(Bot).where(Bot.id == bot_id, Bot.owner_id == user_id)
|
|
).first()
|
|
|
|
|
|
def get_jids_by_bot_id(bot_id: str, session: SessionDep) -> Sequence[AllowedJid]:
|
|
"""
|
|
Retrieve all AllowedJid objects associated with a given bot_id from the database.
|
|
|
|
Args:
|
|
bot_id (str): The ID of the bot for which to retrieve AllowedJids.
|
|
session (SessionDep): A FastAPI dependency that provides access to the database session.
|
|
|
|
Returns:
|
|
Sequence[AllowedJid]: A sequence of AllowedJid objects associated with the given bot_id.
|
|
"""
|
|
|
|
return session.exec(
|
|
select(AllowedJid).where(AllowedJid.bot_id == bot_id),
|
|
).all()
|
|
|
|
|
|
def get_jid_by_jid_token(jid_token: str, session: SessionDep) -> AllowedJid | None:
|
|
"""
|
|
Retrieve an AllowedJid object from the database based on the provided jid_token.
|
|
|
|
Args:
|
|
jid_token (str): The token associated with the AllowedJid to retrieve.
|
|
session (SessionDep): A FastAPI dependency that provides access to the database session.
|
|
|
|
Returns:
|
|
AllowedJid | None: The AllowedJid object associated with the given jid_token, or None if no match is found.
|
|
"""
|
|
|
|
return session.exec(
|
|
select(AllowedJid).where(AllowedJid.token == jid_token),
|
|
).first()
|