Make Terraform/tofu work

This commit is contained in:
PapaTutuWawa 2025-04-06 14:55:58 +02:00
parent 697c89bb4f
commit 38d37a7d5b
23 changed files with 484 additions and 38 deletions

6
.gitignore vendored
View File

@ -8,3 +8,9 @@ wheels/
# Virtual environments
.venv
# Terraform/Tofu
examples/tofu/**/.terraform
examples/tofu/**/.terraform.lock.hcl
examples/tofu/**/.terraform.tfstate
examples/tofu/**/.terraform.tfstate.backup

22
examples/tofu/main.tf Normal file
View File

@ -0,0 +1,22 @@
terraform {
required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.0"
}
}
}
provider "aws" {
region = "eu-west-1"
}
resource "aws_instance" "test-instance" {
ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
instance_type = "micro"
availability_zone = "az-1"
tags = {
TestTag = "TestValue"
}
}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":8,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[{"mode":"managed","type":"aws_instance","name":"test-instance","provider":"provider[\"registry.opentofu.org/hashicorp/aws\"]","instances":[{"schema_version":1,"attributes":{"ami":"0c4dcaafb6a14dbb93b402f1fd6a9dfb","arn":"arn:aws:ec2:eu-west-1:1:instance/19b606d3c0b543c991f2b0cd632013a2","associate_public_ip_address":false,"availability_zone":"az-1","capacity_reservation_specification":[],"cpu_core_count":null,"cpu_options":[],"cpu_threads_per_core":null,"credit_specification":[],"disable_api_stop":false,"disable_api_termination":false,"ebs_block_device":[],"ebs_optimized":false,"enable_primary_ipv6":null,"enclave_options":[],"ephemeral_block_device":[],"get_password_data":false,"hibernation":null,"host_id":null,"host_resource_group_arn":null,"iam_instance_profile":"","id":"19b606d3c0b543c991f2b0cd632013a2","instance_initiated_shutdown_behavior":null,"instance_lifecycle":"","instance_market_options":[],"instance_state":"running","instance_type":"","ipv6_address_count":0,"ipv6_addresses":[],"key_name":"","launch_template":[],"maintenance_options":[],"metadata_options":[],"monitoring":null,"network_interface":[],"outpost_arn":"","password_data":"","placement_group":null,"placement_partition_number":null,"primary_network_interface_id":"","private_dns":"","private_dns_name_options":[],"private_ip":"","public_dns":"","public_ip":"","root_block_device":[],"secondary_private_ips":[],"security_groups":[],"source_dest_check":true,"spot_instance_request_id":"","subnet_id":"","tags":{"TestTag":"TestValue"},"tags_all":{"TestTag":"TestValue"},"tenancy":null,"timeouts":null,"user_data":null,"user_data_base64":null,"user_data_replace_on_change":false,"volume_tags":null,"vpc_security_group_ids":[]},"sensitive_attributes":[],"private":"eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjo2MDAwMDAwMDAwMDAsImRlbGV0ZSI6MTIwMDAwMDAwMDAwMCwicmVhZCI6OTAwMDAwMDAwMDAwLCJ1cGRhdGUiOjYwMDAwMDAwMDAwMH0sInNjaGVtYV92ZXJzaW9uIjoiMSJ9"}]}],"check_results":null}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":7,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[],"check_results":null}

View File

@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import cast
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_objects, parse_array_plain
from openec2.api.create_tags import CreateTagsResponse
@dataclass
class Tag:
key: str
value: str
def create_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
tags: dict[str, str] = {}
for tag in parse_array_objects("Tag", cast(dict, params)):
tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
print(f"Unknown instance {instance_id}")
continue
instance.tags = {
**instance.tags,
**tags,
}
print(instance)
db.commit()
return Response(
CreateTagsResponse(
requestId=uuid.uuid4().hex,
return_=True,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,17 @@
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.vpc import VPC
def create_vpc(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
cidr_block = params["CidrBlock"]

View File

@ -0,0 +1,43 @@
import uuid
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_attribute import DescribeInstanceAttributeResponse, InstanceInitiatedShutdownBehaviour, DisableApiTermination, DisableApiStop
def describe_instance_attribute(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
raise HTTPException(
status_code=404,
)
return Response(
DescribeInstanceAttributeResponse(
requestId=uuid.uuid4().hex,
instanceId=instance.id,
instanceInitiatedShutdownBehaviour=InstanceInitiatedShutdownBehaviour(
value="stop",
) if attribute == "instanceInitiatedShutdownBehaviour" else None,
disableApiTermination=DisableApiTermination(
value="false",
) if attribute == "disableApiTermination" else None,
disableApiStop=DisableApiStop(
value="false",
) if attribute == "disableApiStop" else None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,36 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_types import DescribeInstanceTypesResponse, InstanceTypeInfo
def describe_instance_types(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
response: list[InstanceTypeInfo] = []
instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all()
for instance in instances:
response.append(
InstanceTypeInfo(
instanceType=instance.instanceType,
),
)
return Response(
DescribeInstanceTypesResponse(
requestId=uuid.uuid4().hex,
instanceTypeSet=response,
nextToken=None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,99 @@
from typing import Literal, cast
from dataclasses import dataclass
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_tags import DescribeTagsResponse, TagDescription
from openec2.utils.array import parse_array_objects, parse_array_plain, find
@dataclass
class Filter:
name: Literal["resource-id"] | Literal["resource-type"] | Literal["key"] | str
values: list[str]
def match(self, instance: Instance) -> bool:
if self.name not in instance.tags:
return False
value: str | None
if self.name == "resource-type":
value = "instance"
elif self.name == "resource-id":
value = instance.id
elif self.name == "key":
return any(key in instance.tags for key in self.values)
elif self.name.startswith("tag:"):
value = instance.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def value(self, instance: Instance) -> TagDescription:
if self.name == "resource-type":
return TagDescription(resourceType="instance")
elif self.name == "resource-id":
return TagDescription(resourceId=instance.id)
elif self.name == "key":
key_name = find(
self.values,
lambda x: x in instance.tags,
)
assert key_name is not None
key_value = instance.tags[key_name]
return TagDescription(
key=key_name,
value=key_value,
)
elif self.name.startswith("tag:"):
tag = self.name.replace("tag:", "")
return TagDescription(
key=tag,
value=instance.tags[tag],
)
else:
raise Exception(f"Unknown filter name {self.name}")
def describe_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Values", filter),
)
)
response: list[TagDescription] = []
for instance in db.exec(select(Instance).where(Instance.owner_id == user.id)).all():
filter = find(
filters,
lambda f: f.match(instance),
)
if filter is None:
continue
response.append(filter.value(instance))
return Response(
DescribeTagsResponse(
requestId=uuid.uuid4().hex,
nextToken=None,
tagSet=response,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,30 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.api.get_caller_identity import GetCallerIdentityResponse, GetCallerIdentityResult, ResponseMetadata
def get_caller_identity(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
return Response(
GetCallerIdentityResponse(
result=GetCallerIdentityResult(
arn=f"arn:aws:iam::{user.id}:user/{user.name}",
user_id=str(user.id),
account=str(user.id),
),
metadata=ResponseMetadata(
request_id=uuid.uuid4().hex,
),
).to_xml(),
media_type="application/xml",
)

View File

@ -115,11 +115,11 @@ def run_instances(
)
# Parse tags
# TODO: broken
tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", cast(dict, params)):
for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"]
print(f"Creating with tags {tags}")
# Get a private IPv4
instance_id = uuid.uuid4().hex
@ -176,7 +176,7 @@ def run_instances(
return Response(
RunInstanceResponse(
request_id=uuid.uuid4().hex,
instance_set=RunInstanceInstanceSet(
instancesSet=RunInstanceInstanceSet(
item=[description],
),
).to_xml(),

View File

@ -0,0 +1,10 @@
from pydantic_xml import BaseXmlModel, element
class CreateTagsResponse(
BaseXmlModel,
nsmap={"": ""},
):
requestId: str = element()
return_: bool = element(tag="return")

View File

@ -0,0 +1,26 @@
from pydantic_xml import BaseXmlModel, element
class InstanceInitiatedShutdownBehaviour(BaseXmlModel):
value: str = element()
class DisableApiTermination(BaseXmlModel):
value: str = element()
class DisableApiStop(BaseXmlModel):
value: str = element()
class DescribeInstanceAttributeResponse(
BaseXmlModel,
tag="DescribeInstanceAttributeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceId: str = element()
instanceInitiatedShutdownBehaviour: InstanceInitiatedShutdownBehaviour | None = element(defaut=None)
disableApiTermination: DisableApiTermination | None = element(default=None)
disableApiStop: DisableApiStop | None = element(default=None)

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element, wrapped
class InstanceTypeInfo(BaseXmlModel, tag="item"):
instanceType: str = element()
class DescribeInstanceTypesResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceTypeSet: list[InstanceTypeInfo] = wrapped(
"instanceTypeSet",
element(tag="item"),
)
nextToken: str | None = element()

View File

@ -1,4 +1,4 @@
from pydantic_xml import BaseXmlModel, element
from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt
from openec2.db.instance import Instance
@ -21,7 +21,7 @@ class InstanceDescription(
instance_id: str = element(tag="instanceId")
image_id: str = element(tag="imageId")
instance_state: InstanceState = element(tag="instanceState")
tag_set: InstanceTagSet = element(tag="tagSet")
tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item"))
class ReservationSetInstancesSet(BaseXmlModel):
@ -64,13 +64,11 @@ def describe_instance(
instance_id=instance.id,
image_id=instance.imageId,
instance_state=describe_instance_state(domain),
tag_set=InstanceTagSet(
item=[
tagSet=[
InstanceTag(
key=key,
value=value,
)
for key, value in instance.tags.items()
],
),
)

View File

@ -0,0 +1,22 @@
from pydantic_xml import BaseXmlModel, wrapped, element
class TagDescription(BaseXmlModel):
key: str | None = element(default=None)
resourceId: str | None = element(default=None)
resourceType: str | None = element(default=None)
value: str | None = element(default=None)
class DescribeTagsResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}
):
requestId: str = element()
nextToken: str | None = element()
tagSet: list[TagDescription] = wrapped("tagSet", element(tag="item"))

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element
class GetCallerIdentityResult(BaseXmlModel):
arn: str = element(tag="Arn")
user_id: str = element(tag="UserId")
account: str = element(tag="Account")
class ResponseMetadata(BaseXmlModel):
request_id: str = element(tag="RequestId")
class GetCallerIdentityResponse(
BaseXmlModel,
nsmap={"": "https://sts.amazonaws.com/doc/2011-06-15/"}
):
result: GetCallerIdentityResult = element(tag="GetCallerIdentityResult")
metadata: ResponseMetadata = element()

View File

@ -10,4 +10,4 @@ class RunInstanceInstanceSet(BaseXmlModel):
class RunInstanceResponse(BaseXmlModel):
request_id: str = element(tag="requestId")
instance_set: RunInstanceInstanceSet = element(tag="instanceSet")
instancesSet: RunInstanceInstanceSet = element(tag="instancesSet")

View File

@ -59,7 +59,7 @@ def _get_config() -> _OpenEC2Config:
insecure=False,
database=_OpenEC2DatabaseConfig(
url="sqlite:////home/alexander/openec2/db2.sqlite3",
debug=True,
debug=False,
),
)

View File

@ -1,12 +1,12 @@
from sqlmodel import SQLModel, Field, PrimaryKeyConstraint
from sqlmodel import SQLModel, Field
class VPC(SQLModel, table=True):
# ID of the VPC
id: str = Field(default=None, primary_key=True)
# Subnet mask
subnet: str
# Base IPv4
ipv4_base: str
cidr: str
# Owning user
owner_id: int = Field(foreign_key="user.id")

View File

@ -23,6 +23,12 @@ from openec2.actions.attach_volume import attach_volume
from openec2.actions.describe_volumes import describe_volumes
from openec2.actions.detach_volume import detach_volume
from openec2.db.instance import Instance
from openec2.actions.get_caller_identity import get_caller_identity
from openec2.actions.describe_instance_types import describe_instance_types
from openec2.actions.describe_tags import describe_tags
from openec2.actions.create_tags import create_tags
from openec2.actions.describe_instance_attribute import describe_instance_attribute
app = FastAPI()
@ -43,6 +49,7 @@ def healthz():
def get_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("GET Action")
return run_action(
request,
config,
@ -56,9 +63,13 @@ def get_action(
async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
print("POST Action")
body = (await request.body()).decode()
print(f"--> {body}")
query_params = {
key: value[0]
for key, value in parse_qs((await request.body()).decode()).items()
for key, value in parse_qs(body).items()
}
return run_action(
request,
@ -76,9 +87,8 @@ def run_action(
query_params: dict[str, str],
user: User,
):
print(query_params)
action = query_params["Action"]
return {
action_function = {
"ImportImage": import_image,
"DescribeImages": describe_images,
"RunInstances": run_instances,
@ -91,7 +101,21 @@ def run_action(
"AttachVolume": attach_volume,
"DescribeVolumes": describe_volumes,
"DetachVolume": detach_volume,
}[action](query_params, config, db, user)
"GetCallerIdentity": get_caller_identity,
"DescribeInstanceTypes": describe_instance_types,
"DescribeTags": describe_tags,
"CreateTags": create_tags,
"DescribeInstanceAttribute": describe_instance_attribute,
}.get(action)
if action_function is None:
print(f"Unknown action: '{action}'")
raise HTTPException(
status_code=404,
detail="Unknown action",
)
return action_function(query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@ -2,7 +2,7 @@ from typing import Annotated, cast
from dataclasses import dataclass
import datetime
from hashlib import sha256
from urllib.parse import quote, parse_qs
from urllib.parse import parse_qs
from sqlmodel import select
from fastapi import Request, HTTPException, Depends
@ -37,13 +37,20 @@ class AWSRequest:
# The payload, if we used a POST/PUT
payload: str | None
# List of headers that were signed
signed_headers: list[str]
def include_in_canonical_string(self, name: str) -> bool:
return name.lower() in self.signed_headers
# TODO: Probably vulnerable against a replay because we never get the current date ourselves
def sign(
self, secret_access_key: str, region: str, product: str, credential_scope: str
) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)]
[name for name in self.headers.keys() if self.include_in_canonical_string(name)]
)
canonical_header_string = (
"\n".join(
@ -67,8 +74,6 @@ class AWSRequest:
hashed_payload,
]
)
print("Canonical request")
print(canonical_request)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
date = dt.strftime("%Y%m%d")
@ -81,9 +86,6 @@ class AWSRequest:
]
)
print("String to sign")
print(string_to_sign)
date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode())
date_region_key = _hmac_sha256(date_key, region.encode())
date_region_service_key = _hmac_sha256(date_region_key, product.encode())
@ -99,10 +101,7 @@ class AWSAuthentication:
x_amz_signature: str
def include_in_canonical_string(name: str) -> bool:
lower = name.lower()
return lower in ("host", "content-type") or lower.startswith("x-amz")
signed_headers: str
def get_authentication_info(request: Request) -> AWSAuthentication:
@ -117,12 +116,14 @@ def get_authentication_info(request: Request) -> AWSAuthentication:
x_amz_algorithm=algorithm,
x_amz_credential=auth["Credential"],
x_amz_signature=auth["Signature"],
signed_headers=auth["SignedHeaders"],
)
return AWSAuthentication(
"",
"",
"",
"",
)
@ -142,7 +143,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException(
status_code=400,
status_code=403,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
)
@ -154,6 +155,16 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if user is None:
raise HTTPException(status_code=403)
# Validate the signed headers
signed_headers = auth_info.signed_headers.split(";")
if any(header not in signed_headers for header in [
"host",
"content-type",
]):
print("Validation of signed headers failed!")
raise HTTPException(status_code=403)
print(auth_info.x_amz_credential)
x_amz_signature = auth_info.x_amz_signature
signature = AWSRequest(
url=request.url,
@ -161,6 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
params=query_params,
headers=request.headers,
payload=body,
signed_headers=signed_headers,
).sign(
user.secret_access_key,
region,
@ -168,9 +180,11 @@ async def check_request_signature(request: Request, db: DatabaseDep):
"/".join([date, region, service, key]),
)
print(x_amz_signature, signature)
if x_amz_signature != signature:
raise HTTPException(status_code=401)
print("Signature mismatch!")
print(f"Expected: {signature}")
print(f"Got: {x_amz_signature}")
raise HTTPException(status_code=403)
return user

View File

@ -1,3 +1,6 @@
from typing import Callable
def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]:
items: dict[str, dict[str, str]] = {}
for key, value in params.items():
@ -25,3 +28,10 @@ def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]:
indices = sorted(list(items.keys()))
return [items[key] for key in indices]
def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
for item in l:
if pred(item):
return item
return None