Handle creating/destroying instances from Terraform/tofu

This commit is contained in:
PapaTutuWawa 2025-04-06 17:20:26 +02:00
parent 38d37a7d5b
commit 6d99b446a0
19 changed files with 195 additions and 88 deletions

View File

@ -11,12 +11,49 @@ provider "aws" {
region = "eu-west-1"
}
resource "aws_instance" "test-instance" {
ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2
# Import using:
# aws ec2 import-image --disk-container "Url=https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2" --tag-specification 'Tags=[{Key="Linux",Value="ArchLinux-nocloud"}]'
data "aws_ami" "archlinux-nocloud" {
filter {
name = "tag:Linux"
values = ["ArchLinux-nocloud"]
}
}
resource "aws_instance" "test-instance-1" {
ami = data.aws_ami.archlinux-nocloud.id
instance_type = "micro"
availability_zone = "az-1"
private_ip = "192.168.122.3"
tags = {
TestTag = "TestValue"
UseCase = "k8s-control-plane"
}
}
# resource "aws_instance" "test-instance-2" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.4"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }
# resource "aws_instance" "test-instance-3" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.5"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }

View File

@ -26,10 +26,10 @@ def attach_volume(
instance_id = params["InstanceId"]
volume_id = params["VolumeId"]
volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id)).first()
volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id, Instance.owner_id == user.id)).first()
if volume is None:
return
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
return

View File

@ -31,7 +31,7 @@ def create_tags(
tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
print(f"Unknown instance {instance_id}")
continue

View File

@ -1,4 +1,6 @@
import uuid
from typing import cast
from dataclasses import dataclass
from fastapi import Response
from fastapi.datastructures import QueryParams
@ -8,31 +10,64 @@ from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image
from openec2.api.describe_images import DescribeImagesResponse, Image
from openec2.api.shared import Tag
from openec2.utils.array import parse_array_objects, parse_array_plain
@dataclass
class Filter:
name: str
values: list[str]
def match(self, image: AMI) -> bool:
value: str | None
if self.name.startswith("tag:"):
value = image.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def describe_images(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
_: User,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Value", filter),
)
)
images: list[Image] = []
for ami in db.exec(select(AMI)).all():
for ami in db.exec(select(AMI).where(AMI.owner_id == user.id)).all():
if not all(f.match(ami) for f in filters):
continue
images.append(
Image(
imageId=ami.id,
imageState="available",
name=ami.originalFilename,
tagSet=[
Tag(
key=key,
value=value,
) for key, value in ami.tags.items()
],
),
)
return Response(
DescribeImagesResponse(
requestId=uuid.uuid4().hex,
imagesSet=ImagesSet(
items=images,
),
imagesSet=images,
).to_xml(),
media_type="application/xml",
)

View File

@ -19,7 +19,7 @@ def describe_instance_attribute(
):
instance_id = params["InstanceId"]
attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(
status_code=404,

View File

@ -18,11 +18,10 @@ def describe_instance_types(
user: User,
):
response: list[InstanceTypeInfo] = []
instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all()
for instance in instances:
for name, instanceConfig in config.instances.types.items():
response.append(
InstanceTypeInfo(
instanceType=instance.instanceType,
instanceType=name,
),
)

View File

@ -1,23 +1,25 @@
import uuid
from typing import cast
import datetime
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from sqlmodel import select, or_
from openec2.libvirt import LibvirtSingleton
from openec2.api.describe_instances import (
InstanceDescription,
DescribeInstancesResponse,
DescribeInstancesResponseReservationSet,
ReservationSet,
ReservationSetInstancesSet,
InstanceState,
describe_instance,
)
from openec2.api.shared import Tag
from openec2.db.user import User
from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_plain
def describe_instances(
@ -26,12 +28,42 @@ def describe_instances(
db: DatabaseDep,
user: User,
):
if "InstanceId.1" in params:
instance_ids = parse_array_plain("InstanceId", cast(dict, params))
instance_expr = [Instance.id == instance_id for instance_id in instance_ids]
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id, or_(*instance_expr)),
).all()
else:
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id),
).all()
response_items: list[InstanceDescription] = []
conn = LibvirtSingleton.of().connection
for instance in db.exec(select(Instance)).all():
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
now = datetime.datetime.now()
for instance in instances:
# Include terminated instances for an hour
if instance.terminated:
assert instance.terminationDate is not None
if now <= instance.terminationDate + datetime.timedelta(hours=1):
response_items.append(
InstanceDescription(
instanceId=instance.id,
imageId=instance.imageId,
instanceState=InstanceState(
code=48,
name="terminated",
),
tagSet=[
Tag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
)
continue
dom = conn.lookupByName(instance.id)
@ -41,18 +73,14 @@ def describe_instances(
return Response(
DescribeInstancesResponse(
request_id=uuid.uuid4().hex,
reservation_set=DescribeInstancesResponseReservationSet(
item=[
ReservationSet(
reservationId=instance.instance_id,
instancesSet=ReservationSetInstancesSet(
item=[instance],
),
)
for instance in response_items
],
),
requestId=uuid.uuid4().hex,
reservationSet=[
ReservationSet(
reservationId=instance.instanceId,
instancesSet=[instance],
)
for instance in response_items
],
).to_xml(),
media_type="application/xml",
)

View File

@ -12,6 +12,7 @@ from openec2.config import OpenEC2Config, ConfigSingleton
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.utils.array import parse_tag_specification
def import_image(
@ -24,6 +25,7 @@ def import_image(
url = urlparse(first_disk_image_url)
ami_id = uuid.uuid4().hex
tags: dict[str, str] = parse_tag_specification(cast(dict, params))
imageLocation = cast(Path, config.images)
imageLocation.mkdir(exist_ok=True)
dst = imageLocation / ami_id
@ -54,6 +56,7 @@ def import_image(
description=None,
originalFilename=filename,
owner_id=user.id,
tags=tags,
),
)
db.commit()

View File

@ -155,6 +155,8 @@ def run_instances(
interfaceMac=mac,
instanceType=instance_type_name,
owner_id=user.id,
terminated=False,
terminationDate=None,
)
db.add(instance)
print("Inserted new instance")

View File

@ -24,7 +24,7 @@ def start_instances(
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")

View File

@ -24,7 +24,7 @@ def stop_instances(
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")

View File

@ -1,7 +1,9 @@
import logging
from typing import cast
import uuid
import datetime
from fastapi import HTTPException, Response
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
@ -31,16 +33,11 @@ def terminate_instances(
conn = LibvirtSingleton.of().connection
image_ids: set[str] = set()
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
continue
# raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom)
if dom.isActive():
@ -64,11 +61,15 @@ def terminate_instances(
for volume in instance.ebs_volumes:
volume.instances.remove(instance)
image_ids.add(instance.imageId)
image_ids.add(cast(str, instance.imageId))
remove_instance_dhcp_mapping(
instance.id, instance.interfaceMac, instance.privateIPv4, db
)
db.delete(instance)
# Mark the instance as terminated
instance.terminated = True
instance.terminationDate = datetime.datetime.now()
db.commit()

View File

@ -1,4 +1,6 @@
from pydantic_xml import BaseXmlModel, element
from pydantic_xml import BaseXmlModel, wrapped, element
from openec2.api.shared import Tag
class Image(BaseXmlModel):
@ -8,9 +10,7 @@ class Image(BaseXmlModel):
name: str = element()
class ImagesSet(BaseXmlModel, tag="imagesSet"):
items: list[Image] = element(tag="item")
tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class DescribeImagesResponse(
@ -20,4 +20,4 @@ class DescribeImagesResponse(
):
requestId: str = element()
imagesSet: ImagesSet = element()
imagesSet: list[Image] = wrapped("imagesSet", element(tag="item"))

View File

@ -2,42 +2,21 @@ from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt
from openec2.db.instance import Instance
from openec2.api.shared import InstanceState
class InstanceTag(BaseXmlModel):
key: str = element()
value: str = element()
class InstanceTagSet(BaseXmlModel):
item: list[InstanceTag] = element()
from openec2.api.shared import InstanceState, Tag
class InstanceDescription(
BaseXmlModel,
tag="item",
):
instance_id: str = element(tag="instanceId")
image_id: str = element(tag="imageId")
instance_state: InstanceState = element(tag="instanceState")
tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item"))
class ReservationSetInstancesSet(BaseXmlModel):
item: list[InstanceDescription] = element()
instanceId: str = element()
imageId: str = element()
instanceState: InstanceState = element()
tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class ReservationSet(BaseXmlModel):
reservationId: str = element()
instancesSet: ReservationSetInstancesSet = element()
class DescribeInstancesResponseReservationSet(
BaseXmlModel,
tag="reservationSet",
):
item: list[ReservationSet] = element("")
instancesSet: list[InstanceDescription] = wrapped("instancesSet", element(tag="item"))
class DescribeInstancesResponse(
@ -45,8 +24,8 @@ class DescribeInstancesResponse(
tag="DescribeInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
request_id: str = element(tag="requestId")
reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet")
requestId: str = element()
reservationSet: list[ReservationSet] = wrapped("reservationSet", element("item"))
def describe_instance_state(domain: libvirt.virDomain) -> InstanceState:
@ -61,11 +40,11 @@ def describe_instance(
instance: Instance, domain: libvirt.virDomain
) -> InstanceDescription:
return InstanceDescription(
instance_id=instance.id,
image_id=instance.imageId,
instance_state=describe_instance_state(domain),
instanceId=instance.id,
imageId=instance.imageId,
instanceState=describe_instance_state(domain),
tagSet=[
InstanceTag(
Tag(
key=key,
value=value,
)

View File

@ -14,3 +14,8 @@ class InstanceInfo(BaseXmlModel):
class InstancesSet(BaseXmlModel, tag="instancesSet"):
item: list[InstanceInfo] = element()
class Tag(BaseXmlModel):
key: str = element()
value: str = element()

View File

@ -1,4 +1,4 @@
from sqlmodel import SQLModel, Field
from sqlmodel import SQLModel, Field, Column, JSON
class AMI(SQLModel, table=True):
@ -16,3 +16,6 @@ class AMI(SQLModel, table=True):
# Owner of the image who created it
owner_id: int = Field(foreign_key="user.id")
# Tags associated with the AMI
tags: dict = Field(sa_column=Column(JSON), default={})

View File

@ -1,4 +1,5 @@
from pathlib import Path
from datetime import datetime
from sqlmodel import SQLModel, Field, JSON, Column, Relationship
@ -40,7 +41,7 @@ class Instance(SQLModel, table=True):
instanceType: str
# ImageID of the used AMI
# ImageID of the used AMI. None only if terminated == True.
imageId: str
# Optional user data associated with the VM
@ -59,3 +60,9 @@ class Instance(SQLModel, table=True):
ebs_volumes: list[EBSVolume] = Relationship(
back_populates="instances", link_model=EBSVolumeInstanceLink
)
# Is the instance terminated
terminated: bool
# Date at which the instance got terminated
terminationDate: datetime | None

View File

@ -7,7 +7,7 @@ from openec2.db.image import AMI
def garbage_collect_image(image_id: str, db: DatabaseDep):
instances = db.exec(select(Instance).where(Instance.imageId == image_id)).all()
instances = db.exec(select(Instance).where(Instance.imageId == image_id, Instance.terminated == False)).all()
if instances:
print("Instances sill using AMI. Not cleaning up")
print(instances)

View File

@ -35,3 +35,11 @@ def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
if pred(item):
return item
return None
def parse_tag_specification(params: dict) -> dict[str, str]:
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", params):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
return tags