Add user checking for instance actions
This commit is contained in:
parent
9d4867d74e
commit
1128d73bee
@ -4,6 +4,7 @@ from sqlmodel import select
|
||||
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.user import User
|
||||
from openec2.db.image import AMI
|
||||
from openec2.images import garbage_collect_image
|
||||
|
||||
@ -12,6 +13,7 @@ def deregister_image(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
_: User,
|
||||
):
|
||||
image_id = params["ImageId"]
|
||||
ami = db.exec(select(AMI).where(AMI.id == image_id)).one()
|
||||
|
@ -6,6 +6,7 @@ from sqlmodel import select
|
||||
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.user import User
|
||||
from openec2.db.image import AMI
|
||||
from openec2.api.describe_images import DescribeImagesResponse, ImagesSet, Image
|
||||
|
||||
@ -14,6 +15,7 @@ def describe_images(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
_: User,
|
||||
):
|
||||
images: list[Image] = []
|
||||
for ami in db.exec(select(AMI)).all():
|
||||
|
@ -13,6 +13,7 @@ from openec2.api.describe_instances import (
|
||||
ReservationSetInstancesSet,
|
||||
describe_instance,
|
||||
)
|
||||
from openec2.db.user import User
|
||||
from openec2.api.shared import InstanceState
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
@ -23,13 +24,17 @@ def describe_instances(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
user: User,
|
||||
):
|
||||
response_items: list[InstanceDescription] = []
|
||||
conn = LibvirtSingleton.of().connection
|
||||
for instance in db.exec(select(Instance)).all():
|
||||
dom = conn.lookupByName(instance.id)
|
||||
running = dom.isActive()
|
||||
# Check for permission issues
|
||||
if instance.owner_id != user.id:
|
||||
# TODO: Add the error to the response
|
||||
continue
|
||||
|
||||
dom = conn.lookupByName(instance.id)
|
||||
response_items.append(
|
||||
describe_instance(instance, dom),
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ import requests
|
||||
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.user import User
|
||||
from openec2.db.image import AMI
|
||||
|
||||
|
||||
@ -16,6 +17,7 @@ def import_image(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
_: User,
|
||||
):
|
||||
first_disk_image_url = params["DiskContainer.1.Url"]
|
||||
url = urlparse(first_disk_image_url)
|
||||
|
@ -12,6 +12,7 @@ from openec2.utils.qemu import create_cow_copy
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.instance import Instance
|
||||
from openec2.db.image import AMI
|
||||
from openec2.db.user import User
|
||||
from openec2.api.run_instances import RunInstanceResponse, RunInstanceInstanceSet
|
||||
from openec2.api.describe_instances import describe_instance
|
||||
from openec2.utils.array import parse_array_objects
|
||||
@ -94,6 +95,7 @@ def run_instances(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
user: User,
|
||||
):
|
||||
image_id = params["ImageId"]
|
||||
instance_type = params["InstanceType"]
|
||||
@ -151,6 +153,7 @@ def run_instances(
|
||||
else None,
|
||||
privateIPv4=private_ipv4,
|
||||
interfaceMac=mac,
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(instance)
|
||||
print("Inserted new instance")
|
||||
|
@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.instance import Instance
|
||||
from openec2.db.user import User
|
||||
from openec2.utils.array import parse_array_plain
|
||||
from openec2.api.shared import InstanceInfo, InstancesSet
|
||||
from openec2.api.describe_instances import describe_instance_state
|
||||
@ -18,6 +19,7 @@ def start_instances(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
user: User,
|
||||
):
|
||||
conn = LibvirtSingleton.of().connection
|
||||
instances: list[InstanceInfo] = []
|
||||
@ -26,8 +28,12 @@ def start_instances(
|
||||
if instance is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown instance")
|
||||
|
||||
dom = conn.lookupByName(instance_id)
|
||||
# Check for permission issues
|
||||
if instance.owner_id != user.id:
|
||||
# TODO: Add the error to the response
|
||||
continue
|
||||
|
||||
dom = conn.lookupByName(instance_id)
|
||||
prev_state = describe_instance_state(dom)
|
||||
if not dom.isActive():
|
||||
dom.create()
|
||||
|
@ -8,6 +8,7 @@ from openec2.libvirt import LibvirtSingleton
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.instance import Instance
|
||||
from openec2.db.user import User
|
||||
from openec2.utils.array import parse_array_plain
|
||||
from openec2.api.shared import InstanceInfo, InstancesSet
|
||||
from openec2.api.describe_instances import describe_instance_state
|
||||
@ -18,6 +19,7 @@ def stop_instances(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
user: User,
|
||||
):
|
||||
conn = LibvirtSingleton.of().connection
|
||||
instances: list[InstanceInfo] = []
|
||||
@ -26,6 +28,11 @@ def stop_instances(
|
||||
if instance is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown instance")
|
||||
|
||||
# Check for permission issues
|
||||
if instance.owner_id != user.id:
|
||||
# TODO: Add the error to the response
|
||||
continue
|
||||
|
||||
dom = conn.lookupByName(instance_id)
|
||||
running = dom.isActive()
|
||||
prev_state = describe_instance_state(dom)
|
||||
|
@ -9,9 +9,11 @@ from openec2.libvirt import LibvirtSingleton
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep
|
||||
from openec2.db.instance import Instance
|
||||
from openec2.db.user import User
|
||||
from openec2.utils.array import parse_array_plain
|
||||
from openec2.images import garbage_collect_image
|
||||
from openec2.ipam import remove_instance_dhcp_mapping
|
||||
from openec2.api.shared import InstanceInfo
|
||||
from openec2.api.describe_instances import describe_instance_state
|
||||
from openec2.api.terminate_instances import TerminateInstancesResponse, InstancesSet
|
||||
|
||||
@ -23,9 +25,9 @@ def terminate_instances(
|
||||
params: QueryParams,
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
user: User,
|
||||
):
|
||||
instances: list[InstanceInfo] = []
|
||||
|
||||
conn = LibvirtSingleton.of().connection
|
||||
image_ids: set[str] = set()
|
||||
for instance_id in parse_array_plain("InstanceId", params):
|
||||
@ -34,6 +36,11 @@ def terminate_instances(
|
||||
continue
|
||||
# raise HTTPException(status_code=404, detail="Unknown instance")
|
||||
|
||||
# Check for permission issues
|
||||
if instance.owner_id != user.id:
|
||||
# TODO: Add the error to the response
|
||||
continue
|
||||
|
||||
dom = conn.lookupByName(instance_id)
|
||||
prev_state = describe_instance_state(dom)
|
||||
if dom.isActive():
|
||||
|
@ -18,3 +18,6 @@ class Instance(SQLModel, table=True):
|
||||
|
||||
# Private IPv4 of the instance
|
||||
privateIPv4: str
|
||||
|
||||
# The owner that creatd the resource.
|
||||
owner_id: int = Field(foreign_key="user.id")
|
||||
|
@ -9,6 +9,7 @@ from openec2.security.aws import AWSSignature
|
||||
from openec2.utils.text import multiline_yaml_response
|
||||
from openec2.config import OpenEC2Config
|
||||
from openec2.db import DatabaseDep, engine
|
||||
from openec2.db.user import User
|
||||
from openec2.actions.describe_images import describe_images
|
||||
from openec2.actions.import_image import import_image
|
||||
from openec2.actions.describe_instances import describe_instances
|
||||
@ -35,19 +36,33 @@ def healthz():
|
||||
|
||||
|
||||
@app.get("/Action", response_model=None)
|
||||
def action(request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature):
|
||||
return run_action(request, config, db, cast(dict, request.query_params))
|
||||
def get_action(
|
||||
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
|
||||
):
|
||||
return run_action(
|
||||
request,
|
||||
config,
|
||||
db,
|
||||
cast(dict, request.query_params),
|
||||
user,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/Action", response_model=None)
|
||||
async def test(
|
||||
request: Request, config: OpenEC2Config, db: DatabaseDep, _: AWSSignature
|
||||
async def post_action(
|
||||
request: Request, config: OpenEC2Config, db: DatabaseDep, user: AWSSignature
|
||||
):
|
||||
query_params = {
|
||||
key: value[0]
|
||||
for key, value in parse_qs((await request.body()).decode()).items()
|
||||
}
|
||||
return run_action(request, config, db, cast(dict, query_params))
|
||||
return run_action(
|
||||
request,
|
||||
config,
|
||||
db,
|
||||
cast(dict, query_params),
|
||||
user,
|
||||
)
|
||||
|
||||
|
||||
def run_action(
|
||||
@ -55,6 +70,7 @@ def run_action(
|
||||
config: OpenEC2Config,
|
||||
db: DatabaseDep,
|
||||
query_params: dict[str, str],
|
||||
user: User,
|
||||
):
|
||||
print(query_params)
|
||||
action = query_params["Action"]
|
||||
@ -67,7 +83,7 @@ def run_action(
|
||||
"StartInstances": start_instances,
|
||||
"StopInstances": stop_instances,
|
||||
"DeregisterImage": deregister_image,
|
||||
}[action](query_params, config, db)
|
||||
}[action](query_params, config, db, user)
|
||||
|
||||
|
||||
@app.get("/private/cloudinit/{instance_id}/{entry}")
|
||||
|
@ -6,7 +6,7 @@ from urllib.parse import quote, parse_qs
|
||||
|
||||
from sqlmodel import select
|
||||
from fastapi import Request, HTTPException, Depends
|
||||
from fastapi.datastructures import QueryParams, URL, Headers
|
||||
from fastapi.datastructures import URL, Headers
|
||||
from cryptography.hazmat.primitives import hashes, hmac
|
||||
|
||||
from openec2.config import ConfigSingleton
|
||||
@ -42,14 +42,6 @@ class AWSRequest:
|
||||
) -> str:
|
||||
dt = datetime.datetime.fromisoformat(self.headers["X-Amz-Date"])
|
||||
|
||||
canonical_query_string_keys = sorted(self.params.keys())
|
||||
canonical_query_string = "&".join(
|
||||
[
|
||||
f"{key}={quote(self.params[key][0])}"
|
||||
for key in canonical_query_string_keys
|
||||
if key not in ("X-Amz-Signature",)
|
||||
]
|
||||
)
|
||||
canonical_header_string_keys = sorted(
|
||||
[name for name in self.headers.keys() if include_in_canonical_string(name)]
|
||||
)
|
||||
@ -150,7 +142,8 @@ async def check_request_signature(request: Request, db: DatabaseDep):
|
||||
|
||||
if auth_info.x_amz_algorithm != "AWS4-HMAC-SHA256":
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid signature algorithm: {x_amz_algorithm}"
|
||||
status_code=400,
|
||||
detail=f"Invalid signature algorithm: {auth_info.x_amz_algorithm}",
|
||||
)
|
||||
|
||||
x_amz_credential = auth_info.x_amz_credential
|
||||
@ -179,5 +172,7 @@ async def check_request_signature(request: Request, db: DatabaseDep):
|
||||
if x_amz_signature != signature:
|
||||
raise HTTPException(status_code=401)
|
||||
|
||||
return user
|
||||
|
||||
AWSSignature = Annotated[None, Depends(check_request_signature)]
|
||||
|
||||
AWSSignature = Annotated[User, Depends(check_request_signature)]
|
||||
|
Loading…
Reference in New Issue
Block a user