This commit is contained in:
PapaTutuWawa 2025-03-30 23:03:10 +02:00
parent 97f3a12617
commit 9d4867d74e
30 changed files with 240 additions and 86 deletions

View File

@ -21,3 +21,8 @@ openec2 = "openec2:main"
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[dependency-groups]
dev = [
"ruff>=0.11.2",
]

View File

@ -7,6 +7,7 @@ from openec2.db import DatabaseDep
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.images import garbage_collect_image from openec2.images import garbage_collect_image
def deregister_image( def deregister_image(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,

View File

@ -30,7 +30,7 @@ def describe_images(
requestId=uuid.uuid4().hex, requestId=uuid.uuid4().hex,
imagesSet=ImagesSet( imagesSet=ImagesSet(
items=images, items=images,
) ),
).to_xml(), ).to_xml(),
media_type="application/xml", media_type="application/xml",
) )

View File

@ -5,12 +5,20 @@ from fastapi.datastructures import QueryParams
from sqlmodel import select from sqlmodel import select
from openec2.libvirt import LibvirtSingleton from openec2.libvirt import LibvirtSingleton
from openec2.api.describe_instances import InstanceDescription, DescribeInstancesResponse, DescribeInstancesResponseReservationSet, ReservationSet, ReservationSetInstancesSet, describe_instance from openec2.api.describe_instances import (
InstanceDescription,
DescribeInstancesResponse,
DescribeInstancesResponseReservationSet,
ReservationSet,
ReservationSetInstancesSet,
describe_instance,
)
from openec2.api.shared import InstanceState from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
def describe_instances( def describe_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
@ -36,7 +44,8 @@ def describe_instances(
instancesSet=ReservationSetInstancesSet( instancesSet=ReservationSetInstancesSet(
item=[instance], item=[instance],
), ),
) for instance in response_items )
for instance in response_items
], ],
), ),
).to_xml(), ).to_xml(),

View File

@ -11,6 +11,7 @@ from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.image import AMI from openec2.db.image import AMI
def import_image( def import_image(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,

View File

@ -15,9 +15,14 @@ from openec2.db.image import AMI
from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet
from openec2.api.describe_instances import describe_instance from openec2.api.describe_instances import describe_instance
from openec2.utils.array import parse_array_objects from openec2.utils.array import parse_array_objects
from openec2.ipam import get_available_ipv4, is_ipv4_available, add_instance_dhcp_mapping from openec2.ipam import (
get_available_ipv4,
is_ipv4_available,
add_instance_dhcp_mapping,
)
from openec2.utils.ip import generate_available_mac from openec2.utils.ip import generate_available_mac
def create_libvirt_domain( def create_libvirt_domain(
name: str, name: str,
memory: int, memory: int,
@ -84,6 +89,7 @@ def create_libvirt_domain(
</domain> </domain>
""" """
def run_instances( def run_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
@ -94,7 +100,7 @@ def run_instances(
instance_type = config.instances.types.get(params["InstanceType"]) instance_type = config.instances.types.get(params["InstanceType"])
if instance_type is None: if instance_type is None:
raise Exception(f"Unknown instance type {params["InstanceType"]}") raise Exception(f"Unknown instance type {params['InstanceType']}")
ami = db.exec(select(AMI).where(AMI.id == image_id)).first() ami = db.exec(select(AMI).where(AMI.id == image_id)).first()
if ami is None: if ami is None:
@ -140,7 +146,9 @@ def run_instances(
id=instance_id, id=instance_id,
imageId=image_id, imageId=image_id,
tags=tags, tags=tags,
userData=base64.b64decode(value).decode() if (value := params.get("UserData")) is not None else None, userData=base64.b64decode(value).decode()
if (value := params.get("UserData")) is not None
else None,
privateIPv4=private_ipv4, privateIPv4=private_ipv4,
interfaceMac=mac, interfaceMac=mac,
) )

View File

@ -18,6 +18,7 @@ from openec2.api.terminate_instances import TerminateInstancesResponse, Instance
logger = logging.getLogger() logger = logging.getLogger()
def terminate_instances( def terminate_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
@ -31,7 +32,7 @@ def terminate_instances(
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
if instance is None: if instance is None:
continue continue
#raise HTTPException(status_code=404, detail="Unknown instance") # raise HTTPException(status_code=404, detail="Unknown instance")
dom = conn.lookupByName(instance_id) dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom) prev_state = describe_instance_state(dom)
@ -54,7 +55,9 @@ def terminate_instances(
instance_disk.unlink() instance_disk.unlink()
image_ids.add(instance.imageId) image_ids.add(instance.imageId)
remove_instance_dhcp_mapping(instance.id, instance.interfaceMac, instance.privateIPv4, db) remove_instance_dhcp_mapping(
instance.id, instance.interfaceMac, instance.privateIPv4, db
)
db.delete(instance) db.delete(instance)
db.commit() db.commit()

View File

@ -1,5 +1,6 @@
from pydantic_xml import BaseXmlModel, element from pydantic_xml import BaseXmlModel, element
class Image(BaseXmlModel): class Image(BaseXmlModel):
imageId: str = element() imageId: str = element()
@ -7,13 +8,15 @@ class Image(BaseXmlModel):
name: str = element() name: str = element()
class ImagesSet(BaseXmlModel, tag="imagesSet"): class ImagesSet(BaseXmlModel, tag="imagesSet"):
items: list[Image] = element(tag="item") items: list[Image] = element(tag="item")
class DescribeImagesResponse( class DescribeImagesResponse(
BaseXmlModel, BaseXmlModel,
tag="DescribeImagesResponse", tag="DescribeImagesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"} nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
): ):
requestId: str = element() requestId: str = element()

View File

@ -9,9 +9,11 @@ class InstanceTag(BaseXmlModel):
key: str = element() key: str = element()
value: str = element() value: str = element()
class InstanceTagSet(BaseXmlModel): class InstanceTagSet(BaseXmlModel):
item: list[InstanceTag] = element() item: list[InstanceTag] = element()
class InstanceDescription( class InstanceDescription(
BaseXmlModel, BaseXmlModel,
tag="item", tag="item",
@ -25,20 +27,23 @@ class InstanceDescription(
class ReservationSetInstancesSet(BaseXmlModel): class ReservationSetInstancesSet(BaseXmlModel):
item: list[InstanceDescription] = element() item: list[InstanceDescription] = element()
class ReservationSet(BaseXmlModel): class ReservationSet(BaseXmlModel):
reservationId: str = element() reservationId: str = element()
instancesSet: ReservationSetInstancesSet = element() instancesSet: ReservationSetInstancesSet = element()
class DescribeInstancesResponseReservationSet( class DescribeInstancesResponseReservationSet(
BaseXmlModel, BaseXmlModel,
tag="reservationSet", tag="reservationSet",
): ):
item: list[ReservationSet] = element("") item: list[ReservationSet] = element("")
class DescribeInstancesResponse( class DescribeInstancesResponse(
BaseXmlModel, BaseXmlModel,
tag="DescribeInstancesResponse", tag="DescribeInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"} nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
): ):
request_id: str = element(tag="requestId") request_id: str = element(tag="requestId")
reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet") reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet")
@ -51,7 +56,10 @@ def describe_instance_state(domain: libvirt.virDomain) -> InstanceState:
name="running" if running else "stopped", name="running" if running else "stopped",
) )
def describe_instance(instance: Instance, domain: libvirt.virDomain) -> InstanceDescription:
def describe_instance(
instance: Instance, domain: libvirt.virDomain
) -> InstanceDescription:
return InstanceDescription( return InstanceDescription(
instance_id=instance.id, instance_id=instance.id,
image_id=instance.imageId, image_id=instance.imageId,
@ -61,7 +69,8 @@ def describe_instance(instance: Instance, domain: libvirt.virDomain) -> Instance
InstanceTag( InstanceTag(
key=key, key=key,
value=value, value=value,
) for key, value in instance.tags.items() )
for key, value in instance.tags.items()
], ],
), ),
) )

View File

@ -1,12 +1,12 @@
from pydantic_xml import BaseXmlModel, element from pydantic_xml import BaseXmlModel, element
from openec2.db.instance import Instance
from openec2.api.shared import InstanceState
from openec2.api.describe_instances import InstanceDescription from openec2.api.describe_instances import InstanceDescription
class RunInstanceInstanceSet(BaseXmlModel): class RunInstanceInstanceSet(BaseXmlModel):
item: list[InstanceDescription] = element() item: list[InstanceDescription] = element()
class RunInstanceResponse(BaseXmlModel): class RunInstanceResponse(BaseXmlModel):
request_id: str = element(tag="requestId") request_id: str = element(tag="requestId")

View File

@ -6,7 +6,7 @@ from openec2.api.shared import InstancesSet
class StartInstancesResponse( class StartInstancesResponse(
BaseXmlModel, BaseXmlModel,
tag="StartInstancesResponse", tag="StartInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"} nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
): ):
request_id: str = element(tag="requestId") request_id: str = element(tag="requestId")
instancesSet: InstancesSet = element() instancesSet: InstancesSet = element()

View File

@ -6,7 +6,7 @@ from openec2.api.shared import InstancesSet
class StopInstancesResponse( class StopInstancesResponse(
BaseXmlModel, BaseXmlModel,
tag="StopInstancesResponse", tag="StopInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"} nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
): ):
requestId: str = element(tag="requestId") requestId: str = element(tag="requestId")
instancesSet: InstancesSet = element() instancesSet: InstancesSet = element()

View File

@ -41,5 +41,6 @@ def main():
print(f"Access key: {access_key}") print(f"Access key: {access_key}")
print(f"Secret access key: {secret_access_key}") print(f"Secret access key: {secret_access_key}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,19 +3,23 @@ from fastapi import Depends
from pydantic import BaseModel from pydantic import BaseModel
from typing import Annotated from typing import Annotated
class _OpenEC2InstanceType(BaseModel): class _OpenEC2InstanceType(BaseModel):
memory: int # MiB memory: int # MiB
vcpu: float vcpu: float
disk: int # Gi disk: int # Gi
class _OpenEC2InstanceConfig(BaseModel): class _OpenEC2InstanceConfig(BaseModel):
location: Path location: Path
types: dict[str, _OpenEC2InstanceType] types: dict[str, _OpenEC2InstanceType]
class _OpenEC2LibvirtConfig(BaseModel): class _OpenEC2LibvirtConfig(BaseModel):
connection: str connection: str
class _OpenEC2DatabaseConfig(BaseModel): class _OpenEC2DatabaseConfig(BaseModel):
# DB URL for sqlmodel # DB URL for sqlmodel
url: str url: str
@ -23,6 +27,7 @@ class _OpenEC2DatabaseConfig(BaseModel):
# Print SQL statements # Print SQL statements
debug: bool debug: bool
class _OpenEC2Config(BaseModel): class _OpenEC2Config(BaseModel):
images: Path images: Path
seed: Path seed: Path
@ -32,6 +37,7 @@ class _OpenEC2Config(BaseModel):
debug: bool debug: bool
insecure: bool insecure: bool
def _get_config() -> _OpenEC2Config: def _get_config() -> _OpenEC2Config:
# TODO: Read from disk # TODO: Read from disk
return _OpenEC2Config( return _OpenEC2Config(
@ -47,9 +53,7 @@ def _get_config() -> _OpenEC2Config:
), ),
}, },
), ),
libvirt=_OpenEC2LibvirtConfig( libvirt=_OpenEC2LibvirtConfig(connection="qemu:///system"),
connection="qemu:///system"
),
debug=True, debug=True,
insecure=False, insecure=False,
database=_OpenEC2DatabaseConfig( database=_OpenEC2DatabaseConfig(
@ -58,6 +62,7 @@ def _get_config() -> _OpenEC2Config:
), ),
) )
class ConfigSingleton: class ConfigSingleton:
__instance: "ConfigSingleton | None" = None __instance: "ConfigSingleton | None" = None
@ -75,4 +80,5 @@ class ConfigSingleton:
ConfigSingleton.__instance = ConfigSingleton() ConfigSingleton.__instance = ConfigSingleton()
return ConfigSingleton.__instance return ConfigSingleton.__instance
OpenEC2Config = Annotated[_OpenEC2Config, Depends(ConfigSingleton.of().get_config)] OpenEC2Config = Annotated[_OpenEC2Config, Depends(ConfigSingleton.of().get_config)]

View File

@ -12,8 +12,10 @@ engine = create_engine(
echo=ConfigSingleton.of().config.database.debug, echo=ConfigSingleton.of().config.database.debug,
) )
def get_session() -> Generator[Session]: def get_session() -> Generator[Session]:
with Session(engine) as session: with Session(engine) as session:
yield session yield session
DatabaseDep = Annotated[Session, Depends(get_session)] DatabaseDep = Annotated[Session, Depends(get_session)]

View File

@ -1,5 +1,6 @@
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
class AMI(SQLModel, table=True): class AMI(SQLModel, table=True):
# ID of the AMI # ID of the AMI
id: str = Field(default=None, primary_key=True) id: str = Field(default=None, primary_key=True)

View File

@ -1,5 +1,6 @@
from sqlmodel import SQLModel, Field, JSON, Column from sqlmodel import SQLModel, Field, JSON, Column
class Instance(SQLModel, table=True): class Instance(SQLModel, table=True):
id: str = Field(default=None, primary_key=True) id: str = Field(default=None, primary_key=True)

View File

@ -19,6 +19,4 @@ class IPAMEntry(SQLModel, table=True):
def set_ipv4(self, addr: str): def set_ipv4(self, addr: str):
self.ipv4_addr_raw = ipv4_to_int(addr) self.ipv4_addr_raw = ipv4_to_int(addr)
__table_args = ( __table_args = (PrimaryKeyConstraint("ipv4_addr_raw", "vpc_id"),)
PrimaryKeyConstraint("ipv4_addr_raw", "vpc_id"),
)

View File

@ -1,5 +1,6 @@
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
class User(SQLModel, table=True): class User(SQLModel, table=True):
id: int = Field(default=None, primary_key=True) id: int = Field(default=None, primary_key=True)

View File

@ -1,5 +1,6 @@
from sqlmodel import SQLModel, Field, PrimaryKeyConstraint from sqlmodel import SQLModel, Field, PrimaryKeyConstraint
class VPC(SQLModel, table=True): class VPC(SQLModel, table=True):
# ID of the VPC # ID of the VPC
id: str = Field(default=None, primary_key=True) id: str = Field(default=None, primary_key=True)

View File

@ -13,7 +13,9 @@ def garbage_collect_image(image_id: str, db: DatabaseDep):
print(instances) print(instances)
return return
ami = db.exec(select(AMI).where(AMI.id == image_id, AMI.deregistered == True)).first() ami = db.exec(
select(AMI).where(AMI.id == image_id, AMI.deregistered == True)
).first()
if ami is not None: if ami is not None:
db.delete(ami) db.delete(ami)
db.commit() db.commit()

View File

@ -34,9 +34,16 @@ def add_instance_dhcp_mapping(instance_id: str, mac: str, ipv4: str, db: Databas
flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE, flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE,
) )
def remove_instance_dhcp_mapping(instance_id: str, mac: str ,ipv4: str, db: DatabaseDep):
def remove_instance_dhcp_mapping(
instance_id: str, mac: str, ipv4: str, db: DatabaseDep
):
i = ipv4_to_int(ipv4) i = ipv4_to_int(ipv4)
entry = db.exec(select(IPAMEntry).where(IPAMEntry.ipv4_addr_raw == i, IPAMEntry.instance_id == instance_id)).first() entry = db.exec(
select(IPAMEntry).where(
IPAMEntry.ipv4_addr_raw == i, IPAMEntry.instance_id == instance_id
)
).first()
db.delete(entry) db.delete(entry)
# Tell libvirt about this mapping # Tell libvirt about this mapping
@ -49,13 +56,21 @@ def remove_instance_dhcp_mapping(instance_id: str, mac: str ,ipv4: str, db: Data
flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE, flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE,
) )
def is_ipv4_available(ipv4: str, db: DatabaseDep) -> bool: def is_ipv4_available(ipv4: str, db: DatabaseDep) -> bool:
i = ipv4_to_int(ipv4) i = ipv4_to_int(ipv4)
return db.exec(select(IPAMEntry).where(IPAMEntry.ipv4_addr_raw == i)).first() is None return (
db.exec(select(IPAMEntry).where(IPAMEntry.ipv4_addr_raw == i)).first() is None
)
def get_available_ipv4(db: DatabaseDep) -> str: def get_available_ipv4(db: DatabaseDep) -> str:
entries = db.exec(select(IPAMEntry)).all() entries = db.exec(select(IPAMEntry)).all()
# TODO: Use the VPC's subnet # TODO: Use the VPC's subnet
max_ip = max(e.ipv4_addr_raw for e in entries) if entries else ipv4_to_int("192.168.122.2") max_ip = (
max(e.ipv4_addr_raw for e in entries)
if entries
else ipv4_to_int("192.168.122.2")
)
# TODO: Check if we're still inside the subnet # TODO: Check if we're still inside the subnet
return int_to_ipv4(max_ip + 1) return int_to_ipv4(max_ip + 1)

View File

@ -21,27 +21,35 @@ from openec2.db.instance import Instance
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
def on_startup(): def on_startup():
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
@app.get("/healthz", response_model=None) @app.get("/healthz", response_model=None)
def healthz(): def healthz():
return { return {
"status": "OK", "status": "OK",
} }
@app.get("/Action", response_model=None) @app.get("/Action", response_model=None)
def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature): def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature):
return run_action(request, config, db, cast(dict, request.query_params)) return run_action(request, config, db, cast(dict, request.query_params))
@app.post("/Action", response_model=None) @app.post("/Action", response_model=None)
async def test(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature): async def test(
request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature
):
query_params = { query_params = {
key: value[0] for key, value in parse_qs((await request.body()).decode()).items() key: value[0]
for key, value in parse_qs((await request.body()).decode()).items()
} }
return run_action(request, config, db, cast(dict, query_params)) return run_action(request, config, db, cast(dict, query_params))
def run_action( def run_action(
request: Request, request: Request,
config: OpenEC2Config, config: OpenEC2Config,
@ -68,7 +76,9 @@ def cloud_init_data(instance_id: str, entry: str, db: DatabaseDep):
raise HTTPException(status_code=404, detail="Unknown cloud-init file") raise HTTPException(status_code=404, detail="Unknown cloud-init file")
if entry == "user-data": if entry == "user-data":
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()[0] instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()[
0
]
if instance is None: if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance") raise HTTPException(status_code=404, detail="Unknown instance")
@ -79,17 +89,21 @@ def cloud_init_data(instance_id: str, entry: str, db: DatabaseDep):
media_type="application/yaml", media_type="application/yaml",
) )
elif entry == "meta-data": elif entry == "meta-data":
return multiline_yaml_response([ return multiline_yaml_response(
[
f"instance-id: {instance_id}", f"instance-id: {instance_id}",
f"local-hostname: {instance_id}", f"local-hostname: {instance_id}",
]) ]
)
elif entry == "vendor-data": elif entry == "vendor-data":
return multiline_yaml_response([ return multiline_yaml_response(
[
"#cloud-config", "#cloud-config",
"growpart:", "growpart:",
" devices: [/]", " devices: [/]",
" ignore_growroot_disabled: true", " ignore_growroot_disabled: true",
]) ]
)
elif entry == "network-config": elif entry == "network-config":
return Response( return Response(
"", "",

View File

@ -19,6 +19,7 @@ def _hmac_sha256(key: bytes, payload: bytes) -> bytes:
h.update(payload) h.update(payload)
return h.finalize() return h.finalize()
@dataclass @dataclass
class AWSRequest: class AWSRequest:
# The entire used URL # The entire used URL
@ -36,43 +37,57 @@ class AWSRequest:
# The payload, if we used a POST/PUT # The payload, if we used a POST/PUT
payload: str | None payload: str | None
def sign(self, secret_access_key: str, region: str, product: str, credential_scope: str) -> str: def sign(
self, secret_access_key: str, region: str, product: str, credential_scope: str
) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"]) dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_query_string_keys = sorted(self.params.keys()) canonical_query_string_keys = sorted(self.params.keys())
canonical_query_string = "&".join([ canonical_query_string = "&".join(
f"{key}={quote(self.params[key][0])}" for key in canonical_query_string_keys if key not in ( [
"X-Amz-Signature", f"{key}={quote(self.params[key][0])}"
for key in canonical_query_string_keys
if key not in ("X-Amz-Signature",)
]
)
canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)]
)
canonical_header_string = (
"\n".join(
[
f"{name.lower()}:{self.headers[name].strip()}"
for name in canonical_header_string_keys
]
)
+ "\n"
) )
])
canonical_header_string_keys = sorted([
name for name in self.headers.keys() if include_in_canonical_string(name)
])
canonical_header_string = "\n".join([
f"{name.lower()}:{self.headers[name].strip()}" for name in canonical_header_string_keys
]) + "\n"
signed_headers = ";".join(canonical_header_string_keys) signed_headers = ";".join(canonical_header_string_keys)
hashed_payload = sha256((self.payload or "").encode()).hexdigest() hashed_payload = sha256((self.payload or "").encode()).hexdigest()
canonical_request = "\n".join([ canonical_request = "\n".join(
[
self.method.upper(), self.method.upper(),
self.url.path or "/", self.url.path or "/",
#canonical_query_string, # canonical_query_string,
"", "",
canonical_header_string, canonical_header_string,
signed_headers, signed_headers,
hashed_payload, hashed_payload,
]) ]
)
print("Canonical request") print("Canonical request")
print(canonical_request) print(canonical_request)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest() hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
date = dt.strftime("%Y%m%d") date = dt.strftime("%Y%m%d")
string_to_sign = "\n".join([ string_to_sign = "\n".join(
[
"AWS4-HMAC-SHA256", "AWS4-HMAC-SHA256",
dt.strftime("%Y%m%dT%H%M%SZ"), dt.strftime("%Y%m%dT%H%M%SZ"),
credential_scope, credential_scope,
hashed_canonical_request, hashed_canonical_request,
]) ]
)
print("String to sign") print("String to sign")
print(string_to_sign) print(string_to_sign)
@ -83,6 +98,7 @@ class AWSRequest:
signing_key = _hmac_sha256(date_region_service_key, "aws4_request".encode()) signing_key = _hmac_sha256(date_region_service_key, "aws4_request".encode())
return _hmac_sha256(signing_key, string_to_sign.encode()).hex() return _hmac_sha256(signing_key, string_to_sign.encode()).hex()
@dataclass @dataclass
class AWSAuthentication: class AWSAuthentication:
x_amz_algorithm: str x_amz_algorithm: str
@ -91,10 +107,12 @@ class AWSAuthentication:
x_amz_signature: str x_amz_signature: str
def include_in_canonical_string(name: str) -> bool: def include_in_canonical_string(name: str) -> bool:
lower = name.lower() lower = name.lower()
return lower in ("host", "content-type") or lower.startswith("x-amz") return lower in ("host", "content-type") or lower.startswith("x-amz")
def get_authentication_info(request: Request) -> AWSAuthentication: def get_authentication_info(request: Request) -> AWSAuthentication:
if request.method == "POST": if request.method == "POST":
algorithm, rest = request.headers["Authorization"].split(" ", 1) algorithm, rest = request.headers["Authorization"].split(" ", 1)
@ -110,21 +128,30 @@ def get_authentication_info(request: Request) -> AWSAuthentication:
) )
return AWSAuthentication( return AWSAuthentication(
"", "", "", "",
"",
"",
) )
async def check_request_signature(request: Request, db: DatabaseDep): async def check_request_signature(request: Request, db: DatabaseDep):
# Do not check if we don't care # Do not check if we don't care
if ConfigSingleton.of().config.insecure: if ConfigSingleton.of().config.insecure:
return return
body = (await request.body()).decode() body = (await request.body()).decode()
query_params = cast(dict, parse_qs(body)) if request.method == "POST" else cast(dict, request.query_params) query_params = (
cast(dict, parse_qs(body))
if request.method == "POST"
else cast(dict, request.query_params)
)
auth_info = get_authentication_info(request) auth_info = get_authentication_info(request)
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256": if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException(status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}") raise HTTPException(
status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}"
)
x_amz_credential = auth_info.x_amz_credential x_amz_credential = auth_info.x_amz_credential
@ -145,13 +172,12 @@ async def check_request_signature(request: Request, db: DatabaseDep):
user.secret_access_key, user.secret_access_key,
region, region,
service, service,
"/".join([ "/".join([date, region, service, key]),
date, region, service, key
]),
) )
print(x_amz_signature, signature) print(x_amz_signature, signature)
if x_amz_signature != signature: if x_amz_signature != signature:
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
AWSSignature = Annotated[None, Depends(check_request_signature)] AWSSignature = Annotated[None, Depends(check_request_signature)]

View File

@ -10,6 +10,7 @@ def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]:
items[parts[1]][".".join(parts[2:])] = value items[parts[1]][".".join(parts[2:])] = value
return list(items.values()) return list(items.values())
def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]: def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]:
items: dict[str, str] = {} items: dict[str, str] = {}
for key, value in params.items(): for key, value in params.items():

View File

@ -9,17 +9,19 @@ from openec2.db.instance import Instance
def ipv4_to_int(ip: str) -> int: def ipv4_to_int(ip: str) -> int:
i = 0 i = 0
for idx, p in enumerate(ip.split(".")): for idx, p in enumerate(ip.split(".")):
i += (int(p) << (3-idx)*8) i += int(p) << (3 - idx) * 8
return i return i
def int_to_ipv4(ip: int) -> str: def int_to_ipv4(ip: int) -> str:
parts: list[int] = [] parts: list[int] = []
for i in reversed(range(4)): for i in reversed(range(4)):
parts.append( parts.append(
(ip >> i*8) & 255, (ip >> i * 8) & 255,
) )
return ".".join(str(p) for p in parts) return ".".join(str(p) for p in parts)
def generate_mac() -> str: def generate_mac() -> str:
mac_bytes = random.randbytes(6) mac_bytes = random.randbytes(6)
mac = "" mac = ""
@ -35,10 +37,14 @@ def generate_mac() -> str:
mac += f"{h}:" mac += f"{h}:"
return mac[:-1] return mac[:-1]
def generate_available_mac(db: DatabaseDep) -> str: def generate_available_mac(db: DatabaseDep) -> str:
mac = "" mac = ""
while True: while True:
mac = generate_mac() mac = generate_mac()
if db.exec(select(Instance).where(Instance.interfaceMac == mac)).first() is None: if (
db.exec(select(Instance).where(Instance.interfaceMac == mac)).first()
is None
):
break break
return mac return mac

View File

@ -3,12 +3,17 @@ import subprocess
def create_cow_copy(src: Path, dst: Path, size: str): def create_cow_copy(src: Path, dst: Path, size: str):
subprocess.call([ subprocess.call(
[
"qemu-img", "qemu-img",
"create", "create",
"-f", "qcow2", "-f",
"-b", str(src), "qcow2",
"-F", "qcow2", "-b",
str(src),
"-F",
"qcow2",
str(dst), str(dst),
size, size,
]) ]
)

View File

@ -1,5 +1,6 @@
from fastapi import Response from fastapi import Response
def multiline_yaml_response(lines: list[str]) -> Response: def multiline_yaml_response(lines: list[str]) -> Response:
return Response( return Response(
"\n".join(lines), "\n".join(lines),

View File

@ -14,6 +14,7 @@ def test_array_parsing_keys():
assert parsed[1]["a"] == "3" assert parsed[1]["a"] == "3"
assert parsed[1]["b"] == "4" assert parsed[1]["b"] == "4"
def test_array_plain_parsing(): def test_array_plain_parsing():
params = { params = {
"Key.1": "1", "Key.1": "1",

33
uv.lock
View File

@ -373,6 +373,11 @@ dependencies = [
{ name = "sqlmodel" }, { name = "sqlmodel" },
] ]
[package.dev-dependencies]
dev = [
{ name = "ruff" },
]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "cryptography", specifier = ">=44.0.2" }, { name = "cryptography", specifier = ">=44.0.2" },
@ -384,6 +389,9 @@ requires-dist = [
{ name = "sqlmodel", specifier = ">=0.0.24" }, { name = "sqlmodel", specifier = ">=0.0.24" },
] ]
[package.metadata.requires-dev]
dev = [{ name = "ruff", specifier = ">=0.11.2" }]
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "24.2" version = "24.2"
@ -568,6 +576,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/70/a2/dc0ae0b61d5fce9eec3763c98d5a471f7b07c891a2cbfb3fd6a0f632a9a1/rich_toolkit-0.14.0-py3-none-any.whl", hash = "sha256:75ff4b3e70e27e9cb145164bfe8d8e56758162fa3f87594067f4d85630b98bf9", size = 24062 }, { url = "https://files.pythonhosted.org/packages/70/a2/dc0ae0b61d5fce9eec3763c98d5a471f7b07c891a2cbfb3fd6a0f632a9a1/rich_toolkit-0.14.0-py3-none-any.whl", hash = "sha256:75ff4b3e70e27e9cb145164bfe8d8e56758162fa3f87594067f4d85630b98bf9", size = 24062 },
] ]
[[package]]
name = "ruff"
version = "0.11.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/90/61/fb87430f040e4e577e784e325351186976516faef17d6fcd921fe28edfd7/ruff-0.11.2.tar.gz", hash = "sha256:ec47591497d5a1050175bdf4e1a4e6272cddff7da88a2ad595e1e326041d8d94", size = 3857511 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/62/99/102578506f0f5fa29fd7e0df0a273864f79af044757aef73d1cae0afe6ad/ruff-0.11.2-py3-none-linux_armv6l.whl", hash = "sha256:c69e20ea49e973f3afec2c06376eb56045709f0212615c1adb0eda35e8a4e477", size = 10113146 },
{ url = "https://files.pythonhosted.org/packages/74/ad/5cd4ba58ab602a579997a8494b96f10f316e874d7c435bcc1a92e6da1b12/ruff-0.11.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2c5424cc1c4eb1d8ecabe6d4f1b70470b4f24a0c0171356290b1953ad8f0e272", size = 10867092 },
{ url = "https://files.pythonhosted.org/packages/fc/3e/d3f13619e1d152c7b600a38c1a035e833e794c6625c9a6cea6f63dbf3af4/ruff-0.11.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ecf20854cc73f42171eedb66f006a43d0a21bfb98a2523a809931cda569552d9", size = 10224082 },
{ url = "https://files.pythonhosted.org/packages/90/06/f77b3d790d24a93f38e3806216f263974909888fd1e826717c3ec956bbcd/ruff-0.11.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c543bf65d5d27240321604cee0633a70c6c25c9a2f2492efa9f6d4b8e4199bb", size = 10394818 },
{ url = "https://files.pythonhosted.org/packages/99/7f/78aa431d3ddebfc2418cd95b786642557ba8b3cb578c075239da9ce97ff9/ruff-0.11.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20967168cc21195db5830b9224be0e964cc9c8ecf3b5a9e3ce19876e8d3a96e3", size = 9952251 },
{ url = "https://files.pythonhosted.org/packages/30/3e/f11186d1ddfaca438c3bbff73c6a2fdb5b60e6450cc466129c694b0ab7a2/ruff-0.11.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:955a9ce63483999d9f0b8f0b4a3ad669e53484232853054cc8b9d51ab4c5de74", size = 11563566 },
{ url = "https://files.pythonhosted.org/packages/22/6c/6ca91befbc0a6539ee133d9a9ce60b1a354db12c3c5d11cfdbf77140f851/ruff-0.11.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:86b3a27c38b8fce73bcd262b0de32e9a6801b76d52cdb3ae4c914515f0cef608", size = 12208721 },
{ url = "https://files.pythonhosted.org/packages/19/b0/24516a3b850d55b17c03fc399b681c6a549d06ce665915721dc5d6458a5c/ruff-0.11.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3b66a03b248c9fcd9d64d445bafdf1589326bee6fc5c8e92d7562e58883e30f", size = 11662274 },
{ url = "https://files.pythonhosted.org/packages/d7/65/76be06d28ecb7c6070280cef2bcb20c98fbf99ff60b1c57d2fb9b8771348/ruff-0.11.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0397c2672db015be5aa3d4dac54c69aa012429097ff219392c018e21f5085147", size = 13792284 },
{ url = "https://files.pythonhosted.org/packages/ce/d2/4ceed7147e05852876f3b5f3fdc23f878ce2b7e0b90dd6e698bda3d20787/ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:869bcf3f9abf6457fbe39b5a37333aa4eecc52a3b99c98827ccc371a8e5b6f1b", size = 11327861 },
{ url = "https://files.pythonhosted.org/packages/c4/78/4935ecba13706fd60ebe0e3dc50371f2bdc3d9bc80e68adc32ff93914534/ruff-0.11.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2a2b50ca35457ba785cd8c93ebbe529467594087b527a08d487cf0ee7b3087e9", size = 10276560 },
{ url = "https://files.pythonhosted.org/packages/81/7f/1b2435c3f5245d410bb5dc80f13ec796454c21fbda12b77d7588d5cf4e29/ruff-0.11.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7c69c74bf53ddcfbc22e6eb2f31211df7f65054bfc1f72288fc71e5f82db3eab", size = 9945091 },
{ url = "https://files.pythonhosted.org/packages/39/c4/692284c07e6bf2b31d82bb8c32f8840f9d0627d92983edaac991a2b66c0a/ruff-0.11.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6e8fb75e14560f7cf53b15bbc55baf5ecbe373dd5f3aab96ff7aa7777edd7630", size = 10977133 },
{ url = "https://files.pythonhosted.org/packages/94/cf/8ab81cb7dd7a3b0a3960c2769825038f3adcd75faf46dd6376086df8b128/ruff-0.11.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:842a472d7b4d6f5924e9297aa38149e5dcb1e628773b70e6387ae2c97a63c58f", size = 11378514 },
{ url = "https://files.pythonhosted.org/packages/d9/3a/a647fa4f316482dacf2fd68e8a386327a33d6eabd8eb2f9a0c3d291ec549/ruff-0.11.2-py3-none-win32.whl", hash = "sha256:aca01ccd0eb5eb7156b324cfaa088586f06a86d9e5314b0eb330cb48415097cc", size = 10319835 },
{ url = "https://files.pythonhosted.org/packages/86/54/3c12d3af58012a5e2cd7ebdbe9983f4834af3f8cbea0e8a8c74fa1e23b2b/ruff-0.11.2-py3-none-win_amd64.whl", hash = "sha256:3170150172a8f994136c0c66f494edf199a0bbea7a409f649e4bc8f4d7084080", size = 11373713 },
{ url = "https://files.pythonhosted.org/packages/d6/d4/dd813703af8a1e2ac33bf3feb27e8a5ad514c9f219df80c64d69807e7f71/ruff-0.11.2-py3-none-win_arm64.whl", hash = "sha256:52933095158ff328f4c77af3d74f0379e34fd52f175144cefc1b192e7ccd32b4", size = 10441990 },
]
[[package]] [[package]]
name = "shellingham" name = "shellingham"
version = "1.5.4" version = "1.5.4"