Make Terraform/tofu work

This commit is contained in:
PapaTutuWawa 2025-04-06 14:55:58 +02:00
parent 697c89bb4f
commit 38d37a7d5b
23 changed files with 484 additions and 38 deletions

6
.gitignore vendored
View File

@ -8,3 +8,9 @@ wheels/
# Virtual environments # Virtual environments
.venv .venv
# Terraform/Tofu
examples/tofu/**/.terraform
examples/tofu/**/.terraform.lock.hcl
examples/tofu/**/.terraform.tfstate
examples/tofu/**/.terraform.tfstate.backup

22
examples/tofu/main.tf Normal file
View File

@ -0,0 +1,22 @@
terraform {
required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.0"
}
}
}
provider "aws" {
region = "eu-west-1"
}
resource "aws_instance" "test-instance" {
ami = "0c4dcaafb6a14dbb93b402f1fd6a9dfb"
instance_type = "micro"
availability_zone = "az-1"
tags = {
TestTag = "TestValue"
}
}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":8,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[{"mode":"managed","type":"aws_instance","name":"test-instance","provider":"provider[\"registry.opentofu.org/hashicorp/aws\"]","instances":[{"schema_version":1,"attributes":{"ami":"0c4dcaafb6a14dbb93b402f1fd6a9dfb","arn":"arn:aws:ec2:eu-west-1:1:instance/19b606d3c0b543c991f2b0cd632013a2","associate_public_ip_address":false,"availability_zone":"az-1","capacity_reservation_specification":[],"cpu_core_count":null,"cpu_options":[],"cpu_threads_per_core":null,"credit_specification":[],"disable_api_stop":false,"disable_api_termination":false,"ebs_block_device":[],"ebs_optimized":false,"enable_primary_ipv6":null,"enclave_options":[],"ephemeral_block_device":[],"get_password_data":false,"hibernation":null,"host_id":null,"host_resource_group_arn":null,"iam_instance_profile":"","id":"19b606d3c0b543c991f2b0cd632013a2","instance_initiated_shutdown_behavior":null,"instance_lifecycle":"","instance_market_options":[],"instance_state":"running","instance_type":"","ipv6_address_count":0,"ipv6_addresses":[],"key_name":"","launch_template":[],"maintenance_options":[],"metadata_options":[],"monitoring":null,"network_interface":[],"outpost_arn":"","password_data":"","placement_group":null,"placement_partition_number":null,"primary_network_interface_id":"","private_dns":"","private_dns_name_options":[],"private_ip":"","public_dns":"","public_ip":"","root_block_device":[],"secondary_private_ips":[],"security_groups":[],"source_dest_check":true,"spot_instance_request_id":"","subnet_id":"","tags":{"TestTag":"TestValue"},"tags_all":{"TestTag":"TestValue"},"tenancy":null,"timeouts":null,"user_data":null,"user_data_base64":null,"user_data_replace_on_change":false,"volume_tags":null,"vpc_security_group_ids":[]},"sensitive_attributes":[],"private":"eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjo2MDAwMDAwMDAwMDAsImRlbGV0ZSI6MTIwMDAwMDAwMDAwMCwicmVhZCI6OTAwMDAwMDAwMDAwLCJ1cGRhdGUiOjYwMDAwMDAwMDAwMH0sInNjaGVtYV92ZXJzaW9uIjoiMSJ9"}]}],"check_results":null}

View File

@ -0,0 +1 @@
{"version":4,"terraform_version":"1.9.0","serial":7,"lineage":"a013da38-6954-7573-33dd-c05f6b0ec61f","outputs":{},"resources":[],"check_results":null}

View File

@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import cast
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.utils.array import parse_array_objects, parse_array_plain
from openec2.api.create_tags import CreateTagsResponse
@dataclass
class Tag:
key: str
value: str
def create_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
tags: dict[str, str] = {}
for tag in parse_array_objects("Tag", cast(dict, params)):
tags[tag["Key"]] = tag["Value"]
for instance_id in parse_array_plain("ResourceId", cast(dict, params)):
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
print(f"Unknown instance {instance_id}")
continue
instance.tags = {
**instance.tags,
**tags,
}
print(instance)
db.commit()
return Response(
CreateTagsResponse(
requestId=uuid.uuid4().hex,
return_=True,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,17 @@
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.db.user import User
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.vpc import VPC
def create_vpc(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
cidr_block = params["CidrBlock"]

View File

@ -0,0 +1,43 @@
import uuid
from fastapi import Response, HTTPException
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_attribute import DescribeInstanceAttributeResponse, InstanceInitiatedShutdownBehaviour, DisableApiTermination, DisableApiStop
def describe_instance_attribute(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instance_id = params["InstanceId"]
attribute = params["Attribute"]
instance = db.exec(select(Instance).where(Instance.id == instance_id, Instance.owner_id == user.id)).first()
if instance is None:
raise HTTPException(
status_code=404,
)
return Response(
DescribeInstanceAttributeResponse(
requestId=uuid.uuid4().hex,
instanceId=instance.id,
instanceInitiatedShutdownBehaviour=InstanceInitiatedShutdownBehaviour(
value="stop",
) if attribute == "instanceInitiatedShutdownBehaviour" else None,
disableApiTermination=DisableApiTermination(
value="false",
) if attribute == "disableApiTermination" else None,
disableApiStop=DisableApiStop(
value="false",
) if attribute == "disableApiStop" else None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,36 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_instance_types import DescribeInstanceTypesResponse, InstanceTypeInfo
def describe_instance_types(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
response: list[InstanceTypeInfo] = []
instances = db.exec(select(Instance).where(Instance.owner_id == user.id)).all()
for instance in instances:
response.append(
InstanceTypeInfo(
instanceType=instance.instanceType,
),
)
return Response(
DescribeInstanceTypesResponse(
requestId=uuid.uuid4().hex,
instanceTypeSet=response,
nextToken=None,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,99 @@
from typing import Literal, cast
from dataclasses import dataclass
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.instance import Instance
from openec2.api.describe_tags import DescribeTagsResponse, TagDescription
from openec2.utils.array import parse_array_objects, parse_array_plain, find
@dataclass
class Filter:
name: Literal["resource-id"] | Literal["resource-type"] | Literal["key"] | str
values: list[str]
def match(self, instance: Instance) -> bool:
if self.name not in instance.tags:
return False
value: str | None
if self.name == "resource-type":
value = "instance"
elif self.name == "resource-id":
value = instance.id
elif self.name == "key":
return any(key in instance.tags for key in self.values)
elif self.name.startswith("tag:"):
value = instance.tags.get(self.name.replace("tag:", ""))
else:
raise Exception(f"Unknown filter name {self.name}")
return value in self.values
def value(self, instance: Instance) -> TagDescription:
if self.name == "resource-type":
return TagDescription(resourceType="instance")
elif self.name == "resource-id":
return TagDescription(resourceId=instance.id)
elif self.name == "key":
key_name = find(
self.values,
lambda x: x in instance.tags,
)
assert key_name is not None
key_value = instance.tags[key_name]
return TagDescription(
key=key_name,
value=key_value,
)
elif self.name.startswith("tag:"):
tag = self.name.replace("tag:", "")
return TagDescription(
key=tag,
value=instance.tags[tag],
)
else:
raise Exception(f"Unknown filter name {self.name}")
def describe_tags(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
filters: list[Filter] = []
for filter in parse_array_objects("Filter", cast(dict, params)):
filters.append(
Filter(
name=filter["Name"],
values=parse_array_plain("Values", filter),
)
)
response: list[TagDescription] = []
for instance in db.exec(select(Instance).where(Instance.owner_id == user.id)).all():
filter = find(
filters,
lambda f: f.match(instance),
)
if filter is None:
continue
response.append(filter.value(instance))
return Response(
DescribeTagsResponse(
requestId=uuid.uuid4().hex,
nextToken=None,
tagSet=response,
).to_xml(),
media_type="application/xml",
)

View File

@ -0,0 +1,30 @@
import uuid
from fastapi import Response
from fastapi.datastructures import QueryParams
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.api.get_caller_identity import GetCallerIdentityResponse, GetCallerIdentityResult, ResponseMetadata
def get_caller_identity(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
return Response(
GetCallerIdentityResponse(
result=GetCallerIdentityResult(
arn=f"arn:aws:iam::{user.id}:user/{user.name}",
user_id=str(user.id),
account=str(user.id),
),
metadata=ResponseMetadata(
request_id=uuid.uuid4().hex,
),
).to_xml(),
media_type="application/xml",
)

View File

@ -115,11 +115,11 @@ def run_instances(
) )
# Parse tags # Parse tags
# TODO: broken
tags: dict[str, str] = {} tags: dict[str, str] = {}
for spec in parse_array_objects("TagSpecification", cast(dict, params)): for spec in parse_array_objects("TagSpecification", cast(dict, params)):
for raw_tag in parse_array_objects("Tag", spec): for raw_tag in parse_array_objects("Tag", spec):
tags[raw_tag["Key"]] = raw_tag["Value"] tags[raw_tag["Key"]] = raw_tag["Value"]
print(f"Creating with tags {tags}")
# Get a private IPv4 # Get a private IPv4
instance_id = uuid.uuid4().hex instance_id = uuid.uuid4().hex
@ -176,7 +176,7 @@ def run_instances(
return Response( return Response(
RunInstanceResponse( RunInstanceResponse(
request_id=uuid.uuid4().hex, request_id=uuid.uuid4().hex,
instance_set=RunInstanceInstanceSet( instancesSet=RunInstanceInstanceSet(
item=[description], item=[description],
), ),
).to_xml(), ).to_xml(),

View File

@ -0,0 +1,10 @@
from pydantic_xml import BaseXmlModel, element
class CreateTagsResponse(
BaseXmlModel,
nsmap={"": ""},
):
requestId: str = element()
return_: bool = element(tag="return")

View File

@ -0,0 +1,26 @@
from pydantic_xml import BaseXmlModel, element
class InstanceInitiatedShutdownBehaviour(BaseXmlModel):
value: str = element()
class DisableApiTermination(BaseXmlModel):
value: str = element()
class DisableApiStop(BaseXmlModel):
value: str = element()
class DescribeInstanceAttributeResponse(
BaseXmlModel,
tag="DescribeInstanceAttributeResponse",
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceId: str = element()
instanceInitiatedShutdownBehaviour: InstanceInitiatedShutdownBehaviour | None = element(defaut=None)
disableApiTermination: DisableApiTermination | None = element(default=None)
disableApiStop: DisableApiStop | None = element(default=None)

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element, wrapped
class InstanceTypeInfo(BaseXmlModel, tag="item"):
instanceType: str = element()
class DescribeInstanceTypesResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"},
):
requestId: str = element()
instanceTypeSet: list[InstanceTypeInfo] = wrapped(
"instanceTypeSet",
element(tag="item"),
)
nextToken: str | None = element()

View File

@ -1,4 +1,4 @@
from pydantic_xml import BaseXmlModel, element from pydantic_xml import BaseXmlModel, wrapped, element
import libvirt import libvirt
from openec2.db.instance import Instance from openec2.db.instance import Instance
@ -21,7 +21,7 @@ class InstanceDescription(
instance_id: str = element(tag="instanceId") instance_id: str = element(tag="instanceId")
image_id: str = element(tag="imageId") image_id: str = element(tag="imageId")
instance_state: InstanceState = element(tag="instanceState") instance_state: InstanceState = element(tag="instanceState")
tag_set: InstanceTagSet = element(tag="tagSet") tagSet: list[InstanceTag] = wrapped("tagSet", element(tag="item"))
class ReservationSetInstancesSet(BaseXmlModel): class ReservationSetInstancesSet(BaseXmlModel):
@ -64,13 +64,11 @@ def describe_instance(
instance_id=instance.id, instance_id=instance.id,
image_id=instance.imageId, image_id=instance.imageId,
instance_state=describe_instance_state(domain), instance_state=describe_instance_state(domain),
tag_set=InstanceTagSet( tagSet=[
item=[
InstanceTag( InstanceTag(
key=key, key=key,
value=value, value=value,
) )
for key, value in instance.tags.items() for key, value in instance.tags.items()
], ],
),
) )

View File

@ -0,0 +1,22 @@
from pydantic_xml import BaseXmlModel, wrapped, element
class TagDescription(BaseXmlModel):
key: str | None = element(default=None)
resourceId: str | None = element(default=None)
resourceType: str | None = element(default=None)
value: str | None = element(default=None)
class DescribeTagsResponse(
BaseXmlModel,
nsmap={"": "http://ec2.amazonaws.com/doc/2016-11-15/"}
):
requestId: str = element()
nextToken: str | None = element()
tagSet: list[TagDescription] = wrapped("tagSet", element(tag="item"))

View File

@ -0,0 +1,18 @@
from pydantic_xml import BaseXmlModel, element
class GetCallerIdentityResult(BaseXmlModel):
arn: str = element(tag="Arn")
user_id: str = element(tag="UserId")
account: str = element(tag="Account")
class ResponseMetadata(BaseXmlModel):
request_id: str = element(tag="RequestId")
class GetCallerIdentityResponse(
BaseXmlModel,
nsmap={"": "https://sts.amazonaws.com/doc/2011-06-15/"}
):
result: GetCallerIdentityResult = element(tag="GetCallerIdentityResult")
metadata: ResponseMetadata = element()

View File

@ -10,4 +10,4 @@ class RunInstanceInstanceSet(BaseXmlModel):
class RunInstanceResponse(BaseXmlModel): class RunInstanceResponse(BaseXmlModel):
request_id: str = element(tag="requestId") request_id: str = element(tag="requestId")
instance_set: RunInstanceInstanceSet = element(tag="instanceSet") instancesSet: RunInstanceInstanceSet = element(tag="instancesSet")

View File

@ -59,7 +59,7 @@ def _get_config() -> _OpenEC2Config:
insecure=False, insecure=False,
database=_OpenEC2DatabaseConfig( database=_OpenEC2DatabaseConfig(
url="sqlite:////home/alexander/openec2/db2.sqlite3", url="sqlite:////home/alexander/openec2/db2.sqlite3",
debug=True, debug=False,
), ),
) )

View File

@ -1,12 +1,12 @@
from sqlmodel import SQLModel, Field, PrimaryKeyConstraint from sqlmodel import SQLModel, Field
class VPC(SQLModel, table=True): class VPC(SQLModel, table=True):
# ID of the VPC # ID of the VPC
id: str = Field(default=None, primary_key=True) id: str = Field(default=None, primary_key=True)
# Subnet mask
subnet: str
# Base IPv4 # Base IPv4
ipv4_base: str cidr: str
# Owning user
owner_id: int = Field(foreign_key="user.id")

View File

@ -23,6 +23,12 @@ from openec2.actions.attach_volume import attach_volume
from openec2.actions.describe_volumes import describe_volumes from openec2.actions.describe_volumes import describe_volumes
from openec2.actions.detach_volume import detach_volume from openec2.actions.detach_volume import detach_volume
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.actions.get_caller_identity import get_caller_identity
from openec2.actions.describe_instance_types import describe_instance_types
from openec2.actions.describe_tags import describe_tags
from openec2.actions.create_tags import create_tags
from openec2.actions.describe_instance_attribute import describe_instance_attribute
app = FastAPI() app = FastAPI()
@ -43,6 +49,7 @@ def healthz():
def get_action( def get_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
): ):
print("GET Action")
return run_action( return run_action(
request, request,
config, config,
@ -56,9 +63,13 @@ def get_action(
async def post_action( async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
): ):
print("POST Action")
body = (await request.body()).decode()
print(f"--> {body}")
query_params = { query_params = {
key: value[0] key: value[0]
for key, value in parse_qs((await request.body()).decode()).items() for key, value in parse_qs(body).items()
} }
return run_action( return run_action(
request, request,
@ -76,9 +87,8 @@ def run_action(
query_params: dict[str, str], query_params: dict[str, str],
user: User, user: User,
): ):
print(query_params)
action = query_params["Action"] action = query_params["Action"]
return { action_function = {
"ImportImage": import_image, "ImportImage": import_image,
"DescribeImages": describe_images, "DescribeImages": describe_images,
"RunInstances": run_instances, "RunInstances": run_instances,
@ -91,7 +101,21 @@ def run_action(
"AttachVolume": attach_volume, "AttachVolume": attach_volume,
"DescribeVolumes": describe_volumes, "DescribeVolumes": describe_volumes,
"DetachVolume": detach_volume, "DetachVolume": detach_volume,
}[action](query_params, config, db, user) "GetCallerIdentity": get_caller_identity,
"DescribeInstanceTypes": describe_instance_types,
"DescribeTags": describe_tags,
"CreateTags": create_tags,
"DescribeInstanceAttribute": describe_instance_attribute,
}.get(action)
if action_function is None:
print(f"Unknown action: '{action}'")
raise HTTPException(
status_code=404,
detail="Unknown action",
)
return action_function(query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}") @app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@ -2,7 +2,7 @@ from typing import Annotated, cast
from dataclasses import dataclass from dataclasses import dataclass
import datetime import datetime
from hashlib import sha256 from hashlib import sha256
from urllib.parse import quote, parse_qs from urllib.parse import parse_qs
from sqlmodel import select from sqlmodel import select
from fastapi import Request, HTTPException, Depends from fastapi import Request, HTTPException, Depends
@ -37,13 +37,20 @@ class AWSRequest:
# The payload, if we used a POST/PUT # The payload, if we used a POST/PUT
payload: str | None payload: str | None
# List of headers that were signed
signed_headers: list[str]
def include_in_canonical_string(self, name: str) -> bool:
return name.lower() in self.signed_headers
# TODO: Probably vulnerable against a replay because we never get the current date ourselves
def sign( def sign(
self, secret_access_key: str, region: str, product: str, credential_scope: str self, secret_access_key: str, region: str, product: str, credential_scope: str
) -> str: ) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"]) dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_header_string_keys = sorted( canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)] [name for name in self.headers.keys() if self.include_in_canonical_string(name)]
) )
canonical_header_string = ( canonical_header_string = (
"\n".join( "\n".join(
@ -67,8 +74,6 @@ class AWSRequest:
hashed_payload, hashed_payload,
] ]
) )
print("Canonical request")
print(canonical_request)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest() hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
date = dt.strftime("%Y%m%d") date = dt.strftime("%Y%m%d")
@ -81,9 +86,6 @@ class AWSRequest:
] ]
) )
print("String to sign")
print(string_to_sign)
date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode()) date_key = _hmac_sha256(f"AWS4{secret_access_key}".encode(), date.encode())
date_region_key = _hmac_sha256(date_key, region.encode()) date_region_key = _hmac_sha256(date_key, region.encode())
date_region_service_key = _hmac_sha256(date_region_key, product.encode()) date_region_service_key = _hmac_sha256(date_region_key, product.encode())
@ -99,10 +101,7 @@ class AWSAuthentication:
x_amz_signature: str x_amz_signature: str
signed_headers: str
def include_in_canonical_string(name: str) -> bool:
lower = name.lower()
return lower in ("host", "content-type") or lower.startswith("x-amz")
def get_authentication_info(request: Request) -> AWSAuthentication: def get_authentication_info(request: Request) -> AWSAuthentication:
@ -117,12 +116,14 @@ def get_authentication_info(request: Request) -> AWSAuthentication:
x_amz_algorithm=algorithm, x_amz_algorithm=algorithm,
x_amz_credential=auth["Credential"], x_amz_credential=auth["Credential"],
x_amz_signature=auth["Signature"], x_amz_signature=auth["Signature"],
signed_headers=auth["SignedHeaders"],
) )
return AWSAuthentication( return AWSAuthentication(
"", "",
"", "",
"", "",
"",
) )
@ -142,7 +143,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256": if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException( raise HTTPException(
status_code=400, status_code=403,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}", detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
) )
@ -154,6 +155,16 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if user is None: if user is None:
raise HTTPException(status_code=403) raise HTTPException(status_code=403)
# Validate the signed headers
signed_headers = auth_info.signed_headers.split(";")
if any(header not in signed_headers for header in [
"host",
"content-type",
]):
print("Validation of signed headers failed!")
raise HTTPException(status_code=403)
print(auth_info.x_amz_credential)
x_amz_signature = auth_info.x_amz_signature x_amz_signature = auth_info.x_amz_signature
signature = AWSRequest( signature = AWSRequest(
url=request.url, url=request.url,
@ -161,6 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
params=query_params, params=query_params,
headers=request.headers, headers=request.headers,
payload=body, payload=body,
signed_headers=signed_headers,
).sign( ).sign(
user.secret_access_key, user.secret_access_key,
region, region,
@ -168,9 +180,11 @@ async def check_request_signature(request: Request, db: DatabaseDep):
"/".join([date, region, service, key]), "/".join([date, region, service, key]),
) )
print(x_amz_signature, signature)
if x_amz_signature != signature: if x_amz_signature != signature:
raise HTTPException(status_code=401) print("Signature mismatch!")
print(f"Expected: {signature}")
print(f"Got: {x_amz_signature}")
raise HTTPException(status_code=403)
return user return user

View File

@ -1,3 +1,6 @@
from typing import Callable
def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]: def parse_array_objects(prefix: str, params: dict) -> list[dict[str, str]]:
items: dict[str, dict[str, str]] = {} items: dict[str, dict[str, str]] = {}
for key, value in params.items(): for key, value in params.items():
@ -25,3 +28,10 @@ def parse_array_plain(prefix: str, params: dict[str, str]) -> list[str]:
indices = sorted(list(items.keys())) indices = sorted(list(items.keys()))
return [items[key] for key in indices] return [items[key] for key in indices]
def find[T](l: list[T], pred: Callable[[T], bool]) -> T | None:
for item in l:
if pred(item):
return item
return None