Initial commit
This commit is contained in:
121
src/xmpp_api/db/__init__.py
Normal file
121
src/xmpp_api/db/__init__.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import Annotated, Generator, Sequence, cast
|
||||
|
||||
from fastapi import Depends, Request, HTTPException
|
||||
from sqlmodel import SQLModel, Session, create_engine, select
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from xmpp_api.config.config import ConfigDep
|
||||
from xmpp_api.db.user import User
|
||||
from xmpp_api.db.bot import Bot, AllowedJid
|
||||
|
||||
|
||||
def get_engine(config: ConfigDep) -> Engine:
|
||||
return create_engine(
|
||||
config.database,
|
||||
connect_args={
|
||||
"check_same_thread": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
EngineDep = Annotated[Engine, Depends(get_engine)]
|
||||
|
||||
|
||||
def get_session(engine: EngineDep) -> Generator[Session]:
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
SessionDep = Annotated[Session, Depends(get_session)]
|
||||
|
||||
|
||||
def get_by_token(
|
||||
cls: type[SQLModel], request: Request, session: SessionDep
|
||||
) -> SQLModel:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization is None:
|
||||
raise HTTPException(
|
||||
detail="No authentication provided",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
auth_type, token = authorization.split(" ")
|
||||
if auth_type != "Bearer":
|
||||
raise HTTPException(
|
||||
detail="Invalid token type provided",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
obj = session.exec(select(cls).where(cls.token == token)).first()
|
||||
if obj is None:
|
||||
raise HTTPException(
|
||||
detail="Unauthorized",
|
||||
status_code=403,
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def get_user(request: Request, session: SessionDep) -> User:
|
||||
return cast(User, get_by_token(User, request, session))
|
||||
|
||||
|
||||
UserDep = Annotated[User, Depends(get_user)]
|
||||
|
||||
|
||||
def get_bot(request: Request, session: SessionDep) -> Bot:
|
||||
return cast(Bot, get_by_token(Bot, request, session))
|
||||
|
||||
|
||||
BotDep = Annotated[Bot, Depends(get_bot)]
|
||||
|
||||
|
||||
def get_bot_by_id(bot_id: str, user_id: str, session: SessionDep) -> Bot | None:
|
||||
"""
|
||||
Fetches the specified bot from the database.
|
||||
|
||||
Args
|
||||
:bot_id The ID of the bot.
|
||||
:user_id The ID of the authenticated user.
|
||||
:session_dep The database session
|
||||
|
||||
Returns
|
||||
Bot | None: The bot object, if found, or None.
|
||||
"""
|
||||
|
||||
return session.exec(
|
||||
select(Bot).where(Bot.id == bot_id, Bot.owner_id == user_id)
|
||||
).first()
|
||||
|
||||
|
||||
def get_jids_by_bot_id(bot_id: str, session: SessionDep) -> Sequence[AllowedJid]:
|
||||
"""
|
||||
Retrieve all AllowedJid objects associated with a given bot_id from the database.
|
||||
|
||||
Args:
|
||||
bot_id (str): The ID of the bot for which to retrieve AllowedJids.
|
||||
session (SessionDep): A FastAPI dependency that provides access to the database session.
|
||||
|
||||
Returns:
|
||||
Sequence[AllowedJid]: A sequence of AllowedJid objects associated with the given bot_id.
|
||||
"""
|
||||
|
||||
return session.exec(
|
||||
select(AllowedJid).where(AllowedJid.bot_id == bot_id),
|
||||
).all()
|
||||
|
||||
|
||||
def get_jid_by_jid_token(jid_token: str, session: SessionDep) -> AllowedJid | None:
|
||||
"""
|
||||
Retrieve an AllowedJid object from the database based on the provided jid_token.
|
||||
|
||||
Args:
|
||||
jid_token (str): The token associated with the AllowedJid to retrieve.
|
||||
session (SessionDep): A FastAPI dependency that provides access to the database session.
|
||||
|
||||
Returns:
|
||||
AllowedJid | None: The AllowedJid object associated with the given jid_token, or None if no match is found.
|
||||
"""
|
||||
|
||||
return session.exec(
|
||||
select(AllowedJid).where(AllowedJid.token == jid_token),
|
||||
).first()
|
||||
Reference in New Issue
Block a user