Skip to content

Commit 676510e

Browse files
GenerQAQclaude
andcommitted
fix(api): add ownership checks and encryption threading for security review
- Add disk ownership validation in DownloadArtifact to prevent cross-project access - Thread UserKEK through agent skills Create/CreateFromTemplate for encryption - Validate session_id ownership in DownloadSessionAsset before allowing download Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c3d66f2 commit 676510e

8 files changed

Lines changed: 256 additions & 8 deletions

File tree

src/server/api/go/internal/bootstrap/container.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ func BuildContainer() *do.Injector {
356356
do.Provide(inj, func(i *do.Injector) (*handler.ArtifactHandler, error) {
357357
return handler.NewArtifactHandler(
358358
do.MustInvoke[service.ArtifactService](i),
359+
do.MustInvoke[repo.DiskRepo](i),
359360
do.MustInvoke[*config.Config](i),
360361
do.MustInvoke[*httpclient.CoreClient](i),
361362
do.MustInvoke[*blob.S3Deps](i),

src/server/api/go/internal/modules/handler/agent_skills.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/gin-gonic/gin"
1212
"github.com/google/uuid"
1313
"github.com/memodb-io/Acontext/internal/infra/httpclient"
14+
"github.com/memodb-io/Acontext/internal/middleware"
1415
"github.com/memodb-io/Acontext/internal/modules/model"
1516
"github.com/memodb-io/Acontext/internal/modules/serializer"
1617
"github.com/memodb-io/Acontext/internal/modules/service"
@@ -97,6 +98,7 @@ func (h *AgentSkillsHandler) CreateAgentSkill(c *gin.Context) {
9798
UserID: userID,
9899
ZipFile: fileHeader,
99100
Meta: meta,
101+
UserKEK: middleware.GetUserKEK(c),
100102
})
101103
if err != nil {
102104
// Check if error is a validation error (SKILL.md related)

src/server/api/go/internal/modules/handler/artifact.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/memodb-io/Acontext/internal/middleware"
1717
"github.com/memodb-io/Acontext/internal/infra/httpclient"
1818
"github.com/memodb-io/Acontext/internal/modules/model"
19+
"github.com/memodb-io/Acontext/internal/modules/repo"
1920
"github.com/memodb-io/Acontext/internal/modules/serializer"
2021
"github.com/memodb-io/Acontext/internal/modules/service"
2122
"github.com/memodb-io/Acontext/internal/pkg/utils/fileparser"
@@ -24,13 +25,14 @@ import (
2425

2526
type ArtifactHandler struct {
2627
svc service.ArtifactService
28+
diskRepo repo.DiskRepo
2729
config *config.Config
2830
coreClient *httpclient.CoreClient
2931
s3 *blob.S3Deps
3032
}
3133

32-
func NewArtifactHandler(s service.ArtifactService, cfg *config.Config, coreClient *httpclient.CoreClient, s3 *blob.S3Deps) *ArtifactHandler {
33-
return &ArtifactHandler{svc: s, config: cfg, coreClient: coreClient, s3: s3}
34+
func NewArtifactHandler(s service.ArtifactService, diskRepo repo.DiskRepo, cfg *config.Config, coreClient *httpclient.CoreClient, s3 *blob.S3Deps) *ArtifactHandler {
35+
return &ArtifactHandler{svc: s, diskRepo: diskRepo, config: cfg, coreClient: coreClient, s3: s3}
3436
}
3537

3638
type CreateArtifactReq struct {
@@ -296,6 +298,12 @@ type DownloadArtifactReq struct {
296298
// @Success 200 "File content"
297299
// @Router /disk/{disk_id}/artifact/download [get]
298300
func (h *ArtifactHandler) DownloadArtifact(c *gin.Context) {
301+
project, ok := c.MustGet("project").(*model.Project)
302+
if !ok {
303+
c.JSON(http.StatusBadRequest, serializer.ParamErr("", errors.New("project not found")))
304+
return
305+
}
306+
299307
req := DownloadArtifactReq{}
300308
if err := c.ShouldBind(&req); err != nil {
301309
c.JSON(http.StatusBadRequest, serializer.ParamErr("", err))
@@ -308,6 +316,12 @@ func (h *ArtifactHandler) DownloadArtifact(c *gin.Context) {
308316
return
309317
}
310318

319+
// Verify disk belongs to the authenticated project
320+
if _, err := h.diskRepo.GetByProjectAndID(c.Request.Context(), project.ID, diskID); err != nil {
321+
c.JSON(http.StatusForbidden, serializer.Err(http.StatusForbidden, "access denied: disk does not belong to this project", nil))
322+
return
323+
}
324+
311325
filePath, filename := path.SplitFilePath(req.FilePath)
312326
if err := path.ValidatePath(filePath); err != nil {
313327
c.JSON(http.StatusBadRequest, serializer.ParamErr("invalid path", err))

src/server/api/go/internal/modules/handler/artifact_test.go

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/google/uuid"
1616
"github.com/memodb-io/Acontext/internal/config"
1717
"github.com/memodb-io/Acontext/internal/modules/model"
18+
"github.com/memodb-io/Acontext/internal/modules/repo"
1819
"github.com/memodb-io/Acontext/internal/modules/serializer"
1920
"github.com/memodb-io/Acontext/internal/modules/service"
2021
"github.com/memodb-io/Acontext/internal/pkg/utils/fileparser"
@@ -133,6 +134,40 @@ func (m *MockArtifactService) CreateFromBytes(ctx context.Context, in service.Cr
133134
return args.Get(0).(*model.Artifact), args.Error(1)
134135
}
135136

137+
// MockDiskRepo is a mock implementation of DiskRepo
138+
type MockDiskRepo struct {
139+
mock.Mock
140+
}
141+
142+
func (m *MockDiskRepo) Create(ctx context.Context, d *model.Disk) error {
143+
args := m.Called(ctx, d)
144+
return args.Error(0)
145+
}
146+
147+
func (m *MockDiskRepo) Delete(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) error {
148+
args := m.Called(ctx, projectID, diskID)
149+
return args.Error(0)
150+
}
151+
152+
func (m *MockDiskRepo) GetByProjectAndID(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) (*model.Disk, error) {
153+
args := m.Called(ctx, projectID, diskID)
154+
if args.Get(0) == nil {
155+
return nil, args.Error(1)
156+
}
157+
return args.Get(0).(*model.Disk), args.Error(1)
158+
}
159+
160+
func (m *MockDiskRepo) ListWithCursor(ctx context.Context, projectID uuid.UUID, userIdentifier string, afterCreatedAt time.Time, afterID uuid.UUID, limit int, timeDesc bool) ([]*model.Disk, error) {
161+
args := m.Called(ctx, projectID, userIdentifier, afterCreatedAt, afterID, limit, timeDesc)
162+
if args.Get(0) == nil {
163+
return nil, args.Error(1)
164+
}
165+
return args.Get(0).([]*model.Disk), args.Error(1)
166+
}
167+
168+
// Verify MockDiskRepo implements repo.DiskRepo
169+
var _ repo.DiskRepo = (*MockDiskRepo)(nil)
170+
136171
// createTestConfig creates a test config with default artifact settings
137172
func createTestConfig(maxUploadSizeBytes int64) *config.Config {
138173
return &config.Config{
@@ -260,7 +295,7 @@ func TestArtifactHandler_UpsertArtifact(t *testing.T) {
260295
tt.mockSetup(mockService, tt.diskID, projectID)
261296

262297
testConfig := createTestConfig(tt.maxUploadSize)
263-
handler := NewArtifactHandler(mockService, testConfig, nil, nil)
298+
handler := NewArtifactHandler(mockService, nil, testConfig, nil, nil)
264299

265300
// Create multipart form data
266301
body := &bytes.Buffer{}
@@ -356,7 +391,7 @@ func TestArtifactHandler_DeleteArtifact(t *testing.T) {
356391
tt.mockSetup(mockService, tt.diskID, tt.filePath, projectID)
357392

358393
testConfig := createDefaultTestConfig() // Default 16MB
359-
handler := NewArtifactHandler(mockService, testConfig, nil, nil)
394+
handler := NewArtifactHandler(mockService, nil, testConfig, nil, nil)
360395

361396
// Create request with query parameters
362397
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/disk/%s/artifact?file_path=%s", tt.diskID, tt.filePath), nil)
@@ -482,7 +517,7 @@ func TestArtifactHandler_UpdateArtifact(t *testing.T) {
482517
tt.mockSetup(mockService, tt.diskID)
483518

484519
testConfig := createDefaultTestConfig() // Default 16MB
485-
handler := NewArtifactHandler(mockService, testConfig, nil, nil)
520+
handler := NewArtifactHandler(mockService, nil, testConfig, nil, nil)
486521

487522
// Create JSON request body
488523
requestBody := map[string]string{
@@ -629,7 +664,7 @@ func TestArtifactHandler_GetArtifact(t *testing.T) {
629664
tt.mockSetup(mockService, tt.diskID, tt.filePath)
630665

631666
testConfig := createDefaultTestConfig() // Default 16MB
632-
handler := NewArtifactHandler(mockService, testConfig, nil, nil)
667+
handler := NewArtifactHandler(mockService, nil, testConfig, nil, nil)
633668

634669
// Create request with query parameters
635670
url := fmt.Sprintf("/disk/%s/artifact?file_path=%s", tt.diskID, tt.filePath)
@@ -750,7 +785,7 @@ func TestArtifactHandler_GrepArtifacts(t *testing.T) {
750785
mockSvc := new(MockArtifactService)
751786
tt.setupMock(mockSvc)
752787

753-
handler := NewArtifactHandler(mockSvc, createTestConfig(10*1024*1024), nil, nil)
788+
handler := NewArtifactHandler(mockSvc, nil, createTestConfig(10*1024*1024), nil, nil)
754789

755790
w := httptest.NewRecorder()
756791
c, _ := gin.CreateTestContext(w)
@@ -835,7 +870,7 @@ func TestArtifactHandler_GlobArtifacts(t *testing.T) {
835870
mockSvc := new(MockArtifactService)
836871
tt.setupMock(mockSvc)
837872

838-
handler := NewArtifactHandler(mockSvc, createTestConfig(10*1024*1024), nil, nil)
873+
handler := NewArtifactHandler(mockSvc, nil, createTestConfig(10*1024*1024), nil, nil)
839874

840875
w := httptest.NewRecorder()
841876
c, _ := gin.CreateTestContext(w)
@@ -857,3 +892,74 @@ func TestArtifactHandler_GlobArtifacts(t *testing.T) {
857892
})
858893
}
859894
}
895+
896+
func TestArtifactHandler_DownloadArtifact(t *testing.T) {
897+
gin.SetMode(gin.TestMode)
898+
899+
t.Run("returns 403 when disk belongs to different project", func(t *testing.T) {
900+
mockService := new(MockArtifactService)
901+
mockDiskRepo := new(MockDiskRepo)
902+
903+
projectID := uuid.New()
904+
diskID := uuid.New()
905+
906+
// Disk not found for this project
907+
mockDiskRepo.On("GetByProjectAndID", mock.Anything, projectID, diskID).
908+
Return(nil, fmt.Errorf("record not found"))
909+
910+
handler := NewArtifactHandler(mockService, mockDiskRepo, createDefaultTestConfig(), nil, nil)
911+
912+
w := httptest.NewRecorder()
913+
c, _ := gin.CreateTestContext(w)
914+
c.Set("project", &model.Project{ID: projectID})
915+
c.Request = httptest.NewRequest("GET", "/disk/"+diskID.String()+"/artifact/download?file_path=/test/file.txt", nil)
916+
c.Params = gin.Params{{Key: "disk_id", Value: diskID.String()}}
917+
918+
handler.DownloadArtifact(c)
919+
920+
assert.Equal(t, http.StatusForbidden, w.Code)
921+
mockService.AssertNotCalled(t, "GetByPath")
922+
mockDiskRepo.AssertExpectations(t)
923+
})
924+
925+
t.Run("succeeds when disk belongs to authenticated project", func(t *testing.T) {
926+
mockService := new(MockArtifactService)
927+
mockDiskRepo := new(MockDiskRepo)
928+
929+
projectID := uuid.New()
930+
diskID := uuid.New()
931+
932+
// Disk found for this project
933+
mockDiskRepo.On("GetByProjectAndID", mock.Anything, projectID, diskID).
934+
Return(&model.Disk{ID: diskID, ProjectID: projectID}, nil)
935+
936+
artifact := &model.Artifact{
937+
ID: uuid.New(),
938+
DiskID: diskID,
939+
Path: "/test/",
940+
Filename: "file.txt",
941+
AssetMeta: datatypes.NewJSONType(model.Asset{
942+
Bucket: "test-bucket",
943+
S3Key: "test-key",
944+
MIME: "text/plain",
945+
}),
946+
}
947+
mockService.On("GetByPath", mock.Anything, diskID, "/test/", "file.txt").Return(artifact, nil)
948+
mockService.On("DownloadRawContent", mock.Anything, artifact).Return([]byte("content"), "text/plain", nil)
949+
950+
handler := NewArtifactHandler(mockService, mockDiskRepo, createDefaultTestConfig(), nil, nil)
951+
952+
w := httptest.NewRecorder()
953+
c, _ := gin.CreateTestContext(w)
954+
c.Set("project", &model.Project{ID: projectID})
955+
c.Request = httptest.NewRequest("GET", "/disk/"+diskID.String()+"/artifact/download?file_path=/test/file.txt", nil)
956+
c.Params = gin.Params{{Key: "disk_id", Value: diskID.String()}}
957+
958+
handler.DownloadArtifact(c)
959+
960+
assert.Equal(t, http.StatusOK, w.Code)
961+
assert.Equal(t, "content", w.Body.String())
962+
mockDiskRepo.AssertExpectations(t)
963+
mockService.AssertExpectations(t)
964+
})
965+
}

src/server/api/go/internal/modules/handler/session.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,23 @@ func (h *SessionHandler) DownloadSessionAsset(c *gin.Context) {
579579
return
580580
}
581581

582+
// Validate session_id from path and verify it belongs to the project
583+
sessionID, err := uuid.Parse(c.Param("session_id"))
584+
if err != nil {
585+
c.JSON(http.StatusBadRequest, serializer.ParamErr("invalid session_id", err))
586+
return
587+
}
588+
589+
session, err := h.svc.GetByID(c.Request.Context(), &model.Session{ID: sessionID})
590+
if err != nil {
591+
c.JSON(http.StatusNotFound, serializer.Err(http.StatusNotFound, "session not found", nil))
592+
return
593+
}
594+
if session.ProjectID != project.ID {
595+
c.JSON(http.StatusForbidden, serializer.Err(http.StatusForbidden, "access denied: session does not belong to this project", nil))
596+
return
597+
}
598+
582599
s3Key := c.Query("s3_key")
583600
if s3Key == "" {
584601
c.JSON(http.StatusBadRequest, serializer.ParamErr("", errors.New("s3_key is required")))

src/server/api/go/internal/modules/handler/session_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4091,3 +4091,98 @@ func TestSessionHandler_CopySession_InternalError(t *testing.T) {
40914091

40924092
mockService.AssertExpectations(t)
40934093
}
4094+
4095+
func TestSessionHandler_DownloadSessionAsset(t *testing.T) {
4096+
gin.SetMode(gin.TestMode)
4097+
4098+
t.Run("returns 400 for invalid session_id", func(t *testing.T) {
4099+
mockService := new(MockSessionService)
4100+
handler := NewSessionHandler(mockService, nil, getMockSessionCoreClient())
4101+
4102+
projectID := uuid.New()
4103+
4104+
w := httptest.NewRecorder()
4105+
c, _ := gin.CreateTestContext(w)
4106+
c.Set("project", &model.Project{ID: projectID})
4107+
c.Request = httptest.NewRequest("GET", "/session/invalid-uuid/asset/download?s3_key=assets/"+projectID.String()+"/file.txt", nil)
4108+
c.Params = gin.Params{{Key: "session_id", Value: "invalid-uuid"}}
4109+
4110+
handler.DownloadSessionAsset(c)
4111+
4112+
assert.Equal(t, http.StatusBadRequest, w.Code)
4113+
})
4114+
4115+
t.Run("returns 404 when session not found", func(t *testing.T) {
4116+
mockService := new(MockSessionService)
4117+
handler := NewSessionHandler(mockService, nil, getMockSessionCoreClient())
4118+
4119+
projectID := uuid.New()
4120+
sessionID := uuid.New()
4121+
4122+
mockService.On("GetByID", mock.Anything, mock.MatchedBy(func(s *model.Session) bool {
4123+
return s.ID == sessionID
4124+
})).Return(nil, errors.New("not found"))
4125+
4126+
w := httptest.NewRecorder()
4127+
c, _ := gin.CreateTestContext(w)
4128+
c.Set("project", &model.Project{ID: projectID})
4129+
c.Request = httptest.NewRequest("GET", "/session/"+sessionID.String()+"/asset/download?s3_key=assets/"+projectID.String()+"/file.txt", nil)
4130+
c.Params = gin.Params{{Key: "session_id", Value: sessionID.String()}}
4131+
4132+
handler.DownloadSessionAsset(c)
4133+
4134+
assert.Equal(t, http.StatusNotFound, w.Code)
4135+
mockService.AssertExpectations(t)
4136+
})
4137+
4138+
t.Run("returns 403 when session belongs to different project", func(t *testing.T) {
4139+
mockService := new(MockSessionService)
4140+
handler := NewSessionHandler(mockService, nil, getMockSessionCoreClient())
4141+
4142+
projectID := uuid.New()
4143+
otherProjectID := uuid.New()
4144+
sessionID := uuid.New()
4145+
4146+
mockService.On("GetByID", mock.Anything, mock.MatchedBy(func(s *model.Session) bool {
4147+
return s.ID == sessionID
4148+
})).Return(&model.Session{ID: sessionID, ProjectID: otherProjectID}, nil)
4149+
4150+
w := httptest.NewRecorder()
4151+
c, _ := gin.CreateTestContext(w)
4152+
c.Set("project", &model.Project{ID: projectID})
4153+
c.Request = httptest.NewRequest("GET", "/session/"+sessionID.String()+"/asset/download?s3_key=assets/"+projectID.String()+"/file.txt", nil)
4154+
c.Params = gin.Params{{Key: "session_id", Value: sessionID.String()}}
4155+
4156+
handler.DownloadSessionAsset(c)
4157+
4158+
assert.Equal(t, http.StatusForbidden, w.Code)
4159+
mockService.AssertExpectations(t)
4160+
})
4161+
4162+
t.Run("succeeds with valid session_id and matching project", func(t *testing.T) {
4163+
mockService := new(MockSessionService)
4164+
handler := NewSessionHandler(mockService, nil, getMockSessionCoreClient())
4165+
4166+
projectID := uuid.New()
4167+
sessionID := uuid.New()
4168+
s3Key := "assets/" + projectID.String() + "/file.txt"
4169+
4170+
mockService.On("GetByID", mock.Anything, mock.MatchedBy(func(s *model.Session) bool {
4171+
return s.ID == sessionID
4172+
})).Return(&model.Session{ID: sessionID, ProjectID: projectID}, nil)
4173+
4174+
mockService.On("DownloadAsset", mock.Anything, s3Key).Return([]byte("file-content"), nil)
4175+
4176+
w := httptest.NewRecorder()
4177+
c, _ := gin.CreateTestContext(w)
4178+
c.Set("project", &model.Project{ID: projectID})
4179+
c.Request = httptest.NewRequest("GET", "/session/"+sessionID.String()+"/asset/download?s3_key="+s3Key, nil)
4180+
c.Params = gin.Params{{Key: "session_id", Value: sessionID.String()}}
4181+
4182+
handler.DownloadSessionAsset(c)
4183+
4184+
assert.Equal(t, http.StatusOK, w.Code)
4185+
assert.Equal(t, "file-content", w.Body.String())
4186+
mockService.AssertExpectations(t)
4187+
})
4188+
}

src/server/api/go/internal/modules/repo/disk.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
type DiskRepo interface {
1414
Create(ctx context.Context, d *model.Disk) error
1515
Delete(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) error
16+
GetByProjectAndID(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) (*model.Disk, error)
1617
ListWithCursor(ctx context.Context, projectID uuid.UUID, userIdentifier string, afterCreatedAt time.Time, afterID uuid.UUID, limit int, timeDesc bool) ([]*model.Disk, error)
1718
}
1819

@@ -29,6 +30,14 @@ func (r *diskRepo) Create(ctx context.Context, d *model.Disk) error {
2930
return r.db.WithContext(ctx).Create(d).Error
3031
}
3132

33+
func (r *diskRepo) GetByProjectAndID(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) (*model.Disk, error) {
34+
var disk model.Disk
35+
if err := r.db.WithContext(ctx).Where("id = ? AND project_id = ?", diskID, projectID).First(&disk).Error; err != nil {
36+
return nil, err
37+
}
38+
return &disk, nil
39+
}
40+
3241
func (r *diskRepo) Delete(ctx context.Context, projectID uuid.UUID, diskID uuid.UUID) error {
3342
// Use transaction to ensure atomicity
3443
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {

0 commit comments

Comments
 (0)