Compare commits

..

3 Commits

44 changed files with 1109 additions and 119 deletions

6
.gitignore vendored
View File

@@ -8,3 +8,9 @@ wheels/
# Virtual environments
.venv
# Terraform/Tofu
examples/tofu/**/.terraform
examples/tofu/**/.terraform.lock.hcl
examples/tofu/**/.terraform.tfstate
examples/tofu/**/.terraform.tfstate.backup

View File

@@ -1 +1,7 @@
Test
# Private Compute Stack (Pieces)
Pieces is a mostly API-compatible implementation of AWS services.
## EC2
A very small subset of EC2 functionality is implemented.

59
examples/tofu/main.tf Normal file
View File

@@ -0,0 +1,59 @@
terraform {
required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.0"
}
}
}
provider "aws" {
region = "eu-west-1"
}
# https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2
# Import using:
# aws ec2 import-image --disk-container "Url=https://geo.mirror.pkgbuild.com/images/v20250315.322357/Arch-Linux-x86_64-cloudimg.qcow2" --tag-specification 'Tags=[{Key="Linux",Value="ArchLinux-nocloud"}]'
data "aws_ami" "archlinux-nocloud" {
filter {
name = "tag:Linux"
values = ["ArchLinux-nocloud"]
}
}
resource "aws_instance" "test-instance-1" {
ami = data.aws_ami.archlinux-nocloud.id
instance_type = "micro"
availability_zone = "az-1"
private_ip = "192.168.122.3"
tags = {
UseCase = "k8s-control-plane"
}
}
# resource "aws_instance" "test-instance-2" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.4"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }
# resource "aws_instance" "test-instance-3" {
# ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
# instance_type = "micro"
# availability_zone = "az-1"
# private_ip = "192.168.122.5"
# tags = {
# UseCase = "k8s-control-plane"
# }
# }

View File

@@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":8,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[{"mode":"managed","type":"aws_instance","name":"test-instance","provider":"provider[\"registry.opentofu.org/hashicorp/aws\"]","instances":[{"schema_version":1,"attributes":{"ami":"0c4dcaafb6a14dbb93b402f1fd6a9dfb","arn":"arn:aws:ec2:eu-west-1:1:instance/19b606d3c0b543c991f2b0cd632013a2","associate_public_ip_address":false,"availability_zone":"az-1","capacity_reservation_specification":[],"cpu_core_count":null,"cpu_options":[],"cpu_threads_per_core":null,"credit_specification":[],"disable_api_stop":false,"disable_api_termination":false,"ebs_block_device":[],"ebs_optimized":false,"enable_primary_ipv6":null,"enclave_options":[],"ephemeral_block_device":[],"get_password_data":false,"hibernation":null,"host_id":null,"host_resource_group_arn":null,"iam_instance_profile":"","id":"19b606d3c0b543c991f2b0cd632013a2","instance_initiated_shutdown_behavior":null,"instance_lifecycle":"","instance_market_options":[],"instance_state":"running","instance_type":"","ipv6_address_count":0,"ipv6_addresses":[],"key_name":"","launch_template":[],"maintenance_options":[],"metadata_options":[],"monitoring":null,"network_interface":[],"outpost_arn":"","password_data":"","placement_group":null,"placement_partition_number":null,"primary_network_interface_id":"","private_dns":"","private_dns_name_options":[],"private_ip":"","public_dns":"","public_ip":"","root_block_device":[],"secondary_private_ips":[],"security_groups":[],"source_dest_check":true,"spot_instance_request_id":"","subnet_id":"","tags":{"TestTag":"TestValue"},"tags_all":{"TestTag":"TestValue"},"tenancy":null,"timeouts":null,"user_data":null,"user_data_base64":null,"user_data_replace_on_change":false,"volume_tags":null,"vpc_security_group_ids":[]},"sensitive_attributes":[],"private":"eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjo2MDAwMDAwMDAwMDAsImRlbGV0ZSI6MTIwMDAwMDAwMDAwMCwicmVhZCI6OTAwMDAwMDAwMDAwLCJ1cGRhdGUiOjYwMDAwMDAwMDAwMH0sInNjaGVtYV92ZXJzaW9uIjoiMSJ9"}]}],"check_results":null}

View File

@@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":7,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[],"check_results":null}

View File

@@ -0,0 +1,90 @@
import uuid
import xml.etree.ElementTree as ET
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
import libvirt
from openec2.api.attach_volume import AttachVolumeResponse
from openec2.libvirt import LibvirtSingleton
from openec2.utils.libvirt import instance_to_libvirt_xml
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance, EBSVolume
from openec2.utils.libvirt import ebs_volume_to_libvirt_xml
def attach_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
device = params["Device"]
instance_id = params["InstanceId"]
volume_id = params["VolumeId"]
volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id, Instance.owner_id == user.id)).first()
if volume is None:
return
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
return
attached_volume_ids = [i.id for i in instance.ebs_volumes]
if volume_id in attached_volume_ids:
print("CANNOT ATTACH THE SAME VOLUME TO THE EC2 TWICE")
return
if not volume.multi_attach_enabled and volume.instances:
print("CANNOT ATTACH NON-MULTIATTACH again")
return
# Add the required data to libvirt
conn = LibvirtSingleton.of().connection
domain = conn.lookupByName(instance_id)
# Add the memory backing if required
running = domain.isActive()
if not instance.ebs_volumes:
if running:
raise HTTPException(
status_code=500,
detail="Instance is running",
)
# Update the instance
volume.instances.append(instance)
domain_xml = domain.XMLDesc()
domain_uuid = ET.fromstring(domain_xml).find("uuid").text
print(f"Updating XML for {instance.id} with {domain_uuid}")
instance_xml = instance_to_libvirt_xml(instance, config, domain_uuid)
print(instance_xml)
conn.defineXML(instance_xml)
else:
# Attach the device
volume.instances.append(instance)
domain.attachDeviceFlags(
ebs_volume_to_libvirt_xml(volume, config),
libvirt.VIR_DOMAIN_DEVICE_MODIFY_LIVE
if running
else libvirt.VIR_DOMAIN_DEVICE_MODIFY_CONFIG,
)
db.add(volume)
db.commit()
return Response(
AttachVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
instanceId=instance_id,
device=device,
status="attached",
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import cast
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_objects, parse_array_plain
from openec2.api.create_tags import CreateTagsResponse
@dataclass
class Tag:
key: str
value: str
def create_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
tags: dict[str, str] = {}
for tag in parse_array_objects("Tag", cast(dict, params)):
tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
print(f"Unknown instance {instance_id}")
continue
instance.tags = {
**instance.tags,
**tags,
}
print(instance)
db.commit()
return Response(
CreateTagsResponse(
requestId=uuid.uuid4().hex,
return_=True,
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,41 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import EBSVolume
from openec2.api.create_volume import CreateVolumeResponse
def create_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
availabilityZone = params["AvailabilityZone"]
volume_id = f"vol-{uuid.uuid4().hex}"
volume = EBSVolume(
id=volume_id,
availability_zone=availabilityZone,
multi_attach_enabled=params.get("MultiAttachEnabled", "false") == "true",
owner_id=user.id,
)
volume.path(config).mkdir()
db.add(volume)
db.commit()
return Response(
CreateVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
availabilityZone=volume.availability_zone,
multiAttachEnabled=volume.multi_attach_enabled,
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,17 @@
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.vpc import VPC
def create_vpc(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
cidr_block = params["CidrBlock"]

View File

@@ -1,4 +1,6 @@
import uuid
from typing import cast
from dataclasses import dataclass
from fastapi import Response
from fastapi.datastructures import QueryParams
@@ -8,31 +10,64 @@ from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image
from openec2.api.describe_images import DescribeImagesResponse, Image
from openec2.api.shared import Tag
from openec2.utils.array import parse_array_objects, parse_array_plain
@dataclass
class Filter:
name: str
values: list[str]
def match(self, image: AMI) -> bool:
value: str | None
if self.name.startswith("tag:"):
value = image.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def describe_images(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
_: User,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Value", filter),
)
)
images: list[Image] = []
for ami in db.exec(select(AMI)).all():
for ami in db.exec(select(AMI).where(AMI.owner_id == user.id)).all():
if not all(f.match(ami) for f in filters):
continue
images.append(
Image(
imageId=ami.id,
imageState="available",
name=ami.originalFilename,
tagSet=[
Tag(
key=key,
value=value,
) for key, value in ami.tags.items()
],
),
)
return Response(
DescribeImagesResponse(
requestId=uuid.uuid4().hex,
imagesSet=ImagesSet(
items=images,
),
imagesSet=images,
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,43 @@
import uuid
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_attribute import DescribeInstanceAttributeResponse, InstanceInitiatedShutdownBehaviour, DisableApiTermination, DisableApiStop
def describe_instance_attribute(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(
status_code=404,
)
return Response(
DescribeInstanceAttributeResponse(
requestId=uuid.uuid4().hex,
instanceId=instance.id,
instanceInitiatedShutdownBehaviour=InstanceInitiatedShutdownBehaviour(
value="stop",
) if attribute == "instanceInitiatedShutdownBehaviour" else None,
disableApiTermination=DisableApiTermination(
value="false",
) if attribute == "disableApiTermination" else None,
disableApiStop=DisableApiStop(
value="false",
) if attribute == "disableApiStop" else None,
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,35 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_types import DescribeInstanceTypesResponse, InstanceTypeInfo
def describe_instance_types(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
response: list[InstanceTypeInfo] = []
for name, instanceConfig in config.instances.types.items():
response.append(
InstanceTypeInfo(
instanceType=name,
),
)
return Response(
DescribeInstanceTypesResponse(
requestId=uuid.uuid4().hex,
instanceTypeSet=response,
nextToken=None,
).to_xml(),
media_type="application/xml",
)

View File

@@ -1,23 +1,25 @@
import uuid
from typing import cast
import datetime
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from sqlmodel import select, or_
from openec2.libvirt import LibvirtSingleton
from openec2.api.describe_instances import (
InstanceDescription,
DescribeInstancesResponse,
DescribeInstancesResponseReservationSet,
ReservationSet,
ReservationSetInstancesSet,
InstanceState,
describe_instance,
)
from openec2.api.shared import Tag
from openec2.db.user import User
from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_plain
def describe_instances(
@@ -26,12 +28,42 @@ def describe_instances(
db: DatabaseDep,
user: User,
):
if "InstanceId.1" in params:
instance_ids = parse_array_plain("InstanceId", cast(dict, params))
instance_expr = [Instance.id == instance_id for instance_id in instance_ids]
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id, or_(*instance_expr)),
).all()
else:
instances = db.exec(
select(Instance).where(Instance.owner_id == user.id),
).all()
response_items: list[InstanceDescription] = []
conn = LibvirtSingleton.of().connection
for instance in db.exec(select(Instance)).all():
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
now = datetime.datetime.now()
for instance in instances:
# Include terminated instances for an hour
if instance.terminated:
assert instance.terminationDate is not None
if now <= instance.terminationDate + datetime.timedelta(hours=1):
response_items.append(
InstanceDescription(
instanceId=instance.id,
imageId=instance.imageId,
instanceState=InstanceState(
code=48,
name="terminated",
),
tagSet=[
Tag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
)
continue
dom = conn.lookupByName(instance.id)
@@ -41,18 +73,14 @@ def describe_instances(
return Response(
DescribeInstancesResponse(
request_id=uuid.uuid4().hex,
reservation_set=DescribeInstancesResponseReservationSet(
item=[
requestId=uuid.uuid4().hex,
reservationSet=[
ReservationSet(
reservationId=instance.instance_id,
instancesSet=ReservationSetInstancesSet(
item=[instance],
),
reservationId=instance.instanceId,
instancesSet=[instance],
)
for instance in response_items
],
),
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,99 @@
from typing import Literal, cast
from dataclasses import dataclass
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_tags import DescribeTagsResponse, TagDescription
from openec2.utils.array import parse_array_objects, parse_array_plain, find
@dataclass
class Filter:
name: Literal["resource-id"] | Literal["resource-type"] | Literal["key"] | str
values: list[str]
def match(self, instance: Instance) -> bool:
if self.name not in instance.tags:
return False
value: str | None
if self.name == "resource-type":
value = "instance"
elif self.name == "resource-id":
value = instance.id
elif self.name == "key":
return any(key in instance.tags for key in self.values)
elif self.name.startswith("tag:"):
value = instance.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def value(self, instance: Instance) -> TagDescription:
if self.name == "resource-type":
return TagDescription(resourceType="instance")
elif self.name == "resource-id":
return TagDescription(resourceId=instance.id)
elif self.name == "key":
key_name = find(
self.values,
lambda x: x in instance.tags,
)
assert key_name is not None
key_value = instance.tags[key_name]
return TagDescription(
key=key_name,
value=key_value,
)
elif self.name.startswith("tag:"):
tag = self.name.replace("tag:", "")
return TagDescription(
key=tag,
value=instance.tags[tag],
)
else:
raise Exception(f"Unknown filter name {self.name}")
def describe_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Values", filter),
)
)
response: list[TagDescription] = []
for instance in db.exec(select(Instance).where(Instance.owner_id == user.id)).all():
filter = find(
filters,
lambda f: f.match(instance),
)
if filter is None:
continue
response.append(filter.value(instance))
return Response(
DescribeTagsResponse(
requestId=uuid.uuid4().hex,
nextToken=None,
tagSet=response,
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,35 @@
import uuid
from sqlmodel import select
from fastapi import Response
from openec2.db import DatabaseDep
from openec2.config import OpenEC2Config
from fastapi.datastructures import QueryParams
from openec2.db.user import User
from openec2.db.instance import EBSVolume
from openec2.api.describe_volumes import DescribeVolumesResponse, VolumeSet, Volume
def describe_volumes(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
volumes = db.exec(select(EBSVolume).where(EBSVolume.owner_id == user.id)).all()
return Response(
DescribeVolumesResponse(
requestId=uuid.uuid4().hex,
volumeSet=VolumeSet(
item=[
Volume(
volumeId=volume.id,
multiAttachEnabled=volume.multi_attach_enabled,
)
for volume in volumes
],
),
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,61 @@
import uuid
import libvirt
from sqlmodel import select
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance, EBSVolume
from openec2.api.detach_volume import DetachVolumeResponse
from openec2.utils.libvirt import ebs_volume_to_libvirt_xml
def detach_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
volume_id = params["VolumeId"]
# Find the instance
instance = db.exec(
select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)
).first()
if instance is None:
return
# Find the volume
volume = db.exec(
select(EBSVolume).where(
EBSVolume.id == volume_id, EBSVolume.owner_id == user.id
)
).first()
if volume is None:
return
if instance_id not in [i.id for i in volume.instances]:
return
# Remove the volume from the instance
domain = LibvirtSingleton.of().connection.lookupByName(instance_id)
domain.detachDeviceFlags(
ebs_volume_to_libvirt_xml(volume, config),
libvirt.VIR_DOMAIN_DEVICE_MODIFY_LIVE
if domain.isActive()
else libvirt.VIR_DOMAIN_DEVICE_MODIFY_CONFIG,
)
return Response(
DetachVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
instanceId=instance_id,
status="detached",
).to_xml(),
media_type="application/xml",
)

View File

@@ -0,0 +1,30 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.api.get_caller_identity import GetCallerIdentityResponse, GetCallerIdentityResult, ResponseMetadata
def get_caller_identity(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
return Response(
GetCallerIdentityResponse(
result=GetCallerIdentityResult(
arn=f"arn:aws:iam::{user.id}:user/{user.name}",
user_id=str(user.id),
account=str(user.id),
),
metadata=ResponseMetadata(
request_id=uuid.uuid4().hex,
),
).to_xml(),
media_type="application/xml",
)

View File

@@ -12,6 +12,7 @@ from openec2.config import OpenEC2Config, ConfigSingleton
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.utils.array import parse_tag_specification
def import_image(
@@ -24,6 +25,7 @@ def import_image(
url = urlparse(first_disk_image_url)
ami_id = uuid.uuid4().hex
tags: dict[str, str] = parse_tag_specification(cast(dict, params))
imageLocation = cast(Path, config.images)
imageLocation.mkdir(exist_ok=True)
dst = imageLocation / ami_id
@@ -54,6 +56,7 @@ def import_image(
description=None,
originalFilename=filename,
owner_id=user.id,
tags=tags,
),
)
db.commit()

View File

@@ -98,9 +98,9 @@ def run_instances(
user: User,
):
image_id = params["ImageId"]
instance_type = params["InstanceType"]
instance_type_name = params["InstanceType"]
instance_type = config.instances.types.get(params["InstanceType"])
instance_type = config.instances.types.get(instance_type_name)
if instance_type is None:
raise Exception(f"Unknown instance type {params['InstanceType']}")
@@ -115,11 +115,11 @@ def run_instances(
)
# Parse tags
# TODO: broken
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", cast(dict, params)):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
print(f"Creating with tags {tags}")
# Get a private IPv4
instance_id = uuid.uuid4().hex
@@ -153,7 +153,10 @@ def run_instances(
else None,
privateIPv4=private_ipv4,
interfaceMac=mac,
instanceType=instance_type_name,
owner_id=user.id,
terminated=False,
terminationDate=None,
)
db.add(instance)
print("Inserted new instance")
@@ -175,7 +178,7 @@ def run_instances(
return Response(
RunInstanceResponse(
request_id=uuid.uuid4().hex,
instance_set=RunInstanceInstanceSet(
instancesSet=RunInstanceInstanceSet(
item=[description],
),
).to_xml(),

View File

@@ -24,7 +24,7 @@ def start_instances(
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")

View File

@@ -24,7 +24,7 @@ def stop_instances(
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.terminated == False)).first()
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")

View File

@@ -1,7 +1,9 @@
import logging
from typing import cast
import uuid
import datetime
from fastapi import HTTPException, Response
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
@@ -31,16 +33,11 @@ def terminate_instances(
conn = LibvirtSingleton.of().connection
image_ids: set[str] = set()
for instance_id in parse_array_plain("InstanceId", params):
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
continue
# raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom)
if dom.isActive():
@@ -61,11 +58,18 @@ def terminate_instances(
instance_disk = config.instances.location / instance_id
instance_disk.unlink()
image_ids.add(instance.imageId)
for volume in instance.ebs_volumes:
volume.instances.remove(instance)
image_ids.add(cast(str, instance.imageId))
remove_instance_dhcp_mapping(
instance.id, instance.interfaceMac, instance.privateIPv4, db
)
db.delete(instance)
# Mark the instance as terminated
instance.terminated = True
instance.terminationDate = datetime.datetime.now()
db.commit()

View File

@@ -0,0 +1,17 @@
from pydantic_xml import BaseXmlModel, element
class AttachVolumeResponse(
BaseXmlModel,
tag="AttachVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
instanceId: str = element()
device: str = element()
status: str = element()

View File

@@ -0,0 +1,10 @@
from pydantic_xml import BaseXmlModel, element
class CreateTagsResponse(
BaseXmlModel,
nsmap={"": ""},
):
requestId: str = element()
return_: bool = element(tag="return")

View File

@@ -0,0 +1,15 @@
from pydantic_xml import BaseXmlModel, element
class CreateVolumeResponse(
BaseXmlModel,
tag="CreateVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
availabilityZone: str = element()
multiAttachEnabled: bool = element()

View File

@@ -1,4 +1,6 @@
from pydantic_xml import BaseXmlModel, element
from pydantic_xml import BaseXmlModel, wrapped, element
from openec2.api.shared import Tag
class Image(BaseXmlModel):
@@ -8,9 +10,7 @@ class Image(BaseXmlModel):
name: str = element()
class ImagesSet(BaseXmlModel, tag="imagesSet"):
items: list[Image] = element(tag="item")
tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class DescribeImagesResponse(
@@ -20,4 +20,4 @@ class DescribeImagesResponse(
):
requestId: str = element()
imagesSet: ImagesSet = element()
imagesSet: list[Image] = wrapped("imagesSet", element(tag="item"))

View File

@@ -0,0 +1,26 @@
from pydantic_xml import BaseXmlModel, element
class InstanceInitiatedShutdownBehaviour(BaseXmlModel):
value: str = element()
class DisableApiTermination(BaseXmlModel):
value: str = element()
class DisableApiStop(BaseXmlModel):
value: str = element()
class DescribeInstanceAttributeResponse(
BaseXmlModel,
tag="DescribeInstanceAttributeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceId: str = element()
instanceInitiatedShutdownBehaviour: InstanceInitiatedShutdownBehaviour | None = element(defaut=None)
disableApiTermination: DisableApiTermination | None = element(default=None)
disableApiStop: DisableApiStop | None = element(default=None)

View File

@@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element, wrapped
class InstanceTypeInfo(BaseXmlModel, tag="item"):
instanceType: str = element()
class DescribeInstanceTypesResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceTypeSet: list[InstanceTypeInfo] = wrapped(
"instanceTypeSet",
element(tag="item"),
)
nextToken: str | None = element()

View File

@@ -1,43 +1,22 @@
from pydantic_xml import BaseXmlModel, element
from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt
from openec2.db.instance import Instance
from openec2.api.shared import InstanceState
class InstanceTag(BaseXmlModel):
key: str = element()
value: str = element()
class InstanceTagSet(BaseXmlModel):
item: list[InstanceTag] = element()
from openec2.api.shared import InstanceState, Tag
class InstanceDescription(
BaseXmlModel,
tag="item",
):
instance_id: str = element(tag="instanceId")
image_id: str = element(tag="imageId")
instance_state: InstanceState = element(tag="instanceState")
tag_set: InstanceTagSet = element(tag="tagSet")
class ReservationSetInstancesSet(BaseXmlModel):
item: list[InstanceDescription] = element()
instanceId: str = element()
imageId: str = element()
instanceState: InstanceState = element()
tagSet: list[Tag] = wrapped("tagSet", element(tag="item"))
class ReservationSet(BaseXmlModel):
reservationId: str = element()
instancesSet: ReservationSetInstancesSet = element()
class DescribeInstancesResponseReservationSet(
BaseXmlModel,
tag="reservationSet",
):
item: list[ReservationSet] = element("")
instancesSet: list[InstanceDescription] = wrapped("instancesSet", element(tag="item"))
class DescribeInstancesResponse(
@@ -45,8 +24,8 @@ class DescribeInstancesResponse(
tag="DescribeInstancesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
request_id: str = element(tag="requestId")
reservation_set: DescribeInstancesResponseReservationSet = element("reservationSet")
requestId: str = element()
reservationSet: list[ReservationSet] = wrapped("reservationSet", element("item"))
def describe_instance_state(domain: libvirt.virDomain) -> InstanceState:
@@ -61,16 +40,14 @@ def describe_instance(
instance: Instance, domain: libvirt.virDomain
) -> InstanceDescription:
return InstanceDescription(
instance_id=instance.id,
image_id=instance.imageId,
instance_state=describe_instance_state(domain),
tag_set=InstanceTagSet(
item=[
InstanceTag(
instanceId=instance.id,
imageId=instance.imageId,
instanceState=describe_instance_state(domain),
tagSet=[
Tag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
)

View File

@@ -0,0 +1,22 @@
from pydantic_xml import BaseXmlModel, wrapped, element
class TagDescription(BaseXmlModel):
key: str | None = element(default=None)
resourceId: str | None = element(default=None)
resourceType: str | None = element(default=None)
value: str | None = element(default=None)
class DescribeTagsResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}
):
requestId: str = element()
nextToken: str | None = element()
tagSet: list[TagDescription] = wrapped("tagSet", element(tag="item"))

View File

@@ -0,0 +1,20 @@
from pydantic_xml import BaseXmlModel, element
class Volume(BaseXmlModel):
volumeId: str = element()
multiAttachEnabled: bool = element()
class VolumeSet(BaseXmlModel):
item: list[Volume] = element(tag="item")
class DescribeVolumesResponse(
BaseXmlModel,
tag="DescribeVolumesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeSet: VolumeSet = element()

View File

@@ -0,0 +1,15 @@
from pydantic_xml import BaseXmlModel, element
class DetachVolumeResponse(
BaseXmlModel,
tag="DetachVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
instanceId: str = element()
status: str = element()

View File

@@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element
class GetCallerIdentityResult(BaseXmlModel):
arn: str = element(tag="Arn")
user_id: str = element(tag="UserId")
account: str = element(tag="Account")
class ResponseMetadata(BaseXmlModel):
request_id: str = element(tag="RequestId")
class GetCallerIdentityResponse(
BaseXmlModel,
nsmap={"": "https://sts.amazonaws.com/doc/2011-06-15/"}
):
result: GetCallerIdentityResult = element(tag="GetCallerIdentityResult")
metadata: ResponseMetadata = element()

View File

@@ -10,4 +10,4 @@ class RunInstanceInstanceSet(BaseXmlModel):
class RunInstanceResponse(BaseXmlModel):
request_id: str = element(tag="requestId")
instance_set: RunInstanceInstanceSet = element(tag="instanceSet")
instancesSet: RunInstanceInstanceSet = element(tag="instancesSet")

View File

@@ -14,3 +14,8 @@ class InstanceInfo(BaseXmlModel):
class InstancesSet(BaseXmlModel, tag="instancesSet"):
item: list[InstanceInfo] = element()
class Tag(BaseXmlModel):
key: str = element()
value: str = element()

View File

@@ -12,7 +12,7 @@ class _OpenEC2InstanceType(BaseModel):
class _OpenEC2InstanceConfig(BaseModel):
location: Path
volumes: Path
types: dict[str, _OpenEC2InstanceType]
@@ -45,6 +45,7 @@ def _get_config() -> _OpenEC2Config:
seed=Path("/home/alexander/openec2/seed"),
instances=_OpenEC2InstanceConfig(
location=Path("/home/alexander/openec2/instances"),
volumes=Path("/home/alexander/openec2/volumes"),
types={
"micro": _OpenEC2InstanceType(
memory=1024,
@@ -58,7 +59,7 @@ def _get_config() -> _OpenEC2Config:
insecure=False,
database=_OpenEC2DatabaseConfig(
url="sqlite:////home/alexander/openec2/db2.sqlite3",
debug=True,
debug=False,
),
)

View File

@@ -1,4 +1,4 @@
from sqlmodel import SQLModel, Field
from sqlmodel import SQLModel, Field, Column, JSON
class AMI(SQLModel, table=True):
@@ -16,3 +16,6 @@ class AMI(SQLModel, table=True):
# Owner of the image who created it
owner_id: int = Field(foreign_key="user.id")
# Tags associated with the AMI
tags: dict = Field(sa_column=Column(JSON), default={})

View File

@@ -1,4 +1,36 @@
from sqlmodel import SQLModel, Field, JSON, Column
from pathlib import Path
from datetime import datetime
from sqlmodel import SQLModel, Field, JSON, Column, Relationship
from openec2.config import OpenEC2Config
class EBSVolumeInstanceLink(SQLModel, table=True):
instance_id: str | None = Field(
default=None, foreign_key="instance.id", primary_key=True
)
ebs_volume_id: str | None = Field(
default=None, foreign_key="ebsvolume.id", primary_key=True
)
class EBSVolume(SQLModel, table=True):
id: str = Field(primary_key=True)
availability_zone: str
multi_attach_enabled: bool
instances: list["Instance"] = Relationship(
back_populates="ebs_volumes", link_model=EBSVolumeInstanceLink
)
owner_id: int = Field(foreign_key="user.id")
def path(self, config: OpenEC2Config) -> Path:
"""Compute the path of the volume on disk."""
return config.instances.volumes / self.id
class Instance(SQLModel, table=True):
@@ -7,7 +39,9 @@ class Instance(SQLModel, table=True):
# Tags associated with the VM
tags: dict = Field(sa_column=Column(JSON), default={})
# ImageID of the used AMI
instanceType: str
# ImageID of the used AMI. None only if terminated == True.
imageId: str
# Optional user data associated with the VM
@@ -21,3 +55,14 @@ class Instance(SQLModel, table=True):
# The owner that creatd the resource.
owner_id: int = Field(foreign_key="user.id")
# Attached EBS volumes
ebs_volumes: list[EBSVolume] = Relationship(
back_populates="instances", link_model=EBSVolumeInstanceLink
)
# Is the instance terminated
terminated: bool
# Date at which the instance got terminated
terminationDate: datetime | None

View File

@@ -1,12 +1,12 @@
from sqlmodel import SQLModel, Field, PrimaryKeyConstraint
from sqlmodel import SQLModel, Field
class VPC(SQLModel, table=True):
# ID of the VPC
id: str = Field(default=None, primary_key=True)
# Subnet mask
subnet: str
# Base IPv4
ipv4_base: str
cidr: str
# Owning user
owner_id: int = Field(foreign_key="user.id")

View File

@@ -7,7 +7,7 @@ from openec2.db.image import AMI
def garbage_collect_image(image_id: str, db: DatabaseDep):
instances = db.exec(select(Instance).where(Instance.imageId == image_id)).all()
instances = db.exec(select(Instance).where(Instance.imageId == image_id, Instance.terminated == False)).all()
if instances:
print("Instances sill using AMI. Not cleaning up")
print(instances)

View File

@@ -18,7 +18,17 @@ from openec2.actions.terminate_instances import terminate_instances
from openec2.actions.start_instances import start_instances
from openec2.actions.stop_instances import stop_instances
from openec2.actions.deregister_image import deregister_image
from openec2.actions.create_volume import create_volume
from openec2.actions.attach_volume import attach_volume
from openec2.actions.describe_volumes import describe_volumes
from openec2.actions.detach_volume import detach_volume
from openec2.db.instance import Instance
from openec2.actions.get_caller_identity import get_caller_identity
from openec2.actions.describe_instance_types import describe_instance_types
from openec2.actions.describe_tags import describe_tags
from openec2.actions.create_tags import create_tags
from openec2.actions.describe_instance_attribute import describe_instance_attribute
app = FastAPI()
@@ -39,6 +49,7 @@ def healthz():
def get_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("GET Action")
return run_action(
request,
config,
@@ -52,9 +63,13 @@ def get_action(
async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("POST Action")
body = (await request.body()).decode()
print(f"--> {body}")
query_params = {
key: value[0]
for key, value in parse_qs((await request.body()).decode()).items()
for key, value in parse_qs(body).items()
}
return run_action(
request,
@@ -72,9 +87,8 @@ def run_action(
query_params: dict[str, str],
user: User,
):
print(query_params)
action = query_params["Action"]
return {
action_function = {
"ImportImage": import_image,
"DescribeImages": describe_images,
"RunInstances": run_instances,
@@ -83,7 +97,25 @@ def run_action(
"StartInstances": start_instances,
"StopInstances": stop_instances,
"DeregisterImage": deregister_image,
}[action](query_params, config, db, user)
"CreateVolume": create_volume,
"AttachVolume": attach_volume,
"DescribeVolumes": describe_volumes,
"DetachVolume": detach_volume,
"GetCallerIdentity": get_caller_identity,
"DescribeInstanceTypes": describe_instance_types,
"DescribeTags": describe_tags,
"CreateTags": create_tags,
"DescribeInstanceAttribute": describe_instance_attribute,
}.get(action)
if action_function is None:
print(f"Unknown action: '{action}'")
raise HTTPException(
status_code=404,
detail="Unknown action",
)
return action_function(query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@@ -2,7 +2,7 @@ from typing import Annotated, cast
from dataclasses import dataclass
import datetime
from hashlib import sha256
from urllib.parse import quote, parse_qs
from urllib.parse import parse_qs
from sqlmodel import select
from fastapi import Request, HTTPException, Depends
@@ -37,13 +37,20 @@ class AWSRequest:
# The payload, if we used a POST/PUT
payload: str | None
# List of headers that were signed
signed_headers: list[str]
def include_in_canonical_string(self, name: str) -> bool:
return name.lower() in self.signed_headers
# TODO: Probably vulnerable against a replay because we never get the current date ourselves
def sign(
self, secret_access_key: str, region: str, product: str, credential_scope: str
) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)]
[name for name in self.headers.keys() if self.include_in_canonical_string(name)]
)
canonical_header_string = (
"\n".join(
@@ -67,8 +74,6 @@ class AWSRequest:
hashed_payload,
]
)
print("Canonical request")
print(canonical_request)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
date = dt.strftime("%Y%m%d")
@@ -81,9 +86,6 @@ class AWSRequest:
]
)
print("String to sign")
print(string_to_sign)
date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode())
date_region_key = _hmac_sha256(date_key, region.encode())
date_region_service_key = _hmac_sha256(date_region_key, product.encode())
@@ -99,10 +101,7 @@ class AWSAuthentication:
x_amz_signature: str
def include_in_canonical_string(name: str) -> bool:
lower = name.lower()
return lower in ("host", "content-type") or lower.startswith("x-amz")
signed_headers: str
def get_authentication_info(request: Request) -> AWSAuthentication:
@@ -117,12 +116,14 @@ def get_authentication_info(request: Request) -> AWSAuthentication:
x_amz_algorithm=algorithm,
x_amz_credential=auth["Credential"],
x_amz_signature=auth["Signature"],
signed_headers=auth["SignedHeaders"],
)
return AWSAuthentication(
"",
"",
"",
"",
)
@@ -142,7 +143,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException(
status_code=400,
status_code=403,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
)
@@ -154,6 +155,16 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if user is None:
raise HTTPException(status_code=403)
# Validate the signed headers
signed_headers = auth_info.signed_headers.split(";")
if any(header not in signed_headers for header in [
"host",
"content-type",
]):
print("Validation of signed headers failed!")
raise HTTPException(status_code=403)
print(auth_info.x_amz_credential)
x_amz_signature = auth_info.x_amz_signature
signature = AWSRequest(
url=request.url,
@@ -161,6 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
params=query_params,
headers=request.headers,
payload=body,
signed_headers=signed_headers,
).sign(
user.secret_access_key,
region,
@@ -168,9 +180,11 @@ async def check_request_signature(request: Request, db: DatabaseDep):
"/".join([date, region, service, key]),
)
print(x_amz_signature, signature)
if x_amz_signature != signature:
raise HTTPException(status_code=401)
print("Signature mismatch!")
print(f"Expected: {signature}")
print(f"Got: {x_amz_signature}")
raise HTTPException(status_code=403)
return user

View File

@@ -1,3 +1,6 @@
from typing import Callable
def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]:
items: dict[str, dict[str, str]] = {}
for key, value in params.items():
@@ -25,3 +28,18 @@ def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]:
indices = sorted(list(items.keys()))
return [items[key] for key in indices]
def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
for item in l:
if pred(item):
return item
return None
def parse_tag_specification(params: dict) -> dict[str, str]:
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", params):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
return tags

View File

@@ -0,0 +1,86 @@
from openec2.config import OpenEC2Config
from openec2.db.instance import EBSVolume, Instance
def ebs_volume_to_libvirt_xml(volume: EBSVolume, config: OpenEC2Config) -> str:
# TODO: Honour the attached device name
return f"""
<filesystem type='mount' accessmode='passthrough'>
<driver type='virtiofs' queue='1024' />
<source dir='{config.instances.volumes / volume.id}' />
<target dir='{volume.id}' />
</filesystem>
"""
def instance_to_libvirt_xml(
instance: Instance,
config: OpenEC2Config,
uuid: str | None = None,
) -> str:
instance_type = config.instances.types[instance.instanceType]
ami_path = config.instances.location / instance.id
memory_backing = (
"""
<memoryBacking>
<source type='memfd' />
<access mode='shared' />
</memoryBacking>
"""
if instance.ebs_volumes
else ""
)
volumes = "\n".join(
ebs_volume_to_libvirt_xml(volume, config) for volume in instance.ebs_volumes
)
uuid_element = f"<uuid>{uuid}</uuid>" if uuid is not None else ""
return f"""<domain type='kvm'>
{uuid_element}
<name>{instance.id}</name>
<memory unit='MiB'>{instance_type.memory}</memory>
{memory_backing}
<vcpu placement='static'>{int(instance_type.vcpu)}</vcpu>
<os>
<type arch='x86_64'>hvm</type>
<boot dev='hd' />
<smbios mode='sysinfo' />
</os>
<sysinfo type='smbios'>
<system>
<entry name='serial'>ds=nocloud;s=http://192.168.122.1:8000/private/cloudinit/{instance.id}/</entry>
</system>
</sysinfo>
<features>
<acpi />
<apic />
<vmport state='off' />
</features>
<clock offset='utc'>
<timer name='rtc' tickpolicy='catchup'/>
<timer name='pit' tickpolicy='delay'/>
<timer name='hpet' present='no'/>
</clock>
<pm>
<suspend-to-mem enabled='no'/>
<suspend-to-disk enabled='no'/>
</pm>
<devices>
{volumes}
<disk type='file' device='disk'>
<driver name='qemu' type='qcow2'/>
<source file='{ami_path}'/>
<target dev='vda' bus='virtio'/>
</disk>
<rng model="virtio">
<backend model="random">/dev/urandom</backend>
</rng>
<interface type="network">
<source network="default"/>
<mac address="{instance.interfaceMac}" />
<model type="virtio"/>
</interface>
</devices>
</domain>
"""