diff --git a/src/openec2/actions/deregister_image.py b/src/openec2/actions/deregister_image.py new file mode 100644 index 0000000..8bebcf7 --- /dev/null +++ b/src/openec2/actions/deregister_image.py @@ -0,0 +1,26 @@ +from fastapi import HTTPException +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.image import AMI +from openec2.images import garbage_collect_image + +def deregister_image( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, +): + image_id = params["ImageId"] + ami = db.exec(select(AMI).where(AMI.id == image_id)).one() + if ami is None: + raise HTTPException(status_code=404, detail="Unknown AMI") + + # Mark the image as deregistered + ami.deregistered = True + db.add(ami) + db.commit() + + # First round of garbage collection + garbage_collect_image(image_id, db) diff --git a/src/openec2/actions/run_instances.py b/src/openec2/actions/run_instances.py index 6f2f531..27657a5 100644 --- a/src/openec2/actions/run_instances.py +++ b/src/openec2/actions/run_instances.py @@ -4,23 +4,28 @@ from typing import cast import uuid import os +from fastapi import HTTPException from fastapi.datastructures import QueryParams from sqlmodel import select from openec2.libvirt import LibvirtSingleton from openec2.config import OpenEC2Config +from openec2.utils.qemu import create_cow_copy from openec2.db import DatabaseDep from openec2.db.instance import Instance from openec2.db.image import AMI from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet from openec2.api.describe_instances import describe_instance from openec2.utils.array import parse_array_objects +from openec2.ipam import get_available_ipv4, is_ipv4_available, add_instance_dhcp_mapping +from openec2.utils.ip import generate_available_mac def create_libvirt_domain( name: str, memory: int, vcpu: int, ami_path: str, + mac: str, user_data: str | None, ) -> str: return f""" @@ -70,7 +75,7 @@ def create_libvirt_domain( - + @@ -93,31 +98,51 @@ def run_instances( if ami is None: raise Exception(f"Unknown AMI {image_id}") + if ami.deregistered: + raise HTTPException( + status_code=400, + detail="AMI is deregistered and cannot be used anymore", + ) + # Parse tags + # TODO: broken tags: dict[str, str] = {} for spec in parse_array_objects("TagSpecification", cast(dict, params)): for raw_tag in parse_array_objects("Tag", spec): tags[raw_tag["Key"]] = raw_tag["Value"] - # Prepare the instance directory + # Get a private IPv4 instance_id = uuid.uuid4().hex + private_ipv4 = params.get( + "PrivateIpAddress", + get_available_ipv4(db), + ) + if not is_ipv4_available(private_ipv4, db): + raise HTTPException( + status_code=400, + detail="Used IPv4", + ) + mac = generate_available_mac(db) + add_instance_dhcp_mapping(instance_id, mac, private_ipv4, db) + + # Prepare the instance directory config.instances.location.mkdir(exist_ok=True) disk = config.instances.location / instance_id - shutil.copy( - str(config.images / ami.id), - str(disk), + create_cow_copy( + config.images / ami.id, + disk, + f"{instance_type.disk}G", ) - os.system(f"qemu-img resize {disk} {instance_type.disk}G") instance = Instance( id=instance_id, imageId=image_id, tags=tags, userData=base64.b64decode(value).decode() if (value := params.get("UserData")) is not None else None, + privateIPv4=private_ipv4, + interfaceMac=mac, ) db.add(instance) - db.flush() - db.commit() print("Inserted new instance") conn = LibvirtSingleton.of().connection @@ -127,12 +152,14 @@ def run_instances( instance_type.memory, int(instance_type.vcpu), str(config.instances.location / instance_id), + mac, None, ), ) domain.create() description = describe_instance(instance, domain) + db.commit() return RunInstanceResponse( request_id=uuid.uuid4().hex, instance_set=RunInstanceInstanceSet( diff --git a/src/openec2/actions/terminate_instances.py b/src/openec2/actions/terminate_instances.py index 4eb15c5..5b07371 100644 --- a/src/openec2/actions/terminate_instances.py +++ b/src/openec2/actions/terminate_instances.py @@ -9,6 +9,8 @@ from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.instance import Instance from openec2.utils.array import parse_array_plain +from openec2.images import garbage_collect_image +from openec2.ipam import remove_instance_dhcp_mapping logger = logging.getLogger() @@ -19,6 +21,7 @@ def terminate_instances( db: DatabaseDep, ): conn = LibvirtSingleton.of().connection + image_ids: set[str] = set() for instance_id in parse_array_plain("InstanceId", params): instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() if instance is None: @@ -34,7 +37,14 @@ def terminate_instances( instance_disk = config.instances.location / instance_id instance_disk.unlink() + image_ids.add(instance.imageId) + remove_instance_dhcp_mapping(instance.id, instance.interfaceMac, instance.privateIPv4, db) db.delete(instance) - db.commit() + + db.commit() + + # Garbage collect AMIs + for image_id in image_ids: + garbage_collect_image(image_id, db) return "OK" diff --git a/src/openec2/db/image.py b/src/openec2/db/image.py index 4b05004..60be1fc 100644 --- a/src/openec2/db/image.py +++ b/src/openec2/db/image.py @@ -1,8 +1,14 @@ from sqlmodel import SQLModel, Field class AMI(SQLModel, table=True): + # ID of the AMI id: str = Field(default=None, primary_key=True) + # Description of the image description: str | None = None + # Filename that got imported originalFilename: str + + # Was the image registered + deregistered: bool = Field(default=False) diff --git a/src/openec2/db/instance.py b/src/openec2/db/instance.py index 3d440e2..a6428e6 100644 --- a/src/openec2/db/instance.py +++ b/src/openec2/db/instance.py @@ -11,3 +11,9 @@ class Instance(SQLModel, table=True): # Optional user data associated with the VM userData: str | None + + # MAC of the network interface + interfaceMac: str + + # Private IPv4 of the instance + privateIPv4: str diff --git a/src/openec2/db/ipam.py b/src/openec2/db/ipam.py new file mode 100644 index 0000000..b528291 --- /dev/null +++ b/src/openec2/db/ipam.py @@ -0,0 +1,24 @@ +from sqlmodel import SQLModel, Field, PrimaryKeyConstraint + +from openec2.utils.ip import int_to_ipv4, ipv4_to_int + + +class IPAMEntry(SQLModel, table=True): + # IP Address + ipv4_addr_raw: int = Field(primary_key=True) + + # Instance this IP is assigned to + instance_id: str = Field(primary_key=True) + + # VPC ID + vpc_id: str + + def ipv4(self) -> str: + return int_to_ipv4(self.ipv4_addr_raw) + + def set_ipv4(self, addr: str): + self.ipv4_addr_raw = ipv4_to_int(addr) + + __table_args = ( + PrimaryKeyConstraint("ipv4_addr_raw", "vpc_id"), + ) diff --git a/src/openec2/db/vpc.py b/src/openec2/db/vpc.py new file mode 100644 index 0000000..2253c9e --- /dev/null +++ b/src/openec2/db/vpc.py @@ -0,0 +1,11 @@ +from sqlmodel import SQLModel, Field, PrimaryKeyConstraint + +class VPC(SQLModel, table=True): + # ID of the VPC + id: str = Field(default=None, primary_key=True) + + # Subnet mask + subnet: str + + # Base IPv4 + ipv4_base: str diff --git a/src/openec2/images.py b/src/openec2/images.py new file mode 100644 index 0000000..6c804ee --- /dev/null +++ b/src/openec2/images.py @@ -0,0 +1,22 @@ +from sqlmodel import select + +from openec2.config import ConfigSingleton +from openec2.db import DatabaseDep +from openec2.db.instance import Instance +from openec2.db.image import AMI + + +def garbage_collect_image(image_id: str, db: DatabaseDep): + instances = db.exec(select(Instance).where(Instance.imageId == image_id)).all() + if instances: + print("Instances sill using AMI. Not cleaning up") + print(instances) + return + + ami = db.exec(select(AMI).where(AMI.id == image_id, AMI.deregistered == True)).first() + if ami is not None: + db.delete(ami) + db.commit() + image = ConfigSingleton.of().config.images / image_id + image.unlink() + print(f"Removing {image}") diff --git a/src/openec2/ipam.py b/src/openec2/ipam.py new file mode 100644 index 0000000..8cb6539 --- /dev/null +++ b/src/openec2/ipam.py @@ -0,0 +1,61 @@ +from sqlmodel import select +import libvirt + +from openec2.libvirt import LibvirtSingleton +from openec2.db import DatabaseDep +from openec2.db.ipam import IPAMEntry +from openec2.utils.ip import ipv4_to_int, int_to_ipv4 + + +def _libvirt_host_update(instance_id: str, mac: str, ipv4: str) -> str: + return f"" + + +def add_instance_dhcp_mapping(instance_id: str, mac: str, ipv4: str, db: DatabaseDep): + """ + Adds a DHCP entry for the network to give the instance a static + private IPv4 address. + """ + entry = IPAMEntry( + ipv4_addr_raw=ipv4_to_int(ipv4), + instance_id=instance_id, + # TODO + vpc_id="default", + ) + db.add(entry) + + # Tell libvirt about this mapping + conn = LibvirtSingleton.of().connection + conn.networkLookupByName("default").update( + libvirt.VIR_NETWORK_UPDATE_COMMAND_ADD_LAST, + libvirt.VIR_NETWORK_SECTION_IP_DHCP_HOST, + 0, + _libvirt_host_update(instance_id, mac, ipv4), + flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE, + ) + +def remove_instance_dhcp_mapping(instance_id: str, mac: str ,ipv4: str, db: DatabaseDep): + i = ipv4_to_int(ipv4) + entry = db.exec(select(IPAMEntry).where(IPAMEntry.ipv4_addr_raw == i, IPAMEntry.instance_id == instance_id)).first() + db.delete(entry) + + # Tell libvirt about this mapping + conn = LibvirtSingleton.of().connection + conn.networkLookupByName("default").update( + libvirt.VIR_NETWORK_UPDATE_COMMAND_DELETE, + libvirt.VIR_NETWORK_SECTION_IP_DHCP_HOST, + 0, + _libvirt_host_update(instance_id, mac, ipv4), + flags=libvirt.VIR_NETWORK_UPDATE_AFFECT_LIVE, + ) + +def is_ipv4_available(ipv4: str, db: DatabaseDep) -> bool: + i = ipv4_to_int(ipv4) + return db.exec(select(IPAMEntry).where(IPAMEntry.ipv4_addr_raw == i)).first() is None + +def get_available_ipv4(db: DatabaseDep) -> str: + entries = db.exec(select(IPAMEntry)).all() + # TODO: Use the VPC's subnet + max_ip = max(e.ipv4_addr_raw for e in entries) if entries else ipv4_to_int("192.168.122.2") + # TODO: Check if we're still inside the subnet + return int_to_ipv4(max_ip + 1) diff --git a/src/openec2/main.py b/src/openec2/main.py index c6e30ef..40a50d1 100644 --- a/src/openec2/main.py +++ b/src/openec2/main.py @@ -12,6 +12,7 @@ from openec2.actions.run_instances import run_instances from openec2.actions.terminate_instances import terminate_instances from openec2.actions.start_instances import start_instances from openec2.actions.stop_instances import stop_instances +from openec2.actions.deregister_image import deregister_image from openec2.db.instance import Instance app = FastAPI() @@ -37,6 +38,7 @@ def action(request: Request, config: OpenEC2Config, db: DatabaseDep): "TerminateInstances": terminate_instances, "StartInstances": start_instances, "StopInstances": stop_instances, + "DeregisterImage": deregister_image, }[action](request.query_params, config, db) @app.get("/private/cloudinit/{instance_id}/{entry}") diff --git a/src/openec2/utils/ip.py b/src/openec2/utils/ip.py new file mode 100644 index 0000000..d12f0b3 --- /dev/null +++ b/src/openec2/utils/ip.py @@ -0,0 +1,44 @@ +import random + +from sqlmodel import select + +from openec2.db import DatabaseDep +from openec2.db.instance import Instance + + +def ipv4_to_int(ip: str) -> int: + i = 0 + for idx, p in enumerate(ip.split(".")): + i += (int(p) << (3-idx)*8) + return i + +def int_to_ipv4(ip: int) -> str: + parts: list[int] = [] + for i in reversed(range(4)): + parts.append( + (ip >> i*8) & 255, + ) + return ".".join(str(p) for p in parts) + +def generate_mac() -> str: + mac_bytes = random.randbytes(6) + mac = "" + for idx, b in enumerate(mac_bytes): + # Ensure we have a unicast MAC + if idx == 0: + b = b & (255 - 1) + + h = hex(b)[2:] + if len(h) == 1: + mac += f"0{h}:" + else: + mac += f"{h}:" + return mac[:-1] + +def generate_available_mac(db: DatabaseDep) -> str: + mac = "" + while True: + mac = generate_mac() + if db.exec(select(Instance).where(Instance.interfaceMac == mac)).first() is None: + break + return mac diff --git a/src/openec2/utils/qemu.py b/src/openec2/utils/qemu.py new file mode 100644 index 0000000..542a0e6 --- /dev/null +++ b/src/openec2/utils/qemu.py @@ -0,0 +1,14 @@ +from pathlib import Path +import subprocess + + +def create_cow_copy(src: Path, dst: Path, size: str): + subprocess.call([ + "qemu-img", + "create", + "-f", "qcow2", + "-b", str(src), + "-F", "qcow2", + str(dst), + size, + ]) diff --git a/tests/openec2/test_ip.py b/tests/openec2/test_ip.py new file mode 100644 index 0000000..54678c6 --- /dev/null +++ b/tests/openec2/test_ip.py @@ -0,0 +1,8 @@ +from openec2.utils.ip import ipv4_to_int, int_to_ipv4 + + +def test_idempotent(): + ip = "127.0.0.1" + ip_int = ipv4_to_int(ip) + print(ip_int) + assert int_to_ipv4(ip_int) == ip