Allow reading the DB URI from a file

This commit is contained in:
PapaTutuWawa 2025-04-21 16:24:43 +02:00
parent 899d90424e
commit c3581e330e
2 changed files with 41 additions and 6 deletions

View File

@ -27,11 +27,11 @@ class _ComponentConfig(BaseModel):
@model_validator(mode="after")
def validate_secret(self) -> Self:
if self.secret is None and self.secret_file is None:
if self.secret_plain is None and self.secret_file is None:
raise ConfigurationException("No component secret is specified")
if self.secret is not None and self.secret_file is not None:
if self.secret_plain is not None and self.secret_file is not None:
log.warn(
"Both secret and secret_file specified! secret_file takes precedence"
"Both component.secret and component.secret_file specified! component.secret_file takes precedence"
)
return self
@ -46,9 +46,38 @@ class _ComponentConfig(BaseModel):
return cast(str, self.secret_plain)
class _DatabaseConfig(BaseModel):
# The URI to connect to
uri_plain: str | None = Field(default=None, alias="uri")
# The file to read the database URI from
uri_file: str | None = Field(default=None)
@model_validator(mode="after")
def validate_secret(self) -> Self:
if self.uri_plain is None and self.uri_file is None:
raise ConfigurationException("No database URI is specified")
if self.uri_plain is not None and self.uri_file is not None:
log.warn(
"Both database.uri and database.uri_file specified! database.uri_file takes precedence"
)
return self
@property
def uri(self) -> str:
"""
The database URI.
The backing may be None if uri_file is used instead. uri_plain is, however, replaced
with that file's contents during configuration loading.
"""
return cast(str, self.uri_plain)
class _Config(BaseModel):
# DB URI for sqlmodel
database: str
# Database config
database: _DatabaseConfig
# Component configuration
component: _ComponentConfig
@ -73,6 +102,12 @@ def load_config() -> _Config:
with open(config.component.secret_file, "r", encoding="utf8") as f:
config.component.secret_plain = f.read().strip().replace("\n", "")
# Read the database URI from a file, if specified
if config.database.uri_file is not None:
log.info("Reading database URI from %s", config.database.uri_file)
with open(config.database.uri_file, "r", encoding="utf8") as f:
config.database.uri_plain = f.read().strip().replace("\n", "")
return config

View File

@ -11,7 +11,7 @@ from xmpp_api.db.bot import Bot, AllowedJid
def get_engine(config: ConfigDep) -> Engine:
return create_engine(
config.database,
config.database.uri,
connect_args={
"check_same_thread": False,
},