Initial commit

This commit is contained in:
2025-04-20 23:22:04 +02:00
commit 054f182215
17 changed files with 1312 additions and 0 deletions

121
src/xmpp_api/db/__init__.py Normal file
View File

@@ -0,0 +1,121 @@
from typing import Annotated, Generator, Sequence, cast
from fastapi import Depends, Request, HTTPException
from sqlmodel import SQLModel, Session, create_engine, select
from sqlalchemy import Engine
from xmpp_api.config.config import ConfigDep
from xmpp_api.db.user import User
from xmpp_api.db.bot import Bot, AllowedJid
def get_engine(config: ConfigDep) -> Engine:
return create_engine(
config.database,
connect_args={
"check_same_thread": False,
},
)
EngineDep = Annotated[Engine, Depends(get_engine)]
def get_session(engine: EngineDep) -> Generator[Session]:
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
def get_by_token(
cls: type[SQLModel], request: Request, session: SessionDep
) -> SQLModel:
authorization = request.headers.get("Authorization")
if authorization is None:
raise HTTPException(
detail="No authentication provided",
status_code=400,
)
auth_type, token = authorization.split(" ")
if auth_type != "Bearer":
raise HTTPException(
detail="Invalid token type provided",
status_code=400,
)
obj = session.exec(select(cls).where(cls.token == token)).first()
if obj is None:
raise HTTPException(
detail="Unauthorized",
status_code=403,
)
return obj
def get_user(request: Request, session: SessionDep) -> User:
return cast(User, get_by_token(User, request, session))
UserDep = Annotated[User, Depends(get_user)]
def get_bot(request: Request, session: SessionDep) -> Bot:
return cast(Bot, get_by_token(Bot, request, session))
BotDep = Annotated[Bot, Depends(get_bot)]
def get_bot_by_id(bot_id: str, user_id: str, session: SessionDep) -> Bot | None:
"""
Fetches the specified bot from the database.
Args
:bot_id The ID of the bot.
:user_id The ID of the authenticated user.
:session_dep The database session
Returns
Bot | None: The bot object, if found, or None.
"""
return session.exec(
select(Bot).where(Bot.id == bot_id, Bot.owner_id == user_id)
).first()
def get_jids_by_bot_id(bot_id: str, session: SessionDep) -> Sequence[AllowedJid]:
"""
Retrieve all AllowedJid objects associated with a given bot_id from the database.
Args:
bot_id (str): The ID of the bot for which to retrieve AllowedJids.
session (SessionDep): A FastAPI dependency that provides access to the database session.
Returns:
Sequence[AllowedJid]: A sequence of AllowedJid objects associated with the given bot_id.
"""
return session.exec(
select(AllowedJid).where(AllowedJid.bot_id == bot_id),
).all()
def get_jid_by_jid_token(jid_token: str, session: SessionDep) -> AllowedJid | None:
"""
Retrieve an AllowedJid object from the database based on the provided jid_token.
Args:
jid_token (str): The token associated with the AllowedJid to retrieve.
session (SessionDep): A FastAPI dependency that provides access to the database session.
Returns:
AllowedJid | None: The AllowedJid object associated with the given jid_token, or None if no match is found.
"""
return session.exec(
select(AllowedJid).where(AllowedJid.token == jid_token),
).first()