Add domain constraint validation

This commit is contained in:
2025-04-20 23:31:56 +02:00
parent 054f182215
commit 0fb0d62fb7
5 changed files with 130 additions and 3 deletions

View File

@@ -46,7 +46,7 @@ def get_by_token(
status_code=400,
)
obj = session.exec(select(cls).where(cls.token == token)).first()
obj = session.exec(select(cls).where(cls.token == token)).first() # type: ignore
if obj is None:
raise HTTPException(
detail="Unauthorized",

View File

@@ -39,6 +39,7 @@ class BotDomainConstraint(BotConstraint):
Constraints the bot to send messages only to the provided domains
"""
# List of domains that are allowed to send messages to
domains: list[str]

View File

@@ -2,6 +2,7 @@ import uuid
from fastapi import FastAPI, HTTPException, Request, Response
import sqlalchemy
from sqlmodel import SQLModel, select
from slixmpp.jid import JID
from xmpp_api.config.config import load_config
from xmpp_api.api.bot import (
@@ -16,7 +17,7 @@ from xmpp_api.api.bot import (
BotConstraint,
BotDomainConstraint,
)
from xmpp_api.db import BotDep, SessionDep, UserDep, EngineDep, get_engine
from xmpp_api.db import BotDep, SessionDep, UserDep, get_engine
from xmpp_api.db.bot import AllowedJid, Bot, JIDType
import xmpp_api.db.bot as db_bot
from xmpp_api.util.token import generate_token
@@ -87,6 +88,16 @@ def post_create_bot_jid(
detail="Unknown bot",
)
# Validate the domain constraint
parsed_jid = JID(creation_request.jid)
for constraint in bot.constraints:
if isinstance(constraint, BotDomainConstraint):
if parsed_jid.domain not in constraint.domains:
raise HTTPException(
status_code=400,
detail=f'Domain "{parsed_jid.domain}" is not allowed',
)
# Add the JID
# TODO: Query for the JID type
# TODO: If this is a groupchat, then join it
@@ -145,6 +156,13 @@ def post_bot_message(
status_code=404,
)
# Validate the domain constraint
parsed_jid = JID(jid.jid)
for constraint in bot.constraints:
if isinstance(constraint, BotDomainConstraint):
if parsed_jid.domain not in constraint.domains:
raise HTTPException(status_code=400)
# TODO: Send a message
return Response(status_code=200)