Add user checking for instance actions

This commit is contained in:
PapaTutuWawa 2025-03-31 00:09:29 +02:00
parent 9d4867d74e
commit 1128d73bee
11 changed files with 69 additions and 21 deletions

View File

@ -4,6 +4,7 @@ from sqlmodel import select
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.images import garbage_collect_image from openec2.images import garbage_collect_image
@ -12,6 +13,7 @@ def deregister_image(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
_: User,
): ):
image_id = params["ImageId"] image_id = params["ImageId"]
ami = db.exec(select(AMI).where(AMI.id == image_id)).one() ami = db.exec(select(AMI).where(AMI.id == image_id)).one()

View File

@ -6,6 +6,7 @@ from sqlmodel import select
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image
@ -14,6 +15,7 @@ def describe_images(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
_: User,
): ):
images: list[Image] = [] images: list[Image] = []
for ami in db.exec(select(AMI)).all(): for ami in db.exec(select(AMI)).all():

View File

@ -13,6 +13,7 @@ from openec2.api.describe_instances import (
ReservationSetInstancesSet, ReservationSetInstancesSet,
describe_instance, describe_instance,
) )
from openec2.db.user import User
from openec2.api.shared import InstanceState from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
@ -23,13 +24,17 @@ def describe_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
user: User,
): ):
response_items: list[InstanceDescription] = [] response_items: list[InstanceDescription] = []
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
for instance in db.exec(select(Instance)).all(): for instance in db.exec(select(Instance)).all():
dom = conn.lookupByName(instance.id) # Check for permission issues
running = dom.isActive() if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance.id)
response_items.append( response_items.append(
describe_instance(instance, dom), describe_instance(instance, dom),
) )

View File

@ -9,6 +9,7 @@ import requests
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI from openec2.db.image import AMI
@ -16,6 +17,7 @@ def import_image(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
_: User,
): ):
first_disk_image_url = params["DiskContainer.1.Url"] first_disk_image_url = params["DiskContainer.1.Url"]
url = urlparse(first_disk_image_url) url = urlparse(first_disk_image_url)

View File

@ -12,6 +12,7 @@ from openec2.utils.qemu import create_cow_copy
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.db.image import AMI from openec2.db.image import AMI
from openec2.db.user import User
from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet
from openec2.api.describe_instances import describe_instance from openec2.api.describe_instances import describe_instance
from openec2.utils.array import parse_array_objects from openec2.utils.array import parse_array_objects
@ -94,6 +95,7 @@ def run_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
user: User,
): ):
image_id = params["ImageId"] image_id = params["ImageId"]
instance_type = params["InstanceType"] instance_type = params["InstanceType"]
@ -151,6 +153,7 @@ def run_instances(
else None, else None,
privateIPv4=private_ipv4, privateIPv4=private_ipv4,
interfaceMac=mac, interfaceMac=mac,
owner_id=user.id,
) )
db.add(instance) db.add(instance)
print("Inserted new instance") print("Inserted new instance")

View File

@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain from openec2.utils.array import parse_array_plain
from openec2.api.shared import InstanceInfo, InstancesSet from openec2.api.shared import InstanceInfo, InstancesSet
from openec2.api.describe_instances import describe_instance_state from openec2.api.describe_instances import describe_instance_state
@ -18,6 +19,7 @@ def start_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
user: User,
): ):
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = [] instances: list[InstanceInfo] = []
@ -26,8 +28,12 @@ def start_instances(
if instance is None: if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance") raise HTTPException(status_code=404, detail="Unknown instance")
dom = conn.lookupByName(instance_id) # Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom) prev_state = describe_instance_state(dom)
if not dom.isActive(): if not dom.isActive():
dom.create() dom.create()

View File

@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain from openec2.utils.array import parse_array_plain
from openec2.api.shared import InstanceInfo, InstancesSet from openec2.api.shared import InstanceInfo, InstancesSet
from openec2.api.describe_instances import describe_instance_state from openec2.api.describe_instances import describe_instance_state
@ -18,6 +19,7 @@ def stop_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
user: User,
): ):
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = [] instances: list[InstanceInfo] = []
@ -26,6 +28,11 @@ def stop_instances(
if instance is None: if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance") raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id) dom = conn.lookupByName(instance_id)
running = dom.isActive() running = dom.isActive()
prev_state = describe_instance_state(dom) prev_state = describe_instance_state(dom)

View File

@ -9,9 +9,11 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep from openec2.db import DatabaseDep
from openec2.db.instance import Instance from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain from openec2.utils.array import parse_array_plain
from openec2.images import garbage_collect_image from openec2.images import garbage_collect_image
from openec2.ipam import remove_instance_dhcp_mapping from openec2.ipam import remove_instance_dhcp_mapping
from openec2.api.shared import InstanceInfo
from openec2.api.describe_instances import describe_instance_state from openec2.api.describe_instances import describe_instance_state
from openec2.api.terminate_instances import TerminateInstancesResponse, InstancesSet from openec2.api.terminate_instances import TerminateInstancesResponse, InstancesSet
@ -23,9 +25,9 @@ def terminate_instances(
params: QueryParams, params: QueryParams,
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
user: User,
): ):
instances: list[InstanceInfo] = [] instances: list[InstanceInfo] = []
conn = LibvirtSingleton.of().connection conn = LibvirtSingleton.of().connection
image_ids: set[str] = set() image_ids: set[str] = set()
for instance_id in parse_array_plain("InstanceId", params): for instance_id in parse_array_plain("InstanceId", params):
@ -34,6 +36,11 @@ def terminate_instances(
continue continue
# raise HTTPException(status_code=404, detail="Unknown instance") # raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id) dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom) prev_state = describe_instance_state(dom)
if dom.isActive(): if dom.isActive():

View File

@ -18,3 +18,6 @@ class Instance(SQLModel, table=True):
# Private IPv4 of the instance # Private IPv4 of the instance
privateIPv4: str privateIPv4: str
# The owner that creatd the resource.
owner_id: int = Field(foreign_key="user.id")

View File

@ -9,6 +9,7 @@ from openec2.security.aws import AWSSignature
from openec2.utils.text import multiline_yaml_response from openec2.utils.text import multiline_yaml_response
from openec2.config import OpenEC2Config from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep, engine from openec2.db import DatabaseDep, engine
from openec2.db.user import User
from openec2.actions.describe_images import describe_images from openec2.actions.describe_images import describe_images
from openec2.actions.import_image import import_image from openec2.actions.import_image import import_image
from openec2.actions.describe_instances import describe_instances from openec2.actions.describe_instances import describe_instances
@ -35,19 +36,33 @@ def healthz():
@app.get("/Action", response_model=None) @app.get("/Action", response_model=None)
def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature): def get_action(
return run_action(request, config, db, cast(dict, request.query_params)) request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
return run_action(
request,
config,
db,
cast(dict, request.query_params),
user,
)
@app.post("/Action", response_model=None) @app.post("/Action", response_model=None)
async def test( async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
): ):
query_params = { query_params = {
key: value[0] key: value[0]
for key, value in parse_qs((await request.body()).decode()).items() for key, value in parse_qs((await request.body()).decode()).items()
} }
return run_action(request, config, db, cast(dict, query_params)) return run_action(
request,
config,
db,
cast(dict, query_params),
user,
)
def run_action( def run_action(
@ -55,6 +70,7 @@ def run_action(
config: OpenEC2Config, config: OpenEC2Config,
db: DatabaseDep, db: DatabaseDep,
query_params: dict[str, str], query_params: dict[str, str],
user: User,
): ):
print(query_params) print(query_params)
action = query_params["Action"] action = query_params["Action"]
@ -67,7 +83,7 @@ def run_action(
"StartInstances": start_instances, "StartInstances": start_instances,
"StopInstances": stop_instances, "StopInstances": stop_instances,
"DeregisterImage": deregister_image, "DeregisterImage": deregister_image,
}[action](query_params, config, db) }[action](query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}") @app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@ -6,7 +6,7 @@ from urllib.parse import quote, parse_qs
from sqlmodel import select from sqlmodel import select
from fastapi import Request, HTTPException, Depends from fastapi import Request, HTTPException, Depends
from fastapi.datastructures import QueryParams, URL, Headers from fastapi.datastructures import URL, Headers
from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives import hashes, hmac
from openec2.config import ConfigSingleton from openec2.config import ConfigSingleton
@ -42,14 +42,6 @@ class AWSRequest:
) -> str: ) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"]) dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_query_string_keys = sorted(self.params.keys())
canonical_query_string = "&".join(
[
f"{key}={quote(self.params[key][0])}"
for key in canonical_query_string_keys
if key not in ("X-Amz-Signature",)
]
)
canonical_header_string_keys = sorted( canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)] [name for name in self.headers.keys() if include_in_canonical_string(name)]
) )
@ -150,7 +142,8 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256": if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException( raise HTTPException(
status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}" status_code=400,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
) )
x_amz_credential = auth_info.x_amz_credential x_amz_credential = auth_info.x_amz_credential
@ -179,5 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if x_amz_signature != signature: if x_amz_signature != signature:
raise HTTPException(status_code=401) raise HTTPException(status_code=401)
return user
AWSSignature = Annotated[None, Depends(check_request_signature)]
AWSSignature = Annotated[User, Depends(check_request_signature)]