diff --git a/.gitignore b/.gitignore index 505a3b1..ddaec25 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,9 @@ wheels/ # Virtual environments .venv + +# Terraform/Tofu +examples/tofu/**/.terraform +examples/tofu/**/.terraform.lock.hcl +examples/tofu/**/.terraform.tfstate +examples/tofu/**/.terraform.tfstate.backup diff --git a/examples/tofu/main.tf b/examples/tofu/main.tf new file mode 100644 index 0000000..e6e772f --- /dev/null +++ b/examples/tofu/main.tf @@ -0,0 +1,22 @@ +terraform { + required_providers { + aws = { + source = "hashicorp/aws" + version = "~> 5.0" + } + } +} + +provider "aws" { + region = "eu-west-1" +} + +resource "aws_instance" "test-instance" { + ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb" + instance_type = "micro" + availability_zone = "az-1" + + tags = { + TestTag = "TestValue" + } +} diff --git a/examples/tofu/terraform.tfstate b/examples/tofu/terraform.tfstate new file mode 100644 index 0000000..da837b1 --- /dev/null +++ b/examples/tofu/terraform.tfstate @@ -0,0 +1 @@ +{"version":4,"terraform_version":"1.9.0","serial":8,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[{"mode":"managed","type":"aws_instance","name":"test-instance","provider":"provider[\"registry.opentofu.org/hashicorp/aws\"]","instances":[{"schema_version":1,"attributes":{"ami":"0c4dcaafb6a14dbb93b402f1fd6a9dfb","arn":"arn:aws:ec2:eu-west-1:1:instance/19b606d3c0b543c991f2b0cd632013a2","associate_public_ip_address":false,"availability_zone":"az-1","capacity_reservation_specification":[],"cpu_core_count":null,"cpu_options":[],"cpu_threads_per_core":null,"credit_specification":[],"disable_api_stop":false,"disable_api_termination":false,"ebs_block_device":[],"ebs_optimized":false,"enable_primary_ipv6":null,"enclave_options":[],"ephemeral_block_device":[],"get_password_data":false,"hibernation":null,"host_id":null,"host_resource_group_arn":null,"iam_instance_profile":"","id":"19b606d3c0b543c991f2b0cd632013a2","instance_initiated_shutdown_behavior":null,"instance_lifecycle":"","instance_market_options":[],"instance_state":"running","instance_type":"","ipv6_address_count":0,"ipv6_addresses":[],"key_name":"","launch_template":[],"maintenance_options":[],"metadata_options":[],"monitoring":null,"network_interface":[],"outpost_arn":"","password_data":"","placement_group":null,"placement_partition_number":null,"primary_network_interface_id":"","private_dns":"","private_dns_name_options":[],"private_ip":"","public_dns":"","public_ip":"","root_block_device":[],"secondary_private_ips":[],"security_groups":[],"source_dest_check":true,"spot_instance_request_id":"","subnet_id":"","tags":{"TestTag":"TestValue"},"tags_all":{"TestTag":"TestValue"},"tenancy":null,"timeouts":null,"user_data":null,"user_data_base64":null,"user_data_replace_on_change":false,"volume_tags":null,"vpc_security_group_ids":[]},"sensitive_attributes":[],"private":"eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjo2MDAwMDAwMDAwMDAsImRlbGV0ZSI6MTIwMDAwMDAwMDAwMCwicmVhZCI6OTAwMDAwMDAwMDAwLCJ1cGRhdGUiOjYwMDAwMDAwMDAwMH0sInNjaGVtYV92ZXJzaW9uIjoiMSJ9"}]}],"check_results":null} diff --git a/examples/tofu/terraform.tfstate.backup b/examples/tofu/terraform.tfstate.backup new file mode 100644 index 0000000..a71968e --- /dev/null +++ b/examples/tofu/terraform.tfstate.backup @@ -0,0 +1 @@ +{"version":4,"terraform_version":"1.9.0","serial":7,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[],"check_results":null} diff --git a/src/openec2/actions/create_tags.py b/src/openec2/actions/create_tags.py new file mode 100644 index 0000000..decf99e --- /dev/null +++ b/src/openec2/actions/create_tags.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import cast +import uuid + +from fastapi import Response +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.user import User +from openec2.db.instance import Instance +from openec2.utils.array import parse_array_objects, parse_array_plain +from openec2.api.create_tags import CreateTagsResponse + + +@dataclass +class Tag: + key: str + value: str + + +def create_tags( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + tags: dict[str, str] = {} + for tag in parse_array_objects("Tag", cast(dict, params)): + tags[tag["Key"]] = tag["Value"] + + for instance_id in parse_array_plain("ResourceId", cast(dict, params)): + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() + if instance is None: + print(f"Unknown instance {instance_id}") + continue + instance.tags = { + **instance.tags, + **tags, + } + print(instance) + + db.commit() + return Response( + CreateTagsResponse( + requestId=uuid.uuid4().hex, + return_=True, + ).to_xml(), + media_type="application/xml", + ) diff --git a/src/openec2/actions/create_vpc.py b/src/openec2/actions/create_vpc.py new file mode 100644 index 0000000..f4b5935 --- /dev/null +++ b/src/openec2/actions/create_vpc.py @@ -0,0 +1,17 @@ +from fastapi import Response, HTTPException +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.db.user import User +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.vpc import VPC + + +def create_vpc( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + cidr_block = params["CidrBlock"] diff --git a/src/openec2/actions/describe_instance_attribute.py b/src/openec2/actions/describe_instance_attribute.py new file mode 100644 index 0000000..03dca47 --- /dev/null +++ b/src/openec2/actions/describe_instance_attribute.py @@ -0,0 +1,43 @@ +import uuid + +from fastapi import Response, HTTPException +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.user import User +from openec2.db.instance import Instance +from openec2.api.describe_instance_attribute import DescribeInstanceAttributeResponse, InstanceInitiatedShutdownBehaviour, DisableApiTermination, DisableApiStop + + +def describe_instance_attribute( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + instance_id = params["InstanceId"] + attribute = params["Attribute"] + instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first() + if instance is None: + raise HTTPException( + status_code=404, + ) + + return Response( + DescribeInstanceAttributeResponse( + requestId=uuid.uuid4().hex, + instanceId=instance.id, + instanceInitiatedShutdownBehaviour=InstanceInitiatedShutdownBehaviour( + value="stop", + ) if attribute == "instanceInitiatedShutdownBehaviour" else None, + disableApiTermination=DisableApiTermination( + value="false", + ) if attribute == "disableApiTermination" else None, + disableApiStop=DisableApiStop( + value="false", + ) if attribute == "disableApiStop" else None, + ).to_xml(), + media_type="application/xml", + ) diff --git a/src/openec2/actions/describe_instance_types.py b/src/openec2/actions/describe_instance_types.py new file mode 100644 index 0000000..33820cc --- /dev/null +++ b/src/openec2/actions/describe_instance_types.py @@ -0,0 +1,36 @@ +import uuid + +from fastapi import Response +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.user import User +from openec2.db.instance import Instance +from openec2.api.describe_instance_types import DescribeInstanceTypesResponse, InstanceTypeInfo + + +def describe_instance_types( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + response: list[InstanceTypeInfo] = [] + instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all() + for instance in instances: + response.append( + InstanceTypeInfo( + instanceType=instance.instanceType, + ), + ) + + return Response( + DescribeInstanceTypesResponse( + requestId=uuid.uuid4().hex, + instanceTypeSet=response, + nextToken=None, + ).to_xml(), + media_type="application/xml", + ) diff --git a/src/openec2/actions/describe_tags.py b/src/openec2/actions/describe_tags.py new file mode 100644 index 0000000..2dc8d44 --- /dev/null +++ b/src/openec2/actions/describe_tags.py @@ -0,0 +1,99 @@ +from typing import Literal, cast +from dataclasses import dataclass +import uuid + +from fastapi import Response +from fastapi.datastructures import QueryParams +from sqlmodel import select + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.user import User +from openec2.db.instance import Instance +from openec2.api.describe_tags import DescribeTagsResponse, TagDescription +from openec2.utils.array import parse_array_objects, parse_array_plain, find + + +@dataclass +class Filter: + name: Literal["resource-id"] | Literal["resource-type"] | Literal["key"] | str + values: list[str] + + def match(self, instance: Instance) -> bool: + if self.name not in instance.tags: + return False + + value: str | None + if self.name == "resource-type": + value = "instance" + elif self.name == "resource-id": + value = instance.id + elif self.name == "key": + return any(key in instance.tags for key in self.values) + elif self.name.startswith("tag:"): + value = instance.tags.get(self.name.replace("tag:", "")) + else: + raise Exception(f"Unknown filter name {self.name}") + + return value in self.values + + def value(self, instance: Instance) -> TagDescription: + if self.name == "resource-type": + return TagDescription(resourceType="instance") + elif self.name == "resource-id": + return TagDescription(resourceId=instance.id) + elif self.name == "key": + key_name = find( + self.values, + lambda x: x in instance.tags, + ) + assert key_name is not None + + key_value = instance.tags[key_name] + + return TagDescription( + key=key_name, + value=key_value, + ) + elif self.name.startswith("tag:"): + tag = self.name.replace("tag:", "") + return TagDescription( + key=tag, + value=instance.tags[tag], + ) + else: + raise Exception(f"Unknown filter name {self.name}") + +def describe_tags( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + filters: list[Filter] = [] + for filter in parse_array_objects("Filter", cast(dict, params)): + filters.append( + Filter( + name=filter["Name"], + values=parse_array_plain("Values", filter), + ) + ) + + response: list[TagDescription] = [] + for instance in db.exec(select(Instance).where(Instance.owner_id == user.id)).all(): + filter = find( + filters, + lambda f: f.match(instance), + ) + if filter is None: + continue + response.append(filter.value(instance)) + + return Response( + DescribeTagsResponse( + requestId=uuid.uuid4().hex, + nextToken=None, + tagSet=response, + ).to_xml(), + media_type="application/xml", + ) diff --git a/src/openec2/actions/get_caller_identity.py b/src/openec2/actions/get_caller_identity.py new file mode 100644 index 0000000..976e438 --- /dev/null +++ b/src/openec2/actions/get_caller_identity.py @@ -0,0 +1,30 @@ +import uuid + +from fastapi import Response +from fastapi.datastructures import QueryParams + +from openec2.config import OpenEC2Config +from openec2.db import DatabaseDep +from openec2.db.user import User +from openec2.api.get_caller_identity import GetCallerIdentityResponse, GetCallerIdentityResult, ResponseMetadata + + +def get_caller_identity( + params: QueryParams, + config: OpenEC2Config, + db: DatabaseDep, + user: User, +): + return Response( + GetCallerIdentityResponse( + result=GetCallerIdentityResult( + arn=f"arn:aws:iam::{user.id}:user/{user.name}", + user_id=str(user.id), + account=str(user.id), + ), + metadata=ResponseMetadata( + request_id=uuid.uuid4().hex, + ), + ).to_xml(), + media_type="application/xml", + ) diff --git a/src/openec2/actions/run_instances.py b/src/openec2/actions/run_instances.py index 6c23824..14c30f3 100644 --- a/src/openec2/actions/run_instances.py +++ b/src/openec2/actions/run_instances.py @@ -115,11 +115,11 @@ def run_instances( ) # Parse tags - # TODO: broken tags: dict[str, str] = {} for spec in parse_array_objects("TagSpecification", cast(dict, params)): for raw_tag in parse_array_objects("Tag", spec): tags[raw_tag["Key"]] = raw_tag["Value"] + print(f"Creating with tags {tags}") # Get a private IPv4 instance_id = uuid.uuid4().hex @@ -176,7 +176,7 @@ def run_instances( return Response( RunInstanceResponse( request_id=uuid.uuid4().hex, - instance_set=RunInstanceInstanceSet( + instancesSet=RunInstanceInstanceSet( item=[description], ), ).to_xml(), diff --git a/src/openec2/api/create_tags.py b/src/openec2/api/create_tags.py new file mode 100644 index 0000000..4597325 --- /dev/null +++ b/src/openec2/api/create_tags.py @@ -0,0 +1,10 @@ +from pydantic_xml import BaseXmlModel, element + + +class CreateTagsResponse( + BaseXmlModel, + nsmap={"": ""}, +): + requestId: str = element() + + return_: bool = element(tag="return") diff --git a/src/openec2/api/describe_instance_attribute.py b/src/openec2/api/describe_instance_attribute.py new file mode 100644 index 0000000..7c4ef40 --- /dev/null +++ b/src/openec2/api/describe_instance_attribute.py @@ -0,0 +1,26 @@ +from pydantic_xml import BaseXmlModel, element + + +class InstanceInitiatedShutdownBehaviour(BaseXmlModel): + value: str = element() + +class DisableApiTermination(BaseXmlModel): + value: str = element() + +class DisableApiStop(BaseXmlModel): + value: str = element() + +class DescribeInstanceAttributeResponse( + BaseXmlModel, + tag="DescribeInstanceAttributeResponse", + nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}, +): + requestId: str = element() + + instanceId: str = element() + + instanceInitiatedShutdownBehaviour: InstanceInitiatedShutdownBehaviour | None = element(defaut=None) + + disableApiTermination: DisableApiTermination | None = element(default=None) + + disableApiStop: DisableApiStop | None = element(default=None) diff --git a/src/openec2/api/describe_instance_types.py b/src/openec2/api/describe_instance_types.py new file mode 100644 index 0000000..d4f5508 --- /dev/null +++ b/src/openec2/api/describe_instance_types.py @@ -0,0 +1,18 @@ +from pydantic_xml import BaseXmlModel, element, wrapped + + +class InstanceTypeInfo(BaseXmlModel, tag="item"): + instanceType: str = element() + +class DescribeInstanceTypesResponse( + BaseXmlModel, + nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}, +): + requestId: str = element() + + instanceTypeSet: list[InstanceTypeInfo] = wrapped( + "instanceTypeSet", + element(tag="item"), + ) + + nextToken: str | None = element() diff --git a/src/openec2/api/describe_instances.py b/src/openec2/api/describe_instances.py index 3633664..e259a5e 100644 --- a/src/openec2/api/describe_instances.py +++ b/src/openec2/api/describe_instances.py @@ -1,4 +1,4 @@ -from pydantic_xml import BaseXmlModel, element +from pydantic_xml import BaseXmlModel, wrapped, element import libvirt from openec2.db.instance import Instance @@ -21,7 +21,7 @@ class InstanceDescription( instance_id: str = element(tag="instanceId") image_id: str = element(tag="imageId") instance_state: InstanceState = element(tag="instanceState") - tag_set: InstanceTagSet = element(tag="tagSet") + tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item")) class ReservationSetInstancesSet(BaseXmlModel): @@ -64,13 +64,11 @@ def describe_instance( instance_id=instance.id, image_id=instance.imageId, instance_state=describe_instance_state(domain), - tag_set=InstanceTagSet( - item=[ - InstanceTag( - key=key, - value=value, - ) - for key, value in instance.tags.items() - ], - ), + tagSet=[ + InstanceTag( + key=key, + value=value, + ) + for key, value in instance.tags.items() + ], ) diff --git a/src/openec2/api/describe_tags.py b/src/openec2/api/describe_tags.py new file mode 100644 index 0000000..4d02a96 --- /dev/null +++ b/src/openec2/api/describe_tags.py @@ -0,0 +1,22 @@ +from pydantic_xml import BaseXmlModel, wrapped, element + + +class TagDescription(BaseXmlModel): + key: str | None = element(default=None) + + resourceId: str | None = element(default=None) + + resourceType: str | None = element(default=None) + + value: str | None = element(default=None) + + +class DescribeTagsResponse( + BaseXmlModel, + nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"} +): + requestId: str = element() + + nextToken: str | None = element() + + tagSet: list[TagDescription] = wrapped("tagSet", element(tag="item")) diff --git a/src/openec2/api/get_caller_identity.py b/src/openec2/api/get_caller_identity.py new file mode 100644 index 0000000..ea92364 --- /dev/null +++ b/src/openec2/api/get_caller_identity.py @@ -0,0 +1,18 @@ +from pydantic_xml import BaseXmlModel, element + + +class GetCallerIdentityResult(BaseXmlModel): + arn: str = element(tag="Arn") + user_id: str = element(tag="UserId") + account: str = element(tag="Account") + +class ResponseMetadata(BaseXmlModel): + request_id: str = element(tag="RequestId") + +class GetCallerIdentityResponse( + BaseXmlModel, + nsmap={"": "https://sts.amazonaws.com/doc/2011-06-15/"} +): + result: GetCallerIdentityResult = element(tag="GetCallerIdentityResult") + + metadata: ResponseMetadata = element() diff --git a/src/openec2/api/run_instances.py b/src/openec2/api/run_instances.py index 5e37066..97ae48f 100644 --- a/src/openec2/api/run_instances.py +++ b/src/openec2/api/run_instances.py @@ -10,4 +10,4 @@ class RunInstanceInstanceSet(BaseXmlModel): class RunInstanceResponse(BaseXmlModel): request_id: str = element(tag="requestId") - instance_set: RunInstanceInstanceSet = element(tag="instanceSet") + instancesSet: RunInstanceInstanceSet = element(tag="instancesSet") diff --git a/src/openec2/config.py b/src/openec2/config.py index 0557413..092ff85 100644 --- a/src/openec2/config.py +++ b/src/openec2/config.py @@ -59,7 +59,7 @@ def _get_config() -> _OpenEC2Config: insecure=False, database=_OpenEC2DatabaseConfig( url="sqlite:////home/alexander/openec2/db2.sqlite3", - debug=True, + debug=False, ), ) diff --git a/src/openec2/db/vpc.py b/src/openec2/db/vpc.py index 478f60b..b913e90 100644 --- a/src/openec2/db/vpc.py +++ b/src/openec2/db/vpc.py @@ -1,12 +1,12 @@ -from sqlmodel import SQLModel, Field, PrimaryKeyConstraint +from sqlmodel import SQLModel, Field class VPC(SQLModel, table=True): # ID of the VPC id: str = Field(default=None, primary_key=True) - # Subnet mask - subnet: str - # Base IPv4 - ipv4_base: str + cidr: str + + # Owning user + owner_id: int = Field(foreign_key="user.id") diff --git a/src/openec2/main.py b/src/openec2/main.py index 6838532..0af39a2 100644 --- a/src/openec2/main.py +++ b/src/openec2/main.py @@ -23,6 +23,12 @@ from openec2.actions.attach_volume import attach_volume from openec2.actions.describe_volumes import describe_volumes from openec2.actions.detach_volume import detach_volume from openec2.db.instance import Instance +from openec2.actions.get_caller_identity import get_caller_identity +from openec2.actions.describe_instance_types import describe_instance_types +from openec2.actions.describe_tags import describe_tags +from openec2.actions.create_tags import create_tags +from openec2.actions.describe_instance_attribute import describe_instance_attribute + app = FastAPI() @@ -43,6 +49,7 @@ def healthz(): def get_action( request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature ): + print("GET Action") return run_action( request, config, @@ -56,9 +63,13 @@ def get_action( async def post_action( request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature ): + print("POST Action") + + body = (await request.body()).decode() + print(f"--> {body}") query_params = { key: value[0] - for key, value in parse_qs((await request.body()).decode()).items() + for key, value in parse_qs(body).items() } return run_action( request, @@ -76,9 +87,8 @@ def run_action( query_params: dict[str, str], user: User, ): - print(query_params) action = query_params["Action"] - return { + action_function = { "ImportImage": import_image, "DescribeImages": describe_images, "RunInstances": run_instances, @@ -91,7 +101,21 @@ def run_action( "AttachVolume": attach_volume, "DescribeVolumes": describe_volumes, "DetachVolume": detach_volume, - }[action](query_params, config, db, user) + "GetCallerIdentity": get_caller_identity, + "DescribeInstanceTypes": describe_instance_types, + "DescribeTags": describe_tags, + "CreateTags": create_tags, + "DescribeInstanceAttribute": describe_instance_attribute, + }.get(action) + if action_function is None: + print(f"Unknown action: '{action}'") + raise HTTPException( + status_code=404, + detail="Unknown action", + ) + + return action_function(query_params, config, db, user) + @app.get("/private/cloudinit/{instance_id}/{entry}") diff --git a/src/openec2/security/aws.py b/src/openec2/security/aws.py index 25128a7..034b50f 100644 --- a/src/openec2/security/aws.py +++ b/src/openec2/security/aws.py @@ -2,7 +2,7 @@ from typing import Annotated, cast from dataclasses import dataclass import datetime from hashlib import sha256 -from urllib.parse import quote, parse_qs +from urllib.parse import parse_qs from sqlmodel import select from fastapi import Request, HTTPException, Depends @@ -37,13 +37,20 @@ class AWSRequest: # The payload, if we used a POST/PUT payload: str | None + # List of headers that were signed + signed_headers: list[str] + + def include_in_canonical_string(self, name: str) -> bool: + return name.lower() in self.signed_headers + + # TODO: Probably vulnerable against a replay because we never get the current date ourselves def sign( self, secret_access_key: str, region: str, product: str, credential_scope: str ) -> str: dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"]) canonical_header_string_keys = sorted( - [name for name in self.headers.keys() if include_in_canonical_string(name)] + [name for name in self.headers.keys() if self.include_in_canonical_string(name)] ) canonical_header_string = ( "\n".join( @@ -67,8 +74,6 @@ class AWSRequest: hashed_payload, ] ) - print("Canonical request") - print(canonical_request) hashed_canonical_request = sha256(canonical_request.encode()).hexdigest() date = dt.strftime("%Y%m%d") @@ -81,9 +86,6 @@ class AWSRequest: ] ) - print("String to sign") - print(string_to_sign) - date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode()) date_region_key = _hmac_sha256(date_key, region.encode()) date_region_service_key = _hmac_sha256(date_region_key, product.encode()) @@ -99,10 +101,7 @@ class AWSAuthentication: x_amz_signature: str - -def include_in_canonical_string(name: str) -> bool: - lower = name.lower() - return lower in ("host", "content-type") or lower.startswith("x-amz") + signed_headers: str def get_authentication_info(request: Request) -> AWSAuthentication: @@ -117,12 +116,14 @@ def get_authentication_info(request: Request) -> AWSAuthentication: x_amz_algorithm=algorithm, x_amz_credential=auth["Credential"], x_amz_signature=auth["Signature"], + signed_headers=auth["SignedHeaders"], ) return AWSAuthentication( "", "", "", + "", ) @@ -142,7 +143,7 @@ async def check_request_signature(request: Request, db: DatabaseDep): if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256": raise HTTPException( - status_code=400, + status_code=403, detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}", ) @@ -154,6 +155,16 @@ async def check_request_signature(request: Request, db: DatabaseDep): if user is None: raise HTTPException(status_code=403) + # Validate the signed headers + signed_headers = auth_info.signed_headers.split(";") + if any(header not in signed_headers for header in [ + "host", + "content-type", + ]): + print("Validation of signed headers failed!") + raise HTTPException(status_code=403) + + print(auth_info.x_amz_credential) x_amz_signature = auth_info.x_amz_signature signature = AWSRequest( url=request.url, @@ -161,6 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep): params=query_params, headers=request.headers, payload=body, + signed_headers=signed_headers, ).sign( user.secret_access_key, region, @@ -168,9 +180,11 @@ async def check_request_signature(request: Request, db: DatabaseDep): "/".join([date, region, service, key]), ) - print(x_amz_signature, signature) if x_amz_signature != signature: - raise HTTPException(status_code=401) + print("Signature mismatch!") + print(f"Expected: {signature}") + print(f"Got: {x_amz_signature}") + raise HTTPException(status_code=403) return user diff --git a/src/openec2/utils/array.py b/src/openec2/utils/array.py index 60ab58a..cab6a6a 100644 --- a/src/openec2/utils/array.py +++ b/src/openec2/utils/array.py @@ -1,3 +1,6 @@ +from typing import Callable + + def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]: items: dict[str, dict[str, str]] = {} for key, value in params.items(): @@ -25,3 +28,10 @@ def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]: indices = sorted(list(items.keys())) return [items[key] for key in indices] + + +def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None: + for item in l: + if pred(item): + return item + return None