Skip to content

Commit b154fd6

Browse files
authored
Merge pull request #542 from lbedner/rbac-4
RBAC - 4
2 parents 1bb0f09 + 4905640 commit b154fd6

11 files changed

Lines changed: 432 additions & 5 deletions

File tree

aegis/cli/interactive.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,10 @@ def interactive_auth_service_config(
638638
title="With Roles - + role-based access control",
639639
value=AuthLevels.RBAC,
640640
),
641+
questionary.Choice(
642+
title="With Organizations - + multi-tenant support",
643+
value=AuthLevels.ORG,
644+
),
641645
]
642646

643647
result = questionary.select(

aegis/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ class AuthLevels:
102102

103103
BASIC = "basic"
104104
RBAC = "rbac"
105+
ORG = "org"
105106

106-
ALL = [BASIC, RBAC]
107+
ALL = [BASIC, RBAC, ORG]
107108

108109

109110
class AnswerKeys:

aegis/core/migration_generator.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,55 @@ class ServiceMigrationSpec:
101101
],
102102
)
103103

104+
ORG_MIGRATION = ServiceMigrationSpec(
105+
service_name="auth_org",
106+
description="Organization and membership tables",
107+
tables=[
108+
TableSpec(
109+
name="organization",
110+
columns=[
111+
ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True),
112+
ColumnSpec("name", "sa.String()", nullable=False),
113+
ColumnSpec("slug", "sa.String()", nullable=False),
114+
ColumnSpec("description", "sa.String()", nullable=True),
115+
ColumnSpec("is_active", "sa.Boolean()", nullable=False, default="True"),
116+
ColumnSpec("created_at", "sa.DateTime()", nullable=False),
117+
ColumnSpec("updated_at", "sa.DateTime()", nullable=True),
118+
],
119+
indexes=[IndexSpec("ix_organization_slug", ["slug"], unique=True)],
120+
),
121+
TableSpec(
122+
name="organization_member",
123+
columns=[
124+
ColumnSpec("id", "sa.Integer()", nullable=False, primary_key=True),
125+
ColumnSpec("organization_id", "sa.Integer()", nullable=False),
126+
ColumnSpec("user_id", "sa.Integer()", nullable=False),
127+
ColumnSpec("role", "sa.String()", nullable=False, default="'member'"),
128+
ColumnSpec("joined_at", "sa.DateTime()", nullable=False),
129+
],
130+
indexes=[
131+
IndexSpec(
132+
"ix_org_member_org_user",
133+
["organization_id", "user_id"],
134+
unique=True,
135+
),
136+
IndexSpec(
137+
"ix_org_member_organization_id",
138+
["organization_id"],
139+
),
140+
IndexSpec(
141+
"ix_org_member_user_id",
142+
["user_id"],
143+
),
144+
],
145+
foreign_keys=[
146+
ForeignKeySpec(["organization_id"], "organization", ["id"]),
147+
ForeignKeySpec(["user_id"], "user", ["id"]),
148+
],
149+
),
150+
],
151+
)
152+
104153
AI_MIGRATION = ServiceMigrationSpec(
105154
service_name="ai",
106155
description="AI service tables (LLM catalog, usage tracking, conversations)",
@@ -369,6 +418,7 @@ class ServiceMigrationSpec:
369418
# Registry of all service migrations
370419
MIGRATION_SPECS: dict[str, ServiceMigrationSpec] = {
371420
"auth": AUTH_MIGRATION,
421+
"auth_org": ORG_MIGRATION,
372422
"ai": AI_MIGRATION,
373423
"ai_voice": VOICE_MIGRATION,
374424
}
@@ -648,6 +698,17 @@ def get_services_needing_migrations(context: dict[str, Any]) -> list[str]:
648698
if include_auth == "yes" or include_auth is True:
649699
services.append("auth")
650700

701+
# Auth org tables (only with org-level auth)
702+
include_auth_org = context.get("include_auth_org")
703+
auth_level = context.get("auth_level")
704+
org_enabled = (
705+
include_auth_org == "yes"
706+
or include_auth_org is True
707+
or (isinstance(auth_level, str) and auth_level.lower() == "org")
708+
)
709+
if (include_auth == "yes" or include_auth is True) and org_enabled:
710+
services.append("auth_org")
711+
651712
# AI service (only with persistence backend)
652713
include_ai = context.get("include_ai")
653714
ai_backend = context.get("ai_backend", "memory")

aegis/core/services.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ServiceSpec:
5858
template_files=[
5959
"app/components/backend/api/auth/",
6060
"app/models/user.py",
61+
"app/models/org.py",
6162
"app/services/auth/",
6263
"app/core/security.py",
6364
],

aegis/core/template_generator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,14 @@ def get_template_context(self) -> dict[str, Any]:
212212
for s in self.selected_services
213213
)
214214
else "no",
215-
# Auth level selection (basic or rbac)
216-
AnswerKeys.AUTH_LEVEL: self._get_auth_level(),
215+
# Auth level selection (basic, rbac, or org)
216+
AnswerKeys.AUTH_LEVEL: (auth_level := self._get_auth_level()),
217217
# Derived auth level flags for template conditionals
218+
# Org level implies RBAC (org gets both roles and orgs)
218219
AnswerKeys.AUTH_RBAC: "yes"
219-
if self._get_auth_level() == AuthLevels.RBAC
220+
if auth_level in (AuthLevels.RBAC, AuthLevels.ORG)
220221
else "no",
221-
AnswerKeys.AUTH_ORG: "no", # Reserved for future org-level auth
222+
AnswerKeys.AUTH_ORG: "yes" if auth_level == AuthLevels.ORG else "no",
222223
AnswerKeys.AI: "yes"
223224
if any(
224225
extract_base_service_name(s) == AnswerKeys.SERVICE_AI
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Organization and membership models."""
2+
{% if include_auth_org %}
3+
4+
from datetime import UTC, datetime
5+
6+
from sqlmodel import Field, SQLModel
7+
8+
9+
# Org membership role constants
10+
ORG_ROLE_OWNER = "owner"
11+
ORG_ROLE_ADMIN = "admin"
12+
ORG_ROLE_MEMBER = "member"
13+
VALID_ORG_ROLES = {ORG_ROLE_OWNER, ORG_ROLE_ADMIN, ORG_ROLE_MEMBER}
14+
15+
16+
class OrganizationBase(SQLModel):
17+
"""Base organization model with shared fields."""
18+
19+
name: str = Field(index=True)
20+
slug: str = Field(unique=True, index=True)
21+
description: str | None = None
22+
is_active: bool = Field(default=True)
23+
24+
25+
class Organization(OrganizationBase, table=True):
26+
"""Organization database model."""
27+
28+
id: int | None = Field(default=None, primary_key=True)
29+
created_at: datetime = Field(
30+
default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)
31+
)
32+
updated_at: datetime | None = None
33+
34+
35+
class OrganizationMember(SQLModel, table=True):
36+
"""Organization membership database model."""
37+
38+
__tablename__ = "organization_member"
39+
40+
id: int | None = Field(default=None, primary_key=True)
41+
organization_id: int = Field(foreign_key="organization.id", index=True)
42+
user_id: int = Field(foreign_key="user.id", index=True)
43+
role: str = Field(default=ORG_ROLE_MEMBER)
44+
joined_at: datetime = Field(
45+
default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)
46+
)
47+
48+
49+
class OrgCreate(OrganizationBase):
50+
"""Organization creation model."""
51+
52+
pass
53+
54+
55+
class OrgResponse(OrganizationBase):
56+
"""Organization response model."""
57+
58+
id: int
59+
created_at: datetime
60+
updated_at: datetime | None = None
61+
62+
63+
class MemberResponse(SQLModel):
64+
"""Organization member response model."""
65+
66+
id: int
67+
organization_id: int
68+
user_id: int
69+
role: str
70+
joined_at: datetime
71+
{% else %}
72+
# Organization models not included (auth_level != org)
73+
{% endif %}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Membership service for managing organization members."""
2+
{% if include_auth_org %}
3+
4+
from datetime import UTC, datetime
5+
6+
from sqlmodel import select
7+
from sqlmodel.ext.asyncio.session import AsyncSession
8+
9+
from app.models.org import (
10+
VALID_ORG_ROLES,
11+
Organization,
12+
OrganizationMember,
13+
)
14+
15+
16+
class MembershipService:
17+
"""Service for managing organization memberships."""
18+
19+
def __init__(self, db: AsyncSession) -> None:
20+
self.db = db
21+
22+
async def add_member(
23+
self, org_id: int, user_id: int, role: str = "member"
24+
) -> OrganizationMember:
25+
"""Add a user to an organization."""
26+
if role not in VALID_ORG_ROLES:
27+
raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}")
28+
member = OrganizationMember(
29+
organization_id=org_id,
30+
user_id=user_id,
31+
role=role,
32+
joined_at=datetime.now(UTC).replace(tzinfo=None),
33+
)
34+
self.db.add(member)
35+
await self.db.commit()
36+
await self.db.refresh(member)
37+
return member
38+
39+
async def remove_member(self, org_id: int, user_id: int) -> bool:
40+
"""Remove a user from an organization."""
41+
member = await self.get_member(org_id, user_id)
42+
if not member:
43+
return False
44+
await self.db.delete(member)
45+
await self.db.commit()
46+
return True
47+
48+
async def get_member(
49+
self, org_id: int, user_id: int
50+
) -> OrganizationMember | None:
51+
"""Get a specific membership."""
52+
statement = select(OrganizationMember).where(
53+
OrganizationMember.organization_id == org_id,
54+
OrganizationMember.user_id == user_id,
55+
)
56+
result = await self.db.exec(statement)
57+
return result.first()
58+
59+
async def update_member_role(
60+
self, org_id: int, user_id: int, role: str
61+
) -> OrganizationMember | None:
62+
"""Update a member's role within an organization."""
63+
if role not in VALID_ORG_ROLES:
64+
raise ValueError(f"Invalid org role: {role}. Valid: {VALID_ORG_ROLES}")
65+
member = await self.get_member(org_id, user_id)
66+
if not member:
67+
return None
68+
member.role = role
69+
self.db.add(member)
70+
await self.db.commit()
71+
await self.db.refresh(member)
72+
return member
73+
74+
async def list_org_members(self, org_id: int) -> list[OrganizationMember]:
75+
"""List all members of an organization."""
76+
statement = select(OrganizationMember).where(
77+
OrganizationMember.organization_id == org_id
78+
)
79+
result = await self.db.exec(statement)
80+
return list(result.all())
81+
82+
async def list_user_orgs(self, user_id: int) -> list[Organization]:
83+
"""List all organizations a user belongs to."""
84+
statement = (
85+
select(Organization)
86+
.join(
87+
OrganizationMember,
88+
OrganizationMember.organization_id == Organization.id,
89+
)
90+
.where(OrganizationMember.user_id == user_id)
91+
)
92+
result = await self.db.exec(statement)
93+
return list(result.all())
94+
{% else %}
95+
# Membership service not included (auth_level != org)
96+
{% endif %}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Organization service for CRUD operations."""
2+
{% if include_auth_org %}
3+
4+
from datetime import UTC, datetime
5+
6+
from sqlmodel import select
7+
from sqlmodel.ext.asyncio.session import AsyncSession
8+
9+
from app.models.org import OrgCreate, Organization
10+
11+
12+
class OrgService:
13+
"""Service for managing organizations."""
14+
15+
def __init__(self, db: AsyncSession) -> None:
16+
self.db = db
17+
18+
async def create_org(self, org_data: OrgCreate) -> Organization:
19+
"""Create a new organization."""
20+
org = Organization.model_validate(org_data)
21+
self.db.add(org)
22+
await self.db.commit()
23+
await self.db.refresh(org)
24+
return org
25+
26+
async def get_org_by_id(self, org_id: int) -> Organization | None:
27+
"""Get an organization by ID."""
28+
return await self.db.get(Organization, org_id)
29+
30+
async def get_org_by_slug(self, slug: str) -> Organization | None:
31+
"""Get an organization by slug."""
32+
statement = select(Organization).where(Organization.slug == slug)
33+
result = await self.db.exec(statement)
34+
return result.first()
35+
36+
async def update_org(self, org_id: int, **updates: str) -> Organization | None:
37+
"""Update an organization's fields."""
38+
org = await self.get_org_by_id(org_id)
39+
if not org:
40+
return None
41+
for field, value in updates.items():
42+
if hasattr(org, field):
43+
setattr(org, field, value)
44+
org.updated_at = datetime.now(UTC).replace(tzinfo=None)
45+
self.db.add(org)
46+
await self.db.commit()
47+
await self.db.refresh(org)
48+
return org
49+
50+
async def delete_org(self, org_id: int) -> bool:
51+
"""Delete an organization."""
52+
org = await self.get_org_by_id(org_id)
53+
if not org:
54+
return False
55+
await self.db.delete(org)
56+
await self.db.commit()
57+
return True
58+
59+
async def list_orgs(self) -> list[Organization]:
60+
"""List all organizations."""
61+
statement = select(Organization).order_by(Organization.created_at.desc())
62+
result = await self.db.exec(statement)
63+
return list(result.all())
64+
{% else %}
65+
# Organization service not included (auth_level != org)
66+
{% endif %}

tests/core/test_auth_service_parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def test_level_case_insensitive_basic_uppercase(self) -> None:
5858
result = parse_auth_service_config("auth[BASIC]")
5959
assert result.level == "basic"
6060

61+
def test_org_level(self) -> None:
62+
"""auth[org] → org"""
63+
result = parse_auth_service_config("auth[org]")
64+
assert result.level == "org"
65+
66+
def test_org_level_case_insensitive(self) -> None:
67+
"""auth[ORG] → org (case insensitive)"""
68+
result = parse_auth_service_config("auth[ORG]")
69+
assert result.level == "org"
70+
6171

6272
class TestAuthServiceParserWhitespace:
6373
"""Test whitespace handling."""

0 commit comments

Comments
 (0)