diff --git a/.gitignore b/.gitignore index 54813b8..20a81a9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ wheels/ # Testing artifacts config.yaml db.sqlite3 +prosody.cfg.lua +prosody/ diff --git a/src/xmpp_api/api/bot.py b/src/xmpp_api/api/bot.py index f4d3306..2891a47 100644 --- a/src/xmpp_api/api/bot.py +++ b/src/xmpp_api/api/bot.py @@ -1,5 +1,7 @@ from pydantic import BaseModel, Field +from xmpp_api.db.bot import JIDType + class BotInformation(BaseModel): # The bot's ID @@ -8,6 +10,9 @@ class BotInformation(BaseModel): # The bot's name name: str + # The bot's localpart + localpart: str + # The bot's description description: str | None @@ -38,6 +43,9 @@ class CreateBotRequest(BaseModel): # List of constraints constraints: list[BotConstraint] = Field(default_factory=list) + # The localpart of the bot + localpart: str + class BotCreationResponse(BotInformation): # The bot's token @@ -45,8 +53,12 @@ class BotCreationResponse(BotInformation): class AddJidRequest(BaseModel): + # The JID that the message will be sent to jid: str + # The JID type. If set, then we will not discover it automatically + type: JIDType | None = Field(default=None) + class AddJidResponse(BaseModel): token: str diff --git a/src/xmpp_api/config/config.py b/src/xmpp_api/config/config.py index 6cfc0dd..abefefa 100644 --- a/src/xmpp_api/config/config.py +++ b/src/xmpp_api/config/config.py @@ -4,10 +4,24 @@ from pydantic import BaseModel from fastapi import Depends +class _ComponentConfig(BaseModel): + # The JID of the component + jid: str + + # Address of server's component port + server: str + + # The component's secret. + secret: str + + class _Config(BaseModel): # DB URI for sqlmodel database: str + # Component configuration + component: _ComponentConfig + def load_config() -> _Config: """ @@ -16,6 +30,11 @@ def load_config() -> _Config: # TODO: Actually load it return _Config( database="sqlite:///db.sqlite3", + component=_ComponentConfig( + jid="test.localhost", + server="localhost:5869", + secret="abc123", + ), ) diff --git a/src/xmpp_api/db/bot.py b/src/xmpp_api/db/bot.py index 8173ad3..6b094d2 100644 --- a/src/xmpp_api/db/bot.py +++ b/src/xmpp_api/db/bot.py @@ -11,10 +11,10 @@ class JIDType(Enum): """ # JID points to an entity we can directly send messages to - DIRECT = 1 + DIRECT = "DIRECT" # JID points to a MUC. - GC = 2 + GC = "GC" class AllowedJid(SQLModel, table=True): @@ -50,6 +50,9 @@ class Bot(SQLModel, table=True): # The bot's human readable name name: str = Field(unique=True) + # The bot JID's localpart + localpart: str = Field(unique=True) + # The bot's description description: str | None diff --git a/src/xmpp_api/main.py b/src/xmpp_api/main.py index abcb793..5c4c9f2 100644 --- a/src/xmpp_api/main.py +++ b/src/xmpp_api/main.py @@ -23,6 +23,7 @@ import xmpp_api.db.bot as db_bot from xmpp_api.util.token import generate_token from xmpp_api.db import get_bot_by_id, get_jids_by_bot_id, get_jid_by_jid_token from xmpp_api.util.constraints import bot_constraint_to_db, bot_constraint_from_db +from xmpp_api.xmpp.component import XmppApiComponent, XmppApiComponentDep app = FastAPI() @@ -31,13 +32,16 @@ app = FastAPI() @app.on_event("startup") def startup(): # TODO: This is kinda ugly - engine = app.dependency_overrides.get(get_engine, get_engine)( - load_config(), - ) + config = load_config() + engine = app.dependency_overrides.get(get_engine, get_engine)(config) SQLModel.metadata.create_all(engine) + # App startup is done. Connect to the XMPP server + instance = XmppApiComponent.of(config) + instance.run() -@app.post("/api/v1/bot/create") + +@app.post("/api/v1/bot") def post_create_bot( bot_request: CreateBotRequest, user: UserDep, session: SessionDep ) -> BotCreationResponse: @@ -52,6 +56,7 @@ def post_create_bot( bot = Bot( name=bot_request.name, description=bot_request.description, + localpart=bot_request.localpart, token=generate_token(64), constraints=constraints, owner_id=user.id, @@ -71,14 +76,19 @@ def post_create_bot( id=bot.id, name=bot.name, description=bot.description, + localpart=bot.localpart, token=bot.token, constraints=bot_request.constraints, ) @app.post("/api/v1/bot/{bot_id}/jid") -def post_create_bot_jid( - bot_id: str, creation_request: AddJidRequest, user: UserDep, session: SessionDep +async def post_create_bot_jid( + bot_id: str, + creation_request: AddJidRequest, + user: UserDep, + session: SessionDep, + component: XmppApiComponentDep, ) -> AddJidResponse: # Check if the bot exists and we own it bot = get_bot_by_id(bot_id, user.id, session) @@ -98,12 +108,36 @@ def post_create_bot_jid( detail=f'Domain "{parsed_jid.domain}" is not allowed', ) + # Query the JID for its type + allowed_jid_type: JIDType + if creation_request.type is None: + jid_type = await component.get_entity_type(parsed_jid) + if jid_type is None: + raise HTTPException( + status_code=500, + detail=f"Failed to query entity at {creation_request.jid}", + ) + + match jid_type: + case "account": + allowed_jid_type = JIDType.DIRECT + case "groupchat": + allowed_jid_type = JIDType.GC + else: + allowed_jid_type = creation_request.type + + # Deal with groupchat shenanigans + if allowed_jid_type == JIDType.GC: + # TODO: Join the groupchat + raise HTTPException( + status_code=500, + detail="Groupchats are not yet handled", + ) + # Add the JID - # TODO: Query for the JID type - # TODO: If this is a groupchat, then join it jid = AllowedJid( jid=creation_request.jid, - type=JIDType.DIRECT, + type=allowed_jid_type, # This token is only for identification token=uuid.uuid4().hex, bot_id=bot.id, @@ -143,6 +177,7 @@ def post_bot_message( request: Request, bot: BotDep, session: SessionDep, + component: XmppApiComponentDep, ): # Is the bot allowed to access this JID? jid = session.exec( @@ -163,7 +198,17 @@ def post_bot_message( if parsed_jid.domain not in constraint.domains: raise HTTPException(status_code=400) - # TODO: Send a message + match jid.type: + case JIDType.DIRECT: + component.send_direct_message( + body=message.body, + localpart=bot.localpart, + nick=bot.name, + recipient=jid.jid, + ) + case _: + raise HTTPException(status_code=500) + return Response(status_code=200) @@ -187,6 +232,7 @@ def get_bot_information( id=bot.id, name=bot.name, description=bot.description, + localpart=bot.localpart, jids=[ AllowedJidInformation( jid=jid.jid, @@ -206,6 +252,7 @@ def get_bots(user: UserDep, session: SessionDep) -> list[BotInformation]: id=bot.id, name=bot.name, description=bot.description, + localpart=bot.localpart, constraints=[bot_constraint_from_db(c) for c in bot.constraints], ) for bot in bots diff --git a/src/xmpp_api/xmpp/component.py b/src/xmpp_api/xmpp/component.py new file mode 100644 index 0000000..5fd6f0f --- /dev/null +++ b/src/xmpp_api/xmpp/component.py @@ -0,0 +1,85 @@ +from typing import Annotated, Literal + +from fastapi import Depends +from slixmpp.componentxmpp import ComponentXMPP +from slixmpp.jid import JID +from slixmpp.exceptions import IqError, IqTimeout + +from xmpp_api.config.config import ConfigDep + + +class XmppApiComponent(ComponentXMPP): + """ + The XMPP server component that sends the messages + """ + + # The component's bare JID + _jid: str + + # Singleton instance + _instance = None + + def __init__(self, jid: str, secret: str, host: str, port: int): + super().__init__(jid, secret, host, port) + self._jid = jid + + # Register plugins + self.register_plugin("xep_0030") + + # Event handlers + self.add_event_handler("disconnected", self.on_disconnected) + self.add_event_handler("connected", self.on_connected) + + @staticmethod + def of(config: ConfigDep) -> "XmppApiComponent": + if XmppApiComponent._instance is None: + host, port = config.component.server.split(":") + XmppApiComponent._instance = XmppApiComponent( + config.component.jid, + config.component.secret, + host, + int(port), + ) + return XmppApiComponent._instance + + def on_disconnected(self, event): + # Reconnect + self.connect() + + def on_connected(self, event): + # TODO: Join all groupchats that we know of + pass + + def run(self): + # NOTE: We do not have to deal with asyncio here because we get that + # due to fastapi for free! + self.connect() + + async def get_entity_type( + self, jid: JID + ) -> Literal["groupchat"] | Literal["account"] | None: + try: + info = await self.plugin["xep_0030"].get_info(jid=jid) + if "http://jabber.org/protocol/muc" in info["disco_info"]["features"]: + return "groupchat" + return "account" + except IqError: + return None + except IqTimeout: + return None + + def send_direct_message(self, localpart: str, nick: str, recipient: str, body: str): + self.send_message( + mto=JID(recipient), + mfrom=JID(f"{localpart}@{self._jid}"), + mtype="chat", + mbody=body, + mnick=nick, + ) + + +def get_xmpp_component(config: ConfigDep) -> XmppApiComponent: + return XmppApiComponent.of(config) + + +XmppApiComponentDep = Annotated[XmppApiComponent, Depends(get_xmpp_component)]