Handle creating/destroying instances from Terraform/tofu

This commit is contained in:
PapaTutuWawa 2025-04-06 17:20:26 +02:00
parent 38d37a7d5b
commit 6d99b446a0
19 changed files with 195 additions and 88 deletions

View File

@ -11,12 +11,49 @@ provider "aws" {
region = "eu-west-1" region = "eu-west-1"
} }
resource "aws_instance" "test-instance" { # https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2
ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb" # Import using:
# aws ec2 import-image --disk-container "Url=https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2" --tag-specification 'Tags=[{Key="Linux",Value="ArchLinux-nocloud"}]'
data "aws_ami" "archlinux-nocloud" {
filter {
name = "tag:Linux"
values = ["ArchLinux-nocloud"]
}
}
resource "aws_instance" "test-instance-1" {
ami = data.aws_ami.archlinux-nocloud.id
instance_type = "micro" instance_type = "micro"
availability_zone = "az-1" availability_zone = "az-1"
private_ip = "192.168.122.3"
tags = { tags = {
TestTag = "TestValue" UseCase = "k8s-control-plane"
} }
} }
# resource "aws_instance" "test-instance-2" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.4"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }
# resource "aws_instance" "test-instance-3" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.5"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }

View File

@ -26,10 +26,10 @@ def attach_volume(
instance_id = params["InstanceId"] instance_id = params["InstanceId"]
volume_id = params["VolumeId"] volume_id = params["VolumeId"]
volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id)).first() volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id, Instance.owner_id == user.id)).first()
if volume is None: if volume is None:
return return
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None: if instance is None:
return return

View File

@ -31,7 +31,7 @@ def create_tags(
tags[tag["Key"]] = tag["Value"] tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)): for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None: if instance is None:
print(f"Unknown instance {instance_id}") print(f"Unknown instance {instance_id}")
continue continue

View File

@ -1,4 +1,6 @@
import uuid import uuid
from typing import cast
from dataclasses import dataclass
from fastapi import Response from fastapi import Response
from fastapi.datastructures import QueryParams from fastapi.datastructures import QueryParams
@ -8,31 +10,64 @@ from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.user import User from openec2.db.user import User
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image from openec2.api.describe_images import DescribeImagesResponse, Image
from openec2.api.shared import Tag
from openec2.utils.array import parse_array_objects, parse_array_plain
@dataclass
class Filter:
name: str
values: list[str]
def match(self, image: AMI) -> bool:
value: str | None
if self.name.startswith("tag:"):
value = image.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def describe_images( def describe_images(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
_: User, user: User,
): ):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Value", filter),
)
)
images: list[Image] = [] images: list[Image] = []
for ami in db.exec(select(AMI)).all(): for ami in db.exec(select(AMI).where(AMI.owner_id == user.id)).all():
if not all(f.match(ami) for f in filters):
continue
images.append( images.append(
Image( Image(
imageId=ami.id, imageId=ami.id,
imageState="available", imageState="available",
name=ami.originalFilename, name=ami.originalFilename,
tagSet=[
Tag(
key=key,
value=value,
) for key, value in ami.tags.items()
],
), ),
) )
return Response( return Response(
DescribeImagesResponse( DescribeImagesResponse(
requestId=uuid.uuid4().hex, requestId=uuid.uuid4().hex,
imagesSet=ImagesSet( imagesSet=images,
items=images,
),
).to_xml(), ).to_xml(),
media_type="application/xml", media_type="application/xml",
) )

View File

@ -19,7 +19,7 @@ def describe_instance_attribute(
): ):
instance_id = params["InstanceId"] instance_id = params["InstanceId"]
attribute = params["Attribute"] attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None: if instance is None:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,

View File

@ -18,11 +18,10 @@ def describe_instance_types(
user: User, user: User,
): ):
response: list[InstanceTypeInfo] = [] response: list[InstanceTypeInfo] = []
instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all() for name, instanceConfig in config.instances.types.items():
for instance in instances:
response.append( response.append(
InstanceTypeInfo( InstanceTypeInfo(
instanceType=instance.instanceType, instanceType=name,
), ),
) )

View File

@ -1,23 +1,25 @@
import uuid import uuid
from typing import cast
import datetime
from fastapi import Response from fastapi import Response
from fastapi.datastructures import QueryParams from fastapi.datastructures import QueryParams
from sqlmodel import select from sqlmodel import select, or_
from openec2.libvirt import LibvirtSingleton from openec2.libvirt import LibvirtSingleton
from openec2.api.describe_instances import ( from openec2.api.describe_instances import (
InstanceDescription, InstanceDescription,
DescribeInstancesResponse, DescribeInstancesResponse,
DescribeInstancesResponseReservationSet,
ReservationSet, ReservationSet,
ReservationSetInstancesSet, InstanceState,
describe_instance, describe_instance,
) )
from openec2.api.shared import Tag
from openec2.db.user import User from openec2.db.user import User
from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.utils.array import parse_array_plain
def describe_instances( def describe_instances(
@ -26,12 +28,42 @@ def describe_instances(
db: DatabaseDep, db: DatabaseDep,
user: User, user: User,
): ):
if "InstanceId.1" in params:
instance_ids = parse_array_plain("InstanceId", cast(dict, params))
instance_expr = [Instance.id == instance_id for instance_id in instance_ids]
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id, or_(*instance_expr)),
).all()
else:
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id),
).all()
response_items: list[InstanceDescription] = [] response_items: list[InstanceDescription] = []
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
for instance in db.exec(select(Instance)).all(): now = datetime.datetime.now()
# Check for permission issues for instance in instances:
if instance.owner_id != user.id: # Include terminated instances for an hour
# TODO: Add the error to the response if instance.terminated:
assert instance.terminationDate is not None
if now <= instance.terminationDate + datetime.timedelta(hours=1):
response_items.append(
InstanceDescription(
instanceId=instance.id,
imageId=instance.imageId,
instanceState=InstanceState(
code=48,
name="terminated",
),
tagSet=[
Tag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
)
continue continue
dom = conn.lookupByName(instance.id) dom = conn.lookupByName(instance.id)
@ -41,18 +73,14 @@ def describe_instances(
return Response( return Response(
DescribeInstancesResponse( DescribeInstancesResponse(
request_id=uuid.uuid4().hex, requestId=uuid.uuid4().hex,
reservation_set=DescribeInstancesResponseReservationSet( reservationSet=[
item=[
ReservationSet( ReservationSet(
reservationId=instance.instance_id, reservationId=instance.instanceId,
instancesSet=ReservationSetInstancesSet( instancesSet=[instance],
item=[instance],
),
) )
for instance in response_items for instance in response_items
], ],
),
).to_xml(), ).to_xml(),
media_type="application/xml", media_type="application/xml",
) )

View File

@ -12,6 +12,7 @@ from openec2.config import OpenEC2Config, ConfigSingleton
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.user import User from openec2.db.user import User
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.utils.array import parse_tag_specification
def import_image( def import_image(
@ -24,6 +25,7 @@ def import_image(
url = urlparse(first_disk_image_url) url = urlparse(first_disk_image_url)
ami_id = uuid.uuid4().hex ami_id = uuid.uuid4().hex
tags: dict[str, str] = parse_tag_specification(cast(dict, params))
imageLocation = cast(Path, config.images) imageLocation = cast(Path, config.images)
imageLocation.mkdir(exist_ok=True) imageLocation.mkdir(exist_ok=True)
dst = imageLocation / ami_id dst = imageLocation / ami_id
@ -54,6 +56,7 @@ def import_image(
description=None, description=None,
originalFilename=filename, originalFilename=filename,
owner_id=user.id, owner_id=user.id,
tags=tags,
), ),
) )
db.commit() db.commit()

View File

@ -155,6 +155,8 @@ def run_instances(
interfaceMac=mac, interfaceMac=mac,
instanceType=instance_type_name, instanceType=instance_type_name,
owner_id=user.id, owner_id=user.id,
terminated=False,
terminationDate=None,
) )
db.add(instance) db.add(instance)
print("Inserted new instance") print("Inserted new instance")

View File

@ -24,7 +24,7 @@ def start_instances(
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = [] instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params): for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None: if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance") raise HTTPException(status_code=404, detail="Unknown instance")

View File

@ -24,7 +24,7 @@ def stop_instances(
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = [] instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params): for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None: if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance") raise HTTPException(status_code=404, detail="Unknown instance")

View File

@ -1,7 +1,9 @@
import logging import logging
from typing import cast
import uuid import uuid
import datetime
from fastapi import HTTPException, Response from fastapi import Response
from fastapi.datastructures import QueryParams from fastapi.datastructures import QueryParams
from sqlmodel import select from sqlmodel import select
@ -31,16 +33,11 @@ def terminate_instances(
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
image_ids: set[str] = set() image_ids: set[str] = set()
for instance_id in parse_array_plain("InstanceId", params): for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first() instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None: if instance is None:
continue continue
# raise HTTPException(status_code=404, detail="Unknown instance") # raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id) dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom) prev_state = describe_instance_state(dom)
if dom.isActive(): if dom.isActive():
@ -64,11 +61,15 @@ def terminate_instances(
for volume in instance.ebs_volumes: for volume in instance.ebs_volumes:
volume.instances.remove(instance) volume.instances.remove(instance)
image_ids.add(instance.imageId)
image_ids.add(cast(str, instance.imageId))
remove_instance_dhcp_mapping( remove_instance_dhcp_mapping(
instance.id, instance.interfaceMac, instance.privateIPv4, db instance.id, instance.interfaceMac, instance.privateIPv4, db
) )
db.delete(instance)
# Mark the instance as terminated
instance.terminated = True
instance.terminationDate = datetime.datetime.now()
db.commit() db.commit()

View File

@ -1,4 +1,6 @@
from pydantic_xml import BaseXmlModel, element from pydantic_xml import BaseXmlModel, wrapped, element
from openec2.api.shared import Tag
class Image(BaseXmlModel): class Image(BaseXmlModel):
@ -8,9 +10,7 @@ class Image(BaseXmlModel):
name: str = element() name: str = element()
tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class ImagesSet(BaseXmlModel, tag="imagesSet"):
items: list[Image] = element(tag="item")
class DescribeImagesResponse( class DescribeImagesResponse(
@ -20,4 +20,4 @@ class DescribeImagesResponse(
): ):
requestId: str = element() requestId: str = element()
imagesSet: ImagesSet = element() imagesSet: list[Image] = wrapped("imagesSet", element(tag="item"))

View File

@ -2,42 +2,21 @@ from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt import libvirt
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.api.shared import InstanceState from openec2.api.shared import InstanceState, Tag
class InstanceTag(BaseXmlModel):
key: str = element()
value: str = element()
class InstanceTagSet(BaseXmlModel):
item: list[InstanceTag] = element()
class InstanceDescription( class InstanceDescription(
BaseXmlModel, BaseXmlModel,
tag="item", tag="item",
): ):
instance_id: str = element(tag="instanceId") instanceId: str = element()
image_id: str = element(tag="imageId") imageId: str = element()
instance_state: InstanceState = element(tag="instanceState") instanceState: InstanceState = element()
tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item")) tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class ReservationSetInstancesSet(BaseXmlModel):
item: list[InstanceDescription] = element()
class ReservationSet(BaseXmlModel): class ReservationSet(BaseXmlModel):
reservationId: str = element() reservationId: str = element()
instancesSet: ReservationSetInstancesSet = element() instancesSet: list[InstanceDescription] = wrapped("instancesSet", element(tag="item"))
class DescribeInstancesResponseReservationSet(
BaseXmlModel,
tag="reservationSet",
):
item: list[ReservationSet] = element("")
class DescribeInstancesResponse( class DescribeInstancesResponse(
@ -45,8 +24,8 @@ class DescribeInstancesResponse(
tag="DescribeInstancesResponse", tag="DescribeInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}, nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
): ):
request_id: str = element(tag="requestId") requestId: str = element()
reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet") reservationSet: list[ReservationSet] = wrapped("reservationSet", element("item"))
def describe_instance_state(domain: libvirt.virDomain) -> InstanceState: def describe_instance_state(domain: libvirt.virDomain) -> InstanceState:
@ -61,11 +40,11 @@ def describe_instance(
instance: Instance, domain: libvirt.virDomain instance: Instance, domain: libvirt.virDomain
) -> InstanceDescription: ) -> InstanceDescription:
return InstanceDescription( return InstanceDescription(
instance_id=instance.id, instanceId=instance.id,
image_id=instance.imageId, imageId=instance.imageId,
instance_state=describe_instance_state(domain), instanceState=describe_instance_state(domain),
tagSet=[ tagSet=[
InstanceTag( Tag(
key=key, key=key,
value=value, value=value,
) )

View File

@ -14,3 +14,8 @@ class InstanceInfo(BaseXmlModel):
class InstancesSet(BaseXmlModel, tag="instancesSet"): class InstancesSet(BaseXmlModel, tag="instancesSet"):
item: list[InstanceInfo] = element() item: list[InstanceInfo] = element()
class Tag(BaseXmlModel):
key: str = element()
value: str = element()

View File

@ -1,4 +1,4 @@
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field, Column, JSON
class AMI(SQLModel, table=True): class AMI(SQLModel, table=True):
@ -16,3 +16,6 @@ class AMI(SQLModel, table=True):
# Owner of the image who created it # Owner of the image who created it
owner_id: int = Field(foreign_key="user.id") owner_id: int = Field(foreign_key="user.id")
# Tags associated with the AMI
tags: dict = Field(sa_column=Column(JSON), default={})

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from datetime import datetime
from sqlmodel import SQLModel, Field, JSON, Column, Relationship from sqlmodel import SQLModel, Field, JSON, Column, Relationship
@ -40,7 +41,7 @@ class Instance(SQLModel, table=True):
instanceType: str instanceType: str
# ImageID of the used AMI # ImageID of the used AMI. None only if terminated == True.
imageId: str imageId: str
# Optional user data associated with the VM # Optional user data associated with the VM
@ -59,3 +60,9 @@ class Instance(SQLModel, table=True):
ebs_volumes: list[EBSVolume] = Relationship( ebs_volumes: list[EBSVolume] = Relationship(
back_populates="instances", link_model=EBSVolumeInstanceLink back_populates="instances", link_model=EBSVolumeInstanceLink
) )
# Is the instance terminated
terminated: bool
# Date at which the instance got terminated
terminationDate: datetime | None

View File

@ -7,7 +7,7 @@ from openec2.db.image import AMI
def garbage_collect_image(image_id: str, db: DatabaseDep): def garbage_collect_image(image_id: str, db: DatabaseDep):
instances = db.exec(select(Instance).where(Instance.imageId == image_id)).all() instances = db.exec(select(Instance).where(Instance.imageId == image_id, Instance.terminated == False)).all()
if instances: if instances:
print("Instances sill using AMI. Not cleaning up") print("Instances sill using AMI. Not cleaning up")
print(instances) print(instances)

View File

@ -35,3 +35,11 @@ def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
if pred(item): if pred(item):
return item return item
return None return None
def parse_tag_specification(params: dict) -> dict[str, str]:
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", params):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
return tags