Add domain constraint validation
This commit is contained in:
@@ -46,7 +46,7 @@ def get_by_token(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
obj = session.exec(select(cls).where(cls.token == token)).first()
|
||||
obj = session.exec(select(cls).where(cls.token == token)).first() # type: ignore
|
||||
if obj is None:
|
||||
raise HTTPException(
|
||||
detail="Unauthorized",
|
||||
|
||||
@@ -39,6 +39,7 @@ class BotDomainConstraint(BotConstraint):
|
||||
Constraints the bot to send messages only to the provided domains
|
||||
"""
|
||||
|
||||
# List of domains that are allowed to send messages to
|
||||
domains: list[str]
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import uuid
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
import sqlalchemy
|
||||
from sqlmodel import SQLModel, select
|
||||
from slixmpp.jid import JID
|
||||
|
||||
from xmpp_api.config.config import load_config
|
||||
from xmpp_api.api.bot import (
|
||||
@@ -16,7 +17,7 @@ from xmpp_api.api.bot import (
|
||||
BotConstraint,
|
||||
BotDomainConstraint,
|
||||
)
|
||||
from xmpp_api.db import BotDep, SessionDep, UserDep, EngineDep, get_engine
|
||||
from xmpp_api.db import BotDep, SessionDep, UserDep, get_engine
|
||||
from xmpp_api.db.bot import AllowedJid, Bot, JIDType
|
||||
import xmpp_api.db.bot as db_bot
|
||||
from xmpp_api.util.token import generate_token
|
||||
@@ -87,6 +88,16 @@ def post_create_bot_jid(
|
||||
detail="Unknown bot",
|
||||
)
|
||||
|
||||
# Validate the domain constraint
|
||||
parsed_jid = JID(creation_request.jid)
|
||||
for constraint in bot.constraints:
|
||||
if isinstance(constraint, BotDomainConstraint):
|
||||
if parsed_jid.domain not in constraint.domains:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Domain "{parsed_jid.domain}" is not allowed',
|
||||
)
|
||||
|
||||
# Add the JID
|
||||
# TODO: Query for the JID type
|
||||
# TODO: If this is a groupchat, then join it
|
||||
@@ -145,6 +156,13 @@ def post_bot_message(
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
# Validate the domain constraint
|
||||
parsed_jid = JID(jid.jid)
|
||||
for constraint in bot.constraints:
|
||||
if isinstance(constraint, BotDomainConstraint):
|
||||
if parsed_jid.domain not in constraint.domains:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
# TODO: Send a message
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user