diff --git a/docs/docs.json b/docs/docs.json index 714d6cbdd..c62e67846 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -77,6 +77,7 @@ "observe/agent_tasks", "observe/buffer", "observe/disable_tasks", + "observe/disable_task_status_change", "observe/dashboard", "observe/traces" ] diff --git a/docs/observe/agent_tasks.mdx b/docs/observe/agent_tasks.mdx index f83881f9f..3d0678ef8 100644 --- a/docs/observe/agent_tasks.mdx +++ b/docs/observe/agent_tasks.mdx @@ -117,7 +117,10 @@ for (const task of tasks.items) { ## Next Steps - + + +Manually control task outcomes for benchmarking or human-in-the-loop workflows + Understand task extraction timing diff --git a/docs/observe/disable_task_status_change.mdx b/docs/observe/disable_task_status_change.mdx new file mode 100644 index 000000000..af8f770e7 --- /dev/null +++ b/docs/observe/disable_task_status_change.mdx @@ -0,0 +1,110 @@ +--- +title: "Control the Task Status" +description: "Disable automatic status updates and manually control task outcomes" +--- + +Keep task extraction running while preventing the task agent from automatically marking tasks as `success` or `failed`. You decide when to finalize task outcomes via `update_task_status`, which triggers the [self-learning pipeline](/learn/self-learning). + + +```python Python +session = client.sessions.create(disable_task_status_change=True) +``` + +```typescript TypeScript +const session = await client.sessions.create({ disableTaskStatusChange: true }); +``` + + +**What happens:** +- Tasks are still created with descriptions and progress +- Messages are still linked to tasks +- Status remains at `running` — the agent cannot set `success` or `failed` +- Learning is only triggered when you call `update_task_status` + +## Typical Flow + + + + + +```python Python +import os +from acontext import AcontextClient + +client = AcontextClient(api_key=os.getenv("ACONTEXT_API_KEY")) +session = client.sessions.create(disable_task_status_change=True) +``` + +```typescript TypeScript +import { AcontextClient } from '@acontext/acontext'; + +const client = new AcontextClient({ apiKey: process.env.ACONTEXT_API_KEY }); +const session = await client.sessions.create({ disableTaskStatusChange: true }); +``` + + + + + + +```python Python +for msg in conversation_messages: + client.sessions.store_message(session_id=session.id, blob=msg, format="openai") +client.sessions.flush(session.id) + +tasks = client.sessions.get_tasks(session.id) +``` + +```typescript TypeScript +for (const msg of conversationMessages) { + await client.sessions.storeMessage(session.id, msg, { format: "openai" }); +} +await client.sessions.flush(session.id); + +const tasks = await client.sessions.getTasks(session.id); +``` + + + + + +Setting status to `success` or `failed` triggers the self-learning pipeline. + + +```python Python +client.sessions.update_task_status( + session_id=session.id, + task_id=tasks.items[0].id, + status="success", +) +``` + +```typescript TypeScript +await client.sessions.updateTaskStatus(session.id, tasks.items[0].id, { + status: "success", +}); +``` + + + + +## How It Differs from `disable_task_tracking` + +| | `disable_task_tracking` | `disable_task_status_change` | +|---|---|---| +| **What's blocked** | Entire task extraction | Only status transitions to `success`/`failed` | +| **Tasks created?** | No | Yes | +| **Progress tracked?** | No | Yes | +| **Messages linked?** | No | Yes | +| **Learning triggered?** | Never | Only via manual `update_task_status` | + +## Next Steps + + + +How task extraction works + + +Learn agent skills from past sessions + + diff --git a/src/client/acontext-py/src/acontext/resources/async_sessions.py b/src/client/acontext-py/src/acontext/resources/async_sessions.py index ddefdaa8c..55770236c 100644 --- a/src/client/acontext-py/src/acontext/resources/async_sessions.py +++ b/src/client/acontext-py/src/acontext/resources/async_sessions.py @@ -17,6 +17,7 @@ Message, MessageObservingStatus, Session, + Task, TokenCounts, ) from ..uploads import FileUpload, normalize_file_upload @@ -84,6 +85,7 @@ async def create( *, user: str | None = None, disable_task_tracking: bool | None = None, + disable_task_status_change: bool | None = None, configs: Mapping[str, Any] | None = None, use_uuid: str | None = None, ) -> Session: @@ -92,6 +94,8 @@ async def create( Args: user: Optional user identifier string. Defaults to None. disable_task_tracking: Whether to disable task tracking for this session. Defaults to None (server default: False). + disable_task_status_change: Whether to disable automatic task status changes. When True, + the task agent will not set tasks to success/failed automatically. Defaults to None (server default: False). configs: Optional session configuration dictionary. Defaults to None. use_uuid: Optional UUID string to use as the session ID. If not provided, a UUID will be auto-generated. If a session with this UUID already exists, a 409 Conflict error will be raised. @@ -107,6 +111,8 @@ async def create( payload["user"] = user if disable_task_tracking is not None: payload["disable_task_tracking"] = disable_task_tracking + if disable_task_status_change is not None: + payload["disable_task_status_change"] = disable_task_status_change if configs is not None: payload["configs"] = configs if use_uuid is not None: @@ -178,6 +184,33 @@ async def get_tasks( ) return GetTasksOutput.model_validate(data) + async def update_task_status( + self, + session_id: str, + task_id: str, + *, + status: str, + ) -> Task: + """Update a task's status. + + Setting status to "success" or "failed" triggers the skill learning pipeline. + + Args: + session_id: The UUID of the session. + task_id: The UUID of the task. + status: New status for the task. Must be one of: "success", "failed", "running", "pending". + + Returns: + The updated Task object. + """ + payload = {"status": status} + data = await self._requester.request( + "PATCH", + f"/session/{session_id}/task/{task_id}/status", + json_data=payload, + ) + return Task.model_validate(data) + async def get_session_summary( self, session_id: str, diff --git a/src/client/acontext-py/src/acontext/resources/sessions.py b/src/client/acontext-py/src/acontext/resources/sessions.py index 790a3b768..b5196e704 100644 --- a/src/client/acontext-py/src/acontext/resources/sessions.py +++ b/src/client/acontext-py/src/acontext/resources/sessions.py @@ -17,6 +17,7 @@ Message, MessageObservingStatus, Session, + Task, TokenCounts, ) from ..uploads import FileUpload, normalize_file_upload @@ -84,6 +85,7 @@ def create( *, user: str | None = None, disable_task_tracking: bool | None = None, + disable_task_status_change: bool | None = None, configs: Mapping[str, Any] | None = None, use_uuid: str | None = None, ) -> Session: @@ -92,6 +94,8 @@ def create( Args: user: Optional user identifier string. Defaults to None. disable_task_tracking: Whether to disable task tracking for this session. Defaults to None (server default: False). + disable_task_status_change: Whether to disable automatic task status changes. When True, + the task agent will not set tasks to success/failed automatically. Defaults to None (server default: False). configs: Optional session configuration dictionary. Defaults to None. use_uuid: Optional UUID string to use as the session ID. If not provided, a UUID will be auto-generated. If a session with this UUID already exists, a 409 Conflict error will be raised. @@ -107,6 +111,8 @@ def create( payload["user"] = user if disable_task_tracking is not None: payload["disable_task_tracking"] = disable_task_tracking + if disable_task_status_change is not None: + payload["disable_task_status_change"] = disable_task_status_change if configs is not None: payload["configs"] = configs if use_uuid is not None: @@ -178,6 +184,33 @@ def get_tasks( ) return GetTasksOutput.model_validate(data) + def update_task_status( + self, + session_id: str, + task_id: str, + *, + status: str, + ) -> Task: + """Update a task's status. + + Setting status to "success" or "failed" triggers the skill learning pipeline. + + Args: + session_id: The UUID of the session. + task_id: The UUID of the task. + status: New status for the task. Must be one of: "success", "failed", "running", "pending". + + Returns: + The updated Task object. + """ + payload = {"status": status} + data = self._requester.request( + "PATCH", + f"/session/{session_id}/task/{task_id}/status", + json_data=payload, + ) + return Task.model_validate(data) + def get_session_summary( self, session_id: str, diff --git a/src/client/acontext-py/src/acontext/types/session.py b/src/client/acontext-py/src/acontext/types/session.py index b600200aa..630925560 100644 --- a/src/client/acontext-py/src/acontext/types/session.py +++ b/src/client/acontext-py/src/acontext/types/session.py @@ -177,6 +177,10 @@ class Session(BaseModel): disable_task_tracking: bool = Field( False, description="Whether task tracking is disabled for this session" ) + disable_task_status_change: bool = Field( + False, + description="Whether automatic task status changes are disabled for this session", + ) configs: dict[str, Any] | None = Field( None, description="Session configuration dictionary" ) diff --git a/src/client/acontext-py/tests/test_async_client.py b/src/client/acontext-py/tests/test_async_client.py index b08d8cb10..82acab8ea 100644 --- a/src/client/acontext-py/tests/test_async_client.py +++ b/src/client/acontext-py/tests/test_async_client.py @@ -560,6 +560,153 @@ async def test_async_sessions_get_token_counts( assert result.total_tokens == 1234 +@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_async_sessions_create_with_disable_task_status_change( + mock_request, async_client: AcontextAsyncClient +) -> None: + """Test that disable_task_status_change is sent to API when provided.""" + mock_request.return_value = { + "id": "session-id", + "project_id": "project-id", + "disable_task_tracking": False, + "disable_task_status_change": True, + "configs": {}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = await async_client.sessions.create(disable_task_status_change=True) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + method, path = args + assert method == "POST" + assert path == "/session" + assert kwargs["json_data"]["disable_task_status_change"] is True + assert result.disable_task_status_change is True + + +@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_async_sessions_create_without_disable_task_status_change( + mock_request, async_client: AcontextAsyncClient +) -> None: + """Test that disable_task_status_change is not sent when not provided.""" + mock_request.return_value = { + "id": "session-id", + "project_id": "project-id", + "disable_task_tracking": False, + "disable_task_status_change": False, + "configs": {}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + await async_client.sessions.create() + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert "disable_task_status_change" not in (kwargs.get("json_data") or {}) + + +@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_async_sessions_update_task_status_success( + mock_request, async_client: AcontextAsyncClient +) -> None: + """Test update_task_status sends correct PATCH request and returns Task.""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": { + "task_description": "Implement auth", + "progresses": ["Created login form"], + }, + "status": "success", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = await async_client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="success", + ) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + method, path = args + assert method == "PATCH" + assert path == "/session/session-uuid/task/task-uuid/status" + assert kwargs["json_data"] == {"status": "success"} + assert result.id == "task-uuid" + assert result.status == "success" + assert result.data.task_description == "Implement auth" + + +@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_async_sessions_update_task_status_failed( + mock_request, async_client: AcontextAsyncClient +) -> None: + """Test update_task_status with failed status.""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": {"task_description": "Fix bug"}, + "status": "failed", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = await async_client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="failed", + ) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert kwargs["json_data"] == {"status": "failed"} + assert result.status == "failed" + + +@patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_async_sessions_update_task_status_running( + mock_request, async_client: AcontextAsyncClient +) -> None: + """Test update_task_status with running status.""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": {"task_description": "Process data"}, + "status": "running", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = await async_client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="running", + ) + + assert result.status == "running" + args, kwargs = mock_request.call_args + assert kwargs["json_data"] == {"status": "running"} + + @patch("acontext.async_client.AcontextAsyncClient.request", new_callable=AsyncMock) @pytest.mark.asyncio async def test_async_disks_create_hits_disk_endpoint( diff --git a/src/client/acontext-py/tests/test_client.py b/src/client/acontext-py/tests/test_client.py index c157c0c38..b4acdac4e 100644 --- a/src/client/acontext-py/tests/test_client.py +++ b/src/client/acontext-py/tests/test_client.py @@ -727,6 +727,151 @@ def test_sessions_get_token_counts(mock_request, client: AcontextClient) -> None assert result.total_tokens == 1234 +@patch("acontext.client.AcontextClient.request") +def test_sessions_create_with_disable_task_status_change( + mock_request, client: AcontextClient +) -> None: + """Test that disable_task_status_change is sent to API when provided.""" + mock_request.return_value = { + "id": "session-id", + "project_id": "project-id", + "disable_task_tracking": False, + "disable_task_status_change": True, + "configs": {}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = client.sessions.create(disable_task_status_change=True) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + method, path = args + assert method == "POST" + assert path == "/session" + assert kwargs["json_data"]["disable_task_status_change"] is True + assert result.disable_task_status_change is True + + +@patch("acontext.client.AcontextClient.request") +def test_sessions_create_without_disable_task_status_change( + mock_request, client: AcontextClient +) -> None: + """Test that disable_task_status_change is not sent when not provided.""" + mock_request.return_value = { + "id": "session-id", + "project_id": "project-id", + "disable_task_tracking": False, + "disable_task_status_change": False, + "configs": {}, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + client.sessions.create() + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert "disable_task_status_change" not in (kwargs.get("json_data") or {}) + + +@patch("acontext.client.AcontextClient.request") +def test_sessions_update_task_status_success( + mock_request, client: AcontextClient +) -> None: + """Test update_task_status sends correct PATCH request and returns Task.""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": { + "task_description": "Implement auth", + "progresses": ["Created login form"], + }, + "status": "success", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="success", + ) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + method, path = args + assert method == "PATCH" + assert path == "/session/session-uuid/task/task-uuid/status" + assert kwargs["json_data"] == {"status": "success"} + assert result.id == "task-uuid" + assert result.status == "success" + assert result.data.task_description == "Implement auth" + + +@patch("acontext.client.AcontextClient.request") +def test_sessions_update_task_status_failed( + mock_request, client: AcontextClient +) -> None: + """Test update_task_status with failed status.""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": {"task_description": "Fix bug"}, + "status": "failed", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="failed", + ) + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + method, path = args + assert method == "PATCH" + assert path == "/session/session-uuid/task/task-uuid/status" + assert kwargs["json_data"] == {"status": "failed"} + assert result.status == "failed" + + +@patch("acontext.client.AcontextClient.request") +def test_sessions_update_task_status_running( + mock_request, client: AcontextClient +) -> None: + """Test update_task_status with running status (no learning triggered).""" + mock_request.return_value = { + "id": "task-uuid", + "session_id": "session-uuid", + "project_id": "project-uuid", + "order": 1, + "data": {"task_description": "Process data"}, + "status": "running", + "is_planning": False, + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + } + + result = client.sessions.update_task_status( + session_id="session-uuid", + task_id="task-uuid", + status="running", + ) + + assert result.status == "running" + args, kwargs = mock_request.call_args + assert kwargs["json_data"] == {"status": "running"} + + @patch("acontext.client.AcontextClient.request") def test_disks_create_hits_disk_endpoint(mock_request, client: AcontextClient) -> None: mock_request.return_value = { diff --git a/src/client/acontext-ts/src/resources/sessions.ts b/src/client/acontext-ts/src/resources/sessions.ts index 4c16f6a53..acf14ff78 100644 --- a/src/client/acontext-ts/src/resources/sessions.ts +++ b/src/client/acontext-ts/src/resources/sessions.ts @@ -23,6 +23,8 @@ import { MessageSchema, Session, SessionSchema, + Task, + TaskSchema, TokenCounts, TokenCountsSchema, } from '../types'; @@ -81,6 +83,8 @@ export class SessionsAPI { * @param options - Options for creating a session. * @param options.user - Optional user identifier string. * @param options.disableTaskTracking - Whether to disable task tracking for this session. + * @param options.disableTaskStatusChange - Whether to disable automatic task status changes. + * When true, the task agent will not set tasks to success/failed automatically. * @param options.configs - Optional session configuration dictionary. * @param options.useUuid - Optional UUID string to use as the session ID. If not provided, a UUID will be auto-generated. * If a session with this UUID already exists, a 409 Conflict error will be raised. @@ -89,6 +93,7 @@ export class SessionsAPI { async create(options?: { user?: string | null; disableTaskTracking?: boolean | null; + disableTaskStatusChange?: boolean | null; configs?: Record; useUuid?: string | null; }): Promise { @@ -99,6 +104,9 @@ export class SessionsAPI { if (options?.disableTaskTracking !== undefined && options?.disableTaskTracking !== null) { payload.disable_task_tracking = options.disableTaskTracking; } + if (options?.disableTaskStatusChange !== undefined && options?.disableTaskStatusChange !== null) { + payload.disable_task_status_change = options.disableTaskStatusChange; + } if (options?.configs !== undefined) { payload.configs = options.configs; } @@ -151,6 +159,29 @@ export class SessionsAPI { return GetTasksOutputSchema.parse(data); } + /** + * Update a task's status. + * + * Setting status to "success" or "failed" triggers the skill learning pipeline. + * + * @param sessionId - The UUID of the session. + * @param taskId - The UUID of the task. + * @param options - Options containing the new status. + * @param options.status - New status: "success", "failed", "running", or "pending". + * @returns The updated Task object. + */ + async updateTaskStatus( + sessionId: string, + taskId: string, + options: { status: string } + ): Promise { + const payload = { status: options.status }; + const data = await this.requester.request('PATCH', `/session/${sessionId}/task/${taskId}/status`, { + jsonData: payload, + }); + return TaskSchema.parse(data); + } + /** * Get a summary of all tasks in a session as a formatted string. * diff --git a/src/client/acontext-ts/src/types/session.ts b/src/client/acontext-ts/src/types/session.ts index 50bfaebfe..2b4e69c87 100644 --- a/src/client/acontext-ts/src/types/session.ts +++ b/src/client/acontext-ts/src/types/session.ts @@ -49,6 +49,7 @@ export const SessionSchema = z.object({ project_id: z.string(), user_id: z.string().nullable().optional(), disable_task_tracking: z.boolean(), + disable_task_status_change: z.boolean(), configs: z.record(z.string(), z.unknown()).nullable(), created_at: z.string(), updated_at: z.string(), diff --git a/src/client/acontext-ts/tests/client.test.ts b/src/client/acontext-ts/tests/client.test.ts index 740f346ec..f448906eb 100644 --- a/src/client/acontext-ts/tests/client.test.ts +++ b/src/client/acontext-ts/tests/client.test.ts @@ -272,6 +272,96 @@ describe('AcontextClient Unit Tests', () => { expect(result.has_more).toBeDefined(); }); + test('should create a session with disableTaskStatusChange', async () => { + const createdSession = mockSession({ + disable_task_status_change: true, + }); + client.mock().onPost('/session', (options) => { + const data = options?.jsonData as Record; + expect(data?.disable_task_status_change).toBe(true); + return createdSession; + }); + + const session = await client.sessions.create({ + disableTaskStatusChange: true, + }); + expect(session).toBeDefined(); + expect(session.disable_task_status_change).toBe(true); + }); + + test('should not send disableTaskStatusChange when not provided', async () => { + const createdSession = mockSession(); + client.mock().onPost('/session', (options) => { + const data = options?.jsonData as Record; + expect(data?.disable_task_status_change).toBeUndefined(); + return createdSession; + }); + + const session = await client.sessions.create(); + expect(session).toBeDefined(); + expect(session.disable_task_status_change).toBe(false); + }); + + test('should update task status to success', async () => { + const sessionId = 'test-session-id'; + const taskId = 'test-task-id'; + const updatedTask = mockTask({ + id: taskId, + session_id: sessionId, + status: 'success', + }); + client.mock().onPatch(`/session/${sessionId}/task/${taskId}/status`, (options) => { + expect(options?.jsonData).toEqual({ status: 'success' }); + return updatedTask; + }); + + const result = await client.sessions.updateTaskStatus(sessionId, taskId, { + status: 'success', + }); + expect(result).toBeDefined(); + expect(result.status).toBe('success'); + expect(result.id).toBe(taskId); + }); + + test('should update task status to failed', async () => { + const sessionId = 'test-session-id'; + const taskId = 'test-task-id'; + const updatedTask = mockTask({ + id: taskId, + session_id: sessionId, + status: 'failed', + }); + client.mock().onPatch(`/session/${sessionId}/task/${taskId}/status`, (options) => { + expect(options?.jsonData).toEqual({ status: 'failed' }); + return updatedTask; + }); + + const result = await client.sessions.updateTaskStatus(sessionId, taskId, { + status: 'failed', + }); + expect(result).toBeDefined(); + expect(result.status).toBe('failed'); + }); + + test('should update task status to running', async () => { + const sessionId = 'test-session-id'; + const taskId = 'test-task-id'; + const updatedTask = mockTask({ + id: taskId, + session_id: sessionId, + status: 'running', + }); + client.mock().onPatch(`/session/${sessionId}/task/${taskId}/status`, (options) => { + expect(options?.jsonData).toEqual({ status: 'running' }); + return updatedTask; + }); + + const result = await client.sessions.updateTaskStatus(sessionId, taskId, { + status: 'running', + }); + expect(result.status).toBe('running'); + }); + test('should get token counts', async () => { const sessionId = 'test-session-id'; const tokenCounts = { total_tokens: 1234 }; diff --git a/src/client/acontext-ts/tests/mocks.ts b/src/client/acontext-ts/tests/mocks.ts index e2b48d254..2c98117b3 100644 --- a/src/client/acontext-ts/tests/mocks.ts +++ b/src/client/acontext-ts/tests/mocks.ts @@ -199,6 +199,7 @@ export function mockSession(overrides?: Partial<{ project_id: string; user_id: string | null; disable_task_tracking: boolean; + disable_task_status_change: boolean; configs: Record | null; created_at: string; updated_at: string; @@ -209,6 +210,7 @@ export function mockSession(overrides?: Partial<{ project_id: overrides?.project_id ?? mockId(), user_id: overrides?.user_id ?? null, disable_task_tracking: overrides?.disable_task_tracking ?? false, + disable_task_status_change: overrides?.disable_task_status_change ?? false, configs: overrides?.configs ?? {}, created_at: overrides?.created_at ?? now, updated_at: overrides?.updated_at ?? now, diff --git a/src/server/api/go/internal/bootstrap/container.go b/src/server/api/go/internal/bootstrap/container.go index 0d60af8b4..c286765a6 100644 --- a/src/server/api/go/internal/bootstrap/container.go +++ b/src/server/api/go/internal/bootstrap/container.go @@ -282,6 +282,9 @@ func BuildContainer() *do.Injector { do.Provide(inj, func(i *do.Injector) (service.TaskService, error) { return service.NewTaskService( do.MustInvoke[repo.TaskRepo](i), + do.MustInvoke[repo.LearningSpaceSessionRepo](i), + do.MustInvoke[*mq.Publisher](i), + do.MustInvoke[*config.Config](i), do.MustInvoke[*zap.Logger](i), ), nil }) diff --git a/src/server/api/go/internal/config/config.go b/src/server/api/go/internal/config/config.go index bd39371e5..c79b19c19 100644 --- a/src/server/api/go/internal/config/config.go +++ b/src/server/api/go/internal/config/config.go @@ -46,10 +46,12 @@ type RedisCfg struct { type MQExchangeName struct { SessionMessage string + LearningSkill string } type MQRoutingKey struct { SessionMessageInsert string + LearningSkillDistill string } type MQCfg struct { URL string @@ -121,7 +123,9 @@ func setDefaults(v *viper.Viper) { v.SetDefault("rabbitmq.url", "amqp://acontext:helloworld@127.0.0.1:15672/%2F") v.SetDefault("rabbitmq.enableTLS", false) v.SetDefault("rabbitmq.exchangeName.sessionMessage", "session.message") + v.SetDefault("rabbitmq.exchangeName.learningSkill", "learning.skill") v.SetDefault("rabbitmq.routingKey.sessionMessageInsert", "session.message.insert") + v.SetDefault("rabbitmq.routingKey.learningSkillDistill", "learning.skill.distill") v.SetDefault("core.baseURL", "http://127.0.0.1:8019") v.SetDefault("telemetry.otlpEndpoint", "http://127.0.0.1:4317") v.SetDefault("telemetry.enabled", true) diff --git a/src/server/api/go/internal/modules/handler/session.go b/src/server/api/go/internal/modules/handler/session.go index dcbc69a6e..c2bb9d000 100644 --- a/src/server/api/go/internal/modules/handler/session.go +++ b/src/server/api/go/internal/modules/handler/session.go @@ -45,10 +45,11 @@ func NewSessionHandler(s service.SessionService, userSvc service.UserService, co } type CreateSessionReq struct { - User string `form:"user" json:"user" example:"alice@acontext.io"` - DisableTaskTracking *bool `form:"disable_task_tracking" json:"disable_task_tracking" example:"false"` - Configs map[string]interface{} `form:"configs" json:"configs"` - UseUUID *string `form:"use_uuid" json:"use_uuid" example:"123e4567-e89b-12d3-a456-426614174000"` + User string `form:"user" json:"user" example:"alice@acontext.io"` + DisableTaskTracking *bool `form:"disable_task_tracking" json:"disable_task_tracking" example:"false"` + DisableTaskStatusChange *bool `form:"disable_task_status_change" json:"disable_task_status_change" example:"false"` + Configs map[string]interface{} `form:"configs" json:"configs"` + UseUUID *string `form:"use_uuid" json:"use_uuid" example:"123e4567-e89b-12d3-a456-426614174000"` } type GetSessionsReq struct { @@ -173,6 +174,9 @@ func (h *SessionHandler) CreateSession(c *gin.Context) { if req.DisableTaskTracking != nil { session.DisableTaskTracking = *req.DisableTaskTracking } + if req.DisableTaskStatusChange != nil { + session.DisableTaskStatusChange = *req.DisableTaskStatusChange + } if err := h.svc.Create(c.Request.Context(), &session); err != nil { // Check for duplicate key error (PostgreSQL unique violation) if strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "23505") { diff --git a/src/server/api/go/internal/modules/handler/task.go b/src/server/api/go/internal/modules/handler/task.go index befbff249..3a9eeabc9 100644 --- a/src/server/api/go/internal/modules/handler/task.go +++ b/src/server/api/go/internal/modules/handler/task.go @@ -2,9 +2,11 @@ package handler import ( "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/memodb-io/Acontext/internal/modules/model" "github.com/memodb-io/Acontext/internal/modules/serializer" "github.com/memodb-io/Acontext/internal/modules/service" ) @@ -64,3 +66,65 @@ func (h *TaskHandler) GetTasks(c *gin.Context) { c.JSON(http.StatusOK, serializer.Response{Data: out}) } + +type UpdateTaskStatusReq struct { + Status string `json:"status" binding:"required,oneof=success failed running pending"` +} + +// UpdateTaskStatus godoc +// +// @Summary Update task status +// @Description Update a task's status. Setting status to "success" or "failed" triggers the skill learning pipeline. +// @Tags task +// @Accept json +// @Produce json +// @Param session_id path string true "Session ID" format(uuid) +// @Param task_id path string true "Task ID" format(uuid) +// @Param body body UpdateTaskStatusReq true "Status update" +// @Security BearerAuth +// @Success 200 {object} serializer.Response{data=model.Task} +// @Failure 400 {object} serializer.Response +// @Failure 404 {object} serializer.Response +// @Router /session/{session_id}/task/{task_id}/status [patch] +func (h *TaskHandler) UpdateTaskStatus(c *gin.Context) { + req := UpdateTaskStatusReq{} + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, serializer.ParamErr("invalid status value, must be one of: success, failed, running, pending", err)) + return + } + + project, ok := c.MustGet("project").(*model.Project) + if !ok { + c.JSON(http.StatusBadRequest, serializer.ParamErr("project not found", nil)) + return + } + + sessionID, err := uuid.Parse(c.Param("session_id")) + if err != nil { + c.JSON(http.StatusBadRequest, serializer.ParamErr("invalid session_id", err)) + return + } + + taskID, err := uuid.Parse(c.Param("task_id")) + if err != nil { + c.JSON(http.StatusBadRequest, serializer.ParamErr("invalid task_id", err)) + return + } + + task, err := h.svc.UpdateTaskStatus(c.Request.Context(), service.UpdateTaskStatusInput{ + ProjectID: project.ID, + SessionID: sessionID, + TaskID: taskID, + Status: req.Status, + }) + if err != nil { + if strings.Contains(err.Error(), "not found") { + c.JSON(http.StatusNotFound, serializer.Err(http.StatusNotFound, err.Error(), nil)) + return + } + c.JSON(http.StatusBadRequest, serializer.DBErr("", err)) + return + } + + c.JSON(http.StatusOK, serializer.Response{Data: task}) +} diff --git a/src/server/api/go/internal/modules/handler/task_test.go b/src/server/api/go/internal/modules/handler/task_test.go index 83c5b883f..a00dfe30c 100644 --- a/src/server/api/go/internal/modules/handler/task_test.go +++ b/src/server/api/go/internal/modules/handler/task_test.go @@ -3,8 +3,10 @@ package handler import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -29,6 +31,14 @@ func (m *MockTaskService) GetTasks(ctx context.Context, in service.GetTasksInput return args.Get(0).(*service.GetTasksOutput), args.Error(1) } +func (m *MockTaskService) UpdateTaskStatus(ctx context.Context, in service.UpdateTaskStatusInput) (*model.Task, error) { + args := m.Called(ctx, in) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*model.Task), args.Error(1) +} + func TestTaskHandler_GetTasks(t *testing.T) { gin.SetMode(gin.TestMode) serializer.SetLogger(zap.NewNop()) @@ -173,3 +183,202 @@ func TestTaskHandler_GetTasks(t *testing.T) { }) } } + +func TestTaskHandler_UpdateTaskStatus(t *testing.T) { + gin.SetMode(gin.TestMode) + serializer.SetLogger(zap.NewNop()) + + projectID := uuid.New() + sessionID := uuid.New() + taskID := uuid.New() + + project := &model.Project{ID: projectID} + + tests := []struct { + name string + sessionIDParam string + taskIDParam string + body string + setProject bool + setup func(*MockTaskService) + expectedStatus int + checkResponse func(*testing.T, *httptest.ResponseRecorder) + }{ + { + name: "success - set status to success", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"success"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.MatchedBy(func(in service.UpdateTaskStatusInput) bool { + return in.ProjectID == projectID && in.SessionID == sessionID && in.TaskID == taskID && in.Status == "success" + })).Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "success", + }, nil) + }, + expectedStatus: http.StatusOK, + checkResponse: func(t *testing.T, rec *httptest.ResponseRecorder) { + var resp serializer.Response + err := json.Unmarshal(rec.Body.Bytes(), &resp) + assert.NoError(t, err) + assert.Equal(t, 0, resp.Code) + + data, ok := resp.Data.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, "success", data["status"]) + }, + }, + { + name: "success - set status to failed", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"failed"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.MatchedBy(func(in service.UpdateTaskStatusInput) bool { + return in.Status == "failed" + })).Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "failed", + }, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "success - set status to running", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"running"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.MatchedBy(func(in service.UpdateTaskStatusInput) bool { + return in.Status == "running" + })).Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "running", + }, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "success - set status to pending", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"pending"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.MatchedBy(func(in service.UpdateTaskStatusInput) bool { + return in.Status == "pending" + })).Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "pending", + }, nil) + }, + expectedStatus: http.StatusOK, + }, + { + name: "error - invalid status value", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"completed"}`, + setProject: true, + setup: func(svc *MockTaskService) {}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "error - missing status", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{}`, + setProject: true, + setup: func(svc *MockTaskService) {}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "error - invalid session_id", + sessionIDParam: "not-a-uuid", + taskIDParam: taskID.String(), + body: `{"status":"success"}`, + setProject: true, + setup: func(svc *MockTaskService) {}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "error - invalid task_id", + sessionIDParam: sessionID.String(), + taskIDParam: "not-a-uuid", + body: `{"status":"success"}`, + setProject: true, + setup: func(svc *MockTaskService) {}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "error - task not found", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"success"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.Anything). + Return(nil, assert.AnError) + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "error - task not found returns 404", + sessionIDParam: sessionID.String(), + taskIDParam: taskID.String(), + body: `{"status":"success"}`, + setProject: true, + setup: func(svc *MockTaskService) { + svc.On("UpdateTaskStatus", mock.Anything, mock.Anything). + Return(nil, assert.AnError).Run(func(args mock.Arguments) {}) + svc.ExpectedCalls[0].ReturnArguments = mock.Arguments{(*model.Task)(nil), fmt.Errorf("task not found or does not belong to this session")} + }, + expectedStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &MockTaskService{} + tt.setup(svc) + + handler := NewTaskHandler(svc) + + w := httptest.NewRecorder() + _, r := gin.CreateTestContext(w) + + r.PATCH("/session/:session_id/task/:task_id/status", func(c *gin.Context) { + if tt.setProject { + c.Set("project", project) + } + handler.UpdateTaskStatus(c) + }) + + req := httptest.NewRequest(http.MethodPatch, "/session/"+tt.sessionIDParam+"/task/"+tt.taskIDParam+"/status", + strings.NewReader(tt.body)) + req.Header.Set("Content-Type", "application/json") + + r.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + if tt.checkResponse != nil { + tt.checkResponse(t, w) + } + + svc.AssertExpectations(t) + }) + } +} diff --git a/src/server/api/go/internal/modules/model/session.go b/src/server/api/go/internal/modules/model/session.go index 093bebdcd..15a49c5e2 100644 --- a/src/server/api/go/internal/modules/model/session.go +++ b/src/server/api/go/internal/modules/model/session.go @@ -11,7 +11,8 @@ type Session struct { ID uuid.UUID `gorm:"type:uuid;default:gen_random_uuid();primaryKey" json:"id"` ProjectID uuid.UUID `gorm:"type:uuid;not null;index" json:"project_id"` UserID *uuid.UUID `gorm:"type:uuid;index" json:"user_id"` - DisableTaskTracking bool `gorm:"not null;default:false" json:"disable_task_tracking"` + DisableTaskTracking bool `gorm:"not null;default:false" json:"disable_task_tracking"` + DisableTaskStatusChange bool `gorm:"not null;default:false" json:"disable_task_status_change"` Configs datatypes.JSONMap `gorm:"type:jsonb;index:idx_sessions_configs,type:gin" swaggertype:"object" json:"configs"` CreatedAt time.Time `gorm:"autoCreateTime;not null;default:CURRENT_TIMESTAMP" json:"created_at"` diff --git a/src/server/api/go/internal/modules/repo/task.go b/src/server/api/go/internal/modules/repo/task.go index 123fa6db2..c515eaa0e 100644 --- a/src/server/api/go/internal/modules/repo/task.go +++ b/src/server/api/go/internal/modules/repo/task.go @@ -2,6 +2,7 @@ package repo import ( "context" + "fmt" "time" "github.com/google/uuid" @@ -11,6 +12,7 @@ import ( type TaskRepo interface { ListBySessionWithCursor(ctx context.Context, sessionID uuid.UUID, afterCreatedAt time.Time, afterID uuid.UUID, limit int, timeDesc bool) ([]model.Task, error) + UpdateStatus(ctx context.Context, projectID uuid.UUID, sessionID uuid.UUID, taskID uuid.UUID, status string) (*model.Task, error) } type taskRepo struct{ db *gorm.DB } @@ -44,3 +46,24 @@ func (r *taskRepo) ListBySessionWithCursor(ctx context.Context, sessionID uuid.U var items []model.Task return items, q.Order(orderBy).Limit(limit).Find(&items).Error } + +func (r *taskRepo) UpdateStatus(ctx context.Context, projectID uuid.UUID, sessionID uuid.UUID, taskID uuid.UUID, status string) (*model.Task, error) { + validStatuses := map[string]bool{"success": true, "failed": true, "running": true, "pending": true} + if !validStatuses[status] { + return nil, fmt.Errorf("invalid status: %s", status) + } + + var task model.Task + result := r.db.WithContext(ctx). + Where("id = ? AND session_id = ? AND project_id = ?", taskID, sessionID, projectID). + First(&task) + if result.Error != nil { + return nil, result.Error + } + + task.Status = status + if err := r.db.WithContext(ctx).Save(&task).Error; err != nil { + return nil, err + } + return &task, nil +} diff --git a/src/server/api/go/internal/modules/service/learning_space_test.go b/src/server/api/go/internal/modules/service/learning_space_test.go index e53b9952b..b1494f1b4 100644 --- a/src/server/api/go/internal/modules/service/learning_space_test.go +++ b/src/server/api/go/internal/modules/service/learning_space_test.go @@ -221,8 +221,8 @@ description: "Capture and recall general facts about the user" func newTestTemplateFS() fstest.MapFS { return fstest.MapFS{ - "skill_templates/daily-logs/SKILL.md": &fstest.MapFile{Data: []byte(testDailyLogsTemplate)}, - "skill_templates/user-general-facts/SKILL.md": &fstest.MapFile{Data: []byte(testUserFactsTemplate)}, + "skill_templates/daily-logs/SKILL.md": &fstest.MapFile{Data: []byte(testDailyLogsTemplate)}, + "skill_templates/user-general-facts/SKILL.md": &fstest.MapFile{Data: []byte(testUserFactsTemplate)}, } } diff --git a/src/server/api/go/internal/modules/service/task.go b/src/server/api/go/internal/modules/service/task.go index 673dd53c8..41aed6bdb 100644 --- a/src/server/api/go/internal/modules/service/task.go +++ b/src/server/api/go/internal/modules/service/task.go @@ -2,28 +2,40 @@ package service import ( "context" + "errors" + "fmt" "time" "github.com/google/uuid" + "github.com/memodb-io/Acontext/internal/config" + mq "github.com/memodb-io/Acontext/internal/infra/queue" "github.com/memodb-io/Acontext/internal/modules/model" "github.com/memodb-io/Acontext/internal/modules/repo" "github.com/memodb-io/Acontext/internal/pkg/paging" "go.uber.org/zap" + "gorm.io/gorm" ) type TaskService interface { GetTasks(ctx context.Context, in GetTasksInput) (*GetTasksOutput, error) + UpdateTaskStatus(ctx context.Context, in UpdateTaskStatusInput) (*model.Task, error) } type taskService struct { - r repo.TaskRepo - log *zap.Logger + r repo.TaskRepo + lssRepo repo.LearningSpaceSessionRepo + publisher *mq.Publisher + cfg *config.Config + log *zap.Logger } -func NewTaskService(r repo.TaskRepo, log *zap.Logger) TaskService { +func NewTaskService(r repo.TaskRepo, lssRepo repo.LearningSpaceSessionRepo, publisher *mq.Publisher, cfg *config.Config, log *zap.Logger) TaskService { return &taskService{ - r: r, - log: log, + r: r, + lssRepo: lssRepo, + publisher: publisher, + cfg: cfg, + log: log, } } @@ -40,8 +52,20 @@ type GetTasksOutput struct { HasMore bool `json:"has_more"` } +type UpdateTaskStatusInput struct { + ProjectID uuid.UUID `json:"project_id"` + SessionID uuid.UUID `json:"session_id"` + TaskID uuid.UUID `json:"task_id"` + Status string `json:"status"` +} + +type SkillLearnTaskMQ struct { + ProjectID uuid.UUID `json:"project_id"` + SessionID uuid.UUID `json:"session_id"` + TaskID uuid.UUID `json:"task_id"` +} + func (s *taskService) GetTasks(ctx context.Context, in GetTasksInput) (*GetTasksOutput, error) { - // Parse cursor (createdAt, id); an empty cursor indicates starting from the latest var afterT time.Time var afterID uuid.UUID var err error @@ -52,7 +76,6 @@ func (s *taskService) GetTasks(ctx context.Context, in GetTasksInput) (*GetTasks } } - // Query limit+1 is used to determine has_more tasks, err := s.r.ListBySessionWithCursor(ctx, in.SessionID, afterT, afterID, in.Limit+1, in.TimeDesc) if err != nil { return nil, err @@ -71,3 +94,32 @@ func (s *taskService) GetTasks(ctx context.Context, in GetTasksInput) (*GetTasks return out, nil } + +func (s *taskService) UpdateTaskStatus(ctx context.Context, in UpdateTaskStatusInput) (*model.Task, error) { + task, err := s.r.UpdateStatus(ctx, in.ProjectID, in.SessionID, in.TaskID, in.Status) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("task not found or does not belong to this session") + } + return nil, fmt.Errorf("failed to update task status: %w", err) + } + + if in.Status == "success" || in.Status == "failed" { + exists, err := s.lssRepo.ExistsBySessionID(ctx, in.SessionID) + if err != nil { + s.log.Warn("failed to check learning space for session, skipping skill learning publish", zap.Error(err), zap.String("session_id", in.SessionID.String())) + } else if !exists { + s.log.Debug("no learning space found for session, skipping skill learning publish", zap.String("session_id", in.SessionID.String())) + } else if s.publisher != nil { + if pubErr := s.publisher.PublishJSON(ctx, s.cfg.RabbitMQ.ExchangeName.LearningSkill, s.cfg.RabbitMQ.RoutingKey.LearningSkillDistill, SkillLearnTaskMQ{ + ProjectID: task.ProjectID, + SessionID: in.SessionID, + TaskID: in.TaskID, + }); pubErr != nil { + s.log.Error("failed to publish skill learning task", zap.Error(pubErr), zap.String("session_id", in.SessionID.String())) + } + } + } + + return task, nil +} diff --git a/src/server/api/go/internal/modules/service/task_test.go b/src/server/api/go/internal/modules/service/task_test.go new file mode 100644 index 000000000..3f1fb7700 --- /dev/null +++ b/src/server/api/go/internal/modules/service/task_test.go @@ -0,0 +1,235 @@ +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/memodb-io/Acontext/internal/modules/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/zap" + "gorm.io/gorm" +) + +type MockTaskRepo struct { + mock.Mock +} + +func (m *MockTaskRepo) ListBySessionWithCursor(ctx context.Context, sessionID uuid.UUID, afterCreatedAt time.Time, afterID uuid.UUID, limit int, timeDesc bool) ([]model.Task, error) { + args := m.Called(ctx, sessionID, afterCreatedAt, afterID, limit, timeDesc) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]model.Task), args.Error(1) +} + +func (m *MockTaskRepo) UpdateStatus(ctx context.Context, projectID uuid.UUID, sessionID uuid.UUID, taskID uuid.UUID, status string) (*model.Task, error) { + args := m.Called(ctx, projectID, sessionID, taskID, status) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*model.Task), args.Error(1) +} + +// MockLearningSpaceSessionRepo is defined in learning_space_test.go + +func TestTaskService_UpdateTaskStatus(t *testing.T) { + projectID := uuid.New() + sessionID := uuid.New() + taskID := uuid.New() + + tests := []struct { + name string + input UpdateTaskStatusInput + setupRepo func(*MockTaskRepo) + setupLSS func(*MockLearningSpaceSessionRepo) + expectErr bool + errContains string + checkResult func(*testing.T, *model.Task) + }{ + { + name: "success - update to success, learning space exists", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "success", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "success"). + Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "success", + }, nil) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) { + lss.On("ExistsBySessionID", mock.Anything, sessionID).Return(true, nil) + }, + checkResult: func(t *testing.T, task *model.Task) { + assert.Equal(t, "success", task.Status) + assert.Equal(t, taskID, task.ID) + }, + }, + { + name: "success - update to failed, no learning space", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "failed", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "failed"). + Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "failed", + }, nil) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) { + lss.On("ExistsBySessionID", mock.Anything, sessionID).Return(false, nil) + }, + checkResult: func(t *testing.T, task *model.Task) { + assert.Equal(t, "failed", task.Status) + }, + }, + { + name: "success - update to running, no learning space check", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "running", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "running"). + Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "running", + }, nil) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) { + // ExistsBySessionID should NOT be called for running/pending + }, + checkResult: func(t *testing.T, task *model.Task) { + assert.Equal(t, "running", task.Status) + }, + }, + { + name: "success - update to pending, no learning space check", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "pending", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "pending"). + Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "pending", + }, nil) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) {}, + checkResult: func(t *testing.T, task *model.Task) { + assert.Equal(t, "pending", task.Status) + }, + }, + { + name: "error - task not found (gorm.ErrRecordNotFound)", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "success", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "success"). + Return(nil, gorm.ErrRecordNotFound) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) {}, + expectErr: true, + errContains: "not found", + }, + { + name: "error - repo returns generic error", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "success", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "success"). + Return(nil, fmt.Errorf("database connection error")) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) {}, + expectErr: true, + errContains: "failed to update task status", + }, + { + name: "success - learning space check fails, status still updated", + input: UpdateTaskStatusInput{ + ProjectID: projectID, + SessionID: sessionID, + TaskID: taskID, + Status: "success", + }, + setupRepo: func(r *MockTaskRepo) { + r.On("UpdateStatus", mock.Anything, projectID, sessionID, taskID, "success"). + Return(&model.Task{ + ID: taskID, + SessionID: sessionID, + ProjectID: projectID, + Status: "success", + }, nil) + }, + setupLSS: func(lss *MockLearningSpaceSessionRepo) { + lss.On("ExistsBySessionID", mock.Anything, sessionID).Return(false, fmt.Errorf("db error")) + }, + checkResult: func(t *testing.T, task *model.Task) { + assert.Equal(t, "success", task.Status) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRepo := &MockTaskRepo{} + mockLSS := &MockLearningSpaceSessionRepo{} + tt.setupRepo(mockRepo) + tt.setupLSS(mockLSS) + + svc := NewTaskService(mockRepo, mockLSS, nil, nil, zap.NewNop()) + + result, err := svc.UpdateTaskStatus(context.Background(), tt.input) + + if tt.expectErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + if tt.checkResult != nil { + tt.checkResult(t, result) + } + } + + mockRepo.AssertExpectations(t) + mockLSS.AssertExpectations(t) + }) + } +} diff --git a/src/server/api/go/internal/router/router.go b/src/server/api/go/internal/router/router.go index 501fed5e7..d39f92251 100644 --- a/src/server/api/go/internal/router/router.go +++ b/src/server/api/go/internal/router/router.go @@ -87,6 +87,7 @@ func NewRouter(d RouterDeps) *gin.Engine { task := session.Group("/:session_id/task") { task.GET("", d.TaskHandler.GetTasks) + task.PATCH("/:task_id/status", d.TaskHandler.UpdateTaskStatus) } } diff --git a/src/server/core/acontext_core/llm/agent/task.py b/src/server/core/acontext_core/llm/agent/task.py index 80390dc79..c714b0043 100644 --- a/src/server/core/acontext_core/llm/agent/task.py +++ b/src/server/core/acontext_core/llm/agent/task.py @@ -96,6 +96,7 @@ async def build_task_ctx( session_id: asUUID, messages: list[MessageBlob], before_use_ctx: TaskCtx = None, + disable_task_status_change: bool = False, ) -> TaskCtx: if before_use_ctx is not None: before_use_ctx.db_session = db_session @@ -115,6 +116,7 @@ async def build_task_ctx( task_ids_index=[t.id for t in current_tasks], task_index=current_tasks, message_ids_index=[m.message_id for m in messages], + disable_task_status_change=disable_task_status_change, ) return use_ctx @@ -127,6 +129,7 @@ async def task_agent_curd( max_iterations=3, # task curd agent only receive one turn of actions previous_progress_num: int = 6, learning_space_id: Optional[asUUID] = None, + disable_task_status_change: bool = False, ) -> Result[None]: async with DB_CLIENT.get_session_context() as db_session: r = await TD.fetch_current_tasks(db_session, session_id) @@ -201,6 +204,7 @@ async def task_agent_curd( session_id, messages, before_use_ctx=USE_CTX, + disable_task_status_change=disable_task_status_change, ) r = await tool.handler(USE_CTX, tool_arguments) t, eil = r.unpack() diff --git a/src/server/core/acontext_core/llm/tool/task_lib/ctx.py b/src/server/core/acontext_core/llm/tool/task_lib/ctx.py index 1297f41eb..1b4be6ce5 100644 --- a/src/server/core/acontext_core/llm/tool/task_lib/ctx.py +++ b/src/server/core/acontext_core/llm/tool/task_lib/ctx.py @@ -14,3 +14,4 @@ class TaskCtx: message_ids_index: list[asUUID] learning_task_ids: list[asUUID] = field(default_factory=list) pending_preferences: list[str] = field(default_factory=list) + disable_task_status_change: bool = False diff --git a/src/server/core/acontext_core/llm/tool/task_lib/update.py b/src/server/core/acontext_core/llm/tool/task_lib/update.py index 788c65741..ec4ea47fc 100644 --- a/src/server/core/acontext_core/llm/tool/task_lib/update.py +++ b/src/server/core/acontext_core/llm/tool/task_lib/update.py @@ -21,6 +21,12 @@ async def update_task_handler( actually_task_id = ctx.task_ids_index[task_order - 1] task_status = llm_arguments.get("task_status", None) task_description = llm_arguments.get("task_description", None) + + status_skipped = False + if ctx.disable_task_status_change and task_status in ("success", "failed"): + task_status = None + status_skipped = True + r = await TD.update_task( ctx.db_session, actually_task_id, @@ -36,7 +42,7 @@ async def update_task_handler( t, eil = r.unpack() if eil: return r - if task_status in ("success", "failed"): + if not status_skipped and task_status in ("success", "failed"): ctx.learning_task_ids.append(actually_task_id) return Result.resolve(f"Task {t.order} updated") diff --git a/src/server/core/acontext_core/schema/orm/session.py b/src/server/core/acontext_core/schema/orm/session.py index 2c77d8beb..ffe0178df 100644 --- a/src/server/core/acontext_core/schema/orm/session.py +++ b/src/server/core/acontext_core/schema/orm/session.py @@ -51,6 +51,13 @@ class Session(CommonMixin): }, ) + disable_task_status_change: bool = field( + default=False, + metadata={ + "db": Column(Boolean, nullable=False, default=False, server_default="false") + }, + ) + configs: Optional[dict] = field( default=None, metadata={"db": Column(JSONB, nullable=True)} ) diff --git a/src/server/core/acontext_core/service/controller/message.py b/src/server/core/acontext_core/service/controller/message.py index c0507f7ba..67f8bd8d4 100644 --- a/src/server/core/acontext_core/service/controller/message.py +++ b/src/server/core/acontext_core/service/controller/message.py @@ -1,5 +1,6 @@ from ..data import message as MD from ..data import learning_space as LS +from ..data import session as SD from ...infra.db import DB_CLIENT from ...schema.session.task import TaskStatus from ...schema.session.message import MessageBlob @@ -72,6 +73,14 @@ async def process_session_pending_message( if eil: ls_session = None + r_sess = await SD.fetch_session(session, session_id) + sess_obj, eil = r_sess.unpack() + if eil: + LOG.warning(f"Failed to fetch session {session_id} for disable_task_status_change check: {eil}") + disable_task_status_change = ( + sess_obj.disable_task_status_change if sess_obj else False + ) + r = await AT.task_agent_curd( project_id, session_id, @@ -79,6 +88,7 @@ async def process_session_pending_message( max_iterations=project_config.default_task_agent_max_iterations, previous_progress_num=project_config.default_task_agent_previous_progress_num, learning_space_id=ls_session.learning_space_id if ls_session is not None else None, + disable_task_status_change=disable_task_status_change, ) after_status = TaskStatus.SUCCESS diff --git a/src/server/core/tests/llm/test_skill_learner_trigger.py b/src/server/core/tests/llm/test_skill_learner_trigger.py index 1a48b10ae..0f83ccc2a 100644 --- a/src/server/core/tests/llm/test_skill_learner_trigger.py +++ b/src/server/core/tests/llm/test_skill_learner_trigger.py @@ -183,6 +183,155 @@ async def test_multiple_updates_collect_all(self): assert task2.id in ctx.learning_task_ids +# ============================================================================= +# disable_task_status_change guard tests +# ============================================================================= + + +class TestDisableTaskStatusChange: + @pytest.mark.asyncio + async def test_flag_blocks_success_status(self): + """With disable_task_status_change=True, status='success' is stripped.""" + task = _make_task() + ctx = _make_ctx(tasks=[task]) + ctx.disable_task_status_change = True + + mock_updated = MagicMock() + mock_updated.order = 1 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + return_value=Result.resolve(mock_updated), + ) as mock_update: + result = await update_task_handler( + ctx, {"task_order": 1, "task_status": "success"} + ) + assert result.ok() + # status should have been passed as None + call_kwargs = mock_update.call_args + assert call_kwargs[1]["status"] is None or call_kwargs[0][2] is None + assert len(ctx.learning_task_ids) == 0 + + @pytest.mark.asyncio + async def test_flag_blocks_failed_status(self): + """With disable_task_status_change=True, status='failed' is stripped.""" + task = _make_task() + ctx = _make_ctx(tasks=[task]) + ctx.disable_task_status_change = True + + mock_updated = MagicMock() + mock_updated.order = 1 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + return_value=Result.resolve(mock_updated), + ): + result = await update_task_handler( + ctx, {"task_order": 1, "task_status": "failed"} + ) + assert result.ok() + assert len(ctx.learning_task_ids) == 0 + + @pytest.mark.asyncio + async def test_flag_allows_running_status(self): + """With disable_task_status_change=True, status='running' is NOT blocked.""" + task = _make_task() + ctx = _make_ctx(tasks=[task]) + ctx.disable_task_status_change = True + + mock_updated = MagicMock() + mock_updated.order = 1 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + return_value=Result.resolve(mock_updated), + ) as mock_update: + result = await update_task_handler( + ctx, {"task_order": 1, "task_status": "running"} + ) + assert result.ok() + _, kwargs = mock_update.call_args + assert kwargs.get("status") == "running" or mock_update.call_args[0][2] == "running" + + @pytest.mark.asyncio + async def test_flag_allows_description_update(self): + """With disable_task_status_change=True, description updates still work.""" + task = _make_task() + ctx = _make_ctx(tasks=[task]) + ctx.disable_task_status_change = True + + mock_updated = MagicMock() + mock_updated.order = 1 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + return_value=Result.resolve(mock_updated), + ) as mock_update: + result = await update_task_handler( + ctx, + {"task_order": 1, "task_status": "success", "task_description": "New desc"}, + ) + assert result.ok() + _, kwargs = mock_update.call_args + assert kwargs.get("patch_data") == {"task_description": "New desc"} + assert kwargs.get("status") is None + assert len(ctx.learning_task_ids) == 0 + + @pytest.mark.asyncio + async def test_flag_false_allows_success(self): + """With disable_task_status_change=False (default), success works normally.""" + task = _make_task() + ctx = _make_ctx(tasks=[task]) + assert ctx.disable_task_status_change is False + + mock_updated = MagicMock() + mock_updated.order = 1 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + return_value=Result.resolve(mock_updated), + ): + result = await update_task_handler( + ctx, {"task_order": 1, "task_status": "success"} + ) + assert result.ok() + assert task.id in ctx.learning_task_ids + + @pytest.mark.asyncio + async def test_flag_blocks_success_but_no_learning_ids(self): + """With flag enabled, learning_task_ids stays empty even for success/failed.""" + task1 = _make_task(order=1) + task2 = _make_task(order=2) + ctx = _make_ctx(tasks=[task1, task2]) + ctx.disable_task_status_change = True + + mock_updated1 = MagicMock() + mock_updated1.order = 1 + mock_updated2 = MagicMock() + mock_updated2.order = 2 + + with patch( + "acontext_core.llm.tool.task_lib.update.TD.update_task", + new_callable=AsyncMock, + side_effect=[ + Result.resolve(mock_updated1), + Result.resolve(mock_updated2), + ], + ): + await update_task_handler( + ctx, {"task_order": 1, "task_status": "success"} + ) + await update_task_handler( + ctx, {"task_order": 2, "task_status": "failed"} + ) + assert len(ctx.learning_task_ids) == 0 + + # ============================================================================= # TaskCtx default tests # ============================================================================= diff --git a/src/server/core/tests/llm/test_task_agent_atomicity.py b/src/server/core/tests/llm/test_task_agent_atomicity.py index 18c4a6b9f..8132147c2 100644 --- a/src/server/core/tests/llm/test_task_agent_atomicity.py +++ b/src/server/core/tests/llm/test_task_agent_atomicity.py @@ -227,7 +227,7 @@ async def test_rebuild_sees_flushed_writes(self): # 2nd call: rebuild after insert_task sets USE_CTX=None (before_use_ctx=None) build_ctx_db_sessions = [] - async def fake_build_task_ctx(db_session, proj_id, sess_id, msgs, before_use_ctx=None): + async def fake_build_task_ctx(db_session, proj_id, sess_id, msgs, before_use_ctx=None, **kwargs): build_ctx_db_sessions.append(db_session) return TaskCtx( db_session=db_session,