diff --git a/src/openec2/actions/deregister_image.py b/src/openec2/actions/deregister_image.py index 7d6c724..31510b4 100644 --- a/src/openec2/actions/deregister_image.py +++ b/src/openec2/actions/deregister_image.py @@ -4,6 +4,7 @@ from sqlmodel import select from openec2.config import OpenEC2Config from openec2.db import DatabaseDep +from openec2.db.user import User from openec2.db.image import AMI from openec2.images import garbage_collect_image @@ -12,6 +13,7 @@ def deregister_image( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + _: User, ): image_id = params["ImageId"] ami = db.exec(select(AMI).where(AMI.id == image_id)).one() diff --git a/src/openec2/actions/describe_images.py b/src/openec2/actions/describe_images.py index ac367fa..1df32ba 100644 --- a/src/openec2/actions/describe_images.py +++ b/src/openec2/actions/describe_images.py @@ -6,6 +6,7 @@ from sqlmodel import select from openec2.config import OpenEC2Config from openec2.db import DatabaseDep +from openec2.db.user import User from openec2.db.image import AMI from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image @@ -14,6 +15,7 @@ def describe_images( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + _: User, ): images: list[Image] = [] for ami in db.exec(select(AMI)).all(): diff --git a/src/openec2/actions/describe_instances.py b/src/openec2/actions/describe_instances.py index fc3148f..e11f0ac 100644 --- a/src/openec2/actions/describe_instances.py +++ b/src/openec2/actions/describe_instances.py @@ -13,6 +13,7 @@ from openec2.api.describe_instances import ( ReservationSetInstancesSet, describe_instance, ) +from openec2.db.user import User from openec2.api.shared import InstanceState from openec2.config import OpenEC2Config from openec2.db import DatabaseDep @@ -23,13 +24,17 @@ def describe_instances( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + user: User, ): response_items: list[InstanceDescription] = [] conn = LibvirtSingleton.of().connection for instance in db.exec(select(Instance)).all(): - dom = conn.lookupByName(instance.id) - running = dom.isActive() + # Check for permission issues + if instance.owner_id != user.id: + # TODO: Add the error to the response + continue + dom = conn.lookupByName(instance.id) response_items.append( describe_instance(instance, dom), ) diff --git a/src/openec2/actions/import_image.py b/src/openec2/actions/import_image.py index d8c9c9f..02ed6b5 100644 --- a/src/openec2/actions/import_image.py +++ b/src/openec2/actions/import_image.py @@ -9,6 +9,7 @@ import requests from openec2.config import OpenEC2Config from openec2.db import DatabaseDep +from openec2.db.user import User from openec2.db.image import AMI @@ -16,6 +17,7 @@ def import_image( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + _: User, ): first_disk_image_url = params["DiskContainer.1.Url"] url = urlparse(first_disk_image_url) diff --git a/src/openec2/actions/run_instances.py b/src/openec2/actions/run_instances.py index d492e8d..605c1ba 100644 --- a/src/openec2/actions/run_instances.py +++ b/src/openec2/actions/run_instances.py @@ -12,6 +12,7 @@ from openec2.utils.qemu import create_cow_copy from openec2.db import DatabaseDep from openec2.db.instance import Instance from openec2.db.image import AMI +from openec2.db.user import User from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet from openec2.api.describe_instances import describe_instance from openec2.utils.array import parse_array_objects @@ -94,6 +95,7 @@ def run_instances( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + user: User, ): image_id = params["ImageId"] instance_type = params["InstanceType"] @@ -151,6 +153,7 @@ def run_instances( else None, privateIPv4=private_ipv4, interfaceMac=mac, + owner_id=user.id, ) db.add(instance) print("Inserted new instance") diff --git a/src/openec2/actions/start_instances.py b/src/openec2/actions/start_instances.py index 724fcf7..56d09c9 100644 --- a/src/openec2/actions/start_instances.py +++ b/src/openec2/actions/start_instances.py @@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.instance import Instance +from openec2.db.user import User from openec2.utils.array import parse_array_plain from openec2.api.shared import InstanceInfo, InstancesSet from openec2.api.describe_instances import describe_instance_state @@ -18,6 +19,7 @@ def start_instances( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + user: User, ): conn = LibvirtSingleton.of().connection instances: list[InstanceInfo] = [] @@ -26,8 +28,12 @@ def start_instances( if instance is None: raise HTTPException(status_code=404, detail="Unknown instance") - dom = conn.lookupByName(instance_id) + # Check for permission issues + if instance.owner_id != user.id: + # TODO: Add the error to the response + continue + dom = conn.lookupByName(instance_id) prev_state = describe_instance_state(dom) if not dom.isActive(): dom.create() diff --git a/src/openec2/actions/stop_instances.py b/src/openec2/actions/stop_instances.py index 0d3d83a..039f38c 100644 --- a/src/openec2/actions/stop_instances.py +++ b/src/openec2/actions/stop_instances.py @@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.instance import Instance +from openec2.db.user import User from openec2.utils.array import parse_array_plain from openec2.api.shared import InstanceInfo, InstancesSet from openec2.api.describe_instances import describe_instance_state @@ -18,6 +19,7 @@ def stop_instances( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + user: User, ): conn = LibvirtSingleton.of().connection instances: list[InstanceInfo] = [] @@ -26,6 +28,11 @@ def stop_instances( if instance is None: raise HTTPException(status_code=404, detail="Unknown instance") + # Check for permission issues + if instance.owner_id != user.id: + # TODO: Add the error to the response + continue + dom = conn.lookupByName(instance_id) running = dom.isActive() prev_state = describe_instance_state(dom) diff --git a/src/openec2/actions/terminate_instances.py b/src/openec2/actions/terminate_instances.py index faa4100..0441c95 100644 --- a/src/openec2/actions/terminate_instances.py +++ b/src/openec2/actions/terminate_instances.py @@ -9,9 +9,11 @@ from openec2.libvirt import LibvirtSingleton from openec2.config import OpenEC2Config from openec2.db import DatabaseDep from openec2.db.instance import Instance +from openec2.db.user import User from openec2.utils.array import parse_array_plain from openec2.images import garbage_collect_image from openec2.ipam import remove_instance_dhcp_mapping +from openec2.api.shared import InstanceInfo from openec2.api.describe_instances import describe_instance_state from openec2.api.terminate_instances import TerminateInstancesResponse, InstancesSet @@ -23,9 +25,9 @@ def terminate_instances( params: QueryParams, config: OpenEC2Config, db: DatabaseDep, + user: User, ): instances: list[InstanceInfo] = [] - conn = LibvirtSingleton.of().connection image_ids: set[str] = set() for instance_id in parse_array_plain("InstanceId", params): @@ -34,6 +36,11 @@ def terminate_instances( continue # raise HTTPException(status_code=404, detail="Unknown instance") + # Check for permission issues + if instance.owner_id != user.id: + # TODO: Add the error to the response + continue + dom = conn.lookupByName(instance_id) prev_state = describe_instance_state(dom) if dom.isActive(): diff --git a/src/openec2/db/instance.py b/src/openec2/db/instance.py index b0a0e33..6dfff87 100644 --- a/src/openec2/db/instance.py +++ b/src/openec2/db/instance.py @@ -18,3 +18,6 @@ class Instance(SQLModel, table=True): # Private IPv4 of the instance privateIPv4: str + + # The owner that creatd the resource. + owner_id: int = Field(foreign_key="user.id") diff --git a/src/openec2/main.py b/src/openec2/main.py index ddf7841..b454aa8 100644 --- a/src/openec2/main.py +++ b/src/openec2/main.py @@ -9,6 +9,7 @@ from openec2.security.aws import AWSSignature from openec2.utils.text import multiline_yaml_response from openec2.config import OpenEC2Config from openec2.db import DatabaseDep, engine +from openec2.db.user import User from openec2.actions.describe_images import describe_images from openec2.actions.import_image import import_image from openec2.actions.describe_instances import describe_instances @@ -35,19 +36,33 @@ def healthz(): @app.get("/Action", response_model=None) -def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature): - return run_action(request, config, db, cast(dict, request.query_params)) +def get_action( + request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature +): + return run_action( + request, + config, + db, + cast(dict, request.query_params), + user, + ) @app.post("/Action", response_model=None) -async def test( - request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature +async def post_action( + request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature ): query_params = { key: value[0] for key, value in parse_qs((await request.body()).decode()).items() } - return run_action(request, config, db, cast(dict, query_params)) + return run_action( + request, + config, + db, + cast(dict, query_params), + user, + ) def run_action( @@ -55,6 +70,7 @@ def run_action( config: OpenEC2Config, db: DatabaseDep, query_params: dict[str, str], + user: User, ): print(query_params) action = query_params["Action"] @@ -67,7 +83,7 @@ def run_action( "StartInstances": start_instances, "StopInstances": stop_instances, "DeregisterImage": deregister_image, - }[action](query_params, config, db) + }[action](query_params, config, db, user) @app.get("/private/cloudinit/{instance_id}/{entry}") diff --git a/src/openec2/security/aws.py b/src/openec2/security/aws.py index 6012c4a..25128a7 100644 --- a/src/openec2/security/aws.py +++ b/src/openec2/security/aws.py @@ -6,7 +6,7 @@ from urllib.parse import quote, parse_qs from sqlmodel import select from fastapi import Request, HTTPException, Depends -from fastapi.datastructures import QueryParams, URL, Headers +from fastapi.datastructures import URL, Headers from cryptography.hazmat.primitives import hashes, hmac from openec2.config import ConfigSingleton @@ -42,14 +42,6 @@ class AWSRequest: ) -> str: dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"]) - canonical_query_string_keys = sorted(self.params.keys()) - canonical_query_string = "&".join( - [ - f"{key}={quote(self.params[key][0])}" - for key in canonical_query_string_keys - if key not in ("X-Amz-Signature",) - ] - ) canonical_header_string_keys = sorted( [name for name in self.headers.keys() if include_in_canonical_string(name)] ) @@ -150,7 +142,8 @@ async def check_request_signature(request: Request, db: DatabaseDep): if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256": raise HTTPException( - status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}" + status_code=400, + detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}", ) x_amz_credential = auth_info.x_amz_credential @@ -179,5 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep): if x_amz_signature != signature: raise HTTPException(status_code=401) + return user -AWSSignature = Annotated[None, Depends(check_request_signature)] + +AWSSignature = Annotated[User, Depends(check_request_signature)]