Compare commits

...

2 Commits

Author SHA1 Message Date
38d37a7d5b Make Terraform/tofu work 2025-04-06 14:55:58 +02:00
697c89bb4f Add EBS stuff 2025-04-05 00:47:29 +02:00
35 changed files with 926 additions and 43 deletions

6
.gitignore vendored
View File

@ -8,3 +8,9 @@ wheels/
# Virtual environments
.venv
# Terraform/Tofu
examples/tofu/**/.terraform
examples/tofu/**/.terraform.lock.hcl
examples/tofu/**/.terraform.tfstate
examples/tofu/**/.terraform.tfstate.backup

View File

@ -1 +1,7 @@
Test
# Private Compute Stack (Pieces)
Pieces is a mostly API-compatible implementation of AWS services.
## EC2
A very small subset of EC2 functionality is implemented.

22
examples/tofu/main.tf Normal file
View File

@ -0,0 +1,22 @@
terraform {
required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.0"
}
}
}
provider "aws" {
region = "eu-west-1"
}
resource "aws_instance" "test-instance" {
ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
instance_type = "micro"
availability_zone = "az-1"
tags = {
TestTag = "TestValue"
}
}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":8,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[{"mode":"managed","type":"aws_instance","name":"test-instance","provider":"provider[\"registry.opentofu.org/hashicorp/aws\"]","instances":[{"schema_version":1,"attributes":{"ami":"0c4dcaafb6a14dbb93b402f1fd6a9dfb","arn":"arn:aws:ec2:eu-west-1:1:instance/19b606d3c0b543c991f2b0cd632013a2","associate_public_ip_address":false,"availability_zone":"az-1","capacity_reservation_specification":[],"cpu_core_count":null,"cpu_options":[],"cpu_threads_per_core":null,"credit_specification":[],"disable_api_stop":false,"disable_api_termination":false,"ebs_block_device":[],"ebs_optimized":false,"enable_primary_ipv6":null,"enclave_options":[],"ephemeral_block_device":[],"get_password_data":false,"hibernation":null,"host_id":null,"host_resource_group_arn":null,"iam_instance_profile":"","id":"19b606d3c0b543c991f2b0cd632013a2","instance_initiated_shutdown_behavior":null,"instance_lifecycle":"","instance_market_options":[],"instance_state":"running","instance_type":"","ipv6_address_count":0,"ipv6_addresses":[],"key_name":"","launch_template":[],"maintenance_options":[],"metadata_options":[],"monitoring":null,"network_interface":[],"outpost_arn":"","password_data":"","placement_group":null,"placement_partition_number":null,"primary_network_interface_id":"","private_dns":"","private_dns_name_options":[],"private_ip":"","public_dns":"","public_ip":"","root_block_device":[],"secondary_private_ips":[],"security_groups":[],"source_dest_check":true,"spot_instance_request_id":"","subnet_id":"","tags":{"TestTag":"TestValue"},"tags_all":{"TestTag":"TestValue"},"tenancy":null,"timeouts":null,"user_data":null,"user_data_base64":null,"user_data_replace_on_change":false,"volume_tags":null,"vpc_security_group_ids":[]},"sensitive_attributes":[],"private":"eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjo2MDAwMDAwMDAwMDAsImRlbGV0ZSI6MTIwMDAwMDAwMDAwMCwicmVhZCI6OTAwMDAwMDAwMDAwLCJ1cGRhdGUiOjYwMDAwMDAwMDAwMH0sInNjaGVtYV92ZXJzaW9uIjoiMSJ9"}]}],"check_results":null}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":7,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[],"check_results":null}

View File

@ -0,0 +1,90 @@
import uuid
import xml.etree.ElementTree as ET
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
import libvirt
from openec2.api.attach_volume import AttachVolumeResponse
from openec2.libvirt import LibvirtSingleton
from openec2.utils.libvirt import instance_to_libvirt_xml
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance, EBSVolume
from openec2.utils.libvirt import ebs_volume_to_libvirt_xml
def attach_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
device = params["Device"]
instance_id = params["InstanceId"]
volume_id = params["VolumeId"]
volume = db.exec(select(EBSVolume).where(EBSVolume.id == volume_id)).first()
if volume is None:
return
instance = db.exec(select(Instance).where(Instance.id == instance_id)).first()
if instance is None:
return
attached_volume_ids = [i.id for i in instance.ebs_volumes]
if volume_id in attached_volume_ids:
print("CANNOT ATTACH THE SAME VOLUME TO THE EC2 TWICE")
return
if not volume.multi_attach_enabled and volume.instances:
print("CANNOT ATTACH NON-MULTIATTACH again")
return
# Add the required data to libvirt
conn = LibvirtSingleton.of().connection
domain = conn.lookupByName(instance_id)
# Add the memory backing if required
running = domain.isActive()
if not instance.ebs_volumes:
if running:
raise HTTPException(
status_code=500,
detail="Instance is running",
)
# Update the instance
volume.instances.append(instance)
domain_xml = domain.XMLDesc()
domain_uuid = ET.fromstring(domain_xml).find("uuid").text
print(f"Updating XML for {instance.id} with {domain_uuid}")
instance_xml = instance_to_libvirt_xml(instance, config, domain_uuid)
print(instance_xml)
conn.defineXML(instance_xml)
else:
# Attach the device
volume.instances.append(instance)
domain.attachDeviceFlags(
ebs_volume_to_libvirt_xml(volume, config),
libvirt.VIR_DOMAIN_DEVICE_MODIFY_LIVE
if running
else libvirt.VIR_DOMAIN_DEVICE_MODIFY_CONFIG,
)
db.add(volume)
db.commit()
return Response(
AttachVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
instanceId=instance_id,
device=device,
status="attached",
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import cast
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_objects, parse_array_plain
from openec2.api.create_tags import CreateTagsResponse
@dataclass
class Tag:
key: str
value: str
def create_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
tags: dict[str, str] = {}
for tag in parse_array_objects("Tag", cast(dict, params)):
tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
print(f"Unknown instance {instance_id}")
continue
instance.tags = {
**instance.tags,
**tags,
}
print(instance)
db.commit()
return Response(
CreateTagsResponse(
requestId=uuid.uuid4().hex,
return_=True,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,41 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import EBSVolume
from openec2.api.create_volume import CreateVolumeResponse
def create_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
availabilityZone = params["AvailabilityZone"]
volume_id = f"vol-{uuid.uuid4().hex}"
volume = EBSVolume(
id=volume_id,
availability_zone=availabilityZone,
multi_attach_enabled=params.get("MultiAttachEnabled", "false") == "true",
owner_id=user.id,
)
volume.path(config).mkdir()
db.add(volume)
db.commit()
return Response(
CreateVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
availabilityZone=volume.availability_zone,
multiAttachEnabled=volume.multi_attach_enabled,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,17 @@
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.vpc import VPC
def create_vpc(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
cidr_block = params["CidrBlock"]

View File

@ -0,0 +1,43 @@
import uuid
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_attribute import DescribeInstanceAttributeResponse, InstanceInitiatedShutdownBehaviour, DisableApiTermination, DisableApiStop
def describe_instance_attribute(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
raise HTTPException(
status_code=404,
)
return Response(
DescribeInstanceAttributeResponse(
requestId=uuid.uuid4().hex,
instanceId=instance.id,
instanceInitiatedShutdownBehaviour=InstanceInitiatedShutdownBehaviour(
value="stop",
) if attribute == "instanceInitiatedShutdownBehaviour" else None,
disableApiTermination=DisableApiTermination(
value="false",
) if attribute == "disableApiTermination" else None,
disableApiStop=DisableApiStop(
value="false",
) if attribute == "disableApiStop" else None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,36 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_types import DescribeInstanceTypesResponse, InstanceTypeInfo
def describe_instance_types(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
response: list[InstanceTypeInfo] = []
instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all()
for instance in instances:
response.append(
InstanceTypeInfo(
instanceType=instance.instanceType,
),
)
return Response(
DescribeInstanceTypesResponse(
requestId=uuid.uuid4().hex,
instanceTypeSet=response,
nextToken=None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,99 @@
from typing import Literal, cast
from dataclasses import dataclass
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_tags import DescribeTagsResponse, TagDescription
from openec2.utils.array import parse_array_objects, parse_array_plain, find
@dataclass
class Filter:
name: Literal["resource-id"] | Literal["resource-type"] | Literal["key"] | str
values: list[str]
def match(self, instance: Instance) -> bool:
if self.name not in instance.tags:
return False
value: str | None
if self.name == "resource-type":
value = "instance"
elif self.name == "resource-id":
value = instance.id
elif self.name == "key":
return any(key in instance.tags for key in self.values)
elif self.name.startswith("tag:"):
value = instance.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def value(self, instance: Instance) -> TagDescription:
if self.name == "resource-type":
return TagDescription(resourceType="instance")
elif self.name == "resource-id":
return TagDescription(resourceId=instance.id)
elif self.name == "key":
key_name = find(
self.values,
lambda x: x in instance.tags,
)
assert key_name is not None
key_value = instance.tags[key_name]
return TagDescription(
key=key_name,
value=key_value,
)
elif self.name.startswith("tag:"):
tag = self.name.replace("tag:", "")
return TagDescription(
key=tag,
value=instance.tags[tag],
)
else:
raise Exception(f"Unknown filter name {self.name}")
def describe_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Values", filter),
)
)
response: list[TagDescription] = []
for instance in db.exec(select(Instance).where(Instance.owner_id == user.id)).all():
filter = find(
filters,
lambda f: f.match(instance),
)
if filter is None:
continue
response.append(filter.value(instance))
return Response(
DescribeTagsResponse(
requestId=uuid.uuid4().hex,
nextToken=None,
tagSet=response,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,35 @@
import uuid
from sqlmodel import select
from fastapi import Response
from openec2.db import DatabaseDep
from openec2.config import OpenEC2Config
from fastapi.datastructures import QueryParams
from openec2.db.user import User
from openec2.db.instance import EBSVolume
from openec2.api.describe_volumes import DescribeVolumesResponse, VolumeSet, Volume
def describe_volumes(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
volumes = db.exec(select(EBSVolume).where(EBSVolume.owner_id == user.id)).all()
return Response(
DescribeVolumesResponse(
requestId=uuid.uuid4().hex,
volumeSet=VolumeSet(
item=[
Volume(
volumeId=volume.id,
multiAttachEnabled=volume.multi_attach_enabled,
)
for volume in volumes
],
),
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,61 @@
import uuid
import libvirt
from sqlmodel import select
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance, EBSVolume
from openec2.api.detach_volume import DetachVolumeResponse
from openec2.utils.libvirt import ebs_volume_to_libvirt_xml
def detach_volume(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
volume_id = params["VolumeId"]
# Find the instance
instance = db.exec(
select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)
).first()
if instance is None:
return
# Find the volume
volume = db.exec(
select(EBSVolume).where(
EBSVolume.id == volume_id, EBSVolume.owner_id == user.id
)
).first()
if volume is None:
return
if instance_id not in [i.id for i in volume.instances]:
return
# Remove the volume from the instance
domain = LibvirtSingleton.of().connection.lookupByName(instance_id)
domain.detachDeviceFlags(
ebs_volume_to_libvirt_xml(volume, config),
libvirt.VIR_DOMAIN_DEVICE_MODIFY_LIVE
if domain.isActive()
else libvirt.VIR_DOMAIN_DEVICE_MODIFY_CONFIG,
)
return Response(
DetachVolumeResponse(
requestId=uuid.uuid4().hex,
volumeId=volume_id,
instanceId=instance_id,
status="detached",
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,30 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.api.get_caller_identity import GetCallerIdentityResponse, GetCallerIdentityResult, ResponseMetadata
def get_caller_identity(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
return Response(
GetCallerIdentityResponse(
result=GetCallerIdentityResult(
arn=f"arn:aws:iam::{user.id}:user/{user.name}",
user_id=str(user.id),
account=str(user.id),
),
metadata=ResponseMetadata(
request_id=uuid.uuid4().hex,
),
).to_xml(),
media_type="application/xml",
)

View File

@ -98,9 +98,9 @@ def run_instances(
user: User,
):
image_id = params["ImageId"]
instance_type = params["InstanceType"]
instance_type_name = params["InstanceType"]
instance_type = config.instances.types.get(params["InstanceType"])
instance_type = config.instances.types.get(instance_type_name)
if instance_type is None:
raise Exception(f"Unknown instance type {params['InstanceType']}")
@ -115,11 +115,11 @@ def run_instances(
)
# Parse tags
# TODO: broken
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", cast(dict, params)):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
print(f"Creating with tags {tags}")
# Get a private IPv4
instance_id = uuid.uuid4().hex
@ -153,6 +153,7 @@ def run_instances(
else None,
privateIPv4=private_ipv4,
interfaceMac=mac,
instanceType=instance_type_name,
owner_id=user.id,
)
db.add(instance)
@ -175,7 +176,7 @@ def run_instances(
return Response(
RunInstanceResponse(
request_id=uuid.uuid4().hex,
instance_set=RunInstanceInstanceSet(
instancesSet=RunInstanceInstanceSet(
item=[description],
),
).to_xml(),

View File

@ -61,6 +61,9 @@ def terminate_instances(
instance_disk = config.instances.location / instance_id
instance_disk.unlink()
for volume in instance.ebs_volumes:
volume.instances.remove(instance)
image_ids.add(instance.imageId)
remove_instance_dhcp_mapping(
instance.id, instance.interfaceMac, instance.privateIPv4, db

View File

@ -0,0 +1,17 @@
from pydantic_xml import BaseXmlModel, element
class AttachVolumeResponse(
BaseXmlModel,
tag="AttachVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
instanceId: str = element()
device: str = element()
status: str = element()

View File

@ -0,0 +1,10 @@
from pydantic_xml import BaseXmlModel, element
class CreateTagsResponse(
BaseXmlModel,
nsmap={"": ""},
):
requestId: str = element()
return_: bool = element(tag="return")

View File

@ -0,0 +1,15 @@
from pydantic_xml import BaseXmlModel, element
class CreateVolumeResponse(
BaseXmlModel,
tag="CreateVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
availabilityZone: str = element()
multiAttachEnabled: bool = element()

View File

@ -0,0 +1,26 @@
from pydantic_xml import BaseXmlModel, element
class InstanceInitiatedShutdownBehaviour(BaseXmlModel):
value: str = element()
class DisableApiTermination(BaseXmlModel):
value: str = element()
class DisableApiStop(BaseXmlModel):
value: str = element()
class DescribeInstanceAttributeResponse(
BaseXmlModel,
tag="DescribeInstanceAttributeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceId: str = element()
instanceInitiatedShutdownBehaviour: InstanceInitiatedShutdownBehaviour | None = element(defaut=None)
disableApiTermination: DisableApiTermination | None = element(default=None)
disableApiStop: DisableApiStop | None = element(default=None)

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element, wrapped
class InstanceTypeInfo(BaseXmlModel, tag="item"):
instanceType: str = element()
class DescribeInstanceTypesResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceTypeSet: list[InstanceTypeInfo] = wrapped(
"instanceTypeSet",
element(tag="item"),
)
nextToken: str | None = element()

View File

@ -1,4 +1,4 @@
from pydantic_xml import BaseXmlModel, element
from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt
from openec2.db.instance import Instance
@ -21,7 +21,7 @@ class InstanceDescription(
instance_id: str = element(tag="instanceId")
image_id: str = element(tag="imageId")
instance_state: InstanceState = element(tag="instanceState")
tag_set: InstanceTagSet = element(tag="tagSet")
tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item"))
class ReservationSetInstancesSet(BaseXmlModel):
@ -64,13 +64,11 @@ def describe_instance(
instance_id=instance.id,
image_id=instance.imageId,
instance_state=describe_instance_state(domain),
tag_set=InstanceTagSet(
item=[
InstanceTag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
tagSet=[
InstanceTag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
)

View File

@ -0,0 +1,22 @@
from pydantic_xml import BaseXmlModel, wrapped, element
class TagDescription(BaseXmlModel):
key: str | None = element(default=None)
resourceId: str | None = element(default=None)
resourceType: str | None = element(default=None)
value: str | None = element(default=None)
class DescribeTagsResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}
):
requestId: str = element()
nextToken: str | None = element()
tagSet: list[TagDescription] = wrapped("tagSet", element(tag="item"))

View File

@ -0,0 +1,20 @@
from pydantic_xml import BaseXmlModel, element
class Volume(BaseXmlModel):
volumeId: str = element()
multiAttachEnabled: bool = element()
class VolumeSet(BaseXmlModel):
item: list[Volume] = element(tag="item")
class DescribeVolumesResponse(
BaseXmlModel,
tag="DescribeVolumesResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeSet: VolumeSet = element()

View File

@ -0,0 +1,15 @@
from pydantic_xml import BaseXmlModel, element
class DetachVolumeResponse(
BaseXmlModel,
tag="DetachVolumeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
volumeId: str = element()
instanceId: str = element()
status: str = element()

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element
class GetCallerIdentityResult(BaseXmlModel):
arn: str = element(tag="Arn")
user_id: str = element(tag="UserId")
account: str = element(tag="Account")
class ResponseMetadata(BaseXmlModel):
request_id: str = element(tag="RequestId")
class GetCallerIdentityResponse(
BaseXmlModel,
nsmap={"": "https://sts.amazonaws.com/doc/2011-06-15/"}
):
result: GetCallerIdentityResult = element(tag="GetCallerIdentityResult")
metadata: ResponseMetadata = element()

View File

@ -10,4 +10,4 @@ class RunInstanceInstanceSet(BaseXmlModel):
class RunInstanceResponse(BaseXmlModel):
request_id: str = element(tag="requestId")
instance_set: RunInstanceInstanceSet = element(tag="instanceSet")
instancesSet: RunInstanceInstanceSet = element(tag="instancesSet")

View File

@ -12,7 +12,7 @@ class _OpenEC2InstanceType(BaseModel):
class _OpenEC2InstanceConfig(BaseModel):
location: Path
volumes: Path
types: dict[str, _OpenEC2InstanceType]
@ -45,6 +45,7 @@ def _get_config() -> _OpenEC2Config:
seed=Path("/home/alexander/openec2/seed"),
instances=_OpenEC2InstanceConfig(
location=Path("/home/alexander/openec2/instances"),
volumes=Path("/home/alexander/openec2/volumes"),
types={
"micro": _OpenEC2InstanceType(
memory=1024,
@ -58,7 +59,7 @@ def _get_config() -> _OpenEC2Config:
insecure=False,
database=_OpenEC2DatabaseConfig(
url="sqlite:////home/alexander/openec2/db2.sqlite3",
debug=True,
debug=False,
),
)

View File

@ -1,4 +1,35 @@
from sqlmodel import SQLModel, Field, JSON, Column
from pathlib import Path
from sqlmodel import SQLModel, Field, JSON, Column, Relationship
from openec2.config import OpenEC2Config
class EBSVolumeInstanceLink(SQLModel, table=True):
instance_id: str | None = Field(
default=None, foreign_key="instance.id", primary_key=True
)
ebs_volume_id: str | None = Field(
default=None, foreign_key="ebsvolume.id", primary_key=True
)
class EBSVolume(SQLModel, table=True):
id: str = Field(primary_key=True)
availability_zone: str
multi_attach_enabled: bool
instances: list["Instance"] = Relationship(
back_populates="ebs_volumes", link_model=EBSVolumeInstanceLink
)
owner_id: int = Field(foreign_key="user.id")
def path(self, config: OpenEC2Config) -> Path:
"""Compute the path of the volume on disk."""
return config.instances.volumes / self.id
class Instance(SQLModel, table=True):
@ -7,6 +38,8 @@ class Instance(SQLModel, table=True):
# Tags associated with the VM
tags: dict = Field(sa_column=Column(JSON), default={})
instanceType: str
# ImageID of the used AMI
imageId: str
@ -21,3 +54,8 @@ class Instance(SQLModel, table=True):
# The owner that creatd the resource.
owner_id: int = Field(foreign_key="user.id")
# Attached EBS volumes
ebs_volumes: list[EBSVolume] = Relationship(
back_populates="instances", link_model=EBSVolumeInstanceLink
)

View File

@ -1,12 +1,12 @@
from sqlmodel import SQLModel, Field, PrimaryKeyConstraint
from sqlmodel import SQLModel, Field
class VPC(SQLModel, table=True):
# ID of the VPC
id: str = Field(default=None, primary_key=True)
# Subnet mask
subnet: str
# Base IPv4
ipv4_base: str
cidr: str
# Owning user
owner_id: int = Field(foreign_key="user.id")

View File

@ -18,7 +18,17 @@ from openec2.actions.terminate_instances import terminate_instances
from openec2.actions.start_instances import start_instances
from openec2.actions.stop_instances import stop_instances
from openec2.actions.deregister_image import deregister_image
from openec2.actions.create_volume import create_volume
from openec2.actions.attach_volume import attach_volume
from openec2.actions.describe_volumes import describe_volumes
from openec2.actions.detach_volume import detach_volume
from openec2.db.instance import Instance
from openec2.actions.get_caller_identity import get_caller_identity
from openec2.actions.describe_instance_types import describe_instance_types
from openec2.actions.describe_tags import describe_tags
from openec2.actions.create_tags import create_tags
from openec2.actions.describe_instance_attribute import describe_instance_attribute
app = FastAPI()
@ -39,6 +49,7 @@ def healthz():
def get_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("GET Action")
return run_action(
request,
config,
@ -52,9 +63,13 @@ def get_action(
async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("POST Action")
body = (await request.body()).decode()
print(f"--> {body}")
query_params = {
key: value[0]
for key, value in parse_qs((await request.body()).decode()).items()
for key, value in parse_qs(body).items()
}
return run_action(
request,
@ -72,9 +87,8 @@ def run_action(
query_params: dict[str, str],
user: User,
):
print(query_params)
action = query_params["Action"]
return {
action_function = {
"ImportImage": import_image,
"DescribeImages": describe_images,
"RunInstances": run_instances,
@ -83,7 +97,25 @@ def run_action(
"StartInstances": start_instances,
"StopInstances": stop_instances,
"DeregisterImage": deregister_image,
}[action](query_params, config, db, user)
"CreateVolume": create_volume,
"AttachVolume": attach_volume,
"DescribeVolumes": describe_volumes,
"DetachVolume": detach_volume,
"GetCallerIdentity": get_caller_identity,
"DescribeInstanceTypes": describe_instance_types,
"DescribeTags": describe_tags,
"CreateTags": create_tags,
"DescribeInstanceAttribute": describe_instance_attribute,
}.get(action)
if action_function is None:
print(f"Unknown action: '{action}'")
raise HTTPException(
status_code=404,
detail="Unknown action",
)
return action_function(query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@ -2,7 +2,7 @@ from typing import Annotated, cast
from dataclasses import dataclass
import datetime
from hashlib import sha256
from urllib.parse import quote, parse_qs
from urllib.parse import parse_qs
from sqlmodel import select
from fastapi import Request, HTTPException, Depends
@ -37,13 +37,20 @@ class AWSRequest:
# The payload, if we used a POST/PUT
payload: str | None
# List of headers that were signed
signed_headers: list[str]
def include_in_canonical_string(self, name: str) -> bool:
return name.lower() in self.signed_headers
# TODO: Probably vulnerable against a replay because we never get the current date ourselves
def sign(
self, secret_access_key: str, region: str, product: str, credential_scope: str
) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)]
[name for name in self.headers.keys() if self.include_in_canonical_string(name)]
)
canonical_header_string = (
"\n".join(
@ -67,8 +74,6 @@ class AWSRequest:
hashed_payload,
]
)
print("Canonical request")
print(canonical_request)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
date = dt.strftime("%Y%m%d")
@ -81,9 +86,6 @@ class AWSRequest:
]
)
print("String to sign")
print(string_to_sign)
date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode())
date_region_key = _hmac_sha256(date_key, region.encode())
date_region_service_key = _hmac_sha256(date_region_key, product.encode())
@ -99,10 +101,7 @@ class AWSAuthentication:
x_amz_signature: str
def include_in_canonical_string(name: str) -> bool:
lower = name.lower()
return lower in ("host", "content-type") or lower.startswith("x-amz")
signed_headers: str
def get_authentication_info(request: Request) -> AWSAuthentication:
@ -117,12 +116,14 @@ def get_authentication_info(request: Request) -> AWSAuthentication:
x_amz_algorithm=algorithm,
x_amz_credential=auth["Credential"],
x_amz_signature=auth["Signature"],
signed_headers=auth["SignedHeaders"],
)
return AWSAuthentication(
"",
"",
"",
"",
)
@ -142,7 +143,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException(
status_code=400,
status_code=403,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
)
@ -154,6 +155,16 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if user is None:
raise HTTPException(status_code=403)
# Validate the signed headers
signed_headers = auth_info.signed_headers.split(";")
if any(header not in signed_headers for header in [
"host",
"content-type",
]):
print("Validation of signed headers failed!")
raise HTTPException(status_code=403)
print(auth_info.x_amz_credential)
x_amz_signature = auth_info.x_amz_signature
signature = AWSRequest(
url=request.url,
@ -161,6 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
params=query_params,
headers=request.headers,
payload=body,
signed_headers=signed_headers,
).sign(
user.secret_access_key,
region,
@ -168,9 +180,11 @@ async def check_request_signature(request: Request, db: DatabaseDep):
"/".join([date, region, service, key]),
)
print(x_amz_signature, signature)
if x_amz_signature != signature:
raise HTTPException(status_code=401)
print("Signature mismatch!")
print(f"Expected: {signature}")
print(f"Got: {x_amz_signature}")
raise HTTPException(status_code=403)
return user

View File

@ -1,3 +1,6 @@
from typing import Callable
def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]:
items: dict[str, dict[str, str]] = {}
for key, value in params.items():
@ -25,3 +28,10 @@ def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]:
indices = sorted(list(items.keys()))
return [items[key] for key in indices]
def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
for item in l:
if pred(item):
return item
return None

View File

@ -0,0 +1,86 @@
from openec2.config import OpenEC2Config
from openec2.db.instance import EBSVolume, Instance
def ebs_volume_to_libvirt_xml(volume: EBSVolume, config: OpenEC2Config) -> str:
# TODO: Honour the attached device name
return f"""
<filesystem type='mount' accessmode='passthrough'>
<driver type='virtiofs' queue='1024' />
<source dir='{config.instances.volumes / volume.id}' />
<target dir='{volume.id}' />
</filesystem>
"""
def instance_to_libvirt_xml(
instance: Instance,
config: OpenEC2Config,
uuid: str | None = None,
) -> str:
instance_type = config.instances.types[instance.instanceType]
ami_path = config.instances.location / instance.id
memory_backing = (
"""
<memoryBacking>
<source type='memfd' />
<access mode='shared' />
</memoryBacking>
"""
if instance.ebs_volumes
else ""
)
volumes = "\n".join(
ebs_volume_to_libvirt_xml(volume, config) for volume in instance.ebs_volumes
)
uuid_element = f"<uuid>{uuid}</uuid>" if uuid is not None else ""
return f"""<domain type='kvm'>
{uuid_element}
<name>{instance.id}</name>
<memory unit='MiB'>{instance_type.memory}</memory>
{memory_backing}
<vcpu placement='static'>{int(instance_type.vcpu)}</vcpu>
<os>
<type arch='x86_64'>hvm</type>
<boot dev='hd' />
<smbios mode='sysinfo' />
</os>
<sysinfo type='smbios'>
<system>
<entry name='serial'>ds=nocloud;s=http://192.168.122.1:8000/private/cloudinit/{instance.id}/</entry>
</system>
</sysinfo>
<features>
<acpi />
<apic />
<vmport state='off' />
</features>
<clock offset='utc'>
<timer name='rtc' tickpolicy='catchup'/>
<timer name='pit' tickpolicy='delay'/>
<timer name='hpet' present='no'/>
</clock>
<pm>
<suspend-to-mem enabled='no'/>
<suspend-to-disk enabled='no'/>
</pm>
<devices>
{volumes}
<disk type='file' device='disk'>
<driver name='qemu' type='qcow2'/>
<source file='{ami_path}'/>
<target dev='vda' bus='virtio'/>
</disk>
<rng model="virtio">
<backend model="random">/dev/urandom</backend>
</rng>
<interface type="network">
<source network="default"/>
<mac address="{instance.interfaceMac}" />
<model type="virtio"/>
</interface>
</devices>
</domain>
"""