diff --git a/adk/middlewares/automemory/automemory.go b/adk/middlewares/automemory/automemory.go
index a0d085385..6a33c7283 100644
--- a/adk/middlewares/automemory/automemory.go
+++ b/adk/middlewares/automemory/automemory.go
@@ -20,20 +20,16 @@ package automemory
import (
"context"
- "encoding/json"
"fmt"
"path/filepath"
"sort"
"strings"
"sync"
- "time"
"github.com/slongfield/pyfmt"
- "gopkg.in/yaml.v3"
"github.com/cloudwego/eino/adk"
ainternal "github.com/cloudwego/eino/adk/middlewares/automemory/internal"
- adkfs "github.com/cloudwego/eino/adk/middlewares/filesystem"
fsmw "github.com/cloudwego/eino/adk/middlewares/filesystem"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
@@ -45,12 +41,23 @@ func init() {
}
type Config[M adk.MessageType] struct {
- MemoryDirectory string
+ // MemoryStores defines the persistent memory stores exposed to automemory.
+ // Required. At least one store must be configured.
+ MemoryStores []MemoryStore
+ // MemoryBackend is the storage backend used by all MemoryStores.
+ // Required. Store paths are resolved against this backend and bounded per store.
MemoryBackend Backend
+ // GenInstruction returns the auto memory policy block appended to the system prompt.
+ // Use it to customize memory read/write strength and criteria. The framework always
+ // appends the memory store manifest and memory indexes after this block.
+ // Optional. Defaults to the built-in auto memory instruction.
+ GenInstruction func(ctx context.Context) (string, error)
+
// Model is the default model used by topic selection and memory extraction.
// Per-read/per-write overrides can be configured in Read.Model / Write.Model.
+ // Optional. Defaults to nil; topic selection and extraction must then provide their own models.
Model model.BaseModel[M]
// Read controls how memories are loaded and injected.
@@ -71,6 +78,20 @@ type Config[M adk.MessageType] struct {
OnError func(ctx context.Context, stage ErrorStage, err error)
}
+type MemoryStore struct {
+ // Path is the root path of this memory store.
+ // Required. Relative paths are resolved against the process working directory.
+ Path string
+
+ // Name is the display name and relative path prefix used to disambiguate this store.
+ // Optional. Defaults to the base name of Path.
+ Name string
+
+ // Description describes the purpose of this memory store in the system prompt manifest.
+ // Optional. Defaults to empty.
+ Description string
+}
+
type ReadMode string
const (
@@ -84,12 +105,8 @@ type ReadConfig[M adk.MessageType] struct {
// Model is used for topic selection. Defaults to Config.Model.
Model model.BaseModel[M]
- // Instruction overrides the default auto memory instruction block appended to system prompt.
- // Optional.
- Instruction *string
-
- // Index controls how MEMORY.md is loaded into system prompt.
- // Optional.
+ // Index controls whether and how MEMORY.md is loaded into system prompt.
+ // Optional. Defaults to enabled with MEMORY.md as the index file.
Index *IndexConfig
// TopicSelection controls the "LLM select topics" path.
@@ -99,13 +116,25 @@ type ReadConfig[M adk.MessageType] struct {
}
type IndexConfig struct {
+ // EnableMemoryIndex controls whether MEMORY.md is used as a memory index.
+ // Optional. Defaults to true when nil.
+ EnableMemoryIndex *bool
+
+ // FileName is the index file name under each memory store.
+ // Optional. Defaults to MEMORY.md.
FileName string
+
+ // MaxLines caps index content injected into system prompt.
+ // Optional. Defaults to package default.
MaxLines int
+
+ // MaxBytes caps index content injected into system prompt.
+ // Optional. Defaults to package default.
MaxBytes int
}
type TopicSelectionConfig struct {
- // CandidateGlob is matched against the RELATIVE path under MemoryDirectory.
+ // CandidateGlob is matched against the RELATIVE path under each memory store.
// Example: "**/*.md"
CandidateGlob string
CandidateLimit int
@@ -114,8 +143,12 @@ type TopicSelectionConfig struct {
TopK int
+ // MaxLines caps single topic memory file read lines.
MaxLines int
+ // MaxBytes caps single topic memory file read bytes.
MaxBytes int
+ // MaxTotalBytes caps the total rendered topic memory reminder across all stores.
+ MaxTotalBytes int
}
type WriteMode string
@@ -152,8 +185,7 @@ type middleware[M adk.MessageType] struct {
cfg *Config[M]
- resolvedMemoryDirectory string
- boundedMemoryBackend Backend
+ memoryStores []runtimeMemoryStore
topicSelectionModel model.BaseModel[M]
extractionHandler adk.TypedChatModelAgentMiddleware[M]
@@ -174,8 +206,7 @@ type selectionFuture struct {
type ctxKeySelectionFuture struct{}
const (
- memoryExtraKey = "__eino_automemory__"
- instructionMarker = ""
+ memoryExtraKey = "__eino_automemory__"
)
type memoryExtra struct {
@@ -183,6 +214,13 @@ type memoryExtra struct {
Cursor int
}
+type runtimeMemoryStore struct {
+ MemoryStore
+
+ Path string
+ Backend *ainternal.FSBackend
+}
+
// New creates an automemory middleware from the provided configuration.
func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedChatModelAgentMiddleware[M], error) {
if config == nil {
@@ -190,19 +228,11 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
}
cfg := cloneConfig(config)
- if cfg.MemoryDirectory == "" || cfg.MemoryBackend == nil {
+ if cfg.MemoryBackend == nil {
return nil, fmt.Errorf("auto memory config: invalid")
}
- resolvedMemoryDir, err := ainternal.ResolveMemoryDir(cfg.MemoryDirectory)
- if err != nil {
- return nil, fmt.Errorf("auto memory config: resolve memory directory: %w", err)
- }
- boundedMemoryBackend, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{
- BaseDir: resolvedMemoryDir,
- NotFoundAsContent: true,
- ErrorPrefix: "memory backend",
- })
+ stores, err := buildRuntimeMemoryStores(cfg)
if err != nil {
return nil, err
}
@@ -214,8 +244,7 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
m := &middleware[M]{
TypedBaseChatModelAgentMiddleware: adk.TypedBaseChatModelAgentMiddleware[M]{},
cfg: cfg,
- resolvedMemoryDirectory: resolvedMemoryDir,
- boundedMemoryBackend: boundedMemoryBackend,
+ memoryStores: stores,
coordination: cfg.Coordination,
}
@@ -228,12 +257,8 @@ func New[M adk.MessageType](ctx context.Context, config *Config[M]) (adk.TypedCh
}
if cfg.Write.Mode != WriteModeDisabled && cfg.Write.Model != nil {
- writeFSBackend, err := newFSBackend(cfg.MemoryBackend, resolvedMemoryDir)
- if err != nil {
- return nil, err
- }
fileSystemMiddleware, err := fsmw.NewTyped[M](ctx, &fsmw.MiddlewareConfig{
- Backend: writeFSBackend,
+ Backend: newMultiStoreBackend(stores),
LsToolConfig: &fsmw.ToolConfig{Disable: true},
GrepToolConfig: &fsmw.ToolConfig{Disable: true},
})
@@ -257,7 +282,8 @@ func (m *middleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAg
if nRunCtx.AgentInput != nil && len(nRunCtx.AgentInput.Messages) > 0 && m.coordination != nil && m.coordination.Coordinator != nil {
if sessionID, err := m.resolveSessionID(ctx, &adk.TypedChatModelAgentState[M]{Messages: nRunCtx.AgentInput.Messages}); err == nil && sessionID != "" {
localCursor := getWriteCursorFromMessages(nRunCtx.AgentInput.Messages)
- if remoteCursor, ok, err := m.coordination.Coordinator.GetCursor(ctx, sessionID); err == nil && ok && remoteCursor > localCursor {
+ coordKey := m.coordinatorKey(sessionID)
+ if remoteCursor, ok, err := getCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey); err == nil && ok && remoteCursor > localCursor {
st := markWriteCursor(&adk.TypedChatModelAgentState[M]{Messages: nRunCtx.AgentInput.Messages}, remoteCursor)
if st != nil {
nRunCtx.AgentInput = &adk.TypedAgentInput[M]{
@@ -269,37 +295,51 @@ func (m *middleware[M]) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAg
}
}
- // System-prompt injection and transcript-memory injection are idempotent,
- // but they are independent concerns: instruction should be rebuilt each run
- // unless this exact instruction already carries the marker, while transcript
- // memory messages should only be skipped when a real automemory reminder is
- // already present in the message list.
- // 1) System prompt: inject auto memory instruction + MEMORY.md content (best-effort).
- if !hasInstructionInjected(nRunCtx.Instruction) {
- nRunCtx.Instruction = m.injectIndexIntoInstruction(ctx, nRunCtx.Instruction)
+ // 1) System prompt: inject stable auto memory instruction and store manifest (best-effort).
+ instruction, err := m.renderInstruction(ctx, nRunCtx.Instruction)
+ if err != nil {
+ m.onErr(ctx, OnErrorStageRenderInstruction, err)
+ } else {
+ nRunCtx.Instruction = instruction
}
- // Skip topic memories injection if they already exist.
- if nRunCtx.AgentInput == nil || alreadyInjected(nRunCtx.AgentInput.Messages) {
+ if nRunCtx.AgentInput == nil || len(nRunCtx.AgentInput.Messages) == 0 {
return ctx, &nRunCtx, nil
}
- // 2) Topic memories: sync mode injects before the user's query.
- if m.cfg.Read.Mode == ReadModeSync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil {
+ var reminders []M
+
+ // 2) Memory index reminder: inject dynamic MEMORY.md content before the user's query.
+ if !hasMemoryIndexInjected(nRunCtx.AgentInput.Messages) {
+ indexMsg, err := m.buildMemoryIndexMessage(ctx)
+ if err != nil {
+ m.onErr(ctx, OnErrorStageRenderInstruction, err)
+ } else if !isNilMessage(indexMsg) {
+ m.sendTopicMemoryEvent(ctx, nRunCtx.AgentInput.Messages, indexMsg)
+ reminders = append(reminders, indexMsg)
+ }
+ }
+
+ // 3) Topic memories: sync mode selects from the original user query.
+ if !hasTopicMemoryInjected(nRunCtx.AgentInput.Messages) &&
+ m.cfg.Read.Mode == ReadModeSync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil {
memMsg, err := m.selectAndBuildTopicMemoryMessage(ctx, nRunCtx.AgentInput)
if err != nil {
m.onErr(ctx, OnErrorStageTopicSelectionSync, err)
- } else if memMsg != nil && nRunCtx.AgentInput != nil && len(nRunCtx.AgentInput.Messages) > 0 {
+ } else if !isNilMessage(memMsg) {
m.sendTopicMemoryEvent(ctx, nRunCtx.AgentInput.Messages, memMsg)
- msgs := append([]M{}, nRunCtx.AgentInput.Messages...)
- msgs = append(msgs, memMsg)
- nRunCtx.AgentInput = &adk.TypedAgentInput[M]{Messages: msgs, EnableStreaming: nRunCtx.AgentInput.EnableStreaming}
-
+ reminders = append(reminders, memMsg)
}
}
- // 3) Topic memories: async mode starts selection here (cannot use RunLocalValue in BeforeAgent).
- if m.cfg.Read.Mode == ReadModeAsync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil {
+ if len(reminders) > 0 {
+ msgs := insertMessagesBeforeLastUserQuery(nRunCtx.AgentInput.Messages, reminders)
+ nRunCtx.AgentInput = &adk.TypedAgentInput[M]{Messages: msgs, EnableStreaming: nRunCtx.AgentInput.EnableStreaming}
+ }
+
+ // 4) Topic memories: async mode starts selection here (cannot use RunLocalValue in BeforeAgent).
+ if !hasTopicMemoryInjected(nRunCtx.AgentInput.Messages) &&
+ m.cfg.Read.Mode == ReadModeAsync && m.cfg.Read.TopicSelection != nil && m.topicSelectionModel != nil {
if existing, _ := ctx.Value(ctxKeySelectionFuture{}).(*selectionFuture); existing == nil {
fut := &selectionFuture{done: make(chan struct{})}
ctx = context.WithValue(ctx, ctxKeySelectionFuture{}, fut)
@@ -382,207 +422,72 @@ func (m *middleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.
return ctx, &adk.TypedChatModelAgentState[M]{Messages: msgs}, nil
}
-func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
- if cfg.Read.Mode == "" {
- cfg.Read.Mode = ReadModeSync
- }
- if cfg.Read.Index == nil {
- cfg.Read.Index = &IndexConfig{}
- }
- if cfg.Read.Index.FileName == "" {
- cfg.Read.Index.FileName = memoryIndexFileName
- }
- if cfg.Read.Index.MaxLines <= 0 {
- cfg.Read.Index.MaxLines = defaultIndexMaxLines
- }
- if cfg.Read.Index.MaxBytes <= 0 {
- cfg.Read.Index.MaxBytes = defaultIndexMaxBytes
- }
- if cfg.Read.Model == nil {
- cfg.Read.Model = cfg.Model
- }
- if cfg.Read.TopicSelection == nil {
- cfg.Read.TopicSelection = &TopicSelectionConfig{}
- }
- if cfg.Read.TopicSelection.TopK <= 0 {
- cfg.Read.TopicSelection.TopK = defaultTopicTopK
- }
- if cfg.Read.TopicSelection.CandidateGlob == "" {
- cfg.Read.TopicSelection.CandidateGlob = CandidateGlobPattern
- }
- if cfg.Read.TopicSelection.CandidateLimit <= 0 {
- cfg.Read.TopicSelection.CandidateLimit = defaultCandidateLimit
- }
- if cfg.Read.TopicSelection.CandidatePreviewLines <= 0 {
- cfg.Read.TopicSelection.CandidatePreviewLines = defaultCandidatePreviewLine
- }
- if cfg.Read.TopicSelection.MaxLines <= 0 {
- cfg.Read.TopicSelection.MaxLines = defaultTopicMaxLines
- }
- if cfg.Read.TopicSelection.MaxBytes <= 0 {
- cfg.Read.TopicSelection.MaxBytes = defaultTopicMaxBytes
- }
-
- if cfg.Write == nil {
- cfg.Write = &WriteConfig[M]{Mode: WriteModeDisabled}
- }
- if cfg.Write.Mode == "" {
- cfg.Write.Mode = WriteModeDisabled
- }
- if cfg.Write.Model == nil {
- cfg.Write.Model = cfg.Model
- }
- if cfg.Write.MaxTurns <= 0 {
- cfg.Write.MaxTurns = defaultMemoryWriteMaxTurns
- }
-
- if cfg.Coordination == nil {
- cfg.Coordination = &CoordinationConfig[M]{}
- }
- if cfg.Coordination.Coordinator == nil {
- cfg.Coordination.Coordinator = NewLocalCoordinator()
- }
- if cfg.Coordination.LockTTL <= 0 {
- cfg.Coordination.LockTTL = 2 * time.Minute
- }
-}
-
-func cloneConfig[M adk.MessageType](cfg *Config[M]) *Config[M] {
- if cfg == nil {
- return nil
- }
-
- cp := *cfg
- if cfg.Read != nil {
- readCopy := *cfg.Read
- cp.Read = &readCopy
- if cfg.Read.Instruction != nil {
- instructionCopy := *cfg.Read.Instruction
- cp.Read.Instruction = &instructionCopy
- }
- if cfg.Read.Index != nil {
- indexCopy := *cfg.Read.Index
- cp.Read.Index = &indexCopy
- }
- if cfg.Read.TopicSelection != nil {
- topicSelectionCopy := *cfg.Read.TopicSelection
- cp.Read.TopicSelection = &topicSelectionCopy
- }
- }
- if cfg.Write != nil {
- writeCopy := *cfg.Write
- cp.Write = &writeCopy
- }
- if cfg.Coordination != nil {
- coordinationCopy := *cfg.Coordination
- cp.Coordination = &coordinationCopy
- }
- return &cp
-}
-
type topicSelectionResp struct {
SelectedMemories []string `json:"selected_memories"`
}
-func (m *middleware[M]) injectIndexIntoInstruction(ctx context.Context, baseInstruction string) string {
- memDir := m.resolvedMemoryDirectory
-
- var memDesc string
- if m.cfg.Read.Instruction != nil {
- memDesc = *m.cfg.Read.Instruction
- } else {
- s, err := pyfmt.Fmt(getDefaultMemoryInstruction(), map[string]any{"memory_dir": memDir})
+func (m *middleware[M]) renderInstruction(ctx context.Context, baseInstruction string) (string, error) {
+ enableIndex := m.memoryIndexEnabled()
+ memDesc := getDefaultMemoryInstruction(enableIndex)
+ if m.cfg.GenInstruction != nil {
+ custom, err := m.cfg.GenInstruction(ctx)
if err != nil {
- m.onErr(ctx, OnErrorStageRenderInstruction, err)
- return baseInstruction
+ return "", err
}
- memDesc = s
- }
-
- indexPath := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName)
- indexContent := ""
- totalLines := 0
-
- fc, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: indexPath})
- if err == nil && fc != nil {
- if isFileNotFoundContent(fc.Content) {
- indexContent = ""
- } else {
- indexContent = fc.Content
- totalLines = strings.Count(indexContent, "\n") + 1
+ if strings.TrimSpace(custom) != "" {
+ memDesc = custom
}
- } else {
- // Missing index is not fatal; keep empty.
- indexContent = ""
}
- sb := make([]string, 0, 5)
- sb = append(sb, memDesc)
- sb = append(sb, "## "+m.cfg.Read.Index.FileName)
- if strings.TrimSpace(indexContent) == "" {
- sb = append(sb, getAppendEmptyIndexTemplate())
- } else {
- truncatedMemoryIndex, _, truncated := linesOrSizeTrunc(indexContent, m.cfg.Read.Index.MaxLines, m.cfg.Read.Index.MaxBytes)
- sb = append(sb, truncatedMemoryIndex)
- if truncated {
- notify, err := pyfmt.Fmt(getAppendCurrentIndexTruncNotify(), map[string]any{
- "memory_lines": totalLines,
- })
- if err == nil {
- sb = append(sb, notify)
- }
- }
+ stores := make([]memoryStorePromptInfo, 0, len(m.memoryStores))
+ for _, store := range m.memoryStores {
+ stores = append(stores, memoryStorePromptInfo{
+ Name: store.displayName(),
+ Mount: store.Path,
+ Description: strings.TrimSpace(store.Description),
+ })
}
- return baseInstruction + "\n" + instructionMarker + "\n" + strings.Join(sb, "\n")
+ return buildSystemMemoryInstruction(baseInstruction, memDesc, stores)
}
-func linesOrSizeTrunc(content string, lines, size int) (newContent string, reason string, truncated bool) {
- linesTrunc := func(content string, lines int) {
- sp := strings.Split(content, "\n")
- if len(sp) > lines {
- newContent = strings.Join(sp[:lines], "\n")
- reason = fmt.Sprintf("first %d lines", lines)
- truncated = true
- } else {
- newContent = content
- }
+func (m *middleware[M]) buildMemoryIndexMessage(ctx context.Context) (M, error) {
+ if !m.memoryIndexEnabled() {
+ return nil, nil
}
+ stores := make([]memoryStorePromptInfo, 0, len(m.memoryStores))
+ hasIndex := false
+ for _, store := range m.memoryStores {
+ indexPath := filepath.Join(store.Path, m.cfg.Read.Index.FileName)
+ indexContent := ""
+ totalLines := 0
- sizeTrunc := func(content string, size int) {
- if len(content) > size {
- newContent = content[:size]
- reason = fmt.Sprintf("%d byte limit", size)
- truncated = true
- } else {
- newContent = content
+ fc, err := store.Backend.Read(ctx, &ReadRequest{FilePath: indexPath})
+ if err == nil && fc != nil && !isFileNotFoundContent(fc.Content) {
+ indexContent = fc.Content
+ totalLines = strings.Count(indexContent, "\n") + 1
}
+ truncatedMemoryIndex, _, truncated := linesOrSizeTrunc(indexContent, m.cfg.Read.Index.MaxLines, m.cfg.Read.Index.MaxBytes)
+ stores = append(stores, memoryStorePromptInfo{
+ Name: store.displayName(),
+ Mount: store.Path,
+ Description: strings.TrimSpace(store.Description),
+ Index: &memoryIndexPromptInfo{
+ FileName: m.cfg.Read.Index.FileName,
+ Path: indexPath,
+ Content: truncatedMemoryIndex,
+ Empty: strings.TrimSpace(indexContent) == "",
+ Truncated: truncated,
+ Lines: totalLines,
+ IncludeContent: true,
+ },
+ })
+ hasIndex = true
}
-
- if lines == 0 && size == 0 {
- return content, "", false
- } else if lines == 0 {
- sizeTrunc(content, size)
- } else if size == 0 {
- linesTrunc(content, lines)
- } else {
- linesTrunc(content, lines)
- sizeTrunc(newContent, size)
- }
- return
-}
-
-func isFileNotFoundContent(content string) bool {
- return strings.HasPrefix(strings.TrimSpace(content), "File not found: ")
-}
-
-func (m *middleware[M]) onErr(ctx context.Context, stage ErrorStage, err error) {
- if err == nil {
- return
- }
- if m.cfg != nil && m.cfg.OnError != nil {
- m.cfg.OnError(ctx, stage, err)
+ if !hasIndex {
+ return nil, nil
}
+ return newMemoryIndexMessage[M](buildMemoryIndexReminder(stores)), nil
}
type topicFrontmatter struct {
@@ -592,27 +497,21 @@ type topicFrontmatter struct {
}
type topicCandidateBundle struct {
- AbsPath string
- RelPath string
- Info FileInfo
+ StoreName string
+ StorePath string
+ Backend Backend
+ Key string
+ AbsPath string
+ RelPath string
+ Info FileInfo
}
-func parseFrontmatter(md string) (fm topicFrontmatter, ok bool) {
- // Only consider YAML frontmatter at the beginning.
- s := strings.TrimLeft(md, "\ufeff \t\r\n")
- if !strings.HasPrefix(s, "---\n") && !strings.HasPrefix(s, "---\r\n") {
- return topicFrontmatter{}, false
- }
- // Find the next delimiter.
- parts := strings.SplitN(s, "\n---", 2)
- if len(parts) != 2 {
- return topicFrontmatter{}, false
- }
- yml := strings.TrimPrefix(parts[0], "---\n")
- if err := yaml.Unmarshal([]byte(yml), &fm); err != nil {
- return topicFrontmatter{}, false
- }
- return fm, true
+type topicMemoryPromptInfo struct {
+ StoreName string
+ StorePath string
+ Path string
+ Saved string
+ Content string
}
func (m *middleware[M]) selectAndBuildTopicMemoryMessage(ctx context.Context, agentIn *adk.TypedAgentInput[M]) (M, error) {
@@ -632,26 +531,12 @@ func (m *middleware[M]) selectAndBuildTopicMemoryMessage(ctx context.Context, ag
return nil, err
}
- rendered := m.renderTopicMemories(ctx, selected, relToBundle, topK)
- if len(rendered) == 0 {
+ topics := m.renderTopicMemories(ctx, selected, relToBundle, topK)
+ if len(topics) == 0 {
return nil, nil
}
- return newMemoryMessage[M]("\n" + strings.Join(rendered, "\n\n")), nil
-}
-
-func (m *middleware[M]) lastUserMessage(agentIn *adk.TypedAgentInput[M]) (M, bool) {
- if agentIn == nil || len(agentIn.Messages) == 0 {
- return nil, false
- }
- if m.cfg.Read.TopicSelection == nil || m.topicSelectionModel == nil {
- return nil, false
- }
- last := agentIn.Messages[len(agentIn.Messages)-1]
- if isNilMessage(last) || !isUserRole(last) {
- return nil, false
- }
- return last, true
+ return newMemoryMessage[M]("\n" + buildTopicMemoryReminder(topics)), nil
}
func (m *middleware[M]) listTopicCandidates(ctx context.Context) (map[string]topicCandidateBundle, []string, []string, error) {
@@ -669,37 +554,52 @@ func (m *middleware[M]) listTopicCandidates(ctx context.Context) (map[string]top
if !ok {
continue
}
- relToBundle[bundle.RelPath] = bundle
+ relToBundle[bundle.Key] = bundle
available = append(available, manifestLine)
- orderedRel = append(orderedRel, bundle.RelPath)
+ orderedRel = append(orderedRel, bundle.Key)
}
return relToBundle, available, orderedRel, nil
}
-func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]FileInfo, error) {
- files, err := m.boundedMemoryBackend.GlobInfo(ctx, &GlobInfoRequest{
- Pattern: m.cfg.Read.TopicSelection.CandidateGlob,
- Path: m.resolvedMemoryDirectory,
- })
- if err != nil || len(files) == 0 {
- return nil, err
- }
-
- indexAbs := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName)
- candidates := make([]FileInfo, 0, len(files))
- for _, fi := range files {
- if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) {
- continue
+func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]topicCandidateBundle, error) {
+ var candidates []topicCandidateBundle
+ for _, store := range m.memoryStores {
+ files, err := store.Backend.GlobInfo(ctx, &GlobInfoRequest{
+ Pattern: m.cfg.Read.TopicSelection.CandidateGlob,
+ Path: store.Path,
+ })
+ if err != nil {
+ return nil, err
+ }
+ indexAbs := filepath.Join(store.Path, m.cfg.Read.Index.FileName)
+ for _, fi := range files {
+ if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) {
+ continue
+ }
+ rel, relErr := filepath.Rel(store.Path, fi.Path)
+ if relErr != nil {
+ rel = filepath.Base(fi.Path)
+ }
+ rel = filepath.ToSlash(rel)
+ key := filepath.ToSlash(filepath.Join(store.displayName(), rel))
+ candidates = append(candidates, topicCandidateBundle{
+ StoreName: store.displayName(),
+ StorePath: store.Path,
+ Backend: store.Backend,
+ Key: key,
+ AbsPath: fi.Path,
+ RelPath: rel,
+ Info: fi,
+ })
}
- candidates = append(candidates, fi)
}
if len(candidates) == 0 {
return nil, nil
}
sort.Slice(candidates, func(i, j int) bool {
- return parseRFC3339NanoBestEffort(candidates[i].ModifiedAt).After(parseRFC3339NanoBestEffort(candidates[j].ModifiedAt))
+ return parseRFC3339NanoBestEffort(candidates[i].Info.ModifiedAt).After(parseRFC3339NanoBestEffort(candidates[j].Info.ModifiedAt))
})
if len(candidates) > m.cfg.Read.TopicSelection.CandidateLimit {
candidates = candidates[:m.cfg.Read.TopicSelection.CandidateLimit]
@@ -707,15 +607,9 @@ func (m *middleware[M]) topicSelectionCandidates(ctx context.Context) ([]FileInf
return candidates, nil
}
-func (m *middleware[M]) buildTopicCandidateBundle(ctx context.Context, fi FileInfo) (topicCandidateBundle, string, bool) {
- rel, relErr := filepath.Rel(m.resolvedMemoryDirectory, fi.Path)
- if relErr != nil {
- rel = filepath.Base(fi.Path)
- }
- rel = filepath.ToSlash(rel)
-
- preview, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{
- FilePath: fi.Path,
+func (m *middleware[M]) buildTopicCandidateBundle(ctx context.Context, bundle topicCandidateBundle) (topicCandidateBundle, string, bool) {
+ preview, err := bundle.Backend.Read(ctx, &ReadRequest{
+ FilePath: bundle.AbsPath,
Limit: m.cfg.Read.TopicSelection.CandidatePreviewLines,
})
if err != nil || preview == nil || isFileNotFoundContent(preview.Content) {
@@ -723,40 +617,8 @@ func (m *middleware[M]) buildTopicCandidateBundle(ctx context.Context, fi FileIn
}
desc := describeTopicCandidate(preview.Content)
- manifestLine := fmt.Sprintf("- %s (saved %s): %s", rel, fi.ModifiedAt, desc)
- return topicCandidateBundle{AbsPath: fi.Path, RelPath: rel, Info: fi}, manifestLine, true
-}
-
-func describeTopicCandidate(content string) string {
- desc := ""
- if fm, ok := parseFrontmatter(content); ok {
- switch {
- case strings.TrimSpace(fm.Description) != "":
- desc = strings.TrimSpace(fm.Description)
- case strings.TrimSpace(fm.Name) != "":
- desc = strings.TrimSpace(fm.Name)
- }
- if strings.TrimSpace(fm.Type) != "" {
- if desc == "" {
- desc = "type=" + strings.TrimSpace(fm.Type)
- } else {
- desc = desc + " (type=" + strings.TrimSpace(fm.Type) + ")"
- }
- }
- }
- if desc == "" {
- snippet, _, _ := linesOrSizeTrunc(content, 3, 256)
- desc = strings.TrimSpace(snippet)
- }
- return desc
-}
-
-func (m *middleware[M]) topicSelectionTopK() int {
- topK := m.cfg.Read.TopicSelection.TopK
- if topK <= 0 {
- return defaultTopicTopK
- }
- return topK
+ manifestLine := fmt.Sprintf("- %s (store: %s, saved %s): %s", bundle.Key, bundle.StoreName, bundle.Info.ModifiedAt, desc)
+ return bundle, manifestLine, true
}
func (m *middleware[M]) selectTopicCandidates(
@@ -768,12 +630,10 @@ func (m *middleware[M]) selectTopicCandidates(
relToBundle map[string]topicCandidateBundle,
) ([]string, error) {
topK := m.topicSelectionTopK()
- if len(orderedRel) <= topK {
- return orderedRel, nil
- }
userMsg, err := pyfmt.Fmt(getTopicSelectionUserPrompt(), map[string]any{
"user_query": userQuery,
+ "top_k": topK,
"available_memories": strings.Join(available, "\n"),
"tools": strings.Join(collectToolNames(agentIn.Messages), ", "),
})
@@ -798,22 +658,14 @@ func (m *middleware[M]) selectTopicCandidates(
for k := range relToBundle {
valid[k] = struct{}{}
}
- return parseTopicSelectionFromToolCall(resp, valid)
-}
-
-func collectToolNames[M adk.MessageType](msgs []M) []string {
- dedupTools := make(map[string]struct{})
- for _, msg := range msgs {
- for _, name := range messageToolNames(msg) {
- dedupTools[name] = struct{}{}
- }
+ selected, err := parseTopicSelectionFromToolCall(resp, valid)
+ if err != nil {
+ return nil, err
}
- tools := make([]string, 0, len(dedupTools))
- for t := range dedupTools {
- tools = append(tools, t)
+ if len(selected) > topK {
+ return selected[:topK], nil
}
- sort.Strings(tools)
- return tools
+ return selected, nil
}
func (m *middleware[M]) renderTopicMemories(
@@ -821,12 +673,14 @@ func (m *middleware[M]) renderTopicMemories(
selected []string,
relToBundle map[string]topicCandidateBundle,
topK int,
-) []string {
+) []topicMemoryPromptInfo {
capHint := topK
if capHint > len(selected) {
capHint = len(selected)
}
- rendered := make([]string, 0, capHint)
+ rendered := make([]topicMemoryPromptInfo, 0, capHint)
+ totalBytes := 0
+ maxTotalBytes := m.cfg.Read.TopicSelection.MaxTotalBytes
for _, rel := range selected {
if len(rendered) >= topK {
break
@@ -835,19 +689,30 @@ func (m *middleware[M]) renderTopicMemories(
if !ok {
continue
}
- renderedContent, ok := m.renderTopicMemory(ctx, bundle)
+ topic, ok := m.renderTopicMemory(ctx, bundle)
if !ok {
continue
}
- rendered = append(rendered, renderedContent)
+ topicBytes := len(topic.Content) + len(topic.StoreName) + len(topic.StorePath) + len(topic.Path)
+ if maxTotalBytes > 0 && totalBytes+topicBytes > maxTotalBytes {
+ if len(rendered) == 0 {
+ if len(topic.Content) > maxTotalBytes {
+ topic.Content = topic.Content[:maxTotalBytes]
+ }
+ rendered = append(rendered, topic)
+ }
+ break
+ }
+ rendered = append(rendered, topic)
+ totalBytes += topicBytes
}
return rendered
}
-func (m *middleware[M]) renderTopicMemory(ctx context.Context, bundle topicCandidateBundle) (string, bool) {
- full, err := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: bundle.AbsPath})
+func (m *middleware[M]) renderTopicMemory(ctx context.Context, bundle topicCandidateBundle) (topicMemoryPromptInfo, bool) {
+ full, err := bundle.Backend.Read(ctx, &ReadRequest{FilePath: bundle.AbsPath})
if err != nil || full == nil || isFileNotFoundContent(full.Content) {
- return "", false
+ return topicMemoryPromptInfo{}, false
}
content, truncReason, truncated := linesOrSizeTrunc(full.Content, m.cfg.Read.TopicSelection.MaxLines, m.cfg.Read.TopicSelection.MaxBytes)
@@ -861,402 +726,13 @@ func (m *middleware[M]) renderTopicMemory(ctx context.Context, bundle topicCandi
}
}
- return fmt.Sprintf(
- "\nContents of %s (saved %s):\n\n%s\n",
- bundle.AbsPath,
- bundle.Info.ModifiedAt,
- content,
- ), true
-}
-
-func topicSelectionToolInfo() *schema.ToolInfo {
- return &schema.ToolInfo{
- Name: topicSelectionToolName,
- Desc: "Select which memory files to surface for the current query. Return selected_memories as RELATIVE paths (relative to the memory directory).",
- ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
- "selected_memories": {
- Type: schema.Array,
- Desc: "Relative paths of selected memory files, e.g. \"debugging.md\" or \"notes/patterns.md\".",
- Required: true,
- ElemInfo: &schema.ParameterInfo{Type: schema.String},
- },
- }),
- }
-}
-
-func parseTopicSelectionFromToolCall[M adk.MessageType](msg M, valid map[string]struct{}) ([]string, error) {
- toolCalls := messageToolCalls(msg)
- if len(toolCalls) == 0 {
- return nil, fmt.Errorf("no tool calls")
- }
- tc := toolCalls[0]
- if tc.Function.Name != topicSelectionToolName {
- return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name)
- }
- var parsed topicSelectionResp
- if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err != nil {
- return nil, err
- }
- out := normalizeSelected(parsed.SelectedMemories)
- // Filter to known candidates to avoid hallucinated paths.
- filtered := make([]string, 0, len(out))
- for _, p := range out {
- if _, ok := valid[p]; ok {
- filtered = append(filtered, p)
- }
- }
- return filtered, nil
-}
-
-func normalizeSelected(in []string) []string {
- out := make([]string, 0, len(in))
- seen := make(map[string]struct{}, len(in))
- for _, s := range in {
- s = strings.TrimSpace(s)
- s = strings.TrimPrefix(s, "./")
- s = filepath.ToSlash(s)
- if s == "" {
- continue
- }
- if _, ok := seen[s]; ok {
- continue
- }
- seen[s] = struct{}{}
- out = append(out, s)
- }
- return out
-}
-
-func isNilMessage[M adk.MessageType](msg M) bool {
- var zero M
- return any(msg) == any(zero)
-}
-
-func isUserRole[M adk.MessageType](msg M) bool {
- switch m := any(msg).(type) {
- case *schema.Message:
- return m != nil && m.Role == schema.User
- case *schema.AgenticMessage:
- return m != nil && m.Role == schema.AgenticRoleTypeUser
- default:
- panic("unreachable")
- }
-}
-
-func isAssistantRole[M adk.MessageType](msg M) bool {
- switch m := any(msg).(type) {
- case *schema.Message:
- return m != nil && m.Role == schema.Assistant
- case *schema.AgenticMessage:
- return m != nil && m.Role == schema.AgenticRoleTypeAssistant
- default:
- panic("unreachable")
- }
-}
-
-func userMessageTextContent[M adk.MessageType](msg M) string {
- switch m := any(msg).(type) {
- case *schema.Message:
- if m == nil {
- return ""
- }
- if len(m.UserInputMultiContent) == 0 {
- return m.Content
- }
- parts := make([]string, 0, len(m.UserInputMultiContent))
- for _, part := range m.UserInputMultiContent {
- if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
- parts = append(parts, part.Text)
- }
- }
- if len(parts) > 0 {
- return strings.Join(parts, "\n")
- }
- return m.Content
- case *schema.AgenticMessage:
- if m == nil {
- return ""
- }
- parts := make([]string, 0, len(m.ContentBlocks))
- for _, block := range m.ContentBlocks {
- if block != nil && block.UserInputText != nil {
- parts = append(parts, block.UserInputText.Text)
- }
- }
- return strings.Join(parts, "\n")
- default:
- panic("unreachable")
- }
-}
-
-func getMsgExtra[M adk.MessageType](msg M) map[string]any {
- switch m := any(msg).(type) {
- case *schema.Message:
- if m == nil {
- return nil
- }
- return m.Extra
- case *schema.AgenticMessage:
- if m == nil {
- return nil
- }
- return m.Extra
- default:
- panic("unreachable")
- }
-}
-
-func copyAndSetMsgExtra[M adk.MessageType](msg M, key string, value any) {
- existing := getMsgExtra(msg)
- newExtra := make(map[string]any, len(existing)+1)
- for k, v := range existing {
- newExtra[k] = v
- }
- newExtra[key] = value
-
- switch m := any(msg).(type) {
- case *schema.Message:
- m.Extra = newExtra
- case *schema.AgenticMessage:
- m.Extra = newExtra
- default:
- panic("unreachable")
- }
-}
-
-func makeUserMsg[M adk.MessageType](text string) M {
- var zero M
- switch any(zero).(type) {
- case *schema.Message:
- return any(schema.UserMessage(text)).(M)
- case *schema.AgenticMessage:
- return any(schema.UserAgenticMessage(text)).(M)
- default:
- panic("unreachable")
- }
-}
-
-func makeSystemMsg[M adk.MessageType](text string) M {
- var zero M
- switch any(zero).(type) {
- case *schema.Message:
- return any(schema.SystemMessage(text)).(M)
- case *schema.AgenticMessage:
- return any(schema.SystemAgenticMessage(text)).(M)
- default:
- panic("unreachable")
- }
-}
-
-func makeToolChoiceForced[M adk.MessageType](name string) model.Option {
- var zero M
- switch any(zero).(type) {
- case *schema.Message:
- return model.WithToolChoice(schema.ToolChoiceForced, name)
- case *schema.AgenticMessage:
- return model.WithAgenticToolChoice(&schema.AgenticToolChoice{
- Type: schema.ToolChoiceForced,
- Forced: &schema.AgenticForcedToolChoice{
- Tools: []*schema.AllowedTool{{FunctionName: name}},
- },
- })
- default:
- panic("unreachable")
- }
-}
-
-func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall {
- switch m := any(msg).(type) {
- case *schema.Message:
- if m == nil {
- return nil
- }
- return m.ToolCalls
- case *schema.AgenticMessage:
- if m == nil {
- return nil
- }
- out := make([]schema.ToolCall, 0, len(m.ContentBlocks))
- for _, block := range m.ContentBlocks {
- if block == nil || block.FunctionToolCall == nil {
- continue
- }
- out = append(out, schema.ToolCall{
- ID: block.FunctionToolCall.CallID,
- Type: "function",
- Function: schema.FunctionCall{
- Name: block.FunctionToolCall.Name,
- Arguments: block.FunctionToolCall.Arguments,
- },
- })
- }
- return out
- default:
- panic("unreachable")
- }
-}
-
-func messageToolNames[M adk.MessageType](msg M) []string {
- switch m := any(msg).(type) {
- case *schema.Message:
- if m == nil || m.Role != schema.Tool || m.ToolName == "" {
- return nil
- }
- return []string{m.ToolName}
- case *schema.AgenticMessage:
- if m == nil {
- return nil
- }
- var out []string
- for _, block := range m.ContentBlocks {
- if block == nil || block.FunctionToolResult == nil || block.FunctionToolResult.Name == "" {
- continue
- }
- out = append(out, block.FunctionToolResult.Name)
- }
- return out
- default:
- panic("unreachable")
- }
-}
-
-func projectMessagesToSchema[M adk.MessageType](msgs []M) []adk.Message {
- out := make([]adk.Message, 0, len(msgs))
- for _, msg := range msgs {
- if projected := projectMessageToSchema(msg); projected != nil {
- out = append(out, projected)
- }
- }
- return out
-}
-
-func projectMessageToSchema[M adk.MessageType](msg M) adk.Message {
- switch m := any(msg).(type) {
- case *schema.Message:
- return m
- case *schema.AgenticMessage:
- if m == nil {
- return nil
- }
- text := m.String()
- switch m.Role {
- case schema.AgenticRoleTypeSystem:
- return schema.SystemMessage(text)
- case schema.AgenticRoleTypeAssistant:
- return schema.AssistantMessage(text, messageToolCalls(msg))
- case schema.AgenticRoleTypeUser:
- return schema.UserMessage(text)
- default:
- return schema.UserMessage(text)
- }
- default:
- panic("unreachable")
- }
-}
-
-func alreadyInjected[M adk.MessageType](msgs []M) bool {
- for _, m := range msgs {
- if isMemoryMessage(m) {
- return true
- }
- }
- return false
-}
-
-func isMemoryMessage[M adk.MessageType](m M) bool {
- if isNilMessage(m) || !isUserRole(m) {
- return false
- }
- if extra := getMsgExtra(m); extra != nil {
- if v, ok := extra[memoryExtraKey]; ok {
- if isAutomemoryMemoryExtra(v) {
- return true
- }
- }
- }
- // Backward compatible marker (older versions).
- return strings.Contains(userMessageTextContent(m), "")
-}
-
-func isAutomemoryMemoryExtra(v any) bool {
- switch meta := v.(type) {
- case *memoryExtra:
- return meta != nil && meta.Type == "memory"
- case map[string]any:
- typ, _ := meta["type"].(string)
- return typ == "memory"
- default:
- return false
- }
-}
-
-func hasInstructionInjected(instruction string) bool {
- return strings.Contains(instruction, instructionMarker)
-}
-
-func newMemoryMessage[M adk.MessageType](content string) M {
- msg := makeUserMsg[M](content)
- copyAndSetMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory"})
- return msg
-}
-
-func ensureMemoryMsgUnchanged[M adk.MessageType](state *adk.TypedChatModelAgentState[M], expectedContent string) *adk.TypedChatModelAgentState[M] {
- if state == nil || strings.TrimSpace(expectedContent) == "" {
- return state
- }
- changed := false
- out := *state
- out.Messages = append([]M{}, state.Messages...)
-
- for i, m := range out.Messages {
- if !isMemoryMessage(m) {
- continue
- }
- extra := getMsgExtra(m)
- if userMessageTextContent(m) != expectedContent || extra == nil || extra[memoryExtraKey] == nil {
- out.Messages[i] = newMemoryMessage[M](expectedContent)
- changed = true
- }
- }
- if !changed {
- return state
- }
- return &out
-}
-
-func extractFilePath(args string) (string, bool) {
- var m map[string]any
- if err := json.Unmarshal([]byte(args), &m); err != nil {
- return "", false
- }
- if v, ok := m["file_path"]; ok {
- if s, ok := v.(string); ok && s != "" {
- return s, true
- }
- }
- if v, ok := m["filePath"]; ok { // tolerate camelCase
- if s, ok := v.(string); ok && s != "" {
- return s, true
- }
- }
- return "", false
-}
-
-func isPathWithinMemoryDir(memDir string, filePath string) bool {
- if memDir == "" || filePath == "" {
- return false
- }
- md := filepath.Clean(memDir)
- fp := filepath.Clean(filePath)
- if !filepath.IsAbs(fp) {
- fp = filepath.Join(md, fp)
- fp = filepath.Clean(fp)
- }
- if fp == md {
- return true
- }
- sep := string(filepath.Separator)
- return strings.HasPrefix(fp, md+sep)
+ return topicMemoryPromptInfo{
+ StoreName: bundle.StoreName,
+ StorePath: bundle.StorePath,
+ Path: bundle.RelPath,
+ Saved: bundle.Info.ModifiedAt,
+ Content: content,
+ }, true
}
func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (context.Context, error) {
@@ -1275,10 +751,11 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
m.onErr(ctx, OnErrorStageResolveSessionID, err)
return ctx, nil
}
+ coordKey := m.coordinatorKey(sessionID)
cursor := getWriteCursorFromMessages(state.Messages)
- if sessionID != "" {
- if remoteCursor, ok, err := m.coordination.Coordinator.GetCursor(ctx, sessionID); err == nil && ok && remoteCursor > cursor {
+ if coordKey != "" {
+ if remoteCursor, ok, err := getCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey); err == nil && ok && remoteCursor > cursor {
cursor = remoteCursor
state = markWriteCursor(state, cursor)
}
@@ -1288,10 +765,10 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
}
// Skip background extraction if the main agent already wrote memory files in this range.
- if hasMemoryWritesSince(state.Messages, cursor, m.resolvedMemoryDirectory) {
+ if hasMemoryWritesSince(state.Messages, cursor, m.memoryStores) {
end := len(state.Messages)
- if sessionID != "" {
- _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end)
+ if coordKey != "" {
+ _ = setCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey, end)
}
state = markWriteCursor(state, end)
return ctx, nil
@@ -1299,8 +776,8 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
if countModelVisibleMessages(state.Messages[cursor:]) == 0 {
end := len(state.Messages)
- if sessionID != "" {
- _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end)
+ if coordKey != "" {
+ _ = setCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey, end)
}
state = markWriteCursor(state, end)
return ctx, nil
@@ -1317,8 +794,8 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
m.onErr(ctx, OnErrorStageMemoryWriteSync, err)
return ctx, nil
}
- if sessionID != "" {
- _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, end)
+ if coordKey != "" {
+ _ = setCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey, end)
}
state = markWriteCursor(state, end)
return ctx, nil
@@ -1326,24 +803,25 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
case WriteModeAsync:
if sessionID == "" {
sessionID = getOrInitWriteSessionID(ctx)
+ coordKey = m.coordinatorKey(sessionID)
}
snap, err := buildPendingSnapshot(state.Messages, cursor, state.ToolInfos)
if err != nil {
m.onErr(ctx, OnErrorStageSnapshotMarshal, err)
return ctx, nil
}
- unlock, ok, err := m.coordination.Coordinator.AcquireLock(ctx, sessionID, m.coordination.LockTTL)
+ unlock, ok, err := m.coordination.Coordinator.AcquireLock(ctx, coordKey, m.coordination.LockTTL)
if err != nil {
m.onErr(ctx, OnErrorStageAcquireExtractionLock, err)
return ctx, nil
}
if !ok {
- if err := m.coordination.Coordinator.SetPendingSnapshot(ctx, sessionID, snap); err != nil {
+ if err := setCoordinatorPendingSnapshot(ctx, m.coordination.Coordinator, coordKey, snap, m.coordination.LockTTL); err != nil {
m.onErr(ctx, OnErrorStageStashPendingSnapshot, err)
}
return ctx, nil
}
- go m.runExtractionDrain(ctx, sessionID, unlock, snap)
+ go m.runExtractionDrain(ctx, coordKey, unlock, snap)
return ctx, nil
default:
@@ -1351,122 +829,7 @@ func (m *middleware[M]) AfterAgent(ctx context.Context, state *adk.TypedChatMode
}
}
-func getWriteCursorFromMessages[M adk.MessageType](msgs []M) int {
- for i := len(msgs) - 1; i >= 0; i-- {
- m := msgs[i]
- extra := getMsgExtra(m)
- if isNilMessage(m) || extra == nil {
- continue
- }
- v, ok := extra[memoryExtraKey]
- if !ok {
- continue
- }
- switch meta := v.(type) {
- case *memoryExtra:
- if meta != nil && meta.Type == "write_cursor" {
- return meta.Cursor
- }
- case map[string]any:
- if typ, _ := meta["type"].(string); typ != "write_cursor" {
- continue
- }
- switch c := meta["cursor"].(type) {
- case int:
- return c
- case int64:
- return int(c)
- case float64:
- return int(c)
- }
- }
- }
- return 0
-}
-
-func markWriteCursor[M adk.MessageType](state *adk.TypedChatModelAgentState[M], cursor int) *adk.TypedChatModelAgentState[M] {
- if state == nil || len(state.Messages) == 0 {
- return state
- }
- last := state.Messages[len(state.Messages)-1]
- if isNilMessage(last) {
- return state
- }
-
- copyAndSetMsgExtra(last, memoryExtraKey, &memoryExtra{
- Type: "write_cursor",
- Cursor: cursor,
- })
-
- return state
-}
-
-func countModelVisibleMessages[M adk.MessageType](msgs []M) int {
- n := 0
- for _, m := range msgs {
- if isNilMessage(m) {
- continue
- }
- if isUserRole(m) || isAssistantRole(m) {
- n++
- }
- }
- return n
-}
-
-func getOrInitWriteSessionID(ctx context.Context) string {
- const key = "__automemory_write_session_id__"
- if v, ok := adk.GetSessionValue(ctx, key); ok {
- if s, ok := v.(string); ok && s != "" {
- return s
- }
- }
- // Stable enough for in-process session identity.
- s := fmt.Sprintf("%d", time.Now().UnixNano())
- adk.AddSessionValue(ctx, key, s)
- return s
-}
-
-func (m *middleware[M]) resolveSessionID(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (string, error) {
- if m.coordination != nil && m.coordination.SessionIDFunc != nil {
- return m.coordination.SessionIDFunc(ctx, state)
- }
- return getOrInitWriteSessionID(ctx), nil
-}
-
-func buildPendingSnapshot[M adk.MessageType](messages []M, cursor int, toolInfos []*schema.ToolInfo) (*PendingSnapshot, error) {
- raw, err := json.Marshal(messages)
- if err != nil {
- return nil, err
- }
- var rawToolInfos json.RawMessage
- if toolInfos != nil {
- rawToolInfos, err = json.Marshal(toolInfos)
- if err != nil {
- return nil, err
- }
- }
- return &PendingSnapshot{Cursor: cursor, Messages: raw, ToolInfos: rawToolInfos}, nil
-}
-
-func decodePendingSnapshot[M adk.MessageType](snapshot *PendingSnapshot) ([]M, int, []*schema.ToolInfo, error) {
- if snapshot == nil {
- return nil, 0, nil, nil
- }
- var msgs []M
- if err := json.Unmarshal(snapshot.Messages, &msgs); err != nil {
- return nil, 0, nil, err
- }
- var toolInfos []*schema.ToolInfo
- if len(snapshot.ToolInfos) > 0 {
- if err := json.Unmarshal(snapshot.ToolInfos, &toolInfos); err != nil {
- return nil, 0, nil, err
- }
- }
- return msgs, snapshot.Cursor, toolInfos, nil
-}
-
-func (m *middleware[M]) runExtractionDrain(ctx context.Context, sessionID string, unlock func(context.Context) error, initial *PendingSnapshot) {
+func (m *middleware[M]) runExtractionDrain(ctx context.Context, coordKey string, unlock func(context.Context) error, initial *PendingSnapshot) {
defer func() {
if unlock == nil {
return
@@ -1484,10 +847,10 @@ func (m *middleware[M]) runExtractionDrain(ctx context.Context, sessionID string
} else if err := m.runMemoryExtractionAgent(ctx, msgs, cursor, toolInfos); err != nil {
m.onErr(ctx, OnErrorStageMemoryWriteAsync, err)
} else {
- _ = m.coordination.Coordinator.SetCursor(ctx, sessionID, len(msgs))
+ _ = setCoordinatorCursor(ctx, m.coordination.Coordinator, coordKey, len(msgs))
}
- next, loadErr := m.coordination.Coordinator.PopPendingSnapshot(ctx, sessionID)
+ next, loadErr := popCoordinatorPendingSnapshot(ctx, m.coordination.Coordinator, coordKey)
if loadErr != nil {
m.onErr(ctx, OnErrorStageLoadPendingSnapshot, loadErr)
return
@@ -1496,36 +859,6 @@ func (m *middleware[M]) runExtractionDrain(ctx context.Context, sessionID string
}
}
-func hasMemoryWritesSince[M adk.MessageType](msgs []M, cursor int, memoryDir string) bool {
- if cursor < 0 {
- cursor = 0
- }
- for _, msg := range msgs[cursor:] {
- if isNilMessage(msg) || !isAssistantRole(msg) {
- continue
- }
- for _, tc := range messageToolCalls(msg) {
- if tc.Function.Name != adkfs.ToolNameWriteFile && tc.Function.Name != adkfs.ToolNameEditFile {
- continue
- }
- if fp, ok := extractFilePath(tc.Function.Arguments); ok && isPathWithinMemoryDir(memoryDir, fp) {
- return true
- }
- }
- }
- return false
-}
-
-func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int {
- if cursor < 0 {
- cursor = 0
- }
- if cursor >= len(msgs) {
- return 0
- }
- return countModelVisibleMessages(msgs[cursor:])
-}
-
func (m *middleware[M]) newExtractionAgent(ctx context.Context, toolInfos []*schema.ToolInfo) (*adk.TypedChatModelAgent[M], error) {
if m.cfg == nil || m.cfg.Write == nil || m.cfg.Write.Model == nil {
return nil, fmt.Errorf("auto memory extraction agent init failed: missing write model")
@@ -1566,7 +899,8 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
return err
}
newMessageCount := countModelVisibleMessagesSince(snapshot, cursor)
- userPrompt := buildExtractAutoOnlyPrompt(m.resolvedMemoryDirectory, newMessageCount, manifest, m.cfg.Write.SkipIndex)
+ enableMemoryIndex := m.memoryIndexEnabled() && !m.cfg.Write.SkipIndex
+ userPrompt := buildExtractAutoOnlyPrompt(m.extractionMemoryStoresPrompt(), newMessageCount, manifest, enableMemoryIndex)
msgs := append(append([]M{}, snapshot...), makeUserMsg[M](userPrompt))
extractionAgent, err := m.newExtractionAgent(ctx, toolInfos)
if err != nil {
@@ -1596,52 +930,73 @@ func (m *middleware[M]) runMemoryExtractionAgent(ctx context.Context, snapshot [
}
}
-func (m *middleware[M]) buildMemoryManifest(ctx context.Context) (string, error) {
- files, err := m.boundedMemoryBackend.GlobInfo(ctx, &GlobInfoRequest{
- Pattern: CandidateGlobPattern,
- Path: m.resolvedMemoryDirectory,
- })
- if err != nil {
- return "", err
- }
- indexAbs := filepath.Join(m.resolvedMemoryDirectory, m.cfg.Read.Index.FileName)
- lines := make([]string, 0, len(files))
- for _, fi := range files {
- rel, relErr := filepath.Rel(m.resolvedMemoryDirectory, fi.Path)
- if relErr != nil {
- rel = filepath.Base(fi.Path)
- }
- rel = filepath.ToSlash(rel)
- if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) {
- rel = m.cfg.Read.Index.FileName
+func (m *middleware[M]) extractionMemoryStoresPrompt() string {
+ stores := make([]memoryStorePromptInfo, 0, len(m.memoryStores))
+ for _, store := range m.memoryStores {
+ info := memoryStorePromptInfo{
+ Name: store.displayName(),
+ Mount: store.Path,
+ Description: strings.TrimSpace(store.Description),
}
- desc := ""
- preview, rerr := m.boundedMemoryBackend.Read(ctx, &ReadRequest{FilePath: fi.Path, Limit: defaultCandidatePreviewLine})
- if rerr == nil && preview != nil && !isFileNotFoundContent(preview.Content) {
- if fm, ok := parseFrontmatter(preview.Content); ok {
- desc = strings.TrimSpace(fm.Description)
+ if m.memoryIndexEnabled() {
+ info.Index = &memoryIndexPromptInfo{
+ FileName: m.cfg.Read.Index.FileName,
+ Path: filepath.Join(store.Path, m.cfg.Read.Index.FileName),
}
}
- if desc != "" {
- lines = append(lines, fmt.Sprintf("- %s (saved %s): %s", rel, fi.ModifiedAt, desc))
- } else {
- lines = append(lines, fmt.Sprintf("- %s (saved %s)", rel, fi.ModifiedAt))
- }
+ stores = append(stores, info)
}
- return strings.Join(lines, "\n"), nil
+ return buildMemoryStoresManifest(stores)
}
-func parseRFC3339NanoBestEffort(s string) time.Time {
- if s == "" {
- return time.Time{}
- }
- if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
- return t
- }
- if t, err := time.Parse(time.RFC3339, s); err == nil {
- return t
+func (m *middleware[M]) buildMemoryManifest(ctx context.Context) (string, error) {
+ var stores []memoryManifestStorePromptInfo
+ for _, store := range m.memoryStores {
+ files, err := store.Backend.GlobInfo(ctx, &GlobInfoRequest{
+ Pattern: CandidateGlobPattern,
+ Path: store.Path,
+ })
+ if err != nil {
+ return "", err
+ }
+ storeInfo := memoryManifestStorePromptInfo{
+ Name: store.displayName(),
+ Mount: store.Path,
+ }
+ indexAbs := filepath.Join(store.Path, m.cfg.Read.Index.FileName)
+ if len(files) == 0 {
+ stores = append(stores, storeInfo)
+ continue
+ }
+ for _, fi := range files {
+ rel, relErr := filepath.Rel(store.Path, fi.Path)
+ if relErr != nil {
+ rel = filepath.Base(fi.Path)
+ }
+ rel = filepath.ToSlash(rel)
+ if filepath.Clean(fi.Path) == filepath.Clean(indexAbs) {
+ if !m.memoryIndexEnabled() {
+ continue
+ }
+ rel = m.cfg.Read.Index.FileName
+ }
+ desc := ""
+ preview, rerr := store.Backend.Read(ctx, &ReadRequest{FilePath: fi.Path, Limit: defaultCandidatePreviewLine})
+ if rerr == nil && preview != nil && !isFileNotFoundContent(preview.Content) {
+ if fm, ok := parseFrontmatter(preview.Content); ok {
+ desc = strings.TrimSpace(fm.Description)
+ }
+ }
+ storeInfo.Files = append(storeInfo.Files, memoryManifestFilePromptInfo{
+ MemoryPath: filepath.ToSlash(filepath.Join(store.displayName(), rel)),
+ AbsPath: fi.Path,
+ Saved: fi.ModifiedAt,
+ Description: desc,
+ })
+ }
+ stores = append(stores, storeInfo)
}
- return time.Time{}
+ return buildExtractionMemoryManifest(stores), nil
}
type toolInfoOverrideMiddleware[M adk.MessageType] struct {
@@ -1687,19 +1042,3 @@ func (m *modelWithTools[M]) Stream(ctx context.Context, input []M, opts ...model
newOpts[len(opts)] = model.WithTools(m.tools)
return m.base.Stream(ctx, input, newOpts...)
}
-
-func (m *middleware[M]) sendTopicMemoryEvent(ctx context.Context, msgs []M, memMsg M) {
- var beforeID string
- if len(msgs) > 0 && !isNilMessage(msgs[len(msgs)-1]) {
- beforeID = adk.GetMessageID(msgs[len(msgs)-1])
- }
- if sendEventErr := adk.TypedSendEvent(ctx, &adk.TypedAgentEvent[M]{SessionEvent: &adk.SessionEvent[M]{
- Kind: adk.SessionEventMessageInserted,
- MessageInserted: &adk.MessageInsertedEvent[M]{
- Message: memMsg,
- BeforeMessageID: beforeID,
- },
- }}); sendEventErr != nil {
- m.onErr(ctx, OnErrorStageSendSessionEvent, sendEventErr)
- }
-}
diff --git a/adk/middlewares/automemory/automemory_test.go b/adk/middlewares/automemory/automemory_test.go
index 4be199c48..6bab0f2f0 100644
--- a/adk/middlewares/automemory/automemory_test.go
+++ b/adk/middlewares/automemory/automemory_test.go
@@ -30,7 +30,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/cloudwego/eino/adk"
- adksession "github.com/cloudwego/eino/adk/session"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
)
@@ -40,7 +39,16 @@ type fixedModel struct {
}
func (m *fixedModel) Generate(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.Message, error) {
- return schema.AssistantMessage(m.out, nil), nil
+ return schema.AssistantMessage("", []schema.ToolCall{
+ {
+ ID: "select-fixed",
+ Type: "function",
+ Function: schema.FunctionCall{
+ Name: topicSelectionToolName,
+ Arguments: m.out,
+ },
+ },
+ }), nil
}
func (m *fixedModel) Stream(ctx context.Context, input []*schema.Message, _ ...model.Option) (*schema.StreamReader[*schema.Message], error) {
@@ -52,13 +60,82 @@ func (m *fixedModel) WithTools(_ []*schema.ToolInfo) (model.ToolCallingChatModel
return m, nil
}
+func requireMemoryIndexMessage(t *testing.T, msg *schema.Message, contains ...string) {
+ t.Helper()
+ require.True(t, isMemoryIndexMessage(msg))
+ require.NotNil(t, msg.Extra)
+ require.NotNil(t, msg.Extra[memoryExtraKey])
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.NotContains(t, msg.Content, "### 1. Name:")
+ require.NotContains(t, msg.Content, "#### Index file content:")
+ for _, s := range contains {
+ require.Contains(t, msg.Content, s)
+ }
+}
+
+func requireTopicMemoryMessage(t *testing.T, msg *schema.Message, contains ...string) {
+ t.Helper()
+ require.True(t, isTopicMemoryMessage(msg))
+ require.NotNil(t, msg.Extra)
+ require.NotNil(t, msg.Extra[memoryExtraKey])
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "Topic memories are long-term memory files selected as relevant to the current query")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ require.Contains(t, msg.Content, "")
+ for _, s := range contains {
+ require.Contains(t, msg.Content, s)
+ }
+}
+
+func requireWriteCursor(t *testing.T, msgs []*schema.Message, cursor int) {
+ t.Helper()
+ for _, msg := range msgs {
+ if msg == nil || msg.Extra == nil {
+ continue
+ }
+ meta, ok := msg.Extra[memoryExtraKey].(*memoryExtra)
+ if ok && meta != nil && meta.Type == "write_cursor" {
+ require.EqualValues(t, cursor, meta.Cursor)
+ return
+ }
+ }
+ require.Fail(t, "write cursor not found")
+}
+
+func countMemoryIndexMessages(msgs []*schema.Message) int {
+ count := 0
+ for _, msg := range msgs {
+ if isMemoryIndexMessage(msg) {
+ count++
+ }
+ }
+ return count
+}
+
+func countTopicMemoryMessages(msgs []*schema.Message) int {
+ count := 0
+ for _, msg := range msgs {
+ if isTopicMemoryMessage(msg) {
+ count++
+ }
+ }
+ return count
+}
+
func TestMiddleware_IndexInjection_Empty(t *testing.T) {
ctx := context.Background()
b := NewInMemoryBackend()
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
// Model nil => topic selection disabled.
})
require.NoError(t, err)
@@ -70,9 +147,16 @@ func TestMiddleware_IndexInjection_Empty(t *testing.T) {
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Contains(t, out.Instruction, "# auto memory")
- require.Contains(t, out.Instruction, "## MEMORY.md")
- require.Contains(t, out.Instruction, "currently empty")
+ require.Contains(t, out.Instruction, "# Auto memory")
+ require.Contains(t, out.Instruction, "## Memory stores")
+ require.Contains(t, out.Instruction, "1. Name: mem")
+ require.Contains(t, out.Instruction, "Path: /mem")
+ require.NotContains(t, out.Instruction, "Index file path: /mem/MEMORY.md")
+ require.NotContains(t, out.Instruction, "#### Index file content: MEMORY.md")
+ require.NotContains(t, out.Instruction, "Rules:")
+ require.Len(t, out.AgentInput.Messages, 2)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Memory indexes are the high-level table of contents", "Index Memory File Path: /mem/MEMORY.md", "currently empty")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "hi")
}
func TestMiddleware_IndexInjection_ChineseInstruction(t *testing.T) {
@@ -85,8 +169,8 @@ func TestMiddleware_IndexInjection_ChineseInstruction(t *testing.T) {
b := NewInMemoryBackend()
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
})
require.NoError(t, err)
@@ -98,7 +182,72 @@ func TestMiddleware_IndexInjection_ChineseInstruction(t *testing.T) {
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
require.Contains(t, out.Instruction, "# 自动记忆")
- require.Contains(t, out.Instruction, "你的 MEMORY.md 当前为空")
+ require.NotContains(t, out.Instruction, "你的 MEMORY.md 当前为空")
+ require.Len(t, out.AgentInput.Messages, 2)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "记忆索引是每个记忆存储的高层目录", "索引文件当前为空")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "hi")
+}
+
+func TestMiddleware_IndexInjection_CustomInstructionKeepsStoreManifest(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+ custom := "custom memory header"
+
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{
+ {Path: "/mem", Name: "profile", Description: "User profile."},
+ },
+ MemoryBackend: b,
+ GenInstruction: func(ctx context.Context) (string, error) {
+ return custom, nil
+ },
+ })
+ require.NoError(t, err)
+
+ runCtx := &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: "base",
+ AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}},
+ }
+
+ _, out, err := mw.BeforeAgent(ctx, runCtx)
+ require.NoError(t, err)
+ require.Contains(t, out.Instruction, "custom memory header")
+ require.Contains(t, out.Instruction, "## Memory stores")
+ require.Contains(t, out.Instruction, "1. Name: profile")
+ require.Contains(t, out.Instruction, "Path: /mem")
+ require.Contains(t, out.Instruction, "Description: User profile.")
+ require.NotContains(t, out.Instruction, "Index file path")
+ require.Len(t, out.AgentInput.Messages, 2)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "hi")
+}
+
+func TestMiddleware_IndexInjection_CustomInstructionErrorReportsRenderStage(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+ var stages []ErrorStage
+
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ GenInstruction: func(ctx context.Context) (string, error) {
+ return "", fmt.Errorf("custom instruction failed")
+ },
+ OnError: func(ctx context.Context, stage ErrorStage, err error) {
+ stages = append(stages, stage)
+ },
+ })
+ require.NoError(t, err)
+
+ runCtx := &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: "base",
+ AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}},
+ }
+
+ _, out, err := mw.BeforeAgent(ctx, runCtx)
+ require.NoError(t, err)
+ require.Equal(t, "base", out.Instruction)
+ require.Equal(t, []ErrorStage{OnErrorStageRenderInstruction}, stages)
}
func TestNew_DoesNotMutateConfig(t *testing.T) {
@@ -106,9 +255,9 @@ func TestNew_DoesNotMutateConfig(t *testing.T) {
b := NewInMemoryBackend()
cfgNilNested := &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["mem/debugging.md"]}`},
}
_, err := New(ctx, cfgNilNested)
require.NoError(t, err)
@@ -117,12 +266,12 @@ func TestNew_DoesNotMutateConfig(t *testing.T) {
require.Nil(t, cfgNilNested.Coordination)
cfgExplicitNested := &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`},
- Read: &ReadConfig[*schema.Message]{},
- Write: &WriteConfig[*schema.Message]{},
- Coordination: &CoordinationConfig[*schema.Message]{},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["mem/debugging.md"]}`},
+ Read: &ReadConfig[*schema.Message]{},
+ Write: &WriteConfig[*schema.Message]{},
+ Coordination: &CoordinationConfig[*schema.Message]{},
}
_, err = New(ctx, cfgExplicitNested)
require.NoError(t, err)
@@ -147,9 +296,9 @@ func TestMiddleware_TopicSelection_InsertsMemoryMessage(t *testing.T) {
b.put("/mem/other.md", "---\nname: Other\ndescription: unrelated\ntype: misc\n---\n", now.Add(-time.Hour))
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["mem/debugging.md"]}`},
})
require.NoError(t, err)
@@ -162,69 +311,69 @@ func TestMiddleware_TopicSelection_InsertsMemoryMessage(t *testing.T) {
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
require.NotNil(t, out.AgentInput)
- require.Len(t, out.AgentInput.Messages, 2)
- require.Equal(t, schema.User, out.AgentInput.Messages[0].Role)
- require.Contains(t, out.AgentInput.Messages[0].Content, "How to run tests?")
- require.Contains(t, out.AgentInput.Messages[1].Content, "")
- require.NotNil(t, out.AgentInput.Messages[1].Extra)
- require.NotNil(t, out.AgentInput.Messages[1].Extra["__eino_automemory__"])
- require.Contains(t, out.AgentInput.Messages[1].Content, "Contents of /mem/debugging.md")
+ require.Len(t, out.AgentInput.Messages, 3)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ requireTopicMemoryMessage(t, out.AgentInput.Messages[1], "Memory Store Name: mem", "Topic Memory File Path: /mem/debugging.md")
+ require.Equal(t, schema.User, out.AgentInput.Messages[2].Role)
+ require.Contains(t, out.AgentInput.Messages[2].Content, "How to run tests?")
}
-func TestMiddleware_BeforeAgent_MessageInsertedEventPersistsToSessionStore(t *testing.T) {
+func TestMiddleware_MultipleMemoryStores_IndexAndTopicSelection(t *testing.T) {
ctx := context.Background()
b := NewInMemoryBackend()
now := time.Now()
- b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md) - notes\n", now)
- b.put("/mem/debugging.md", "---\nname: Debugging\ndescription: build and test commands\ntype: project\n---\n\n# Debugging\npnpm test\n", now)
+ b.put("/user/MEMORY.md", "- [prefs.md](prefs.md) - user preferences\n", now)
+ b.put("/user/prefs.md", "---\ndescription: editor preferences\n---\n\nUse concise answers.\n", now)
+ b.put("/project/MEMORY.md", "- [debugging.md](debugging.md) - project debugging\n", now)
+ b.put("/project/debugging.md", "---\ndescription: test commands\n---\n\nRun go test ./...\n", now)
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: "ok"},
+ MemoryStores: []MemoryStore{
+ {Path: "/user", Name: "user_profile", Description: "User preferences."},
+ {Path: "/project", Name: "project_context", Description: "Project conventions."},
+ },
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["project_context/debugging.md"]}`},
+ Read: &ReadConfig[*schema.Message]{
+ Index: &IndexConfig{EnableMemoryIndex: boolPtr(true)},
+ },
})
require.NoError(t, err)
- agent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
- Name: "automemory-session-event-agent",
+ runCtx := &adk.ChatModelAgentContext[*schema.Message]{
Instruction: "base",
- Model: &fixedModel{out: "ok"},
- Handlers: []adk.ChatModelAgentMiddleware{mw},
- })
- require.NoError(t, err)
-
- const sessionID = "automemory-message-inserted-session"
- store := adksession.NewInMemoryStore[*schema.Message](nil)
- runner := adk.NewRunner(ctx, adk.RunnerConfig{
- Agent: agent,
- SessionID: sessionID,
- SessionStore: store,
- })
-
- iter := runner.Query(ctx, "How to run tests?")
- for {
- event, ok := iter.Next()
- if !ok {
- break
- }
- require.NoError(t, event.Err)
+ AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How should I run tests?")}},
}
- loaded, err := store.LoadEvents(ctx, &adk.LoadSessionEventsRequest{
- SessionID: sessionID,
- Kinds: []adk.SessionEventKind{adk.SessionEventMessageInserted},
- })
+ _, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Len(t, loaded.Events, 1, "AutoMemory BeforeAgent MessageInserted event should be persisted in SessionStore")
-
- inserted := loaded.Events[0].MessageInserted
- require.NotNil(t, inserted)
- require.NotEmpty(t, inserted.BeforeMessageID)
- require.NotNil(t, inserted.Message)
- require.Contains(t, inserted.Message.Content, "")
- require.Contains(t, inserted.Message.Content, "Contents of /mem/debugging.md")
- require.NotNil(t, inserted.Message.Extra[memoryExtraKey])
+ require.Contains(t, out.Instruction, "## Memory stores")
+ require.Contains(t, out.Instruction, "1. Name: user_profile")
+ require.Contains(t, out.Instruction, "Path: /user")
+ require.Contains(t, out.Instruction, "Description: User preferences.")
+ require.Contains(t, out.Instruction, "2. Name: project_context")
+ require.Contains(t, out.Instruction, "Path: /project")
+ require.NotContains(t, out.Instruction, "Index file path: /user/MEMORY.md")
+ require.NotContains(t, out.Instruction, "Index file path: /project/MEMORY.md")
+ require.NotContains(t, out.Instruction, "#### Index file content: MEMORY.md")
+ require.Len(t, out.AgentInput.Messages, 3)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0],
+ "Index Memory File Path: /user/MEMORY.md",
+ "Index Memory File Path: /project/MEMORY.md",
+ "- [prefs.md](prefs.md) - user preferences",
+ "- [debugging.md](debugging.md) - project debugging",
+ )
+ indexReminder := out.AgentInput.Messages[0].Content
+ userStorePos := strings.Index(indexReminder, "")
+ userIndexPos := strings.Index(indexReminder, "- [prefs.md](prefs.md) - user preferences")
+ projectStorePos := strings.Index(indexReminder, "")
+ projectIndexPos := strings.Index(indexReminder, "- [debugging.md](debugging.md) - project debugging")
+ require.True(t, userStorePos >= 0 && userIndexPos > userStorePos && userIndexPos < projectStorePos)
+ require.True(t, projectStorePos >= 0 && projectIndexPos > projectStorePos)
+ requireTopicMemoryMessage(t, out.AgentInput.Messages[1], "Memory Store Name: project_context", "Topic Memory File Path: /project/debugging.md", "Run go test ./...")
+ require.NotContains(t, out.AgentInput.Messages[1].Content, "Use concise answers.")
+ require.Contains(t, out.AgentInput.Messages[2].Content, "How should I run tests?")
}
func TestMiddleware_TopicSelection_AsyncInjectsInBeforeModel(t *testing.T) {
@@ -236,10 +385,10 @@ func TestMiddleware_TopicSelection_AsyncInjectsInBeforeModel(t *testing.T) {
b.put("/mem/debugging.md", "---\nname: Debugging\ndescription: build and test commands\ntype: project\n---\n\n# Debugging\npnpm test\n", now)
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`},
- Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["mem/debugging.md"]}`},
+ Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync},
})
require.NoError(t, err)
@@ -249,7 +398,9 @@ func TestMiddleware_TopicSelection_AsyncInjectsInBeforeModel(t *testing.T) {
}
ctx2, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Len(t, out.AgentInput.Messages, 1) // async doesn't inject here
+ require.Len(t, out.AgentInput.Messages, 2) // async doesn't inject topic memory here
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "How to run tests?")
st := &adk.ChatModelAgentState{Messages: []adk.Message{schema.UserMessage("How to run tests?")}}
@@ -288,7 +439,7 @@ func (m *toolCallSelectionModel) Generate(_ context.Context, _ []*schema.Message
Type: "function",
Function: schema.FunctionCall{
Name: topicSelectionToolName,
- Arguments: `{"selected_memories":["debugging.md","hallucinated.md"]}`,
+ Arguments: `{"selected_memories":["mem/debugging.md","hallucinated.md"]}`,
},
},
}), nil
@@ -310,6 +461,8 @@ type extractionModel struct {
mu sync.Mutex
promptSeen []string
boundToolCalls [][]string
+ topicPath string
+ indexPath string
blockFirstRun chan struct{}
firstRunStarted chan struct{}
blockedOnce uint32 // atomic (0/1)
@@ -387,13 +540,21 @@ func (m *extractionModel) Generate(_ context.Context, input []*schema.Message, _
}
payload := lastBusinessUserBeforePrompt(input, promptIdx)
+ topicPath := m.topicPath
+ if topicPath == "" {
+ topicPath = "topic.md"
+ }
+ indexPath := m.indexPath
+ if indexPath == "" {
+ indexPath = "MEMORY.md"
+ }
return schema.AssistantMessage("", []schema.ToolCall{
{
ID: "write-topic",
Type: "function",
Function: schema.FunctionCall{
Name: "write_file",
- Arguments: fmt.Sprintf(`{"file_path":"topic.md","content":%q}`, payload),
+ Arguments: fmt.Sprintf(`{"file_path":%q,"content":%q}`, topicPath, payload),
},
},
{
@@ -401,7 +562,7 @@ func (m *extractionModel) Generate(_ context.Context, input []*schema.Message, _
Type: "function",
Function: schema.FunctionCall{
Name: "write_file",
- Arguments: `{"file_path":"MEMORY.md","content":"- [topic.md](topic.md)\n"}`,
+ Arguments: fmt.Sprintf(`{"file_path":%q,"content":"- [topic.md](topic.md)\n"}`, indexPath),
},
},
}), nil
@@ -431,7 +592,8 @@ func (m *extractionModel) WithTools(tools []*schema.ToolInfo) (model.ToolCalling
func findExtractionPromptIndex(input []*schema.Message) int {
for i := len(input) - 1; i >= 0; i-- {
- if input[i] != nil && input[i].Role == schema.User && strings.Contains(input[i].Content, "memory extraction subagent") {
+ if input[i] != nil && input[i].Role == schema.User &&
+ (strings.Contains(input[i].Content, "memory extraction subagent") || strings.Contains(input[i].Content, "记忆提取子智能体")) {
return i
}
}
@@ -464,7 +626,7 @@ func lastBusinessUserBeforePrompt(input []*schema.Message, promptIdx int) string
return "unknown"
}
-func TestMiddleware_TopicSelection_SmallCandidateSetBypassesModel(t *testing.T) {
+func TestMiddleware_TopicSelection_SmallCandidateSetUsesModel(t *testing.T) {
ctx := context.Background()
b := NewInMemoryBackend()
now := time.Now()
@@ -472,11 +634,12 @@ func TestMiddleware_TopicSelection_SmallCandidateSetBypassesModel(t *testing.T)
b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md)\n- [patterns.md](patterns.md)\n", now)
b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now)
b.put("/mem/patterns.md", "---\ndescription: patterns\n---\nbody\n", now)
+ model := &toolCallSelectionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &panicModel{},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: model,
Read: &ReadConfig[*schema.Message]{
Mode: ReadModeSync,
TopicSelection: &TopicSelectionConfig{
@@ -493,9 +656,12 @@ func TestMiddleware_TopicSelection_SmallCandidateSetBypassesModel(t *testing.T)
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Len(t, out.AgentInput.Messages, 2)
- require.Contains(t, out.AgentInput.Messages[1].Content, "debugging.md")
- require.Contains(t, out.AgentInput.Messages[1].Content, "patterns.md")
+ require.Equal(t, int32(1), atomic.LoadInt32(&model.calls))
+ require.Len(t, out.AgentInput.Messages, 3)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ requireTopicMemoryMessage(t, out.AgentInput.Messages[1], "debugging.md")
+ require.NotContains(t, out.AgentInput.Messages[1].Content, "patterns.md")
+ require.Contains(t, out.AgentInput.Messages[2].Content, "How to run tests?")
}
func TestMiddleware_AfterAgent_SyncExtractionWritesMemoryFiles(t *testing.T) {
@@ -507,8 +673,8 @@ func TestMiddleware_AfterAgent_SyncExtractionWritesMemoryFiles(t *testing.T) {
extModel := &extractionModel{}
var onErrStages []ErrorStage
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -557,7 +723,93 @@ func TestMiddleware_AfterAgent_SyncExtractionWritesMemoryFiles(t *testing.T) {
defer extModel.mu.Unlock()
require.NotEmpty(t, extModel.promptSeen)
require.Contains(t, extModel.promptSeen[0], "memory extraction subagent")
- require.Contains(t, extModel.promptSeen[0], "Memory directory: /mem")
+ require.Contains(t, extModel.promptSeen[0], "## Memory stores")
+ require.Contains(t, extModel.promptSeen[0], "Path: /mem")
+}
+
+func TestMiddleware_AfterAgent_SyncExtractionWritesNonPrimaryMemoryStore(t *testing.T) {
+ ctx := context.Background()
+ b := &countingBackend{InMemoryBackend: NewInMemoryBackend()}
+ now := time.Now()
+ b.put("/user/MEMORY.md", "", now)
+ b.put("/project/MEMORY.md", "", now)
+
+ extModel := &extractionModel{
+ topicPath: "project/topic.md",
+ indexPath: "project/MEMORY.md",
+ }
+ var onErrStages []ErrorStage
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{
+ {Path: "/user", Name: "user"},
+ {Path: "/project", Name: "project"},
+ },
+ MemoryBackend: b,
+ Write: &WriteConfig[*schema.Message]{
+ Mode: WriteModeSync,
+ Model: extModel,
+ },
+ OnError: func(ctx context.Context, stage ErrorStage, err error) {
+ onErrStages = append(onErrStages, stage)
+ },
+ })
+ require.NoError(t, err)
+
+ state := &adk.ChatModelAgentState{
+ Messages: []adk.Message{
+ schema.UserMessage("remember project convention"),
+ schema.AssistantMessage("ack", nil),
+ },
+ }
+
+ _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{
+ Messages: state.Messages,
+ })
+ require.NoError(t, err)
+ require.Empty(t, onErrStages)
+
+ topic, err := b.Read(ctx, &ReadRequest{FilePath: "/project/topic.md"})
+ require.NoError(t, err)
+ require.Equal(t, "remember project convention", topic.Content)
+
+ _, err = b.Read(ctx, &ReadRequest{FilePath: "/user/topic.md"})
+ require.Error(t, err)
+
+ b.mu.Lock()
+ paths := append([]string(nil), b.paths...)
+ b.mu.Unlock()
+ require.Contains(t, paths, "/project/topic.md")
+ require.Contains(t, paths, "/project/MEMORY.md")
+ require.NotContains(t, paths, "/user/topic.md")
+}
+
+func TestMultiStoreBackend_RoutesStoresWithSharedRoot(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+
+ stores, err := buildRuntimeMemoryStores(&Config[*schema.Message]{
+ MemoryStores: []MemoryStore{
+ {Path: "/mnt/mem/a", Name: "a"},
+ {Path: "/mnt/mem/b", Name: "b"},
+ },
+ MemoryBackend: b,
+ })
+ require.NoError(t, err)
+
+ fs := newMultiStoreBackend(stores)
+ require.NoError(t, fs.Write(ctx, &WriteRequest{FilePath: "/mnt/mem/b/topic.md", Content: "from absolute"}))
+ require.NoError(t, fs.Write(ctx, &WriteRequest{FilePath: "a/topic.md", Content: "from qualified"}))
+
+ gotB, err := b.Read(ctx, &ReadRequest{FilePath: "/mnt/mem/b/topic.md"})
+ require.NoError(t, err)
+ require.Equal(t, "from absolute", gotB.Content)
+
+ gotA, err := b.Read(ctx, &ReadRequest{FilePath: "/mnt/mem/a/topic.md"})
+ require.NoError(t, err)
+ require.Equal(t, "from qualified", gotA.Content)
+
+ _, err = b.Read(ctx, &ReadRequest{FilePath: "/mnt/mem/topic.md"})
+ require.Error(t, err)
}
func TestMiddleware_AfterAgent_SyncExtraction_IteratorHandlerCanDrain(t *testing.T) {
@@ -569,8 +821,8 @@ func TestMiddleware_AfterAgent_SyncExtraction_IteratorHandlerCanDrain(t *testing
extModel := &extractionModel{}
var seen int32
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -622,8 +874,8 @@ func TestMiddleware_AfterAgent_SkipsExtractionWhenMainAgentAlreadyWroteMemory(t
extModel := &extractionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -678,8 +930,8 @@ func TestMiddleware_AfterAgent_AsyncExtractionKeepsLatestPendingSnapshot(t *test
}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeAsync,
Model: extModel,
@@ -728,7 +980,7 @@ func TestMiddleware_AfterAgent_AsyncExtractionKeepsLatestPendingSnapshot(t *test
if readErr != nil || topic == nil || topic.Content != "remember two" {
return false
}
- cursor, ok, cursorErr := coord.Coordinator.GetCursor(ctx, "session-1")
+ cursor, ok, cursorErr := getCoordinatorCursor(ctx, coord.Coordinator, "/mem::session-1")
if cursorErr != nil || !ok {
return false
}
@@ -736,15 +988,20 @@ func TestMiddleware_AfterAgent_AsyncExtractionKeepsLatestPendingSnapshot(t *test
}, 2*time.Second, 10*time.Millisecond)
}
-func TestMiddleware_BeforeAgent_InstructionIdempotent_NoTopicMemory(t *testing.T) {
+func TestMiddleware_BeforeAgent_GenInstructionRendersAndIndexInjectedOnce(t *testing.T) {
ctx := context.Background()
b := NewInMemoryBackend()
now := time.Now()
b.put("/mem/MEMORY.md", "line1\nline2\n", now)
+ var instructionCalls int32
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ GenInstruction: func(ctx context.Context) (string, error) {
+ atomic.AddInt32(&instructionCalls, 1)
+ return "custom memory policy", nil
+ },
// No topic selection model.
})
require.NoError(t, err)
@@ -756,15 +1013,93 @@ func TestMiddleware_BeforeAgent_InstructionIdempotent_NoTopicMemory(t *testing.T
_, out1, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Contains(t, out1.Instruction, instructionMarker)
+ require.Contains(t, out1.Instruction, "custom memory policy")
+ require.EqualValues(t, 1, atomic.LoadInt32(&instructionCalls))
+ require.Equal(t, 1, countMemoryIndexMessages(out1.AgentInput.Messages))
- // Call again with the already-injected instruction; should not duplicate.
+ // Same turn with already-injected index reminder should not duplicate the reminder.
_, out2, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{
Instruction: out1.Instruction,
- AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi again")}},
+ AgentInput: &adk.AgentInput{Messages: out1.AgentInput.Messages},
+ })
+ require.NoError(t, err)
+ require.Contains(t, out2.Instruction, "custom memory policy")
+ require.EqualValues(t, 2, atomic.LoadInt32(&instructionCalls))
+ require.Equal(t, 1, countMemoryIndexMessages(out2.AgentInput.Messages))
+
+ // A later business user message in the same session should not get another MEMORY.md reminder.
+ nextMessages := append([]*schema.Message{}, out2.AgentInput.Messages...)
+ nextMessages = append(nextMessages, schema.AssistantMessage("ack", nil), schema.UserMessage("next turn"))
+ _, out3, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: out2.Instruction,
+ AgentInput: &adk.AgentInput{Messages: nextMessages},
})
require.NoError(t, err)
- require.Equal(t, 1, strings.Count(out2.Instruction, instructionMarker))
+ require.Contains(t, out3.Instruction, "custom memory policy")
+ require.EqualValues(t, 3, atomic.LoadInt32(&instructionCalls))
+ require.Equal(t, 1, countMemoryIndexMessages(out3.AgentInput.Messages))
+}
+
+func TestMiddleware_BeforeAgent_TopicMemoryInjectedOncePerSession(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+ now := time.Now()
+ b.put("/mem/MEMORY.md", "- [debugging.md](debugging.md)\n", now)
+ b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now)
+
+ selModel := &toolCallSelectionModel{}
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: selModel,
+ Read: &ReadConfig[*schema.Message]{
+ Mode: ReadModeSync,
+ TopicSelection: &TopicSelectionConfig{
+ TopK: 1,
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, out1, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: "base",
+ AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("How to debug?")}},
+ })
+ require.NoError(t, err)
+ require.EqualValues(t, 1, atomic.LoadInt32(&selModel.calls))
+ require.Equal(t, 1, countMemoryIndexMessages(out1.AgentInput.Messages))
+ require.Equal(t, 1, countTopicMemoryMessages(out1.AgentInput.Messages))
+
+ nextMessages := append([]*schema.Message{}, out1.AgentInput.Messages...)
+ nextMessages = append(nextMessages, schema.AssistantMessage("ack", nil), schema.UserMessage("How to debug again?"))
+ _, out2, err := mw.BeforeAgent(ctx, &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: out1.Instruction,
+ AgentInput: &adk.AgentInput{Messages: nextMessages},
+ })
+ require.NoError(t, err)
+ require.EqualValues(t, 1, atomic.LoadInt32(&selModel.calls))
+ require.Equal(t, 1, countMemoryIndexMessages(out2.AgentInput.Messages))
+ require.Equal(t, 1, countTopicMemoryMessages(out2.AgentInput.Messages))
+}
+
+func TestMiddleware_LastUserMessageSkipsSystemReminderPrefix(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":[]}`},
+ })
+ require.NoError(t, err)
+
+ last, ok := mw.(*middleware[*schema.Message]).lastUserMessage(&adk.AgentInput{
+ Messages: []adk.Message{
+ schema.UserMessage("real user query"),
+ schema.UserMessage("\nInjected by another middleware.\n"),
+ },
+ })
+ require.True(t, ok)
+ require.Equal(t, "real user query", last.Content)
}
func TestMiddleware_BeforeAgent_InjectsInstructionWhenMessagesAlreadyContainMemory(t *testing.T) {
@@ -772,12 +1107,12 @@ func TestMiddleware_BeforeAgent_InjectsInstructionWhenMessagesAlreadyContainMemo
b := NewInMemoryBackend()
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
})
require.NoError(t, err)
- memMsg := newMemoryMessage[*schema.Message]("\npreloaded")
+ memMsg := newMemoryMessage[*schema.Message]("\n\nTopic memories are long-term memory files selected as relevant to the current query.\n\n\nMemory Store: mem\nMemory Store Path: /mem\nTopic File Path: preloaded.md\nSaved: now\nTopic Memory Content:\n\npreloaded\n\n\n")
runCtx := &adk.ChatModelAgentContext[*schema.Message]{
Instruction: "base",
AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi"), memMsg}},
@@ -785,8 +1120,11 @@ func TestMiddleware_BeforeAgent_InjectsInstructionWhenMessagesAlreadyContainMemo
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Contains(t, out.Instruction, instructionMarker)
- require.Len(t, out.AgentInput.Messages, 2)
+ require.Contains(t, out.Instruction, "# Auto memory")
+ require.Len(t, out.AgentInput.Messages, 3)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "hi")
+ requireTopicMemoryMessage(t, out.AgentInput.Messages[2], "preloaded")
}
func TestMiddleware_BeforeAgent_DistributedCursorSyncIntoMessageExtra(t *testing.T) {
@@ -799,12 +1137,12 @@ func TestMiddleware_BeforeAgent_DistributedCursorSyncIntoMessageExtra(t *testing
Coordinator: NewLocalCoordinator(),
LockTTL: time.Minute,
}
- require.NoError(t, coord.Coordinator.SetCursor(ctx, "sess-cursor", 5))
+ require.NoError(t, setCoordinatorCursor(ctx, coord.Coordinator, "/mem::sess-cursor", 5))
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Coordination: coord,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Coordination: coord,
})
require.NoError(t, err)
@@ -818,12 +1156,7 @@ func TestMiddleware_BeforeAgent_DistributedCursorSyncIntoMessageExtra(t *testing
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- last := out.AgentInput.Messages[len(out.AgentInput.Messages)-1]
- require.NotNil(t, last.Extra)
- meta, ok := last.Extra[memoryExtraKey].(*memoryExtra)
- require.True(t, ok)
- require.Equal(t, "write_cursor", meta.Type)
- require.EqualValues(t, 5, meta.Cursor)
+ requireWriteCursor(t, out.AgentInput.Messages, 5)
}
func TestMiddleware_BeforeAgent_WriteCursorDoesNotBlockInstructionInjection(t *testing.T) {
@@ -839,12 +1172,12 @@ func TestMiddleware_BeforeAgent_WriteCursorDoesNotBlockInstructionInjection(t *t
Coordinator: NewLocalCoordinator(),
LockTTL: time.Minute,
}
- require.NoError(t, coord.Coordinator.SetCursor(ctx, "sess-cursor", 5))
+ require.NoError(t, setCoordinatorCursor(ctx, coord.Coordinator, "/mem::sess-cursor", 5))
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Coordination: coord,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Coordination: coord,
})
require.NoError(t, err)
@@ -858,15 +1191,11 @@ func TestMiddleware_BeforeAgent_WriteCursorDoesNotBlockInstructionInjection(t *t
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Contains(t, out.Instruction, instructionMarker)
- require.Contains(t, out.Instruction, "remembered")
+ require.Contains(t, out.Instruction, "# Auto memory")
+ require.NotContains(t, out.Instruction, "remembered")
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[1], "remembered")
- last := out.AgentInput.Messages[len(out.AgentInput.Messages)-1]
- require.NotNil(t, last.Extra)
- meta, ok := last.Extra[memoryExtraKey].(*memoryExtra)
- require.True(t, ok)
- require.Equal(t, "write_cursor", meta.Type)
- require.EqualValues(t, 5, meta.Cursor)
+ requireWriteCursor(t, out.AgentInput.Messages, 5)
}
func TestMiddleware_TopicSelection_ToolCallParsingAndFiltering(t *testing.T) {
@@ -879,9 +1208,9 @@ func TestMiddleware_TopicSelection_ToolCallParsingAndFiltering(t *testing.T) {
selModel := &toolCallSelectionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: selModel,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: selModel,
Read: &ReadConfig[*schema.Message]{
Mode: ReadModeSync,
TopicSelection: &TopicSelectionConfig{
@@ -897,10 +1226,13 @@ func TestMiddleware_TopicSelection_ToolCallParsingAndFiltering(t *testing.T) {
}
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Len(t, out.AgentInput.Messages, 2)
+ require.Len(t, out.AgentInput.Messages, 3)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
mem := out.AgentInput.Messages[1]
- require.Contains(t, mem.Content, "Contents of /mem/debugging.md")
+ require.Contains(t, mem.Content, "Memory Store Name: mem")
+ require.Contains(t, mem.Content, "Topic Memory File Path: /mem/debugging.md")
require.NotContains(t, mem.Content, "hallucinated.md")
+ require.Contains(t, out.AgentInput.Messages[2].Content, "How to debug?")
require.EqualValues(t, 1, atomic.LoadInt32(&selModel.calls))
}
@@ -912,10 +1244,10 @@ func TestMiddleware_TopicSelection_AsyncProtectsMemoryMessageFromMutation(t *tes
b.put("/mem/debugging.md", "---\ndescription: debug notes\n---\nbody\n", now)
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
- Model: &fixedModel{out: `{"selected_memories":["debugging.md"]}`},
- Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Model: &fixedModel{out: `{"selected_memories":["mem/debugging.md"]}`},
+ Read: &ReadConfig[*schema.Message]{Mode: ReadModeAsync},
})
require.NoError(t, err)
@@ -955,8 +1287,8 @@ func TestMiddleware_AfterAgent_SyncExtraction_SkipIndexPrompt(t *testing.T) {
extModel := &extractionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -980,6 +1312,58 @@ func TestMiddleware_AfterAgent_SyncExtraction_SkipIndexPrompt(t *testing.T) {
require.NotContains(t, extModel.promptSeen[0], "Step 2")
}
+func TestMiddleware_IndexDisabled_HidesMemoryIndexPrompt(t *testing.T) {
+ ctx := context.Background()
+ b := NewInMemoryBackend()
+ now := time.Now()
+ b.put("/mem/MEMORY.md", "should not be injected\n", now)
+ b.put("/mem/topic.md", "existing topic\n", now)
+ enableIndex := false
+ extModel := &extractionModel{}
+
+ mw, err := New(ctx, &Config[*schema.Message]{
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
+ Read: &ReadConfig[*schema.Message]{
+ Index: &IndexConfig{EnableMemoryIndex: &enableIndex},
+ },
+ Write: &WriteConfig[*schema.Message]{
+ Mode: WriteModeSync,
+ Model: extModel,
+ },
+ })
+ require.NoError(t, err)
+
+ runCtx := &adk.ChatModelAgentContext[*schema.Message]{
+ Instruction: "base",
+ AgentInput: &adk.AgentInput{Messages: []adk.Message{schema.UserMessage("hi")}},
+ }
+ _, out, err := mw.BeforeAgent(ctx, runCtx)
+ require.NoError(t, err)
+ require.NotContains(t, out.Instruction, "MEMORY.md")
+ require.NotContains(t, out.Instruction, "should not be injected")
+ require.Contains(t, out.Instruction, "## Memory stores")
+ require.Contains(t, out.Instruction, "Path: /mem")
+ require.Len(t, out.AgentInput.Messages, 1)
+
+ state := &adk.ChatModelAgentState{
+ Messages: []adk.Message{
+ schema.UserMessage("remember delta"),
+ schema.AssistantMessage("ack", nil),
+ },
+ }
+ _, err = mw.AfterAgent(ctx, &adk.TypedChatModelAgentState[*schema.Message]{Messages: state.Messages})
+ require.NoError(t, err)
+
+ extModel.mu.Lock()
+ defer extModel.mu.Unlock()
+ require.NotEmpty(t, extModel.promptSeen)
+ require.NotContains(t, extModel.promptSeen[0], "MEMORY.md")
+ require.NotContains(t, extModel.promptSeen[0], "should not be injected")
+ require.Contains(t, extModel.promptSeen[0], "## Memory stores")
+ require.Contains(t, extModel.promptSeen[0], "Path: /mem")
+}
+
func TestMiddleware_AfterAgent_SyncExtraction_ChinesePrompt(t *testing.T) {
require.NoError(t, adk.SetLanguage(adk.LanguageChinese))
defer func() {
@@ -993,8 +1377,8 @@ func TestMiddleware_AfterAgent_SyncExtraction_ChinesePrompt(t *testing.T) {
extModel := &extractionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -1014,8 +1398,9 @@ func TestMiddleware_AfterAgent_SyncExtraction_ChinesePrompt(t *testing.T) {
extModel.mu.Lock()
defer extModel.mu.Unlock()
require.NotEmpty(t, extModel.promptSeen)
- require.Contains(t, extModel.promptSeen[0], "你现在扮演 memory extraction subagent")
- require.Contains(t, extModel.promptSeen[0], "记忆目录:/mem")
+ require.Contains(t, extModel.promptSeen[0], "你现在扮演记忆提取子智能体")
+ require.Contains(t, extModel.promptSeen[0], "## 记忆存储")
+ require.Contains(t, extModel.promptSeen[0], "存储路径:/mem")
}
func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T) {
@@ -1034,8 +1419,8 @@ func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T
extModel := &extractionModel{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: ".",
- MemoryBackend: NewLocalBackend(),
+ MemoryStores: []MemoryStore{{Path: "."}},
+ MemoryBackend: NewLocalBackend(),
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeSync,
Model: extModel,
@@ -1054,7 +1439,7 @@ func TestMiddleware_AfterAgent_RelativeMemoryDirRendersAbsolutePath(t *testing.T
extModel.mu.Lock()
require.NotEmpty(t, extModel.promptSeen)
- require.Contains(t, extModel.promptSeen[0], "Memory directory: "+expectedDir)
+ require.Contains(t, extModel.promptSeen[0], "Path: "+expectedDir)
extModel.mu.Unlock()
raw, err := os.ReadFile(filepath.Join(expectedDir, "topic.md"))
@@ -1075,8 +1460,8 @@ func TestMiddleware_BeforeAgent_RelativeMemoryDirReadsResolvedDirectoryAfterCWDC
require.NoError(t, os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("persisted index\n"), 0o644))
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: ".",
- MemoryBackend: NewLocalBackend(),
+ MemoryStores: []MemoryStore{{Path: "."}},
+ MemoryBackend: NewLocalBackend(),
})
require.NoError(t, err)
@@ -1089,7 +1474,10 @@ func TestMiddleware_BeforeAgent_RelativeMemoryDirReadsResolvedDirectoryAfterCWDC
}
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Contains(t, out.Instruction, "persisted index")
+ require.NotContains(t, out.Instruction, "persisted index")
+ require.Len(t, out.AgentInput.Messages, 2)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "persisted index")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "hi")
}
func TestFSBackend_ReadMissingFileReturnsContentInsteadOfError(t *testing.T) {
@@ -1111,9 +1499,9 @@ func TestMiddleware_TopicSelection_IgnoresOutOfBoundsCandidatePaths(t *testing.T
backend := &outOfBoundsCandidateBackend{}
mw, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: backend,
- Model: &panicModel{},
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: backend,
+ Model: &panicModel{},
})
require.NoError(t, err)
@@ -1123,7 +1511,9 @@ func TestMiddleware_TopicSelection_IgnoresOutOfBoundsCandidatePaths(t *testing.T
}
_, out, err := mw.BeforeAgent(ctx, runCtx)
require.NoError(t, err)
- require.Len(t, out.AgentInput.Messages, 1)
+ require.Len(t, out.AgentInput.Messages, 2)
+ requireMemoryIndexMessage(t, out.AgentInput.Messages[0], "Index Memory File Path: /mem/MEMORY.md")
+ require.Contains(t, out.AgentInput.Messages[1].Content, "show memories")
require.Equal(t, int32(0), atomic.LoadInt32(&backend.outsideReadCalled))
}
@@ -1142,13 +1532,14 @@ func TestMiddleware_AfterAgent_AsyncSetsPendingSnapshotWhenLockHeld(t *testing.T
LockTTL: time.Minute,
}
// Hold the lock.
- unlock, ok, err := coord.Coordinator.AcquireLock(ctx, "sess-pending", time.Minute)
+ coordKey := "/mem::sess-pending"
+ unlock, ok, err := coord.Coordinator.AcquireLock(ctx, coordKey, time.Minute)
require.NoError(t, err)
require.True(t, ok)
mwI, err := New(ctx, &Config[*schema.Message]{
- MemoryDirectory: "/mem",
- MemoryBackend: b,
+ MemoryStores: []MemoryStore{{Path: "/mem"}},
+ MemoryBackend: b,
Write: &WriteConfig[*schema.Message]{
Mode: WriteModeAsync,
Model: extModel,
@@ -1172,16 +1563,16 @@ func TestMiddleware_AfterAgent_AsyncSetsPendingSnapshotWhenLockHeld(t *testing.T
})
require.NoError(t, err)
- pending, err := coord.Coordinator.PopPendingSnapshot(ctx, "sess-pending")
+ pending, err := popCoordinatorPendingSnapshot(ctx, coord.Coordinator, coordKey)
require.NoError(t, err)
require.NotNil(t, pending)
// Release and drain manually to complete write synchronously in test.
require.NoError(t, unlock(ctx))
- unlock2, ok, err := coord.Coordinator.AcquireLock(ctx, "sess-pending", time.Minute)
+ unlock2, ok, err := coord.Coordinator.AcquireLock(ctx, coordKey, time.Minute)
require.NoError(t, err)
require.True(t, ok)
- mw.runExtractionDrain(ctx, "sess-pending", unlock2, pending)
+ mw.runExtractionDrain(ctx, coordKey, unlock2, pending)
topic, err := b.Read(ctx, &ReadRequest{FilePath: "/mem/topic.md"})
require.NoError(t, err)
diff --git a/adk/middlewares/automemory/consts.go b/adk/middlewares/automemory/consts.go
index 5b38e0e9f..f5ee3aa43 100644
--- a/adk/middlewares/automemory/consts.go
+++ b/adk/middlewares/automemory/consts.go
@@ -28,9 +28,10 @@ const (
defaultCandidateLimit = 200
defaultCandidatePreviewLine = 30
- defaultTopicTopK = 5
- defaultTopicMaxLines = 200
- defaultTopicMaxBytes = 4 * 1024
+ defaultTopicTopK = 5
+ defaultTopicMaxLines = 200
+ defaultTopicMaxBytes = 4 * 1024
+ defaultTopicMaxTotalBytes = 16 * 1024
defaultMemoryWriteMaxTurns = 5
diff --git a/adk/middlewares/automemory/coordinator.go b/adk/middlewares/automemory/coordinator.go
index c784c50c4..2a6e4ab44 100644
--- a/adk/middlewares/automemory/coordinator.go
+++ b/adk/middlewares/automemory/coordinator.go
@@ -31,40 +31,50 @@ import (
type SessionIDFunc[M adk.MessageType] func(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (string, error)
// Coordinator abstracts distributed coordination for async memory extraction.
-// A Redis-backed implementation can map these methods to SETNX + TTL and plain KV get/set.
+// A Redis-backed implementation can map AcquireLock to SETNX + TTL, Set to SET,
+// Get to GET, and GetAndDelete to GETDEL.
type Coordinator interface {
- // AcquireLock tries to acquire a lock for a given session. When ok==true,
+ // AcquireLock tries to acquire a lock for key. When ok==true,
// it returns an unlock function that must be called exactly once.
- AcquireLock(ctx context.Context, sessionID string, ttl time.Duration) (unlock func(context.Context) error, ok bool, err error)
+ AcquireLock(ctx context.Context, key string, ttl time.Duration) (unlock func(context.Context) error, ok bool, err error)
- // PopPendingSnapshot returns and deletes the pending snapshot for a session.
- // If there is no pending snapshot, it returns (nil, nil).
- PopPendingSnapshot(ctx context.Context, sessionID string) (*PendingSnapshot, error)
- SetPendingSnapshot(ctx context.Context, sessionID string, snapshot *PendingSnapshot) error
+ // Get returns the value for key. When the key does not exist, ok is false.
+ Get(ctx context.Context, key string) (value []byte, ok bool, err error)
- GetCursor(ctx context.Context, sessionID string) (cursor int, ok bool, err error)
- SetCursor(ctx context.Context, sessionID string, cursor int) error
+ // Set stores value for key. ttl<=0 means no expiration.
+ Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
+
+ // GetAndDelete returns the value for key and deletes it atomically.
+ // When the key does not exist, ok is false.
+ GetAndDelete(ctx context.Context, key string) (value []byte, ok bool, err error)
}
type PendingSnapshot struct {
- Cursor int `json:"cursor"`
- Messages json.RawMessage `json:"messages"`
- ToolInfos json.RawMessage `json:"tool_infos,omitempty"`
+ Cursor int `json:"cursor"`
+ Messages []byte `json:"messages"`
+ ToolInfos []byte `json:"tool_infos,omitempty"`
}
type CoordinationConfig[M adk.MessageType] struct {
+ // SessionIDFunc returns the logical session ID used to build the coordinator key.
+ // Optional. Defaults to an internal context-scoped session ID for write extraction.
SessionIDFunc SessionIDFunc[M]
- Coordinator Coordinator
- LockTTL time.Duration
+
+ // Coordinator stores cursor/pending state and coordinates async extraction locks.
+ // Optional. Defaults to NewLocalCoordinator().
+ Coordinator Coordinator
+
+ // LockTTL is the expiration duration for extraction locks and pending snapshots.
+ // Optional. Defaults to the package default lock TTL.
+ LockTTL time.Duration
}
// LocalCoordinator is the default in-process coordinator used in tests and single-instance deployments.
// For distributed deployments, provide a Coordinator backed by Redis or another shared KV.
type LocalCoordinator struct {
- mu sync.Mutex
- locks map[string]localLock
- pending map[string]*PendingSnapshot
- cursor map[string]int
+ mu sync.Mutex
+ locks map[string]localLock
+ kv map[string]localValue
}
type localLock struct {
@@ -72,87 +82,124 @@ type localLock struct {
expiry time.Time
}
+type localValue struct {
+ value []byte
+ expiry time.Time
+}
+
// NewLocalCoordinator returns the default in-process Coordinator implementation.
func NewLocalCoordinator() *LocalCoordinator {
return &LocalCoordinator{
- locks: map[string]localLock{},
- pending: map[string]*PendingSnapshot{},
- cursor: map[string]int{},
+ locks: map[string]localLock{},
+ kv: map[string]localValue{},
}
}
-func (c *LocalCoordinator) AcquireLock(_ context.Context, sessionID string, ttl time.Duration) (func(context.Context) error, bool, error) {
+func (c *LocalCoordinator) AcquireLock(_ context.Context, key string, ttl time.Duration) (func(context.Context) error, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
- if l, ok := c.locks[sessionID]; ok && now.Before(l.expiry) {
+ if l, ok := c.locks[key]; ok && now.Before(l.expiry) {
return nil, false, nil
}
token := randToken()
- c.locks[sessionID] = localLock{token: token, expiry: now.Add(ttl)}
+ c.locks[key] = localLock{token: token, expiry: now.Add(ttl)}
return func(_ context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
- l, ok := c.locks[sessionID]
+ l, ok := c.locks[key]
if !ok {
return nil
}
if l.token != token {
return fmt.Errorf("lock token mismatch")
}
- delete(c.locks, sessionID)
+ delete(c.locks, key)
return nil
}, true, nil
}
-func (c *LocalCoordinator) PopPendingSnapshot(_ context.Context, sessionID string) (*PendingSnapshot, error) {
+func (c *LocalCoordinator) Get(_ context.Context, key string) ([]byte, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
- s, ok := c.pending[sessionID]
- if !ok || s == nil {
- return nil, nil
- }
- cp := *s
- if s.Messages != nil {
- cp.Messages = append([]byte(nil), s.Messages...)
+ v, ok := c.kv[key]
+ if !ok {
+ return nil, false, nil
}
- if s.ToolInfos != nil {
- cp.ToolInfos = append([]byte(nil), s.ToolInfos...)
+ if !v.expiry.IsZero() && time.Now().After(v.expiry) {
+ delete(c.kv, key)
+ return nil, false, nil
}
- delete(c.pending, sessionID)
- return &cp, nil
+ return append([]byte(nil), v.value...), true, nil
}
-func (c *LocalCoordinator) SetPendingSnapshot(_ context.Context, sessionID string, snapshot *PendingSnapshot) error {
+func (c *LocalCoordinator) Set(_ context.Context, key string, value []byte, ttl time.Duration) error {
c.mu.Lock()
defer c.mu.Unlock()
- if snapshot == nil {
- delete(c.pending, sessionID)
- return nil
- }
- cp := *snapshot
- if snapshot.Messages != nil {
- cp.Messages = append([]byte(nil), snapshot.Messages...)
+ var expiry time.Time
+ if ttl > 0 {
+ expiry = time.Now().Add(ttl)
}
- if snapshot.ToolInfos != nil {
- cp.ToolInfos = append([]byte(nil), snapshot.ToolInfos...)
- }
- c.pending[sessionID] = &cp
+ c.kv[key] = localValue{value: append([]byte(nil), value...), expiry: expiry}
return nil
}
-func (c *LocalCoordinator) GetCursor(_ context.Context, sessionID string) (int, bool, error) {
+func (c *LocalCoordinator) GetAndDelete(_ context.Context, key string) ([]byte, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
- v, ok := c.cursor[sessionID]
- return v, ok, nil
+ v, ok := c.kv[key]
+ if !ok {
+ return nil, false, nil
+ }
+ delete(c.kv, key)
+ if !v.expiry.IsZero() && time.Now().After(v.expiry) {
+ return nil, false, nil
+ }
+ return append([]byte(nil), v.value...), true, nil
}
-func (c *LocalCoordinator) SetCursor(_ context.Context, sessionID string, cursor int) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.cursor[sessionID] = cursor
- return nil
+func coordinatorCursorKey(key string) string {
+ return key + "::cursor"
+}
+
+func coordinatorPendingSnapshotKey(key string) string {
+ return key + "::pending_snapshot"
+}
+
+func getCoordinatorCursor(ctx context.Context, c Coordinator, key string) (int, bool, error) {
+ raw, ok, err := c.Get(ctx, coordinatorCursorKey(key))
+ if err != nil || !ok {
+ return 0, ok, err
+ }
+ var cursor int
+ if _, err := fmt.Sscanf(string(raw), "%d", &cursor); err != nil {
+ return 0, false, err
+ }
+ return cursor, true, nil
+}
+
+func setCoordinatorCursor(ctx context.Context, c Coordinator, key string, cursor int) error {
+ return c.Set(ctx, coordinatorCursorKey(key), []byte(fmt.Sprintf("%d", cursor)), 0)
+}
+
+func popCoordinatorPendingSnapshot(ctx context.Context, c Coordinator, key string) (*PendingSnapshot, error) {
+ raw, ok, err := c.GetAndDelete(ctx, coordinatorPendingSnapshotKey(key))
+ if err != nil || !ok {
+ return nil, err
+ }
+ var snapshot PendingSnapshot
+ if err := json.Unmarshal(raw, &snapshot); err != nil {
+ return nil, err
+ }
+ return &snapshot, nil
+}
+
+func setCoordinatorPendingSnapshot(ctx context.Context, c Coordinator, key string, snapshot *PendingSnapshot, ttl time.Duration) error {
+ raw, err := json.Marshal(snapshot)
+ if err != nil {
+ return err
+ }
+ return c.Set(ctx, coordinatorPendingSnapshotKey(key), raw, ttl)
}
func randToken() string {
diff --git a/adk/middlewares/automemory/multistore_backend.go b/adk/middlewares/automemory/multistore_backend.go
new file mode 100644
index 000000000..2d9e1b7ef
--- /dev/null
+++ b/adk/middlewares/automemory/multistore_backend.go
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2026 CloudWeGo Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package automemory
+
+import (
+ "context"
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ adkfs "github.com/cloudwego/eino/adk/middlewares/filesystem"
+)
+
+type multiStoreBackend struct {
+ stores []runtimeMemoryStore
+}
+
+func newMultiStoreBackend(stores []runtimeMemoryStore) *multiStoreBackend {
+ cp := append([]runtimeMemoryStore{}, stores...)
+ return &multiStoreBackend{stores: cp}
+}
+
+func (b *multiStoreBackend) routeFilePath(p string) (runtimeMemoryStore, string, error) {
+ if p == "" {
+ return runtimeMemoryStore{}, "", fmt.Errorf("memory backend: empty path")
+ }
+ if filepath.IsAbs(p) {
+ var selected runtimeMemoryStore
+ ok := false
+ for _, store := range b.stores {
+ if !isPathWithinMemoryDir(store.Path, p) {
+ continue
+ }
+ if !ok || len(store.Path) > len(selected.Path) {
+ selected = store
+ ok = true
+ }
+ }
+ if !ok {
+ return runtimeMemoryStore{}, "", fmt.Errorf("memory backend: path out of bounds: %s", p)
+ }
+ return selected, p, nil
+ }
+
+ if store, rel, ok := b.routeStoreQualifiedPath(p); ok {
+ return store, rel, nil
+ }
+ if len(b.stores) == 1 {
+ return b.stores[0], p, nil
+ }
+ return runtimeMemoryStore{}, "", fmt.Errorf("memory backend: relative path is ambiguous across %d memory stores; use an absolute path or prefix it with the memory store name", len(b.stores))
+}
+
+func (b *multiStoreBackend) routeDirPath(p string) (runtimeMemoryStore, string, error) {
+ if p == "" {
+ if len(b.stores) == 1 {
+ return b.stores[0], b.stores[0].Path, nil
+ }
+ return runtimeMemoryStore{}, "", fmt.Errorf("memory backend: directory path is ambiguous across %d memory stores; use an absolute path or prefix it with the memory store name", len(b.stores))
+ }
+ return b.routeFilePath(p)
+}
+
+func (b *multiStoreBackend) routeStoreQualifiedPath(p string) (runtimeMemoryStore, string, bool) {
+ clean := filepath.ToSlash(filepath.Clean(p))
+ for _, store := range b.stores {
+ name := filepath.ToSlash(store.displayName())
+ if clean == name {
+ return store, ".", true
+ }
+ prefix := name + "/"
+ if strings.HasPrefix(clean, prefix) {
+ return store, strings.TrimPrefix(clean, prefix), true
+ }
+ }
+ return runtimeMemoryStore{}, "", false
+}
+
+func (b *multiStoreBackend) Read(ctx context.Context, req *adkfs.ReadRequest) (*adkfs.FileContent, error) {
+ if req == nil {
+ return nil, fmt.Errorf("read: invalid request")
+ }
+ store, filePath, err := b.routeFilePath(req.FilePath)
+ if err != nil {
+ return nil, err
+ }
+ n := *req
+ n.FilePath = filePath
+ return store.Backend.Read(ctx, &n)
+}
+
+func (b *multiStoreBackend) Write(ctx context.Context, req *adkfs.WriteRequest) error {
+ if req == nil {
+ return fmt.Errorf("write: invalid request")
+ }
+ store, filePath, err := b.routeFilePath(req.FilePath)
+ if err != nil {
+ return err
+ }
+ n := *req
+ n.FilePath = filePath
+ return store.Backend.Write(ctx, &n)
+}
+
+func (b *multiStoreBackend) Edit(ctx context.Context, req *adkfs.EditRequest) error {
+ if req == nil {
+ return fmt.Errorf("edit: invalid request")
+ }
+ store, filePath, err := b.routeFilePath(req.FilePath)
+ if err != nil {
+ return err
+ }
+ n := *req
+ n.FilePath = filePath
+ return store.Backend.Edit(ctx, &n)
+}
+
+func (b *multiStoreBackend) GlobInfo(ctx context.Context, req *adkfs.GlobInfoRequest) ([]adkfs.FileInfo, error) {
+ if req == nil || req.Pattern == "" {
+ return nil, fmt.Errorf("glob: invalid request")
+ }
+ if req.Path == "" && len(b.stores) > 1 {
+ var out []adkfs.FileInfo
+ for _, store := range b.stores {
+ n := *req
+ n.Path = store.Path
+ files, err := store.Backend.GlobInfo(ctx, &n)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, files...)
+ }
+ return out, nil
+ }
+ store, path, err := b.routeDirPath(req.Path)
+ if err != nil {
+ return nil, err
+ }
+ n := *req
+ n.Path = path
+ return store.Backend.GlobInfo(ctx, &n)
+}
+
+func (b *multiStoreBackend) LsInfo(ctx context.Context, req *adkfs.LsInfoRequest) ([]adkfs.FileInfo, error) {
+ if req == nil {
+ return nil, fmt.Errorf("ls: invalid request")
+ }
+ store, path, err := b.routeDirPath(req.Path)
+ if err != nil {
+ return nil, err
+ }
+ n := *req
+ n.Path = path
+ return store.Backend.LsInfo(ctx, &n)
+}
+
+func (b *multiStoreBackend) GrepRaw(ctx context.Context, req *adkfs.GrepRequest) ([]adkfs.GrepMatch, error) {
+ if req == nil {
+ return nil, fmt.Errorf("grep: invalid request")
+ }
+ store, path, err := b.routeDirPath(req.Path)
+ if err != nil {
+ return nil, err
+ }
+ n := *req
+ n.Path = path
+ return store.Backend.GrepRaw(ctx, &n)
+}
diff --git a/adk/middlewares/automemory/prompt.go b/adk/middlewares/automemory/prompt.go
index 2f5d9130c..ab55803db 100644
--- a/adk/middlewares/automemory/prompt.go
+++ b/adk/middlewares/automemory/prompt.go
@@ -18,22 +18,23 @@ package automemory
import (
"fmt"
+ "path/filepath"
"strings"
"github.com/cloudwego/eino/adk/internal"
)
const (
- defaultMemoryInstruction = `# auto memory
+ defaultMemoryInstructionWithIndex = `# Auto memory
-You have a persistent auto memory directory at "{memory_dir}". Its contents persist across conversations.
+You have access to persistent memory stores. Their contents persist across conversations.
As you work, consult your memory files to build on previous experience.
## How to save memories:
- Organize memory semantically by topic, not chronologically
- Use the Write and Edit tools to update your memory files
-- 'MEMORY.md' is always loaded into your conversation context — content is truncated after 200 lines or 4KB, so keep it concise
+- When a store has MEMORY.md enabled, it is loaded into your system prompt context — content is truncated after configured line and byte limits, so keep it concise
- Create separate topic files (e.g., 'debugging.md'', 'patterns.md'') for detailed notes and link to them from MEMORY.md
- Update or remove memories that turn out to be wrong or outdated
- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.
@@ -56,7 +57,43 @@ As you work, consult your memory files to build on previous experience.
- When the user corrects you on something you stated from memory, you MUST update or remove the incorrect entry. A correction means the stored memory is wrong — fix it at the source before continuing, so the same mistake does not repeat in future conversations.
## Searching past context
-- Search topic files in your memory directory: Grep with pattern="" path="{memory_dir}" glob="*.md"
+- Search topic files inside the relevant memory store.
+- Use narrow search terms (error messages, file paths, function names) rather than broad keywords.
+
+`
+
+ defaultMemoryInstructionWithoutIndex = `# Auto memory
+
+You have access to persistent memory stores. Their contents persist across conversations.
+
+As you work, consult your memory files to build on previous experience.
+
+## How to save memories:
+- Organize memory semantically by topic, not chronologically
+- Use the Write and Edit tools to update your memory files
+- Create separate topic files (e.g., 'debugging.md'', 'patterns.md'') for detailed notes
+- Update or remove memories that turn out to be wrong or outdated
+- Do not write duplicate memories. First check if there is an existing memory you can update before writing a new one.
+
+## What to save:
+- Stable patterns and conventions confirmed across multiple interactions
+- Key architectural decisions, important file paths, and project structure
+- User preferences for workflow, tools, and communication style
+- Solutions to recurring problems and debugging insights
+
+## What NOT to save:
+- Session-specific context (current task details, in-progress work, temporary state)
+- Information that might be incomplete — verify against project docs before writing
+- Anything that duplicates or contradicts existing AGENTS.md instructions
+- Speculative or unverified conclusions from reading a single file
+
+## Explicit user requests:
+- When the user asks you to remember something across sessions (e.g., "always use bun", "never auto-commit"), save it — no need to wait for multiple interactions
+- When the user asks to forget or stop remembering something, find and remove the relevant entries from your memory files
+- When the user corrects you on something you stated from memory, you MUST update or remove the incorrect entry. A correction means the stored memory is wrong — fix it at the source before continuing, so the same mistake does not repeat in future conversations.
+
+## Searching past context
+- Search topic files inside the relevant memory store.
- Use narrow search terms (error messages, file paths, function names) rather than broad keywords.
`
@@ -65,15 +102,17 @@ As you work, consult your memory files to build on previous experience.
defaultAppendEmptyIndexTemplate = `Your MEMORY.md is currently empty. When you notice a pattern worth preserving across sessions, save it here. Anything in MEMORY.md will be included in your system prompt next time.`
- defaultTopicSelectionSystemPrompt = `You are selecting memories that will be useful to the agent as it processes a user's query. You will be given the user's query and a list of available memory files with their filenames and descriptions.
+ defaultTopicSelectionSystemPrompt = `You are selecting memories that will be useful to the agent as it processes a user's query. You will be given the user's query and a list of available memory files across one or more memory stores, with their displayed memory paths and descriptions.
-Return a list of RELATIVE FILE PATHS (relative to the memory directory) for the memories that will clearly be useful to the agent as it processes the user's query (up to 5). Only include memories that you are certain will be helpful based on their name/description/type.
+Return a list of memory paths exactly as shown in the available memories list, for the memories that will clearly be useful to the agent as it processes the user's query, up to the selection limit provided by the user message. Only include memories that you are certain will be helpful based on their store, name, description, or type.
- If you are unsure if a memory will be useful in processing the user's query, then do not include it in your list. Be selective and discerning.
- If there are no memories in the list that would clearly be useful, feel free to return an empty list.
- If a list of recently-used tools is provided, do not select memories that are usage reference or API documentation for those tools (the agent is already exercising them). DO still select memories containing warnings, gotchas, or known issues about those tools — active use is exactly when those matter.`
defaultTopicSelectionUserPrompt = `Query: {user_query}
+Selection limit: {top_k}
+
Available memories:
{available_memories}
@@ -83,16 +122,16 @@ Recently used tools:
defaultTopicMemoryTruncNotify = `
> This memory file was truncated ({reason}). Use the Read tool to view the complete file at: {abs_path}`
- defaultMemoryInstructionChinese = `# 自动记忆
+ defaultMemoryInstructionChineseWithIndex = `# 自动记忆
-你有一个持久化的自动记忆目录 "{memory_dir}"。其中的内容会在不同会话之间保留。
+你可以访问持久化的记忆存储。其中的内容会在不同会话之间保留。
在工作过程中,请查阅这些记忆文件,以便基于过去的经验继续推进。
## 如何保存记忆:
- 按主题组织记忆,而不是按时间顺序堆叠
- 使用 Write 和 Edit 工具更新你的记忆文件
-- 'MEMORY.md' 会始终被加载到对话上下文中,其内容在超过 200 行或 4KB 时会被截断,因此请保持简洁
+- 当某个记忆存储启用 MEMORY.md 时,它会被加载进系统提示词,其内容会按配置的行数和字节数限制截断,因此请保持简洁
- 将详细内容写入单独的主题文件(例如 'debugging.md'、'patterns.md'),并在 MEMORY.md 中链接它们
- 当某条记忆被证明错误或过时时,请更新或删除它
- 不要写入重复记忆。创建新记忆前,先检查是否已有可更新的现有文件
@@ -115,24 +154,62 @@ Recently used tools:
- 当用户指出你基于记忆给出的内容有误时,你必须更新或删除错误条目。纠正意味着原有记忆已经错误,必须先从源头修正,避免今后重复犯错
## 如何检索历史上下文
-- 在记忆目录中搜索主题文件:使用 Grep,pattern="<搜索词>" path="{memory_dir}" glob="*.md"
+- 在相关记忆存储中搜索主题文件
+- 尽量使用更窄的检索词,例如报错信息、文件路径、函数名,而不是宽泛关键词
+
+`
+
+ defaultMemoryInstructionChineseWithoutIndex = `# 自动记忆
+
+你可以访问持久化的记忆存储。其中的内容会在不同会话之间保留。
+
+在工作过程中,请查阅这些记忆文件,以便基于过去的经验继续推进。
+
+## 如何保存记忆:
+- 按主题组织记忆,而不是按时间顺序堆叠
+- 使用 Write 和 Edit 工具更新你的记忆文件
+- 将详细内容写入单独的主题文件(例如 'debugging.md'、'patterns.md')
+- 当某条记忆被证明错误或过时时,请更新或删除它
+- 不要写入重复记忆。创建新记忆前,先检查是否已有可更新的现有文件
+
+## 应该保存什么:
+- 已在多次交互中得到确认的稳定模式和约定
+- 关键架构决策、重要文件路径和项目结构
+- 用户在工作流、工具使用和沟通方式上的偏好
+- 可复用的问题解决经验与调试结论
+
+## 不应保存什么:
+- 仅属于当前会话的上下文(当前任务细节、进行中的工作、临时状态)
+- 可能不完整的信息,在写入前应先根据项目文档核实
+- 与现有 AGENTS.md 指令重复或冲突的内容
+- 仅基于阅读单个文件得到的猜测性或未经验证的结论
+
+## 用户的明确要求:
+- 当用户明确要求你跨会话记住某件事时(例如“始终使用 bun”“不要自动提交”),应立即保存,无需等待多轮交互确认
+- 当用户要求你遗忘某件事或停止记忆时,找到对应条目并从记忆文件中删除
+- 当用户指出你基于记忆给出的内容有误时,你必须更新或删除错误条目。纠正意味着原有记忆已经错误,必须先从源头修正,避免今后重复犯错
+
+## 如何检索历史上下文
+- 在相关记忆存储中搜索主题文件
- 尽量使用更窄的检索词,例如报错信息、文件路径、函数名,而不是宽泛关键词
`
defaultAppendCurrentIndexTruncNotifyChinese = `警告:MEMORY.md 已被截断(总行数:{memory_lines},限制:200 行;字节限制:4096)。请将详细内容迁移到独立的主题文件中,并让 MEMORY.md 只保留简洁索引。`
- defaultAppendEmptyIndexTemplateChinese = `你的 MEMORY.md 当前为空。当你发现值得跨会话保留的模式时,请把它写在这里。下一次对话中,MEMORY.md 的内容会被自动加入 system prompt。`
+ defaultAppendEmptyIndexTemplateChinese = `你的 MEMORY.md 当前为空。当你发现值得跨会话保留的模式时,请把它写在这里。下一次对话中,MEMORY.md 的内容会被自动加入系统提示词。`
- defaultTopicSelectionSystemPromptChinese = `你需要从记忆列表中选择对当前用户问题真正有帮助的记忆。你会拿到用户问题,以及一组可用记忆文件的文件名和描述。
+ defaultTopicSelectionSystemPromptChinese = `你需要从记忆列表中选择对当前用户问题真正有帮助的记忆。你会拿到用户问题,以及来自一个或多个记忆存储的可用记忆文件列表,列表中包含展示给你的记忆路径和描述。
-请返回一个 RELATIVE FILE PATHS 列表(相对于 memory directory),列出那些在处理当前用户问题时显然有帮助的记忆文件(最多 5 个)。只有在你能够基于名称、描述或类型确认其确实有帮助时才选择。
+请返回一个记忆路径列表,必须与可用记忆列表中展示的路径完全一致,列出那些在处理当前用户问题时显然有帮助的记忆文件,数量不能超过用户消息中给出的选择上限。只有在你能够基于存储、名称、描述或类型确认其确实有帮助时才选择。
- 如果你不能确定某条记忆是否有帮助,就不要选它。请保持克制和甄别。
- 如果列表中没有任何明显有帮助的记忆,可以返回空列表。
-- 如果提供了最近使用过的工具列表,不要选择那些仅包含这些工具使用说明或 API 文档的记忆(agent 已经在使用它们)。但如果记忆中包含这些工具的警告、坑点或已知问题,仍然应该选择,因为这些内容在实际调用时尤其重要。`
+- 如果提供了最近使用过的工具列表,不要选择那些仅包含这些工具使用说明或 API 文档的记忆(智能体已经在使用它们)。但如果记忆中包含这些工具的警告、坑点或已知问题,仍然应该选择,因为这些内容在实际调用时尤其重要。`
defaultTopicSelectionUserPromptChinese = `问题:{user_query}
+选择上限:{top_k}
+
可用记忆:
{available_memories}
@@ -143,13 +220,338 @@ Recently used tools:
> 该记忆文件已被截断({reason})。请使用 Read 工具查看完整文件:{abs_path}`
)
-func buildExtractAutoOnlyPrompt(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string {
+type memoryStorePromptInfo struct {
+ Name string
+ Mount string
+ Description string
+ Index *memoryIndexPromptInfo
+}
+
+type memoryIndexPromptInfo struct {
+ FileName string
+ Path string
+ Content string
+ Empty bool
+ Truncated bool
+ Lines int
+ IncludeContent bool
+}
+
+type memoryManifestStorePromptInfo struct {
+ Name string
+ Mount string
+ Files []memoryManifestFilePromptInfo
+}
+
+type memoryManifestFilePromptInfo struct {
+ MemoryPath string
+ AbsPath string
+ Saved string
+ Description string
+}
+
+func buildSystemMemoryInstruction(baseInstruction, memoryInstruction string, stores []memoryStorePromptInfo) (string, error) {
+ return baseInstruction + "\n" + internal.SelectPrompt(internal.I18nPrompts{
+ English: buildSystemMemoryInstructionEnglish(memoryInstruction, stores),
+ Chinese: buildSystemMemoryInstructionChinese(memoryInstruction, stores),
+ }), nil
+}
+
+func buildSystemMemoryInstructionEnglish(memoryInstruction string, stores []memoryStorePromptInfo) string {
+ return strings.Join([]string{memoryInstruction, buildMemoryStoresManifestEnglish(stores)}, "\n")
+}
+
+func buildSystemMemoryInstructionChinese(memoryInstruction string, stores []memoryStorePromptInfo) string {
+ return strings.Join([]string{memoryInstruction, buildMemoryStoresManifestChinese(stores)}, "\n")
+}
+
+func buildMemoryStoresManifestEnglish(stores []memoryStorePromptInfo) string {
+ lines := []string{
+ "## Memory stores",
+ "",
+ "Available memory stores (each is a directory):",
+ "",
+ }
+ for i, store := range stores {
+ lines = append(lines,
+ fmt.Sprintf("### %d. Name: %s", i+1, store.Name),
+ fmt.Sprintf("Path: %s", store.Mount),
+ )
+ if strings.TrimSpace(store.Description) != "" {
+ lines = append(lines, fmt.Sprintf("Description: %s", strings.TrimSpace(store.Description)))
+ }
+ if store.Index != nil {
+ lines = append(lines, fmt.Sprintf("Index file path: %s", store.Index.Path), "")
+ if block := buildMemoryIndexBlockEnglish(*store.Index); block != "" {
+ lines = append(lines, block)
+ }
+ }
+ lines = append(lines, "")
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildMemoryStoresManifestChinese(stores []memoryStorePromptInfo) string {
+ lines := []string{
+ "## 记忆存储",
+ "",
+ "可用记忆存储 (每一条是一个目录):",
+ "",
+ }
+ for i, store := range stores {
+ lines = append(lines,
+ fmt.Sprintf("### %d. 名称: %s", i+1, store.Name),
+ fmt.Sprintf("存储路径:%s", store.Mount),
+ )
+ if strings.TrimSpace(store.Description) != "" {
+ lines = append(lines, fmt.Sprintf("功能描述:%s", strings.TrimSpace(store.Description)))
+ }
+ if store.Index != nil {
+ lines = append(lines, fmt.Sprintf("索引文件路径:%s", store.Index.Path), "")
+ if block := buildMemoryIndexBlockChinese(*store.Index); block != "" {
+ lines = append(lines, block)
+ }
+ }
+ lines = append(lines, "")
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildMemoryIndexBlockEnglish(index memoryIndexPromptInfo) string {
+ if !index.IncludeContent {
+ return ""
+ }
+ lines := []string{fmt.Sprintf("#### Index file content: %s", index.FileName)}
+ if index.Empty {
+ lines = append(lines, getAppendEmptyIndexTemplate())
+ } else {
+ lines = append(lines, index.Content)
+ if index.Truncated {
+ lines = append(lines, strings.ReplaceAll(getAppendCurrentIndexTruncNotify(), "{memory_lines}", fmt.Sprintf("%d", index.Lines)))
+ }
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildMemoryIndexBlockChinese(index memoryIndexPromptInfo) string {
+ if !index.IncludeContent {
+ return ""
+ }
+ lines := []string{fmt.Sprintf("#### 索引文件内容:%s", index.FileName)}
+ if index.Empty {
+ lines = append(lines, getAppendEmptyIndexTemplate())
+ } else {
+ lines = append(lines, index.Content)
+ if index.Truncated {
+ lines = append(lines, strings.ReplaceAll(getAppendCurrentIndexTruncNotify(), "{memory_lines}", fmt.Sprintf("%d", index.Lines)))
+ }
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildExtractAutoOnlyPrompt(memoryStores string, newMessageCount int, existingMemories string, enableMemoryIndex bool) string {
+ return internal.SelectPrompt(internal.I18nPrompts{
+ English: buildExtractAutoOnlyPromptEnglish(memoryStores, newMessageCount, existingMemories, enableMemoryIndex),
+ Chinese: buildExtractAutoOnlyPromptChinese(memoryStores, newMessageCount, existingMemories, enableMemoryIndex),
+ })
+}
+
+func buildMemoryStoresManifest(stores []memoryStorePromptInfo) string {
+ return internal.SelectPrompt(internal.I18nPrompts{
+ English: buildMemoryStoresManifestEnglish(stores),
+ Chinese: buildMemoryStoresManifestChinese(stores),
+ })
+}
+
+func buildMemoryIndexReminder(stores []memoryStorePromptInfo) string {
+ return "\n" + internal.SelectPrompt(internal.I18nPrompts{
+ English: buildMemoryIndexReminderEnglish(stores),
+ Chinese: buildMemoryIndexReminderChinese(stores),
+ })
+}
+
+func buildTopicMemoryReminder(topics []topicMemoryPromptInfo) string {
return internal.SelectPrompt(internal.I18nPrompts{
- English: buildExtractAutoOnlyPromptEnglish(memoryDir, newMessageCount, existingMemories, skipIndex),
- Chinese: buildExtractAutoOnlyPromptChinese(memoryDir, newMessageCount, existingMemories, skipIndex),
+ English: buildTopicMemoryReminderEnglish(topics),
+ Chinese: buildTopicMemoryReminderChinese(topics),
})
}
+func buildTopicMemoryReminderEnglish(topics []topicMemoryPromptInfo) string {
+ lines := []string{
+ "",
+ "Topic memories are long-term memory files selected as relevant to the current query. Use them as supporting context for this turn. They may contain durable user preferences, project conventions, or previously saved facts; do not treat them as a replacement for the current user request.",
+ "",
+ }
+ for i, topic := range topics {
+ lines = append(lines,
+ fmt.Sprintf("", i+1),
+ fmt.Sprintf("1. Memory Store Name: %s", topic.StoreName),
+ fmt.Sprintf("2. Topic Memory File Path: %s", filepath.Join(topic.StorePath, topic.Path)),
+ fmt.Sprintf("3. Topic Memory Modified at: %s", topic.Saved),
+ "4. Topic Memory Content:",
+ "",
+ topic.Content,
+ "",
+ fmt.Sprintf("", i+1),
+ "",
+ )
+ }
+ lines = append(lines, "")
+ return strings.Join(lines, "\n")
+}
+
+func buildTopicMemoryReminderChinese(topics []topicMemoryPromptInfo) string {
+ lines := []string{
+ "",
+ "主题记忆是本次查询相关的长期记忆文件。请将它们作为当前轮次的辅助上下文使用,其中可能包含稳定的用户偏好、项目约定或此前保存的事实;不要用它们替代当前用户请求。",
+ "",
+ }
+ for i, topic := range topics {
+ lines = append(lines,
+ fmt.Sprintf("", i+1),
+ fmt.Sprintf("1. 记忆存储名称:%s", topic.StoreName),
+ fmt.Sprintf("2. 主题文件路径:%s", filepath.Join(topic.StorePath, topic.Path)),
+ fmt.Sprintf("3. 更新时间:%s", topic.Saved),
+ "4. 主题记忆内容:",
+ "",
+ topic.Content,
+ "",
+ fmt.Sprintf("", i+1),
+ "",
+ )
+ }
+ lines = append(lines, "")
+ return strings.Join(lines, "\n")
+}
+
+func buildMemoryIndexReminderEnglish(stores []memoryStorePromptInfo) string {
+ lines := []string{
+ "",
+ "Memory indexes are the high-level table of contents for your memory stores. Use them to understand what long-term memories may exist and decide which memory files to inspect with tools. They are not the full memory content; detailed notes usually live in the linked topic files.",
+ "",
+ }
+ for i, store := range stores {
+ lines = append(lines,
+ fmt.Sprintf("", i+1),
+ fmt.Sprintf("1. Memory Store Name: %s", store.Name),
+ )
+ if strings.TrimSpace(store.Description) != "" {
+ lines = append(lines, fmt.Sprintf("2. Description: %s", strings.TrimSpace(store.Description)))
+ }
+ if store.Index != nil {
+ lines = append(lines,
+ fmt.Sprintf("3. Index Memory File Path: %s", store.Index.Path),
+ "4. Index Memory File Content:",
+ "",
+ renderMemoryIndexContentEnglish(*store.Index),
+ "",
+ )
+ }
+ lines = append(lines, fmt.Sprintf("", i+1), "")
+ }
+ lines = append(lines, "")
+ return strings.Join(lines, "\n")
+}
+
+func buildMemoryIndexReminderChinese(stores []memoryStorePromptInfo) string {
+ lines := []string{
+ "",
+ "记忆索引是每个记忆存储的高层目录。请用它判断当前可能有哪些长期记忆,以及需要通过工具进一步查看哪些记忆文件。它不是完整记忆内容,详细信息通常保存在索引中链接的主题文件里。",
+ "",
+ }
+ for i, store := range stores {
+ lines = append(lines,
+ fmt.Sprintf("", i+1),
+ fmt.Sprintf("1. 记忆存储名称:%s", store.Name),
+ )
+ if strings.TrimSpace(store.Description) != "" {
+ lines = append(lines, fmt.Sprintf("2. 功能描述:%s", strings.TrimSpace(store.Description)))
+ }
+ if store.Index != nil {
+ lines = append(lines,
+ fmt.Sprintf("3. 索引记忆文件路径:%s", store.Index.Path),
+ "4. 索引记忆文件内容:",
+ "",
+ renderMemoryIndexContentChinese(*store.Index),
+ "",
+ )
+ }
+ lines = append(lines, fmt.Sprintf("", i+1), "")
+ }
+ lines = append(lines, "")
+ return strings.Join(lines, "\n")
+}
+
+func renderMemoryIndexContentEnglish(index memoryIndexPromptInfo) string {
+ if index.Empty {
+ return "The index file is currently empty."
+ }
+ lines := []string{index.Content}
+ if index.Truncated {
+ lines = append(lines, strings.ReplaceAll(getAppendCurrentIndexTruncNotify(), "{memory_lines}", fmt.Sprintf("%d", index.Lines)))
+ }
+ return strings.Join(lines, "\n")
+}
+
+func renderMemoryIndexContentChinese(index memoryIndexPromptInfo) string {
+ if index.Empty {
+ return "索引文件当前为空。"
+ }
+ lines := []string{index.Content}
+ if index.Truncated {
+ lines = append(lines, strings.ReplaceAll(getAppendCurrentIndexTruncNotify(), "{memory_lines}", fmt.Sprintf("%d", index.Lines)))
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildExtractionMemoryManifest(stores []memoryManifestStorePromptInfo) string {
+ return internal.SelectPrompt(internal.I18nPrompts{
+ English: buildExtractionMemoryManifestEnglish(stores),
+ Chinese: buildExtractionMemoryManifestChinese(stores),
+ })
+}
+
+func buildExtractionMemoryManifestEnglish(stores []memoryManifestStorePromptInfo) string {
+ var lines []string
+ for _, store := range stores {
+ lines = append(lines, fmt.Sprintf("### %s", store.Name))
+ lines = append(lines, fmt.Sprintf("Store path: %s", store.Mount))
+ if len(store.Files) == 0 {
+ lines = append(lines, "- No existing memory files.")
+ continue
+ }
+ for _, file := range store.Files {
+ if file.Description != "" {
+ lines = append(lines, fmt.Sprintf("- %s (path: %s, saved %s): %s", file.MemoryPath, file.AbsPath, file.Saved, file.Description))
+ } else {
+ lines = append(lines, fmt.Sprintf("- %s (path: %s, saved %s)", file.MemoryPath, file.AbsPath, file.Saved))
+ }
+ }
+ }
+ return strings.Join(lines, "\n")
+}
+
+func buildExtractionMemoryManifestChinese(stores []memoryManifestStorePromptInfo) string {
+ var lines []string
+ for _, store := range stores {
+ lines = append(lines, fmt.Sprintf("### %s", store.Name))
+ lines = append(lines, fmt.Sprintf("存储路径:%s", store.Mount))
+ if len(store.Files) == 0 {
+ lines = append(lines, "- 暂无已有 memory 文件。")
+ continue
+ }
+ for _, file := range store.Files {
+ if file.Description != "" {
+ lines = append(lines, fmt.Sprintf("- %s(路径:%s,保存时间:%s):%s", file.MemoryPath, file.AbsPath, file.Saved, file.Description))
+ } else {
+ lines = append(lines, fmt.Sprintf("- %s(路径:%s,保存时间:%s)", file.MemoryPath, file.AbsPath, file.Saved))
+ }
+ }
+ }
+ return strings.Join(lines, "\n")
+}
+
func joinLines(lines []string) string {
if len(lines) == 0 {
return ""
@@ -163,10 +565,16 @@ func joinLines(lines []string) string {
return b.String()
}
-func getDefaultMemoryInstruction() string {
+func getDefaultMemoryInstruction(enableIndex bool) string {
+ english := defaultMemoryInstructionWithoutIndex
+ chinese := defaultMemoryInstructionChineseWithoutIndex
+ if enableIndex {
+ english = defaultMemoryInstructionWithIndex
+ chinese = defaultMemoryInstructionChineseWithIndex
+ }
return internal.SelectPrompt(internal.I18nPrompts{
- English: defaultMemoryInstruction,
- Chinese: defaultMemoryInstructionChinese,
+ English: english,
+ Chinese: chinese,
})
}
@@ -205,13 +613,19 @@ func getTopicMemoryTruncNotify() string {
})
}
-func buildExtractAutoOnlyPromptEnglish(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string {
- manifest := ""
- if existingMemories != "" {
- manifest = fmt.Sprintf("\n\n## Existing memory files\n\n%s\n\nCheck this list before writing — update an existing file rather than creating a duplicate.", existingMemories)
+func buildExtractHowToSaveEnglish(enableMemoryIndex bool) []string {
+ if !enableMemoryIndex {
+ return []string{
+ "## How to save memories",
+ "",
+ "Write each memory to its own file. Do not create duplicate files.",
+ "",
+ "- Organize memory semantically by topic, not chronologically.",
+ "- Update or remove memories that turn out to be wrong or outdated.",
+ "- Do not write duplicate memories.",
+ }
}
-
- howToSave := []string{
+ return []string{
"## How to save memories",
"",
"Saving a memory is a two-step process:",
@@ -224,20 +638,49 @@ func buildExtractAutoOnlyPromptEnglish(memoryDir string, newMessageCount int, ex
"- Update or remove memories that turn out to be wrong or outdated.",
"- Do not write duplicate memories.",
}
- if skipIndex {
- howToSave = []string{
- "## How to save memories",
+}
+
+func buildExtractHowToSaveChinese(enableMemoryIndex bool) []string {
+ if !enableMemoryIndex {
+ return []string{
+ "## 如何保存记忆",
"",
- "Write each memory to its own file. Do not create duplicate files.",
+ "将每条记忆写入各自独立的文件中,不要创建重复文件。",
+ "",
+ "- 按主题组织记忆,而不是按时间顺序堆叠。",
+ "- 当记忆被证明错误或过时时,要及时更新或删除。",
+ "- 不要写入重复记忆。",
}
}
+ return []string{
+ "## 如何保存记忆",
+ "",
+ "保存记忆分为两步:",
+ "",
+ "第 1 步:将记忆写入独立文件。",
+ "第 2 步:在 MEMORY.md 中添加指向该文件的索引。MEMORY.md 只是索引,不应存放记忆正文。",
+ "",
+ "- 保持 MEMORY.md 简洁,因为它会被加载进系统提示词。",
+ "- 按主题组织记忆,而不是按时间顺序堆叠。",
+ "- 当记忆被证明错误或过时时,要及时更新或删除。",
+ "- 不要写入重复记忆。",
+ }
+}
+
+func buildExtractAutoOnlyPromptEnglish(memoryStores string, newMessageCount int, existingMemories string, enableMemoryIndex bool) string {
+ manifest := ""
+ if existingMemories != "" {
+ manifest = fmt.Sprintf("\n\n## Existing memory files\n\n%s\n\nCheck this list before writing — update an existing file rather than creating a duplicate.", existingMemories)
+ }
+
+ howToSave := buildExtractHowToSaveEnglish(enableMemoryIndex)
parts := []string{
fmt.Sprintf("You are now acting as the memory extraction subagent. Analyze only the most recent ~%d messages above and use them to update persistent memory.", newMessageCount),
"",
- fmt.Sprintf("Memory directory: %s", memoryDir),
+ memoryStores,
"",
- "Available tools: read_file, glob, write_file, edit_file. Only paths inside the memory directory are allowed. All other tools are denied.",
+ "Available tools: read_file, glob, write_file, edit_file. Only paths inside the memory stores are allowed. Use absolute paths or the listed relative path prefixes when reading or writing memory files. All other tools are denied.",
"",
"You have a limited turn budget. read_file should happen first for every file you may update, then write_file/edit_file should happen after that. Do not interleave read and write across many turns.",
"",
@@ -260,39 +703,20 @@ func buildExtractAutoOnlyPromptEnglish(memoryDir string, newMessageCount int, ex
return joinLines(parts)
}
-func buildExtractAutoOnlyPromptChinese(memoryDir string, newMessageCount int, existingMemories string, skipIndex bool) string {
+func buildExtractAutoOnlyPromptChinese(memoryStores string, newMessageCount int, existingMemories string, enableMemoryIndex bool) string {
manifest := ""
if existingMemories != "" {
manifest = fmt.Sprintf("\n\n## 现有记忆文件\n\n%s\n\n写入前请先检查这份列表,优先更新已有文件,而不是创建重复记忆。", existingMemories)
}
- howToSave := []string{
- "## 如何保存记忆",
- "",
- "保存记忆分为两步:",
- "",
- "第 1 步:将记忆写入独立文件。",
- "第 2 步:在 MEMORY.md 中添加指向该文件的索引。MEMORY.md 只是索引,不应存放记忆正文。",
- "",
- "- 保持 MEMORY.md 简洁,因为它会被加载进 system prompt。",
- "- 按主题组织记忆,而不是按时间顺序堆叠。",
- "- 当记忆被证明错误或过时时,要及时更新或删除。",
- "- 不要写入重复记忆。",
- }
- if skipIndex {
- howToSave = []string{
- "## 如何保存记忆",
- "",
- "将每条记忆写入各自独立的文件中,不要创建重复文件。",
- }
- }
+ howToSave := buildExtractHowToSaveChinese(enableMemoryIndex)
parts := []string{
- fmt.Sprintf("你现在扮演 memory extraction subagent。只分析上方最近约 %d 条消息,并用它们来更新持久化记忆。", newMessageCount),
+ fmt.Sprintf("你现在扮演记忆提取子智能体。只分析上方最近约 %d 条消息,并用它们来更新持久化记忆。", newMessageCount),
"",
- fmt.Sprintf("记忆目录:%s", memoryDir),
+ memoryStores,
"",
- "可用工具:read_file、glob、write_file、edit_file。只允许访问记忆目录内的路径,其他工具均禁止使用。",
+ "可用工具:read_file、glob、write_file、edit_file。只允许访问记忆存储内的路径。读写记忆文件时请使用绝对路径,或使用上方列出的相对路径前缀。其他工具均禁止使用。",
"",
"你的轮次预算有限。对于每个可能更新的文件,应先 read_file,再进行 write_file/edit_file;不要在多轮里交叉读写大量文件。",
"",
diff --git a/adk/middlewares/automemory/utils.go b/adk/middlewares/automemory/utils.go
new file mode 100644
index 000000000..c9cc94ac1
--- /dev/null
+++ b/adk/middlewares/automemory/utils.go
@@ -0,0 +1,979 @@
+/*
+ * Copyright 2026 CloudWeGo Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package automemory
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/yaml.v3"
+
+ "github.com/cloudwego/eino/adk"
+ ainternal "github.com/cloudwego/eino/adk/middlewares/automemory/internal"
+ adkfs "github.com/cloudwego/eino/adk/middlewares/filesystem"
+ "github.com/cloudwego/eino/components/model"
+ "github.com/cloudwego/eino/schema"
+)
+
+func buildRuntimeMemoryStores[M adk.MessageType](cfg *Config[M]) ([]runtimeMemoryStore, error) {
+ stores := append([]MemoryStore{}, cfg.MemoryStores...)
+ if len(stores) == 0 {
+ return nil, fmt.Errorf("auto memory config: no memory stores")
+ }
+
+ out := make([]runtimeMemoryStore, 0, len(stores))
+ seenName := make(map[string]struct{}, len(stores))
+ seenPath := make(map[string]struct{}, len(stores))
+ for i, store := range stores {
+ if strings.TrimSpace(store.Path) == "" {
+ return nil, fmt.Errorf("auto memory config: memory store %d has empty path", i)
+ }
+ resolvedPath, err := ainternal.ResolveMemoryDir(store.Path)
+ if err != nil {
+ return nil, fmt.Errorf("auto memory config: resolve memory store %d: %w", i, err)
+ }
+ if _, ok := seenPath[resolvedPath]; ok {
+ return nil, fmt.Errorf("auto memory config: duplicate memory store path: %s", resolvedPath)
+ }
+ seenPath[resolvedPath] = struct{}{}
+
+ name := strings.TrimSpace(store.Name)
+ if name == "" {
+ name = filepath.Base(resolvedPath)
+ if name == "." || name == string(filepath.Separator) || name == "" {
+ name = fmt.Sprintf("memory_%d", i+1)
+ }
+ store.Name = name
+ }
+ if strings.ContainsAny(name, `/\`) {
+ return nil, fmt.Errorf("auto memory config: memory store name must not contain path separators: %s", name)
+ }
+ if _, ok := seenName[name]; ok {
+ return nil, fmt.Errorf("auto memory config: duplicate memory store name: %s", name)
+ }
+ seenName[name] = struct{}{}
+
+ bounded, err := ainternal.NewFSBackend(cfg.MemoryBackend, ainternal.FSBackendConfig{
+ BaseDir: resolvedPath,
+ NotFoundAsContent: true,
+ ErrorPrefix: "memory backend",
+ })
+ if err != nil {
+ return nil, err
+ }
+ store.Path = resolvedPath
+ out = append(out, runtimeMemoryStore{
+ MemoryStore: store,
+ Path: resolvedPath,
+ Backend: bounded,
+ })
+ }
+ return out, nil
+}
+
+func applyReadDefaults[M adk.MessageType](cfg *Config[M]) {
+ if cfg.Read.Mode == "" {
+ cfg.Read.Mode = ReadModeSync
+ }
+ if cfg.Read.Index == nil {
+ cfg.Read.Index = &IndexConfig{}
+ }
+ if cfg.Read.Index.EnableMemoryIndex == nil {
+ cfg.Read.Index.EnableMemoryIndex = boolPtr(true)
+ }
+ if cfg.Read.Index.FileName == "" {
+ cfg.Read.Index.FileName = memoryIndexFileName
+ }
+ if cfg.Read.Index.MaxLines <= 0 {
+ cfg.Read.Index.MaxLines = defaultIndexMaxLines
+ }
+ if cfg.Read.Index.MaxBytes <= 0 {
+ cfg.Read.Index.MaxBytes = defaultIndexMaxBytes
+ }
+ if cfg.Read.Model == nil {
+ cfg.Read.Model = cfg.Model
+ }
+ if cfg.Read.TopicSelection == nil {
+ cfg.Read.TopicSelection = &TopicSelectionConfig{}
+ }
+ if cfg.Read.TopicSelection.TopK <= 0 {
+ cfg.Read.TopicSelection.TopK = defaultTopicTopK
+ }
+ if cfg.Read.TopicSelection.CandidateGlob == "" {
+ cfg.Read.TopicSelection.CandidateGlob = CandidateGlobPattern
+ }
+ if cfg.Read.TopicSelection.CandidateLimit <= 0 {
+ cfg.Read.TopicSelection.CandidateLimit = defaultCandidateLimit
+ }
+ if cfg.Read.TopicSelection.CandidatePreviewLines <= 0 {
+ cfg.Read.TopicSelection.CandidatePreviewLines = defaultCandidatePreviewLine
+ }
+ if cfg.Read.TopicSelection.MaxLines <= 0 {
+ cfg.Read.TopicSelection.MaxLines = defaultTopicMaxLines
+ }
+ if cfg.Read.TopicSelection.MaxBytes <= 0 {
+ cfg.Read.TopicSelection.MaxBytes = defaultTopicMaxBytes
+ }
+ if cfg.Read.TopicSelection.MaxTotalBytes <= 0 {
+ cfg.Read.TopicSelection.MaxTotalBytes = defaultTopicMaxTotalBytes
+ }
+
+ if cfg.Write == nil {
+ cfg.Write = &WriteConfig[M]{Mode: WriteModeDisabled}
+ }
+ if cfg.Write.Mode == "" {
+ cfg.Write.Mode = WriteModeDisabled
+ }
+ if cfg.Write.Model == nil {
+ cfg.Write.Model = cfg.Model
+ }
+ if cfg.Write.MaxTurns <= 0 {
+ cfg.Write.MaxTurns = defaultMemoryWriteMaxTurns
+ }
+
+ if cfg.Coordination == nil {
+ cfg.Coordination = &CoordinationConfig[M]{}
+ }
+ if cfg.Coordination.Coordinator == nil {
+ cfg.Coordination.Coordinator = NewLocalCoordinator()
+ }
+ if cfg.Coordination.LockTTL <= 0 {
+ cfg.Coordination.LockTTL = 2 * time.Minute
+ }
+}
+
+func cloneConfig[M adk.MessageType](cfg *Config[M]) *Config[M] {
+ if cfg == nil {
+ return nil
+ }
+
+ cp := *cfg
+ if cfg.Read != nil {
+ readCopy := *cfg.Read
+ cp.Read = &readCopy
+ if cfg.Read.Index != nil {
+ indexCopy := *cfg.Read.Index
+ cp.Read.Index = &indexCopy
+ }
+ if cfg.Read.TopicSelection != nil {
+ topicSelectionCopy := *cfg.Read.TopicSelection
+ cp.Read.TopicSelection = &topicSelectionCopy
+ }
+ }
+ if cfg.Write != nil {
+ writeCopy := *cfg.Write
+ cp.Write = &writeCopy
+ }
+ if cfg.Coordination != nil {
+ coordinationCopy := *cfg.Coordination
+ cp.Coordination = &coordinationCopy
+ }
+ return &cp
+}
+
+func linesOrSizeTrunc(content string, lines, size int) (newContent string, reason string, truncated bool) {
+ linesTrunc := func(content string, lines int) {
+ sp := strings.Split(content, "\n")
+ if len(sp) > lines {
+ newContent = strings.Join(sp[:lines], "\n")
+ reason = fmt.Sprintf("first %d lines", lines)
+ truncated = true
+ } else {
+ newContent = content
+ }
+ }
+
+ sizeTrunc := func(content string, size int) {
+ if len(content) > size {
+ newContent = content[:size]
+ reason = fmt.Sprintf("%d byte limit", size)
+ truncated = true
+ } else {
+ newContent = content
+ }
+ }
+
+ if lines == 0 && size == 0 {
+ return content, "", false
+ } else if lines == 0 {
+ sizeTrunc(content, size)
+ } else if size == 0 {
+ linesTrunc(content, lines)
+ } else {
+ linesTrunc(content, lines)
+ sizeTrunc(newContent, size)
+ }
+ return
+}
+
+func isFileNotFoundContent(content string) bool {
+ return strings.HasPrefix(strings.TrimSpace(content), "File not found: ")
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+func parseFrontmatter(md string) (fm topicFrontmatter, ok bool) {
+ s := strings.TrimLeft(md, "\ufeff \t\r\n")
+ if !strings.HasPrefix(s, "---\n") && !strings.HasPrefix(s, "---\r\n") {
+ return topicFrontmatter{}, false
+ }
+ parts := strings.SplitN(s, "\n---", 2)
+ if len(parts) != 2 {
+ return topicFrontmatter{}, false
+ }
+ yml := strings.TrimPrefix(parts[0], "---\n")
+ if err := yaml.Unmarshal([]byte(yml), &fm); err != nil {
+ return topicFrontmatter{}, false
+ }
+ return fm, true
+}
+
+func describeTopicCandidate(content string) string {
+ desc := ""
+ if fm, ok := parseFrontmatter(content); ok {
+ switch {
+ case strings.TrimSpace(fm.Description) != "":
+ desc = strings.TrimSpace(fm.Description)
+ case strings.TrimSpace(fm.Name) != "":
+ desc = strings.TrimSpace(fm.Name)
+ }
+ if strings.TrimSpace(fm.Type) != "" {
+ if desc == "" {
+ desc = "type=" + strings.TrimSpace(fm.Type)
+ } else {
+ desc = desc + " (type=" + strings.TrimSpace(fm.Type) + ")"
+ }
+ }
+ }
+ if desc == "" {
+ snippet, _, _ := linesOrSizeTrunc(content, 3, 256)
+ desc = strings.TrimSpace(snippet)
+ }
+ return desc
+}
+
+func collectToolNames[M adk.MessageType](msgs []M) []string {
+ dedupTools := make(map[string]struct{})
+ for _, msg := range msgs {
+ for _, name := range messageToolNames(msg) {
+ dedupTools[name] = struct{}{}
+ }
+ }
+ tools := make([]string, 0, len(dedupTools))
+ for t := range dedupTools {
+ tools = append(tools, t)
+ }
+ sort.Strings(tools)
+ return tools
+}
+
+func topicSelectionToolInfo() *schema.ToolInfo {
+ return &schema.ToolInfo{
+ Name: topicSelectionToolName,
+ Desc: "Select which memory files to surface for the current query. Return selected_memories as memory paths exactly as shown in the available memories list.",
+ ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
+ "selected_memories": {
+ Type: schema.Array,
+ Desc: "Memory paths exactly as shown in the available memories list, e.g. \"user_profile/preferences.md\" or \"project_context/notes/patterns.md\".",
+ Required: true,
+ ElemInfo: &schema.ParameterInfo{Type: schema.String},
+ },
+ }),
+ }
+}
+
+func parseTopicSelectionFromToolCall[M adk.MessageType](msg M, valid map[string]struct{}) ([]string, error) {
+ toolCalls := messageToolCalls(msg)
+ if len(toolCalls) == 0 {
+ return nil, fmt.Errorf("no tool calls")
+ }
+ tc := toolCalls[0]
+ if tc.Function.Name != topicSelectionToolName {
+ return nil, fmt.Errorf("unexpected tool call: %s", tc.Function.Name)
+ }
+ var parsed topicSelectionResp
+ if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err != nil {
+ return nil, err
+ }
+ out := normalizeSelected(parsed.SelectedMemories)
+ filtered := make([]string, 0, len(out))
+ for _, p := range out {
+ if _, ok := valid[p]; ok {
+ filtered = append(filtered, p)
+ }
+ }
+ return filtered, nil
+}
+
+func normalizeSelected(in []string) []string {
+ out := make([]string, 0, len(in))
+ seen := make(map[string]struct{}, len(in))
+ for _, s := range in {
+ s = strings.TrimSpace(s)
+ s = strings.TrimPrefix(s, "./")
+ s = filepath.ToSlash(s)
+ if s == "" {
+ continue
+ }
+ if _, ok := seen[s]; ok {
+ continue
+ }
+ seen[s] = struct{}{}
+ out = append(out, s)
+ }
+ return out
+}
+
+func isNilMessage[M adk.MessageType](msg M) bool {
+ var zero M
+ return any(msg) == any(zero)
+}
+
+func isUserRole[M adk.MessageType](msg M) bool {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ return m != nil && m.Role == schema.User
+ case *schema.AgenticMessage:
+ return m != nil && m.Role == schema.AgenticRoleTypeUser
+ default:
+ panic("unreachable")
+ }
+}
+
+func isAssistantRole[M adk.MessageType](msg M) bool {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ return m != nil && m.Role == schema.Assistant
+ case *schema.AgenticMessage:
+ return m != nil && m.Role == schema.AgenticRoleTypeAssistant
+ default:
+ panic("unreachable")
+ }
+}
+
+func userMessageTextContent[M adk.MessageType](msg M) string {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ if m == nil {
+ return ""
+ }
+ if len(m.UserInputMultiContent) == 0 {
+ return m.Content
+ }
+ parts := make([]string, 0, len(m.UserInputMultiContent))
+ for _, part := range m.UserInputMultiContent {
+ if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
+ parts = append(parts, part.Text)
+ }
+ }
+ if len(parts) > 0 {
+ return strings.Join(parts, "\n")
+ }
+ return m.Content
+ case *schema.AgenticMessage:
+ if m == nil {
+ return ""
+ }
+ parts := make([]string, 0, len(m.ContentBlocks))
+ for _, block := range m.ContentBlocks {
+ if block != nil && block.UserInputText != nil {
+ parts = append(parts, block.UserInputText.Text)
+ }
+ }
+ return strings.Join(parts, "\n")
+ default:
+ panic("unreachable")
+ }
+}
+
+func getMsgExtra[M adk.MessageType](msg M) map[string]any {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ if m == nil {
+ return nil
+ }
+ return m.Extra
+ case *schema.AgenticMessage:
+ if m == nil {
+ return nil
+ }
+ return m.Extra
+ default:
+ panic("unreachable")
+ }
+}
+
+func copyAndSetMsgExtra[M adk.MessageType](msg M, key string, value any) {
+ existing := getMsgExtra(msg)
+ newExtra := make(map[string]any, len(existing)+1)
+ for k, v := range existing {
+ newExtra[k] = v
+ }
+ newExtra[key] = value
+
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ m.Extra = newExtra
+ case *schema.AgenticMessage:
+ m.Extra = newExtra
+ default:
+ panic("unreachable")
+ }
+}
+
+func makeUserMsg[M adk.MessageType](text string) M {
+ var zero M
+ switch any(zero).(type) {
+ case *schema.Message:
+ return any(schema.UserMessage(text)).(M)
+ case *schema.AgenticMessage:
+ return any(schema.UserAgenticMessage(text)).(M)
+ default:
+ panic("unreachable")
+ }
+}
+
+func makeSystemMsg[M adk.MessageType](text string) M {
+ var zero M
+ switch any(zero).(type) {
+ case *schema.Message:
+ return any(schema.SystemMessage(text)).(M)
+ case *schema.AgenticMessage:
+ return any(schema.SystemAgenticMessage(text)).(M)
+ default:
+ panic("unreachable")
+ }
+}
+
+func makeToolChoiceForced[M adk.MessageType](name string) model.Option {
+ var zero M
+ switch any(zero).(type) {
+ case *schema.Message:
+ return model.WithToolChoice(schema.ToolChoiceForced, name)
+ case *schema.AgenticMessage:
+ return model.WithAgenticToolChoice(&schema.AgenticToolChoice{
+ Type: schema.ToolChoiceForced,
+ Forced: &schema.AgenticForcedToolChoice{
+ Tools: []*schema.AllowedTool{{FunctionName: name}},
+ },
+ })
+ default:
+ panic("unreachable")
+ }
+}
+
+func messageToolCalls[M adk.MessageType](msg M) []schema.ToolCall {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ if m == nil {
+ return nil
+ }
+ return m.ToolCalls
+ case *schema.AgenticMessage:
+ if m == nil {
+ return nil
+ }
+ out := make([]schema.ToolCall, 0, len(m.ContentBlocks))
+ for _, block := range m.ContentBlocks {
+ if block == nil || block.FunctionToolCall == nil {
+ continue
+ }
+ out = append(out, schema.ToolCall{
+ ID: block.FunctionToolCall.CallID,
+ Type: "function",
+ Function: schema.FunctionCall{
+ Name: block.FunctionToolCall.Name,
+ Arguments: block.FunctionToolCall.Arguments,
+ },
+ })
+ }
+ return out
+ default:
+ panic("unreachable")
+ }
+}
+
+func messageToolNames[M adk.MessageType](msg M) []string {
+ switch m := any(msg).(type) {
+ case *schema.Message:
+ if m == nil || m.Role != schema.Tool || m.ToolName == "" {
+ return nil
+ }
+ return []string{m.ToolName}
+ case *schema.AgenticMessage:
+ if m == nil {
+ return nil
+ }
+ var out []string
+ for _, block := range m.ContentBlocks {
+ if block == nil || block.FunctionToolResult == nil || block.FunctionToolResult.Name == "" {
+ continue
+ }
+ out = append(out, block.FunctionToolResult.Name)
+ }
+ return out
+ default:
+ panic("unreachable")
+ }
+}
+
+func hasTopicMemoryInjected[M adk.MessageType](msgs []M) bool {
+ for _, msg := range msgs {
+ if isTopicMemoryMessage(msg) {
+ return true
+ }
+ }
+ return false
+}
+
+func hasMemoryIndexInjected[M adk.MessageType](msgs []M) bool {
+ for _, msg := range msgs {
+ if isMemoryIndexMessage(msg) {
+ return true
+ }
+ }
+ return false
+}
+
+func insertMessagesBeforeLastUserQuery[M adk.MessageType](msgs []M, inserts []M) []M {
+ if len(inserts) == 0 {
+ return msgs
+ }
+ idx := lastUserQueryMessageIndex(msgs)
+ if idx < 0 {
+ idx = len(msgs)
+ }
+ out := make([]M, 0, len(msgs)+len(inserts))
+ out = append(out, msgs[:idx]...)
+ out = append(out, inserts...)
+ out = append(out, msgs[idx:]...)
+ return out
+}
+
+func lastUserQueryMessageIndex[M adk.MessageType](msgs []M) int {
+ for i := len(msgs) - 1; i >= 0; i-- {
+ msg := msgs[i]
+ if isNilMessage(msg) || !isUserRole(msg) || isAutomemoryReminderMessage(msg) {
+ continue
+ }
+ return i
+ }
+ return -1
+}
+
+func isAutomemoryReminderMessage[M adk.MessageType](m M) bool {
+ if isTopicMemoryMessage(m) || isMemoryIndexMessage(m) {
+ return true
+ }
+ if isNilMessage(m) || !isUserRole(m) {
+ return false
+ }
+ return strings.HasPrefix(strings.TrimSpace(userMessageTextContent(m)), "")
+}
+
+func isTopicMemoryMessage[M adk.MessageType](m M) bool {
+ if isNilMessage(m) || !isUserRole(m) {
+ return false
+ }
+ if extra := getMsgExtra(m); extra != nil {
+ if v, ok := extra[memoryExtraKey]; ok {
+ if isTopicMemoryExtra(v) {
+ return true
+ }
+ }
+ }
+ content := userMessageTextContent(m)
+ return strings.Contains(content, "") && !strings.Contains(content, "")
+}
+
+func isMemoryIndexMessage[M adk.MessageType](m M) bool {
+ if isNilMessage(m) || !isUserRole(m) {
+ return false
+ }
+ if extra := getMsgExtra(m); extra != nil {
+ if v, ok := extra[memoryExtraKey]; ok {
+ if isMemoryIndexExtra(v) {
+ return true
+ }
+ }
+ }
+ return strings.Contains(userMessageTextContent(m), "")
+}
+
+func isTopicMemoryExtra(v any) bool {
+ switch meta := v.(type) {
+ case *memoryExtra:
+ return meta != nil && (meta.Type == "memory" || meta.Type == "topic_memory")
+ case map[string]any:
+ typ, _ := meta["type"].(string)
+ return typ == "memory" || typ == "topic_memory"
+ default:
+ return false
+ }
+}
+
+func isMemoryIndexExtra(v any) bool {
+ switch meta := v.(type) {
+ case *memoryExtra:
+ return meta != nil && meta.Type == "memory_index"
+ case map[string]any:
+ typ, _ := meta["type"].(string)
+ return typ == "memory_index"
+ default:
+ return false
+ }
+}
+
+func newMemoryMessage[M adk.MessageType](content string) M {
+ msg := makeUserMsg[M](content)
+ copyAndSetMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory"})
+ return msg
+}
+
+func newMemoryIndexMessage[M adk.MessageType](content string) M {
+ msg := makeUserMsg[M](content)
+ copyAndSetMsgExtra(msg, memoryExtraKey, &memoryExtra{Type: "memory_index"})
+ return msg
+}
+
+func ensureMemoryMsgUnchanged[M adk.MessageType](state *adk.TypedChatModelAgentState[M], expectedContent string) *adk.TypedChatModelAgentState[M] {
+ if state == nil || strings.TrimSpace(expectedContent) == "" {
+ return state
+ }
+ changed := false
+ out := *state
+ out.Messages = append([]M{}, state.Messages...)
+
+ for i, m := range out.Messages {
+ if !isTopicMemoryMessage(m) {
+ continue
+ }
+ extra := getMsgExtra(m)
+ if userMessageTextContent(m) != expectedContent || extra == nil || extra[memoryExtraKey] == nil {
+ out.Messages[i] = newMemoryMessage[M](expectedContent)
+ changed = true
+ }
+ }
+ if !changed {
+ return state
+ }
+ return &out
+}
+
+func extractFilePath(args string) (string, bool) {
+ var m map[string]any
+ if err := json.Unmarshal([]byte(args), &m); err != nil {
+ return "", false
+ }
+ if v, ok := m["file_path"]; ok {
+ if s, ok := v.(string); ok && s != "" {
+ return s, true
+ }
+ }
+ if v, ok := m["filePath"]; ok {
+ if s, ok := v.(string); ok && s != "" {
+ return s, true
+ }
+ }
+ return "", false
+}
+
+func isPathWithinMemoryDir(memDir string, filePath string) bool {
+ if memDir == "" || filePath == "" {
+ return false
+ }
+ md := filepath.Clean(memDir)
+ fp := filepath.Clean(filePath)
+ if !filepath.IsAbs(fp) {
+ fp = filepath.Join(md, fp)
+ fp = filepath.Clean(fp)
+ }
+ if fp == md {
+ return true
+ }
+ sep := string(filepath.Separator)
+ return strings.HasPrefix(fp, md+sep)
+}
+
+func getWriteCursorFromMessages[M adk.MessageType](msgs []M) int {
+ for i := len(msgs) - 1; i >= 0; i-- {
+ m := msgs[i]
+ extra := getMsgExtra(m)
+ if isNilMessage(m) || extra == nil {
+ continue
+ }
+ v, ok := extra[memoryExtraKey]
+ if !ok {
+ continue
+ }
+ switch meta := v.(type) {
+ case *memoryExtra:
+ if meta != nil && meta.Type == "write_cursor" {
+ return meta.Cursor
+ }
+ case map[string]any:
+ if typ, _ := meta["type"].(string); typ != "write_cursor" {
+ continue
+ }
+ switch c := meta["cursor"].(type) {
+ case int:
+ return c
+ case int64:
+ return int(c)
+ case float64:
+ return int(c)
+ }
+ }
+ }
+ return 0
+}
+
+func markWriteCursor[M adk.MessageType](state *adk.TypedChatModelAgentState[M], cursor int) *adk.TypedChatModelAgentState[M] {
+ if state == nil || len(state.Messages) == 0 {
+ return state
+ }
+ last := state.Messages[len(state.Messages)-1]
+ if isNilMessage(last) {
+ return state
+ }
+
+ copyAndSetMsgExtra(last, memoryExtraKey, &memoryExtra{
+ Type: "write_cursor",
+ Cursor: cursor,
+ })
+
+ return state
+}
+
+func countModelVisibleMessages[M adk.MessageType](msgs []M) int {
+ n := 0
+ for _, m := range msgs {
+ if isNilMessage(m) {
+ continue
+ }
+ if isUserRole(m) || isAssistantRole(m) {
+ n++
+ }
+ }
+ return n
+}
+
+func getOrInitWriteSessionID(ctx context.Context) string {
+ const key = "__automemory_write_session_id__"
+ if v, ok := adk.GetSessionValue(ctx, key); ok {
+ if s, ok := v.(string); ok && s != "" {
+ return s
+ }
+ }
+ s := fmt.Sprintf("%d", time.Now().UnixNano())
+ adk.AddSessionValue(ctx, key, s)
+ return s
+}
+
+func buildPendingSnapshot[M adk.MessageType](messages []M, cursor int, toolInfos []*schema.ToolInfo) (*PendingSnapshot, error) {
+ raw, err := json.Marshal(messages)
+ if err != nil {
+ return nil, err
+ }
+ var rawToolInfos json.RawMessage
+ if toolInfos != nil {
+ rawToolInfos, err = json.Marshal(toolInfos)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &PendingSnapshot{Cursor: cursor, Messages: raw, ToolInfos: rawToolInfos}, nil
+}
+
+func decodePendingSnapshot[M adk.MessageType](snapshot *PendingSnapshot) ([]M, int, []*schema.ToolInfo, error) {
+ if snapshot == nil {
+ return nil, 0, nil, nil
+ }
+ var msgs []M
+ if err := json.Unmarshal(snapshot.Messages, &msgs); err != nil {
+ return nil, 0, nil, err
+ }
+ var toolInfos []*schema.ToolInfo
+ if len(snapshot.ToolInfos) > 0 {
+ if err := json.Unmarshal(snapshot.ToolInfos, &toolInfos); err != nil {
+ return nil, 0, nil, err
+ }
+ }
+ return msgs, snapshot.Cursor, toolInfos, nil
+}
+
+func hasMemoryWritesSince[M adk.MessageType](msgs []M, cursor int, stores []runtimeMemoryStore) bool {
+ if cursor < 0 {
+ cursor = 0
+ }
+ for _, msg := range msgs[cursor:] {
+ if isNilMessage(msg) || !isAssistantRole(msg) {
+ continue
+ }
+ for _, tc := range messageToolCalls(msg) {
+ if tc.Function.Name != adkfs.ToolNameWriteFile && tc.Function.Name != adkfs.ToolNameEditFile {
+ continue
+ }
+ if fp, ok := extractFilePath(tc.Function.Arguments); ok && isPathWithinMemoryStores(stores, fp) {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func isPathWithinMemoryStores(stores []runtimeMemoryStore, filePath string) bool {
+ if filePath == "" {
+ return false
+ }
+ if filepath.IsAbs(filePath) {
+ for _, store := range stores {
+ if isPathWithinMemoryDir(store.Path, filePath) {
+ return true
+ }
+ }
+ return false
+ }
+
+ clean := filepath.ToSlash(filepath.Clean(filePath))
+ for _, store := range stores {
+ name := filepath.ToSlash(store.displayName())
+ if clean == name || strings.HasPrefix(clean, name+"/") {
+ return true
+ }
+ }
+ return len(stores) == 1 && isPathWithinMemoryDir(stores[0].Path, filePath)
+}
+
+func countModelVisibleMessagesSince[M adk.MessageType](msgs []M, cursor int) int {
+ if cursor < 0 {
+ cursor = 0
+ }
+ if cursor >= len(msgs) {
+ return 0
+ }
+ return countModelVisibleMessages(msgs[cursor:])
+}
+
+func parseRFC3339NanoBestEffort(s string) time.Time {
+ if s == "" {
+ return time.Time{}
+ }
+ if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
+ return t
+ }
+ if t, err := time.Parse(time.RFC3339, s); err == nil {
+ return t
+ }
+ return time.Time{}
+}
+
+func (s runtimeMemoryStore) displayName() string {
+ if strings.TrimSpace(s.Name) != "" {
+ return strings.TrimSpace(s.Name)
+ }
+ return s.Path
+}
+
+func (m *middleware[M]) coordinatorKey(sessionID string) string {
+ if sessionID == "" || m == nil || len(m.memoryStores) == 0 {
+ return ""
+ }
+ paths := m.memoryStorePaths()
+ sort.Strings(paths)
+ return strings.Join(paths, "\n") + "::" + sessionID
+}
+
+func (m *middleware[M]) memoryStorePaths() []string {
+ if m == nil || len(m.memoryStores) == 0 {
+ return nil
+ }
+ paths := make([]string, 0, len(m.memoryStores))
+ for _, store := range m.memoryStores {
+ paths = append(paths, store.Path)
+ }
+ return paths
+}
+
+func (m *middleware[M]) memoryIndexEnabled() bool {
+ return m != nil && m.cfg != nil && m.cfg.Read != nil && m.cfg.Read.Index != nil &&
+ m.cfg.Read.Index.EnableMemoryIndex != nil && *m.cfg.Read.Index.EnableMemoryIndex
+}
+
+func (m *middleware[M]) onErr(ctx context.Context, stage ErrorStage, err error) {
+ if err == nil {
+ return
+ }
+ if m.cfg != nil && m.cfg.OnError != nil {
+ m.cfg.OnError(ctx, stage, err)
+ }
+}
+
+func (m *middleware[M]) lastUserMessage(agentIn *adk.TypedAgentInput[M]) (M, bool) {
+ if agentIn == nil || len(agentIn.Messages) == 0 {
+ return nil, false
+ }
+ if m.cfg.Read.TopicSelection == nil || m.topicSelectionModel == nil {
+ return nil, false
+ }
+ for i := len(agentIn.Messages) - 1; i >= 0; i-- {
+ msg := agentIn.Messages[i]
+ if isNilMessage(msg) || !isUserRole(msg) || isAutomemoryReminderMessage(msg) {
+ continue
+ }
+ return msg, true
+ }
+ return nil, false
+}
+
+func (m *middleware[M]) topicSelectionTopK() int {
+ topK := m.cfg.Read.TopicSelection.TopK
+ if topK <= 0 {
+ return defaultTopicTopK
+ }
+ return topK
+}
+
+func (m *middleware[M]) resolveSessionID(ctx context.Context, state *adk.TypedChatModelAgentState[M]) (string, error) {
+ if m.coordination != nil && m.coordination.SessionIDFunc != nil {
+ return m.coordination.SessionIDFunc(ctx, state)
+ }
+ return getOrInitWriteSessionID(ctx), nil
+}
+
+func (m *middleware[M]) sendTopicMemoryEvent(ctx context.Context, msgs []M, memMsg M) {
+ var beforeID string
+ if len(msgs) > 0 && !isNilMessage(msgs[len(msgs)-1]) {
+ beforeID = adk.GetMessageID(msgs[len(msgs)-1])
+ }
+ if sendEventErr := adk.TypedSendEvent(ctx, &adk.TypedAgentEvent[M]{SessionEvent: &adk.SessionEvent[M]{
+ Kind: adk.SessionEventMessageInserted,
+ MessageInserted: &adk.MessageInsertedEvent[M]{
+ Message: memMsg,
+ BeforeMessageID: beforeID,
+ },
+ }}); sendEventErr != nil {
+ m.onErr(ctx, OnErrorStageSendSessionEvent, sendEventErr)
+ }
+}