From 6d99b446a04ad9a79e669e1dd4cac4850cb849a3 Mon Sep 17 00:00:00 2001 From: "Alexander \"PapaTutuWawa" Date: Sun, 6 Apr 2025 17:20:26 +0200 Subject: [PATCH] Handle creating/destroying instances from Terraform/tofu --- examples/tofu/main.tf | 43 +++++++++++- src/openec2/actions/attach_volume.py | 4 +- src/openec2/actions/create_tags.py | 2 +- src/openec2/actions/describe_images.py | 47 +++++++++++-- .../actions/describe_instance_attribute.py | 2 +- .../actions/describe_instance_types.py | 5 +- src/openec2/actions/describe_instances.py | 68 +++++++++++++------ src/openec2/actions/import_image.py | 3 + src/openec2/actions/run_instances.py | 2 + src/openec2/actions/start_instances.py | 2 +- src/openec2/actions/stop_instances.py | 2 +- src/openec2/actions/terminate_instances.py | 19 +++--- src/openec2/api/describe_images.py | 10 +-- src/openec2/api/describe_instances.py | 45 ++++-------- src/openec2/api/shared.py | 5 ++ src/openec2/db/image.py | 5 +- src/openec2/db/instance.py | 9 ++- src/openec2/images.py | 2 +- src/openec2/utils/array.py | 8 +++ 19 files changed, 195 insertions(+), 88 deletions(-) diff --git a/examples/tofu/main.tf b/examples/tofu/main.tf index e6e772f..2fd3874 100644 --- a/examples/tofu/main.tf +++ b/examples/tofu/main.tf @@ -11,12 +11,49 @@ provider "aws" { region = "eu-west-1" } -resource "aws_instance" "test-instance" { - ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb" +# https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2 +# Import using: +# aws ec2 import-image --disk-container "Url=https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2" --tag-specification 'Tags=[{Key="Linux",Value="ArchLinux-nocloud"}]' +data "aws_ami" "archlinux-nocloud" { + filter { + name = "tag:Linux" + values = ["ArchLinux-nocloud"] + } +} + +resource "aws_instance" "test-instance-1" { + ami = data.aws_ami.archlinux-nocloud.id instance_type = "micro" availability_zone = "az-1" + private_ip = "192.168.122.3" + tags = { - TestTag = "TestValue" + UseCase = "k8s-control-plane" } } + +# resource "aws_instance" "test-instance-2" { +# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb" +# instance_type = "micro" +# availability_zone = "az-1" + +# private_ip = "192.168.122.4" + +# tags = { +# UseCase = "k8s-control-plane" +# } +# } + + +# resource "aws_instance" "test-instance-3" { +# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb" +# instance_type = "micro" +# availability_zone = "az-1" + +# private_ip = "192.168.122.5" + +# tags = { +# UseCase = "k8s-control-plane" +# } +# } diff --git a/src/openec2/actions/attach_volume.py b/src/openec2/actions/attach_volume.py index 107bb60..1109581 100644 --- a/src/openec2/actions/attach_volume.py +++ b/src/openec2/actions/attach_volume.py @@ -26,10 +26,10 @@ def attach_volume( instance_id = params["InstanceId"] volume_id = params["VolumeId"] - volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id)).first() + volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id, Instance.owner_id == user.id)).first() if volume is None: return - instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first() if instance is None: return diff --git a/src/openec2/actions/create_tags.py b/src/openec2/actions/create_tags.py index decf99e..e3dcad2 100644 --- a/src/openec2/actions/create_tags.py +++ b/src/openec2/actions/create_tags.py @@ -31,7 +31,7 @@ def create_tags( tags[tag["Key"]] = tag["Value"] for instance_id in parse_array_plain("ResourceId", cast(dict, params)): - instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first() if instance is None: print(f"Unknown instance {instance_id}") continue diff --git a/src/openec2/actions/describe_images.py b/src/openec2/actions/describe_images.py index 1df32ba..3c566c7 100644 --- a/src/openec2/actions/describe_images.py +++ b/src/openec2/actions/describe_images.py @@ -1,4 +1,6 @@ import uuid +from typing import cast +from dataclasses import dataclass from fastapi import Response from fastapi.datastructures import QueryParams @@ -8,31 +10,64 @@ from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.user import User from openec2.db.image import AMI -from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image +from openec2.api.describe_images import DescribeImagesResponse, Image +from openec2.api.shared import Tag +from openec2.utils.array import parse_array_objects, parse_array_plain + + +@dataclass +class Filter: + name: str + values: list[str] + + def match(self, image: AMI) -> bool: + value: str | None + if self.name.startswith("tag:"): + value = image.tags.get(self.name.replace("tag:", "")) + else: + raise Exception(f"Unknown filter name {self.name}") + + return value in self.values def describe_images( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, - _: User, + user: User, ): + filters: list[Filter] = [] + for filter in parse_array_objects("Filter", cast(dict, params)): + filters.append( + Filter( + name=filter["Name"], + values=parse_array_plain("Value", filter), + ) + ) + images: list[Image] = [] - for ami in db.exec(select(AMI)).all(): + for ami in db.exec(select(AMI).where(AMI.owner_id == user.id)).all(): + if not all(f.match(ami) for f in filters): + continue + images.append( Image( imageId=ami.id, imageState="available", name=ami.originalFilename, + tagSet=[ + Tag( + key=key, + value=value, + ) for key, value in ami.tags.items() + ], ), ) return Response( DescribeImagesResponse( requestId=uuid.uuid4().hex, - imagesSet=ImagesSet( - items=images, - ), + imagesSet=images, ).to_xml(), media_type="application/xml", ) diff --git a/src/openec2/actions/describe_instance_attribute.py b/src/openec2/actions/describe_instance_attribute.py index 03dca47..195965c 100644 --- a/src/openec2/actions/describe_instance_attribute.py +++ b/src/openec2/actions/describe_instance_attribute.py @@ -19,7 +19,7 @@ def describe_instance_attribute( ): instance_id = params["InstanceId"] attribute = params["Attribute"] - instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first() if instance is None: raise HTTPException( status_code=404, diff --git a/src/openec2/actions/describe_instance_types.py b/src/openec2/actions/describe_instance_types.py index 33820cc..0c9cfd4 100644 --- a/src/openec2/actions/describe_instance_types.py +++ b/src/openec2/actions/describe_instance_types.py @@ -18,11 +18,10 @@ def describe_instance_types( user: User, ): response: list[InstanceTypeInfo] = [] - instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all() - for instance in instances: + for name, instanceConfig in config.instances.types.items(): response.append( InstanceTypeInfo( - instanceType=instance.instanceType, + instanceType=name, ), ) diff --git a/src/openec2/actions/describe_instances.py b/src/openec2/actions/describe_instances.py index e11f0ac..e0958ed 100644 --- a/src/openec2/actions/describe_instances.py +++ b/src/openec2/actions/describe_instances.py @@ -1,23 +1,25 @@ import uuid +from typing import cast +import datetime from fastapi import Response from fastapi.datastructures import QueryParams -from sqlmodel import select +from sqlmodel import select, or_ from openec2.libvirt import LibvirtSingleton from openec2.api.describe_instances import ( InstanceDescription, DescribeInstancesResponse, - DescribeInstancesResponseReservationSet, ReservationSet, - ReservationSetInstancesSet, + InstanceState, describe_instance, ) +from openec2.api.shared import Tag from openec2.db.user import User -from openec2.api.shared import InstanceState from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.instance import Instance +from openec2.utils.array import parse_array_plain def describe_instances( @@ -26,12 +28,42 @@ def describe_instances( db: DatabaseDep, user: User, ): + if "InstanceId.1" in params: + instance_ids = parse_array_plain("InstanceId", cast(dict, params)) + instance_expr = [Instance.id == instance_id for instance_id in instance_ids] + instances = db.exec( + select(Instance).where(Instance.owner_id == user.id, or_(*instance_expr)), + ).all() + else: + instances = db.exec( + select(Instance).where(Instance.owner_id == user.id), + ).all() + response_items: list[InstanceDescription] = [] conn = LibvirtSingleton.of().connection - for instance in db.exec(select(Instance)).all(): - # Check for permission issues - if instance.owner_id != user.id: - # TODO: Add the error to the response + now = datetime.datetime.now() + for instance in instances: + # Include terminated instances for an hour + if instance.terminated: + assert instance.terminationDate is not None + if now <= instance.terminationDate + datetime.timedelta(hours=1): + response_items.append( + InstanceDescription( + instanceId=instance.id, + imageId=instance.imageId, + instanceState=InstanceState( + code=48, + name="terminated", + ), + tagSet=[ + Tag( + key=key, + value=value, + ) + for key, value in instance.tags.items() + ], + ), + ) continue dom = conn.lookupByName(instance.id) @@ -41,18 +73,14 @@ def describe_instances( return Response( DescribeInstancesResponse( - request_id=uuid.uuid4().hex, - reservation_set=DescribeInstancesResponseReservationSet( - item=[ - ReservationSet( - reservationId=instance.instance_id, - instancesSet=ReservationSetInstancesSet( - item=[instance], - ), - ) - for instance in response_items - ], - ), + requestId=uuid.uuid4().hex, + reservationSet=[ + ReservationSet( + reservationId=instance.instanceId, + instancesSet=[instance], + ) + for instance in response_items + ], ).to_xml(), media_type="application/xml", ) diff --git a/src/openec2/actions/import_image.py b/src/openec2/actions/import_image.py index ee4d4c6..4d0ee3e 100644 --- a/src/openec2/actions/import_image.py +++ b/src/openec2/actions/import_image.py @@ -12,6 +12,7 @@ from openec2.config import OpenEC2Config, ConfigSingleton from openec2.db import DatabaseDep from openec2.db.user import User from openec2.db.image import AMI +from openec2.utils.array import parse_tag_specification def import_image( @@ -24,6 +25,7 @@ def import_image( url = urlparse(first_disk_image_url) ami_id = uuid.uuid4().hex + tags: dict[str, str] = parse_tag_specification(cast(dict, params)) imageLocation = cast(Path, config.images) imageLocation.mkdir(exist_ok=True) dst = imageLocation / ami_id @@ -54,6 +56,7 @@ def import_image( description=None, originalFilename=filename, owner_id=user.id, + tags=tags, ), ) db.commit() diff --git a/src/openec2/actions/run_instances.py b/src/openec2/actions/run_instances.py index 14c30f3..8d45c1e 100644 --- a/src/openec2/actions/run_instances.py +++ b/src/openec2/actions/run_instances.py @@ -155,6 +155,8 @@ def run_instances( interfaceMac=mac, instanceType=instance_type_name, owner_id=user.id, + terminated=False, + terminationDate=None, ) db.add(instance) print("Inserted new instance") diff --git a/src/openec2/actions/start_instances.py b/src/openec2/actions/start_instances.py index 56d09c9..b40bb7e 100644 --- a/src/openec2/actions/start_instances.py +++ b/src/openec2/actions/start_instances.py @@ -24,7 +24,7 @@ def start_instances( conn = LibvirtSingleton.of().connection instances: list[InstanceInfo] = [] for instance_id in parse_array_plain("InstanceId", params): - instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first() if instance is None: raise HTTPException(status_code=404, detail="Unknown instance") diff --git a/src/openec2/actions/stop_instances.py b/src/openec2/actions/stop_instances.py index 039f38c..f884602 100644 --- a/src/openec2/actions/stop_instances.py +++ b/src/openec2/actions/stop_instances.py @@ -24,7 +24,7 @@ def stop_instances( conn = LibvirtSingleton.of().connection instances: list[InstanceInfo] = [] for instance_id in parse_array_plain("InstanceId", params): - instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first() if instance is None: raise HTTPException(status_code=404, detail="Unknown instance") diff --git a/src/openec2/actions/terminate_instances.py b/src/openec2/actions/terminate_instances.py index 0fcc8e6..6839cb3 100644 --- a/src/openec2/actions/terminate_instances.py +++ b/src/openec2/actions/terminate_instances.py @@ -1,7 +1,9 @@ import logging +from typing import cast import uuid +import datetime -from fastapi import HTTPException, Response +from fastapi import Response from fastapi.datastructures import QueryParams from sqlmodel import select @@ -31,16 +33,11 @@ def terminate_instances( conn = LibvirtSingleton.of().connection image_ids: set[str] = set() for instance_id in parse_array_plain("InstanceId", params): - instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() if instance is None: continue # raise HTTPException(status_code=404, detail="Unknown instance") - # Check for permission issues - if instance.owner_id != user.id: - # TODO: Add the error to the response - continue - dom = conn.lookupByName(instance_id) prev_state = describe_instance_state(dom) if dom.isActive(): @@ -64,11 +61,15 @@ def terminate_instances( for volume in instance.ebs_volumes: volume.instances.remove(instance) - image_ids.add(instance.imageId) + + image_ids.add(cast(str, instance.imageId)) remove_instance_dhcp_mapping( instance.id, instance.interfaceMac, instance.privateIPv4, db ) - db.delete(instance) + + # Mark the instance as terminated + instance.terminated = True + instance.terminationDate = datetime.datetime.now() db.commit() diff --git a/src/openec2/api/describe_images.py b/src/openec2/api/describe_images.py index 950ad0f..6f74363 100644 --- a/src/openec2/api/describe_images.py +++ b/src/openec2/api/describe_images.py @@ -1,4 +1,6 @@ -from pydantic_xml import BaseXmlModel, element +from pydantic_xml import BaseXmlModel, wrapped, element + +from openec2.api.shared import Tag class Image(BaseXmlModel): @@ -8,9 +10,7 @@ class Image(BaseXmlModel): name: str = element() - -class ImagesSet(BaseXmlModel, tag="imagesSet"): - items: list[Image] = element(tag="item") + tagSet: list[Tag] = wrapped("tagSet", element(tag="item")) class DescribeImagesResponse( @@ -20,4 +20,4 @@ class DescribeImagesResponse( ): requestId: str = element() - imagesSet: ImagesSet = element() + imagesSet: list[Image] = wrapped("imagesSet", element(tag="item")) diff --git a/src/openec2/api/describe_instances.py b/src/openec2/api/describe_instances.py index e259a5e..73e1139 100644 --- a/src/openec2/api/describe_instances.py +++ b/src/openec2/api/describe_instances.py @@ -2,42 +2,21 @@ from pydantic_xml import BaseXmlModel, wrapped, element import libvirt from openec2.db.instance import Instance -from openec2.api.shared import InstanceState - - -class InstanceTag(BaseXmlModel): - key: str = element() - value: str = element() - - -class InstanceTagSet(BaseXmlModel): - item: list[InstanceTag] = element() - +from openec2.api.shared import InstanceState, Tag class InstanceDescription( BaseXmlModel, tag="item", ): - instance_id: str = element(tag="instanceId") - image_id: str = element(tag="imageId") - instance_state: InstanceState = element(tag="instanceState") - tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item")) - - -class ReservationSetInstancesSet(BaseXmlModel): - item: list[InstanceDescription] = element() + instanceId: str = element() + imageId: str = element() + instanceState: InstanceState = element() + tagSet: list[Tag] = wrapped("tagSet", element(tag="item")) class ReservationSet(BaseXmlModel): reservationId: str = element() - instancesSet: ReservationSetInstancesSet = element() - - -class DescribeInstancesResponseReservationSet( - BaseXmlModel, - tag="reservationSet", -): - item: list[ReservationSet] = element("") + instancesSet: list[InstanceDescription] = wrapped("instancesSet", element(tag="item")) class DescribeInstancesResponse( @@ -45,8 +24,8 @@ class DescribeInstancesResponse( tag="DescribeInstancesResponse", nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}, ): - request_id: str = element(tag="requestId") - reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet") + requestId: str = element() + reservationSet: list[ReservationSet] = wrapped("reservationSet", element("item")) def describe_instance_state(domain: libvirt.virDomain) -> InstanceState: @@ -61,11 +40,11 @@ def describe_instance( instance: Instance, domain: libvirt.virDomain ) -> InstanceDescription: return InstanceDescription( - instance_id=instance.id, - image_id=instance.imageId, - instance_state=describe_instance_state(domain), + instanceId=instance.id, + imageId=instance.imageId, + instanceState=describe_instance_state(domain), tagSet=[ - InstanceTag( + Tag( key=key, value=value, ) diff --git a/src/openec2/api/shared.py b/src/openec2/api/shared.py index 9782c01..6d83822 100644 --- a/src/openec2/api/shared.py +++ b/src/openec2/api/shared.py @@ -14,3 +14,8 @@ class InstanceInfo(BaseXmlModel): class InstancesSet(BaseXmlModel, tag="instancesSet"): item: list[InstanceInfo] = element() + + +class Tag(BaseXmlModel): + key: str = element() + value: str = element() diff --git a/src/openec2/db/image.py b/src/openec2/db/image.py index 17374f6..a00a77e 100644 --- a/src/openec2/db/image.py +++ b/src/openec2/db/image.py @@ -1,4 +1,4 @@ -from sqlmodel import SQLModel, Field +from sqlmodel import SQLModel, Field, Column, JSON class AMI(SQLModel, table=True): @@ -16,3 +16,6 @@ class AMI(SQLModel, table=True): # Owner of the image who created it owner_id: int = Field(foreign_key="user.id") + + # Tags associated with the AMI + tags: dict = Field(sa_column=Column(JSON), default={}) diff --git a/src/openec2/db/instance.py b/src/openec2/db/instance.py index 7e7a15b..89c81c3 100644 --- a/src/openec2/db/instance.py +++ b/src/openec2/db/instance.py @@ -1,4 +1,5 @@ from pathlib import Path +from datetime import datetime from sqlmodel import SQLModel, Field, JSON, Column, Relationship @@ -40,7 +41,7 @@ class Instance(SQLModel, table=True): instanceType: str - # ImageID of the used AMI + # ImageID of the used AMI. None only if terminated == True. imageId: str # Optional user data associated with the VM @@ -59,3 +60,9 @@ class Instance(SQLModel, table=True): ebs_volumes: list[EBSVolume] = Relationship( back_populates="instances", link_model=EBSVolumeInstanceLink ) + + # Is the instance terminated + terminated: bool + + # Date at which the instance got terminated + terminationDate: datetime | None diff --git a/src/openec2/images.py b/src/openec2/images.py index 8d221d4..a9abfb9 100644 --- a/src/openec2/images.py +++ b/src/openec2/images.py @@ -7,7 +7,7 @@ from openec2.db.image import AMI def garbage_collect_image(image_id: str, db: DatabaseDep): - instances = db.exec(select(Instance).where(Instance.imageId == image_id)).all() + instances = db.exec(select(Instance).where(Instance.imageId == image_id, Instance.terminated == False)).all() if instances: print("Instances sill using AMI. Not cleaning up") print(instances) diff --git a/src/openec2/utils/array.py b/src/openec2/utils/array.py index cab6a6a..ac02209 100644 --- a/src/openec2/utils/array.py +++ b/src/openec2/utils/array.py @@ -35,3 +35,11 @@ def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None: if pred(item): return item return None + + +def parse_tag_specification(params: dict) -> dict[str, str]: + tags: dict[str, str] = {} + for spec in parse_array_objects("TagSpecification", params): + for raw_tag in parse_array_objects("Tag", spec): + tags[raw_tag["Key"]] = raw_tag["Value"] + return tags