Add user checking for instance actions

This commit is contained in:
PapaTutuWawa 2025-03-31 00:09:29 +02:00
parent 9d4867d74e
commit 1128d73bee
11 changed files with 69 additions and 21 deletions

View File

@ -4,6 +4,7 @@ from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.images import garbage_collect_image
@ -12,6 +13,7 @@ def deregister_image(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
_: User,
):
image_id = params["ImageId"]
ami = db.exec(select(AMI).where(AMI.id == image_id)).one()

View File

@ -6,6 +6,7 @@ from sqlmodel import select
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image
@ -14,6 +15,7 @@ def describe_images(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
_: User,
):
images: list[Image] = []
for ami in db.exec(select(AMI)).all():

View File

@ -13,6 +13,7 @@ from openec2.api.describe_instances import (
ReservationSetInstancesSet,
describe_instance,
)
from openec2.db.user import User
from openec2.api.shared import InstanceState
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
@ -23,13 +24,17 @@ def describe_instances(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
response_items: list[InstanceDescription] = []
conn = LibvirtSingleton.of().connection
for instance in db.exec(select(Instance)).all():
dom = conn.lookupByName(instance.id)
running = dom.isActive()
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance.id)
response_items.append(
describe_instance(instance, dom),
)

View File

@ -9,6 +9,7 @@ import requests
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.user import User
from openec2.db.image import AMI
@ -16,6 +17,7 @@ def import_image(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
_: User,
):
first_disk_image_url = params["DiskContainer.1.Url"]
url = urlparse(first_disk_image_url)

View File

@ -12,6 +12,7 @@ from openec2.utils.qemu import create_cow_copy
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.db.image import AMI
from openec2.db.user import User
from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet
from openec2.api.describe_instances import describe_instance
from openec2.utils.array import parse_array_objects
@ -94,6 +95,7 @@ def run_instances(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
image_id = params["ImageId"]
instance_type = params["InstanceType"]
@ -151,6 +153,7 @@ def run_instances(
else None,
privateIPv4=private_ipv4,
interfaceMac=mac,
owner_id=user.id,
)
db.add(instance)
print("Inserted new instance")

View File

@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain
from openec2.api.shared import InstanceInfo, InstancesSet
from openec2.api.describe_instances import describe_instance_state
@ -18,6 +19,7 @@ def start_instances(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
@ -26,8 +28,12 @@ def start_instances(
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")
dom = conn.lookupByName(instance_id)
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom)
if not dom.isActive():
dom.create()

View File

@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain
from openec2.api.shared import InstanceInfo, InstancesSet
from openec2.api.describe_instances import describe_instance_state
@ -18,6 +19,7 @@ def stop_instances(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
conn = LibvirtSingleton.of().connection
instances: list[InstanceInfo] = []
@ -26,6 +28,11 @@ def stop_instances(
if instance is None:
raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
running = dom.isActive()
prev_state = describe_instance_state(dom)

View File

@ -9,9 +9,11 @@ from openec2.libvirt import LibvirtSingleton
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep
from openec2.db.instance import Instance
from openec2.db.user import User
from openec2.utils.array import parse_array_plain
from openec2.images import garbage_collect_image
from openec2.ipam import remove_instance_dhcp_mapping
from openec2.api.shared import InstanceInfo
from openec2.api.describe_instances import describe_instance_state
from openec2.api.terminate_instances import TerminateInstancesResponse, InstancesSet
@ -23,9 +25,9 @@ def terminate_instances(
params: QueryParams,
config: OpenEC2Config,
db: DatabaseDep,
user: User,
):
instances: list[InstanceInfo] = []
conn = LibvirtSingleton.of().connection
image_ids: set[str] = set()
for instance_id in parse_array_plain("InstanceId", params):
@ -34,6 +36,11 @@ def terminate_instances(
continue
# raise HTTPException(status_code=404, detail="Unknown instance")
# Check for permission issues
if instance.owner_id != user.id:
# TODO: Add the error to the response
continue
dom = conn.lookupByName(instance_id)
prev_state = describe_instance_state(dom)
if dom.isActive():

View File

@ -18,3 +18,6 @@ class Instance(SQLModel, table=True):
# Private IPv4 of the instance
privateIPv4: str
# The owner that creatd the resource.
owner_id: int = Field(foreign_key="user.id")

View File

@ -9,6 +9,7 @@ from openec2.security.aws import AWSSignature
from openec2.utils.text import multiline_yaml_response
from openec2.config import OpenEC2Config
from openec2.db import DatabaseDep, engine
from openec2.db.user import User
from openec2.actions.describe_images import describe_images
from openec2.actions.import_image import import_image
from openec2.actions.describe_instances import describe_instances
@ -35,19 +36,33 @@ def healthz():
@app.get("/Action", response_model=None)
def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature):
return run_action(request, config, db, cast(dict, request.query_params))
def get_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
return run_action(
request,
config,
db,
cast(dict, request.query_params),
user,
)
@app.post("/Action", response_model=None)
async def test(
request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature
async def post_action(
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
):
query_params = {
key: value[0]
for key, value in parse_qs((await request.body()).decode()).items()
}
return run_action(request, config, db, cast(dict, query_params))
return run_action(
request,
config,
db,
cast(dict, query_params),
user,
)
def run_action(
@ -55,6 +70,7 @@ def run_action(
config: OpenEC2Config,
db: DatabaseDep,
query_params: dict[str, str],
user: User,
):
print(query_params)
action = query_params["Action"]
@ -67,7 +83,7 @@ def run_action(
"StartInstances": start_instances,
"StopInstances": stop_instances,
"DeregisterImage": deregister_image,
}[action](query_params, config, db)
}[action](query_params, config, db, user)
@app.get("/private/cloudinit/{instance_id}/{entry}")

View File

@ -6,7 +6,7 @@ from urllib.parse import quote, parse_qs
from sqlmodel import select
from fastapi import Request, HTTPException, Depends
from fastapi.datastructures import QueryParams, URL, Headers
from fastapi.datastructures import URL, Headers
from cryptography.hazmat.primitives import hashes, hmac
from openec2.config import ConfigSingleton
@ -42,14 +42,6 @@ class AWSRequest:
) -> str:
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
canonical_query_string_keys = sorted(self.params.keys())
canonical_query_string = "&".join(
[
f"{key}={quote(self.params[key][0])}"
for key in canonical_query_string_keys
if key not in ("X-Amz-Signature",)
]
)
canonical_header_string_keys = sorted(
[name for name in self.headers.keys() if include_in_canonical_string(name)]
)
@ -150,7 +142,8 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
raise HTTPException(
status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}"
status_code=400,
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
)
x_amz_credential = auth_info.x_amz_credential
@ -179,5 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
if x_amz_signature != signature:
raise HTTPException(status_code=401)
return user
AWSSignature = Annotated[None, Depends(check_request_signature)]
AWSSignature = Annotated[User, Depends(check_request_signature)]