Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4495151
Add opa dependency function to create OpaUserClient
tpoliaw May 26, 2026
f2f02de
test opa dependency function
tpoliaw May 28, 2026
82ff44b
Add can_submit_task auth check method and config
tpoliaw May 26, 2026
84efe6e
feat: add authz dependency injection
shree-iyengar-dls May 15, 2026
06a8bda
feat: add auth check dependency injections to task endpoints
shree-iyengar-dls May 18, 2026
7536275
feat: create new access task permission fns and add as dependencies
shree-iyengar-dls May 20, 2026
f066cbf
refactor: update rest api version
shree-iyengar-dls May 20, 2026
a56893d
comment out dependency addition in set_state
shree-iyengar-dls May 20, 2026
4518825
refactor: add admin check and check to set state function
May 20, 2026
060ec2e
Update dependency names
tpoliaw May 26, 2026
29e2a5a
Add missing admin check
tpoliaw May 26, 2026
3757ac5
Handle missing opa and fix tests
tpoliaw May 26, 2026
b4c61a7
Remove old admin method
tpoliaw May 26, 2026
38fc638
Use starlette statuses directly
tpoliaw May 28, 2026
b14a256
test task submission authz
tpoliaw May 28, 2026
072c36c
Use _config instead of _conf
tpoliaw Jun 5, 2026
f193a5b
Re-use instrument session regex
tpoliaw Jun 5, 2026
a372c17
remove task access check
tpoliaw Jun 5, 2026
cfc22d3
Add match to raises check
tpoliaw Jun 5, 2026
340ef65
Add exception detail
tpoliaw Jun 5, 2026
8ad021f
Let admin see all tasks
tpoliaw Jun 5, 2026
6aa6121
Start of api authz tests
tpoliaw Jun 5, 2026
824b4e8
Make get_tasks async to access authz check
tpoliaw Jun 8, 2026
e6a6917
Parametrise filter test to check with and without admin
tpoliaw Jun 8, 2026
0518099
Add test for deleting tasks
tpoliaw Jun 8, 2026
fa02ed5
Add test for submit without permission
tpoliaw Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion helm/blueapi/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,20 @@
"tiled_service_account_check": {
"title": "Tiled Service Account Check",
"type": "string"
},
"submit_task_check": {
"title": "Submit Task Check",
"type": "string"
},
"admin_check": {
"title": "Admin Check",
"type": "string"
}
},
"required": [
"tiled_service_account_check"
"tiled_service_account_check",
"submit_task_check",
"admin_check"
],
"title": "OpaConfig",
"type": "object",
Expand Down
12 changes: 11 additions & 1 deletion helm/blueapi/values.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,15 @@
"title": "OpaConfig",
"type": "object",
"required": [
"tiled_service_account_check"
"tiled_service_account_check",
"submit_task_check",
"admin_check"
],
"properties": {
"admin_check": {
"title": "Admin Check",
"type": "string"
},
"audience": {
"title": "Audience",
"default": "account",
Expand All @@ -772,6 +778,10 @@
"maxLength": 2083,
"minLength": 1
},
"submit_task_check": {
"title": "Submit Task Check",
"type": "string"
},
"tiled_service_account_check": {
"title": "Tiled Service Account Check",
"type": "string"
Expand Down
2 changes: 2 additions & 0 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ class OpaConfig(BlueapiBaseModel):
root: HttpUrl = HttpUrl("http://localhost:8181")
audience: str = "account"
tiled_service_account_check: str
submit_task_check: str
admin_check: str
Comment thread
tpoliaw marked this conversation as resolved.


class ApplicationConfig(BlueapiBaseModel):
Expand Down
64 changes: 62 additions & 2 deletions src/blueapi/service/authorization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from collections.abc import Mapping
from contextlib import AbstractAsyncContextManager, aclosing, nullcontext
from typing import Any, Self
from typing import Annotated, Any, Self, cast

from aiohttp import ClientSession
from fastapi import Depends, HTTPException, Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount
from blueapi.service.authentication import TiledAuth
from blueapi.service.authentication import TiledAuth, unchecked_bearer_token
from blueapi.service.model import TaskRequest
from blueapi.utils import INSTRUMENT_SESSION_RE

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +60,41 @@ async def require_tiled_service_account(self, token: str):
f"Tiled service account is not valid for '{self._instrument}'"
)

async def require_submit_task(self, instrument_session: str, token: str):
if not (match := INSTRUMENT_SESSION_RE.match(instrument_session)):
raise ValueError("Invalid instrument session")

if not await self._call_opa(
self._config.submit_task_check,
{
"token": token,
"proposal": int(match["proposal"]),
"visit": int(match["visit"]),
},
):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authorized to submit task"
)

async def is_admin(self, token: str) -> bool:
return await self._call_opa(self._config.admin_check, {"token": token})


class OpaUserClient:
client: OpaClient
token: str

def __init__(self, client: OpaClient, token: str):
self.client = client
self.token = token

async def can_submit_task(self, task: TaskRequest):
LOGGER.info("Checking permissions to run task")
await self.client.require_submit_task(task.instrument_session, self.token)

async def admin(self) -> bool:
return await self.client.is_admin(self.token)


async def validate_tiled_config(
tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None
Expand All @@ -72,3 +111,24 @@ async def validate_tiled_config(
tiled.token_url = oidc.token_endpoint
auth = TiledAuth(tiled)
await opa.require_tiled_service_account(auth.get_access_token())


async def opa(
request: Request, token: str | None = Depends(unchecked_bearer_token)
) -> OpaUserClient | None:

if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)):
if not token:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Authentication missing"
)
return OpaUserClient(opa, token)
return None


async def submit_permission(
opa: Annotated[OpaUserClient | None, Depends(opa)],
task_request: TaskRequest,
):
if opa:
await opa.can_submit_task(task_request)
71 changes: 66 additions & 5 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

from .authorization import OpaClient, validate_tiled_config
from .authorization import (
OpaClient,
OpaUserClient,
opa,
submit_permission,
validate_tiled_config,
)
from .model import (
DeviceModel,
DeviceResponse,
Expand Down Expand Up @@ -146,6 +152,37 @@ def get_app(config: ApplicationConfig):
return app


async def access_task_permission(
opa: Annotated[OpaUserClient | None, Depends(opa)],
task_id: str,
fedid: Fedid,
runner: Annotated[WorkerDispatcher, Depends(_runner)],
):
task = runner.run(interface.get_task_by_id, task_id)

if (
opa
and not await opa.admin()
and (task and fedid != task.task.metadata.get("user"))
):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)


# start_task_permission is used when there is WorkerTask
async def start_task_permission(
task: WorkerTask,
opa: Annotated[OpaUserClient, Depends(opa)],
fedid: Fedid,
runner: Annotated[WorkerDispatcher, Depends(_runner)],
):
if not task.task_id:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="No task id provided",
)
await access_task_permission(opa, task.task_id, fedid, runner)


async def on_key_error_404(_: Request, __: Exception):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down Expand Up @@ -271,13 +308,13 @@ def submit_task(
request: Request,
response: Response,
task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])],
_: Annotated[None, Depends(submit_permission)],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
user: Fedid,
fedid: Fedid,
) -> TaskResponse:
"""Submit a task to the worker."""
try:
user = user or "Unknown"
task_id: str = runner.run(interface.submit_task, task_request, {"user": user})
task_id: str = runner.run(interface.submit_task, task_request, {"user": fedid})
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand Down Expand Up @@ -309,6 +346,7 @@ def submit_task(
@start_as_current_span(TRACER, "task_id")
def delete_submitted_task(
task_id: str,
_: Annotated[None, Depends(access_task_permission)],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
) -> TaskResponse:
return TaskResponse(task_id=runner.run(interface.clear_task, task_id))
Expand All @@ -325,8 +363,10 @@ def validate_task_status(v: str) -> TaskStatusEnum:
@secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK])
@secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK])
@start_as_current_span(TRACER)
def get_tasks(
async def get_tasks(
fedid: Fedid,
runner: Annotated[WorkerDispatcher, Depends(_runner)],
opa: Annotated[OpaUserClient, Depends(opa)],
task_status: str | SkipJsonSchema[None] = None,
) -> TasksListResponse:
"""
Expand All @@ -346,6 +386,10 @@ def get_tasks(
tasks = runner.run(interface.get_tasks_by_status, desired_status)
else:
tasks = runner.run(interface.get_tasks)

if opa and not await opa.admin():
tasks = [t for t in tasks if t.task.metadata.get("user") == fedid]

return TasksListResponse(tasks=tasks)


Expand All @@ -363,6 +407,7 @@ def get_tasks(
def set_active_task(
request: Request,
task: WorkerTask,
_: Annotated[None, Depends(start_task_permission)],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
) -> WorkerTask:
"""Set a task to active status, the worker should begin it as soon as possible.
Expand Down Expand Up @@ -393,6 +438,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]:
@start_as_current_span(TRACER, "task_id")
def get_task(
task_id: str,
_: Annotated[None, Depends(access_task_permission)],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
) -> TrackableTask:
"""Retrieve a task"""
Expand Down Expand Up @@ -470,6 +516,8 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt
def set_state(
state_change_request: StateChangeRequest,
response: Response,
fedid: Fedid,
opa: Annotated[OpaUserClient, Depends(opa)],
runner: Annotated[WorkerDispatcher, Depends(_runner)],
) -> WorkerState:
"""
Expand All @@ -496,6 +544,19 @@ def set_state(
current_state in _ALLOWED_TRANSITIONS
and new_state in _ALLOWED_TRANSITIONS[current_state]
):
active = runner.run(interface.get_active_task)

if (
opa
and not opa.admin()
and active
and active.task.metadata.get("user") != fedid
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to set worker state",
)

if new_state == WorkerState.PAUSED:
runner.run(interface.pause_worker, state_change_request.defer)
elif new_state == WorkerState.RUNNING:
Expand Down
3 changes: 3 additions & 0 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
Expand Down Expand Up @@ -31,6 +32,8 @@
Args = ParamSpec("Args")
Return = TypeVar("Return")

INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P<proposal>\d+)-(?P<visit>\d+)$")


def deprecated(alternative):
from warnings import warn
Expand Down
10 changes: 3 additions & 7 deletions src/blueapi/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import re
from typing import Any

from pydantic import BaseModel

from blueapi import utils


def serialize(obj: Any) -> Any:
"""
Expand All @@ -28,13 +29,8 @@ def serialize(obj: Any) -> Any:
return obj


_INSTRUMENT_SESSION_AUTHZ_REGEX: re.Pattern = re.compile(
r"^[a-zA-Z]{2}(?P<proposal>\d+)-(?P<visit>\d+)$"
)


def access_blob(instrument_session: str, beamline: str) -> str:
m = _INSTRUMENT_SESSION_AUTHZ_REGEX.match(instrument_session)
m = utils.INSTRUMENT_SESSION_RE.match(instrument_session)
if m is None:
raise ValueError(
"Unable to extract proposal and visit from "
Expand Down
Loading
Loading