diff --git a/src/xmpp_api/config/config.py b/src/xmpp_api/config/config.py index 09ea351..7743c25 100644 --- a/src/xmpp_api/config/config.py +++ b/src/xmpp_api/config/config.py @@ -27,11 +27,11 @@ class _ComponentConfig(BaseModel): @model_validator(mode="after") def validate_secret(self) -> Self: - if self.secret is None and self.secret_file is None: + if self.secret_plain is None and self.secret_file is None: raise ConfigurationException("No component secret is specified") - if self.secret is not None and self.secret_file is not None: + if self.secret_plain is not None and self.secret_file is not None: log.warn( - "Both secret and secret_file specified! secret_file takes precedence" + "Both component.secret and component.secret_file specified! component.secret_file takes precedence" ) return self @@ -46,9 +46,38 @@ class _ComponentConfig(BaseModel): return cast(str, self.secret_plain) +class _DatabaseConfig(BaseModel): + # The URI to connect to + uri_plain: str | None = Field(default=None, alias="uri") + + # The file to read the database URI from + uri_file: str | None = Field(default=None) + + @model_validator(mode="after") + def validate_secret(self) -> Self: + if self.uri_plain is None and self.uri_file is None: + raise ConfigurationException("No database URI is specified") + if self.uri_plain is not None and self.uri_file is not None: + log.warn( + "Both database.uri and database.uri_file specified! database.uri_file takes precedence" + ) + return self + + @property + def uri(self) -> str: + """ + The database URI. + + The backing may be None if uri_file is used instead. uri_plain is, however, replaced + with that file's contents during configuration loading. + """ + + return cast(str, self.uri_plain) + + class _Config(BaseModel): - # DB URI for sqlmodel - database: str + # Database config + database: _DatabaseConfig # Component configuration component: _ComponentConfig @@ -73,6 +102,12 @@ def load_config() -> _Config: with open(config.component.secret_file, "r", encoding="utf8") as f: config.component.secret_plain = f.read().strip().replace("\n", "") + # Read the database URI from a file, if specified + if config.database.uri_file is not None: + log.info("Reading database URI from %s", config.database.uri_file) + with open(config.database.uri_file, "r", encoding="utf8") as f: + config.database.uri_plain = f.read().strip().replace("\n", "") + return config diff --git a/src/xmpp_api/db/__init__.py b/src/xmpp_api/db/__init__.py index f31dc96..7f14468 100644 --- a/src/xmpp_api/db/__init__.py +++ b/src/xmpp_api/db/__init__.py @@ -11,7 +11,7 @@ from xmpp_api.db.bot import Bot, AllowedJid def get_engine(config: ConfigDep) -> Engine: return create_engine( - config.database, + config.database.uri, connect_args={ "check_same_thread": False, },