From ac39c4521337a146f0668951aadd872c7437b61e Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Mon, 1 Jun 2026 17:08:41 +0800 Subject: [PATCH 1/9] fix: improve langfuse trace io and token usage --- callbacks/cozeloop/data_parser.go | 72 ++++++++++--- callbacks/cozeloop/data_parser_test.go | 110 ++++++++++++++++++++ callbacks/langfuse/README.md | 10 ++ callbacks/langfuse/README_zh.md | 8 ++ callbacks/langfuse/langfuse.go | 124 +++++++++++++++++++--- callbacks/langfuse/langfuse_test.go | 133 ++++++++++++++++++++++-- callbacks/langfuse/trace.go | 5 +- components/model/claude/claude.go | 77 +++++++++++++- components/model/claude/claude_test.go | 73 +++++++++++++ devops/go.mod | 11 +- devops/go.sum | 26 ++--- devops/internal/model/container.go | 23 +++- devops/internal/model/container_test.go | 81 +++++++++++++++ devops/internal/model/types.go | 3 + 14 files changed, 700 insertions(+), 56 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index f1a752300..411248e1c 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -458,13 +458,7 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu chunks = append(chunks, cbOutput.Message) } - if cbOutput.TokenUsage != nil { - usage = &model.TokenUsage{ - PromptTokens: cbOutput.TokenUsage.PromptTokens, - CompletionTokens: cbOutput.TokenUsage.CompletionTokens, - TotalTokens: cbOutput.TokenUsage.TotalTokens, - } - } + usage = mergeCumulativeTokenUsage(usage, cbOutput.TokenUsage) if cbOutput.Config != nil && !onceSet { onceSet = true @@ -532,14 +526,9 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou chunks = append(chunks, cbOutput.Message) } - if cbOutput.TokenUsage != nil { - usage = &model.TokenUsage{ - PromptTokens: cbOutput.TokenUsage.PromptTokens, - CompletionTokens: cbOutput.TokenUsage.CompletionTokens, - TotalTokens: cbOutput.TokenUsage.TotalTokens, - PromptTokenDetails: cbOutput.TokenUsage.PromptTokenDetails, - CompletionTokensDetails: cbOutput.TokenUsage.CompletionTokensDetails, - } + usage = mergeCumulativeTokenUsage(usage, cbOutput.TokenUsage) + if cbOutput.Message != nil && cbOutput.Message.ResponseMeta != nil { + usage = mergeCumulativeTokenUsage(usage, schemaTokenUsageToModelTokenUsage(cbOutput.Message.ResponseMeta.TokenUsage)) } if cbOutput.Config != nil && !onceSet { @@ -586,6 +575,59 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou return tags } +// mergeCumulativeTokenUsage keeps the final request-level token usage from a +// stream. Streaming callbacks may carry partial cumulative snapshots, and +// TokenUsage does not preserve field-presence metadata, so each monotonically +// increasing counter is merged by its largest observed value. +func mergeCumulativeTokenUsage(dst, src *model.TokenUsage) *model.TokenUsage { + if src == nil { + return dst + } + if dst == nil { + dst = &model.TokenUsage{} + } + + if src.PromptTokens > dst.PromptTokens { + dst.PromptTokens = src.PromptTokens + } + if src.CompletionTokens > dst.CompletionTokens { + dst.CompletionTokens = src.CompletionTokens + } + if src.TotalTokens > dst.TotalTokens { + dst.TotalTokens = src.TotalTokens + } + if src.PromptTokenDetails.CachedTokens > dst.PromptTokenDetails.CachedTokens { + dst.PromptTokenDetails.CachedTokens = src.PromptTokenDetails.CachedTokens + } + if src.CompletionTokensDetails.ReasoningTokens > dst.CompletionTokensDetails.ReasoningTokens { + dst.CompletionTokensDetails.ReasoningTokens = src.CompletionTokensDetails.ReasoningTokens + } + + if total := dst.PromptTokens + dst.CompletionTokens; total > dst.TotalTokens { + dst.TotalTokens = total + } + + return dst +} + +func schemaTokenUsageToModelTokenUsage(usage *schema.TokenUsage) *model.TokenUsage { + if usage == nil { + return nil + } + + return &model.TokenUsage{ + PromptTokens: usage.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: usage.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: usage.CompletionTokens, + CompletionTokensDetails: model.CompletionTokensDetails{ + ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, + }, + TotalTokens: usage.TotalTokens, + } +} + func (d defaultDataParser) ParseDefaultStreamInput(ctx context.Context, input *schema.StreamReader[callbacks.CallbackInput]) (chunks []any, err error) { for { item, recvErr := input.Recv() diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index 1af88f4db..d922b5965 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -376,6 +376,116 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) { }) } +func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage(t *testing.T) { + mockey.PatchConvey("测试 ChatModel 流式输出合并 token usage", t, func() { + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{Role: schema.Assistant, Content: "assistant"}, + TokenUsage: &model.TokenUsage{ + PromptTokens: 6900, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: 3265, + }, + CompletionTokens: 1, + TotalTokens: 6901, + }, + }, nil) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{Role: schema.Assistant, Content: " message"}, + TokenUsage: &model.TokenUsage{ + CompletionTokens: 69, + }, + }, nil) + outsw.Close() + + d := defaultDataParser{} + result := d.ParseChatModelStreamOutput(context.Background(), outsr) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6969) + }) + + mockey.PatchConvey("测试 ChatModel 流式输出使用最终累计 token usage", t, func() { + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{Role: schema.Assistant, Content: "assistant"}, + TokenUsage: &model.TokenUsage{ + PromptTokens: 2679, + CompletionTokens: 3, + TotalTokens: 2682, + }, + }, nil) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{Role: schema.Assistant, Content: " message"}, + TokenUsage: &model.TokenUsage{ + PromptTokens: 10682, + CompletionTokens: 510, + TotalTokens: 11192, + }, + }, nil) + outsw.Close() + + d := defaultDataParser{} + result := d.ParseChatModelStreamOutput(context.Background(), outsr) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 10682) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 510) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 11192) + }) +} + +func Test_defaultDataParser_ParseAgenticModelStreamOutput_MergeMessageMetaTokenUsage(t *testing.T) { + mockey.PatchConvey("测试 AgenticModel 流式输出合并 message meta token usage", t, func() { + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) + outsw.Send(&model.AgenticCallbackOutput{ + Message: &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: "assistant"}), + }, + ResponseMeta: &schema.AgenticResponseMeta{ + TokenUsage: &schema.TokenUsage{ + PromptTokens: 6900, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 3265, + }, + CompletionTokens: 1, + TotalTokens: 6901, + }, + }, + }, + }, nil) + outsw.Send(&model.AgenticCallbackOutput{ + Message: &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{ + schema.NewContentBlock(&schema.AssistantGenText{Text: " message"}), + }, + ResponseMeta: &schema.AgenticResponseMeta{ + TokenUsage: &schema.TokenUsage{ + CompletionTokens: 69, + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: 12, + }, + }, + }, + }, + }, nil) + outsw.Close() + + d := defaultDataParser{} + result := d.ParseAgenticModelStreamOutput(context.Background(), outsr) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6969) + }) +} + // Test_defaultDataParser_tryConcatChunks 为 defaultDataParser 的 tryConcatChunks 方法编写单元测试 func Test_defaultDataParser_tryConcatChunks(t *testing.T) { mockey.PatchConvey("测试 defaultDataParser 的 tryConcatChunks 方法", t, func() { diff --git a/callbacks/langfuse/README.md b/callbacks/langfuse/README.md index d83a6ee16..065293a14 100644 --- a/callbacks/langfuse/README.md +++ b/callbacks/langfuse/README.md @@ -129,6 +129,10 @@ type Config struct { // Public determines if traces are publicly accessible (Optional) Public bool + + // DisableTraceIO disables automatically writing root run input/output + // to trace input/output (Optional) + DisableTraceIO bool } ``` @@ -136,6 +140,12 @@ type Config struct { You can customize individual traces using the `SetTrace` function: +By default, the callback writes the root run input/output to the trace +input/output fields. Use `WithInput` when you want to provide the trace input +manually; use `UpdateTraceOutput` after the run when you want to override the +trace output manually. Set `DisableTraceIO` to `true` to turn off automatic +trace input/output updates. + ```go ctx = langfuse.SetTrace(ctx, langfuse.WithID("trace-id"), diff --git a/callbacks/langfuse/README_zh.md b/callbacks/langfuse/README_zh.md index 6d762236f..46503a8a2 100644 --- a/callbacks/langfuse/README_zh.md +++ b/callbacks/langfuse/README_zh.md @@ -129,6 +129,9 @@ type Config struct { // 是否公开可访问 (选填) Public bool + + // 禁用自动将根运行的输入/输出写入 trace input/output (选填) + DisableTraceIO bool } ``` @@ -136,6 +139,11 @@ type Config struct { 您可以使用 `SetTrace` 函数自定义单个追踪: +默认情况下,回调会将根运行的输入/输出写入 trace input/output 字段。使用 +`WithInput` 可以手动指定 trace input;运行结束后可调用 +`UpdateTraceOutput` 手动覆盖 trace output。将 `DisableTraceIO` 设为 `true` +可关闭自动 trace input/output 更新。 + ```go ctx = langfuse.SetTrace(ctx, langfuse.WithID("trace-id"), diff --git a/callbacks/langfuse/langfuse.go b/callbacks/langfuse/langfuse.go index 0b3edc940..ea56c3bdf 100644 --- a/callbacks/langfuse/langfuse.go +++ b/callbacks/langfuse/langfuse.go @@ -118,6 +118,11 @@ type Config struct { // Default: false // Example: true Public bool + + // DisableTraceIO disables automatically writing root run input/output to the + // Langfuse trace input/output fields (Optional) + // Default: false + DisableTraceIO bool } func NewLangfuseHandler(cfg *Config) (handler *CallbackHandler, flusher func()) { @@ -166,6 +171,8 @@ func NewLangfuseHandler(cfg *Config) (handler *CallbackHandler, flusher func()) release: cfg.Release, tags: cfg.Tags, public: cfg.Public, + + disableTraceIO: cfg.DisableTraceIO, }, cli.Flush } @@ -178,12 +185,16 @@ type CallbackHandler struct { release string tags []string public bool + + disableTraceIO bool } type langfuseStateKey struct{} type langfuseState struct { traceID string observationID string + traceInputSet bool + isRoot bool } func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { @@ -195,6 +206,7 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, if state == nil { return ctx } + isRoot := state.observationID == "" if info.Component == components.ComponentOfChatModel { mcbi := model.ConvCallbackInput(input) @@ -219,10 +231,21 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, log.Printf("create generation error: %v, runinfo: %+v", err, info) return ctx } - return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ + nState := &langfuseState{ traceID: state.traceID, observationID: generationID, - }) + traceInputSet: state.traceInputSet, + isRoot: isRoot, + } + if c.shouldUpdateTraceInput(nState) { + in, err_ := sonic.MarshalString(input) + if err_ != nil { + log.Printf("marshal trace input error: %v, runinfo: %+v", err_, info) + } else { + c.updateTraceInput(ctx, state.traceID, in) + } + } + return context.WithValue(ctx, langfuseStateKey{}, nState) } in, err := sonic.MarshalString(input) @@ -245,10 +268,16 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, log.Printf("create span error: %v", err) return ctx } - return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ + nState := &langfuseState{ traceID: state.traceID, observationID: spanID, - }) + traceInputSet: state.traceInputSet, + isRoot: isRoot, + } + if c.shouldUpdateTraceInput(nState) { + c.updateTraceInput(ctx, state.traceID, in) + } + return context.WithValue(ctx, langfuseStateKey{}, nState) } func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { @@ -283,6 +312,14 @@ func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou if err != nil { log.Printf("end generation error: %v, runinfo: %+v", err, info) } + if c.shouldUpdateTraceOutput(state) { + out, err_ := sonic.MarshalString(output) + if err_ != nil { + log.Printf("marshal trace output error: %v, runinfo: %+v", err_, info) + } else { + c.UpdateTraceOutput(ctx, state.traceID, out) + } + } return ctx } @@ -303,6 +340,9 @@ func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou if err != nil { log.Printf("end span fail: %v, runinfo: %+v", err, info) } + if c.shouldUpdateTraceOutput(state) { + c.UpdateTraceOutput(ctx, state.traceID, out) + } return ctx } @@ -334,6 +374,9 @@ func (c *CallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, if reportErr != nil { log.Printf("end generation fail: %v, runinfo: %+v, execute error: %v", reportErr, info, err) } + if c.shouldUpdateTraceOutput(state) { + c.UpdateTraceOutput(ctx, state.traceID, err.Error()) + } return ctx } @@ -350,6 +393,9 @@ func (c *CallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, if reportErr != nil { log.Printf("end span fail: %v, runinfo: %+v, execute error: %v", reportErr, info, err) } + if c.shouldUpdateTraceOutput(state) { + c.UpdateTraceOutput(ctx, state.traceID, err.Error()) + } return ctx } @@ -362,6 +408,7 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if state == nil { return ctx } + isRoot := state.observationID == "" if info.Component == components.ComponentOfChatModel { generationID, err := c.cli.CreateGeneration(&langfuse.GenerationEventBody{ @@ -378,6 +425,12 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call log.Printf("create generation error: %v, runinfo: %+v", err, info) return ctx } + nState := &langfuseState{ + traceID: state.traceID, + observationID: generationID, + traceInputSet: state.traceInputSet, + isRoot: isRoot, + } go func() { defer func() { @@ -420,12 +473,17 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if err != nil { log.Printf("update stream generation fail: %v, runinfo: %+v", err, info) } + if c.shouldUpdateTraceInput(nState) { + in, err__ := sonic.MarshalString(ins) + if err__ != nil { + log.Printf("marshal trace stream input error: %v, runinfo: %+v", err__, info) + } else { + c.updateTraceInput(ctx, nState.traceID, in) + } + } }() - return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ - traceID: state.traceID, - observationID: generationID, - }) + return context.WithValue(ctx, langfuseStateKey{}, nState) } spanID, err := c.cli.CreateSpan(&langfuse.SpanEventBody{ @@ -442,6 +500,12 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call log.Printf("create span error: %v", err) return ctx } + nState := &langfuseState{ + traceID: state.traceID, + observationID: spanID, + traceInputSet: state.traceInputSet, + isRoot: isRoot, + } go func() { defer func() { @@ -480,12 +544,12 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if err != nil { log.Printf("update stream span error: %v", err) } + if c.shouldUpdateTraceInput(nState) { + c.updateTraceInput(ctx, nState.traceID, in) + } }() - return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ - traceID: state.traceID, - observationID: spanID, - }) + return context.WithValue(ctx, langfuseStateKey{}, nState) } func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { @@ -545,6 +609,14 @@ func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callb if err != nil { log.Printf("end stream generation error: %v, runinfo: %+v", err, info) } + if c.shouldUpdateTraceOutput(state) { + out, err_ := sonic.MarshalString(outs) + if err_ != nil { + log.Printf("marshal trace stream output error: %v, runinfo: %+v", err_, info) + } else { + c.UpdateTraceOutput(ctx, state.traceID, out) + } + } }() return ctx } @@ -585,11 +657,39 @@ func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callb if err != nil { log.Printf("end stream span fail: %v, runinfo: %+v", err, info) } + if c.shouldUpdateTraceOutput(state) { + c.UpdateTraceOutput(ctx, state.traceID, out) + } }() return ctx } +func (c *CallbackHandler) updateTraceInput(ctx context.Context, traceID string, input string) { + _ = ctx + err := c.cli.EndTrace(&langfuse.TraceEventBody{ + BaseEventBody: langfuse.BaseEventBody{ + ID: traceID, + }, + Input: input, + }) + if err != nil { + log.Printf("input end trace fail: %v, traceID: %s", err, traceID) + } +} + +func (c *CallbackHandler) shouldUpdateTraceInput(state *langfuseState) bool { + return c.shouldUpdateTraceIO(state) && !state.traceInputSet +} + +func (c *CallbackHandler) shouldUpdateTraceOutput(state *langfuseState) bool { + return c.shouldUpdateTraceIO(state) +} + +func (c *CallbackHandler) shouldUpdateTraceIO(state *langfuseState) bool { + return state != nil && state.isRoot && !c.disableTraceIO && len(state.traceID) > 0 +} + // UpdateTraceOutput pushes final trace output to Langfuse (via ACL EndTrace). // ctx is reserved for future cancellation / deadline propagation; callers may pass context.Background() for now. func (c *CallbackHandler) UpdateTraceOutput(ctx context.Context, traceID string, output string) { diff --git a/callbacks/langfuse/langfuse_test.go b/callbacks/langfuse/langfuse_test.go index b0c0a1675..37984c7fb 100644 --- a/callbacks/langfuse/langfuse_test.go +++ b/callbacks/langfuse/langfuse_test.go @@ -42,12 +42,13 @@ func TestLangfuseCallback(t *testing.T) { mockLangfuse := mock.NewMockLangfuse(ctrl) defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() cbh, _ := NewLangfuseHandler(&Config{ - Name: "MyTrace", - UserID: "user id", - SessionID: "session", - Release: "release", - Tags: []string{"tag1", "tag2"}, - Public: true, + Name: "MyTrace", + UserID: "user id", + SessionID: "session", + Release: "release", + Tags: []string{"tag1", "tag2"}, + Public: true, + DisableTraceIO: true, }) callbacks.InitCallbackHandlers([]callbacks.Handler{cbh}) ctx := context.Background() @@ -303,14 +304,88 @@ func TestLangfuseCallback(t *testing.T) { WithPublic(true), WithEnvironment("development"), WithVersion("version"), + WithInput("manual input"), ) assert.Equal(t, "traceid", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).ID) assert.Equal(t, "development", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Environment) assert.Equal(t, "version", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Version) + assert.Equal(t, "manual input", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Input) + assert.True(t, ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).inputSet) }) } -func TestAttack_NilMessageInOnEnd(t *testing.T) { +func TestTraceIOAutoPromotion(t *testing.T) { + ctrl := gomock.NewController(t) + mockLangfuse := mock.NewMockLangfuse(ctrl) + defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() + + cbh, _ := NewLangfuseHandler(&Config{ + Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", + Name: "trace", + }) + + mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) + mockLangfuse.EXPECT().CreateSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) (string, error) { + assert.Equal(t, "trace-id", body.TraceID) + assert.Empty(t, body.ParentObservationID) + assert.Equal(t, "\"input\"", body.Input) + return "span-id", nil + }).Times(1) + mockLangfuse.EXPECT().EndSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) error { + assert.Equal(t, "span-id", body.ID) + assert.Equal(t, "\"output\"", body.Output) + return nil + }).Times(1) + + var traceInputs, traceOutputs []string + mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { + assert.Equal(t, "trace-id", body.ID) + if body.Input != "" { + traceInputs = append(traceInputs, body.Input) + } + if body.Output != "" { + traceOutputs = append(traceOutputs, body.Output) + } + return nil + }).Times(2) + + ctx := cbh.OnStart(context.Background(), &callbacks.RunInfo{Name: "root"}, callbacks.CallbackInput("input")) + cbh.OnEnd(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackOutput("output")) + + assert.Equal(t, []string{"\"input\""}, traceInputs) + assert.Equal(t, []string{"\"output\""}, traceOutputs) +} + +func TestTraceIOManualInputNotOverwritten(t *testing.T) { + ctrl := gomock.NewController(t) + mockLangfuse := mock.NewMockLangfuse(ctrl) + defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() + + cbh, _ := NewLangfuseHandler(&Config{ + Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", + Name: "trace", + }) + + mockLangfuse.EXPECT().CreateTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) (string, error) { + assert.Equal(t, "trace-id", body.ID) + assert.Equal(t, "manual input", body.Input) + return "trace-id", nil + }).Times(1) + mockLangfuse.EXPECT().CreateSpan(gomock.Any()).Return("span-id", nil).Times(1) + mockLangfuse.EXPECT().EndSpan(gomock.Any()).Return(nil).Times(1) + mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { + assert.Equal(t, "trace-id", body.ID) + assert.Empty(t, body.Input) + assert.Equal(t, "\"output\"", body.Output) + return nil + }).Times(1) + + ctx := SetTrace(context.Background(), WithID("trace-id"), WithInput("manual input")) + ctx = cbh.OnStart(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackInput("auto input")) + cbh.OnEnd(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackOutput("output")) +} + +func TestTraceIOAutoPromotionStream(t *testing.T) { ctrl := gomock.NewController(t) mockLangfuse := mock.NewMockLangfuse(ctrl) defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() @@ -319,6 +394,49 @@ func TestAttack_NilMessageInOnEnd(t *testing.T) { Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", Name: "trace", }) + + mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) + mockLangfuse.EXPECT().CreateSpan(gomock.Any()).Return("span-id", nil).Times(1) + mockLangfuse.EXPECT().EndSpan(gomock.Any()).Return(nil).Times(2) + + var traceInputs, traceOutputs []string + mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { + assert.Equal(t, "trace-id", body.ID) + if body.Input != "" { + traceInputs = append(traceInputs, body.Input) + } + if body.Output != "" { + traceOutputs = append(traceOutputs, body.Output) + } + return nil + }).Times(2) + + insr, insw := schema.Pipe[callbacks.CallbackInput](1) + insw.Send(callbacks.CallbackInput("stream input"), nil) + insw.Close() + ctx := cbh.OnStartWithStreamInput(context.Background(), &callbacks.RunInfo{Name: "root"}, insr) + + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](1) + outsw.Send(callbacks.CallbackOutput("stream output"), nil) + outsw.Close() + cbh.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Name: "root"}, outsr) + + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, []string{"[\"stream input\"]"}, traceInputs) + assert.Equal(t, []string{"[\"stream output\"]"}, traceOutputs) +} + +func TestAttack_NilMessageInOnEnd(t *testing.T) { + ctrl := gomock.NewController(t) + mockLangfuse := mock.NewMockLangfuse(ctrl) + defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() + + cbh, _ := NewLangfuseHandler(&Config{ + Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", + Name: "trace", + DisableTraceIO: true, + }) mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) mockLangfuse.EXPECT().CreateGeneration(gomock.Any()).Return("generation-id", nil).Times(1) mockLangfuse.EXPECT().EndGeneration(gomock.Any()).DoAndReturn(func(body *langfuse.GenerationEventBody) error { @@ -344,6 +462,7 @@ func TestAttack_ExtractModelOutputErrorIgnored(t *testing.T) { cbh, _ := NewLangfuseHandler(&Config{ Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", + DisableTraceIO: true, }) mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) mockLangfuse.EXPECT().CreateGeneration(gomock.Any()).Return("generation-id", nil).Times(1) diff --git a/callbacks/langfuse/trace.go b/callbacks/langfuse/trace.go index a1a9350ba..2100c434a 100644 --- a/callbacks/langfuse/trace.go +++ b/callbacks/langfuse/trace.go @@ -54,6 +54,7 @@ func WithUserID(userID string) TraceOption { func WithInput(input string) TraceOption { return func(o *traceOptions) { o.Input = input + o.inputSet = true } } func WithSessionID(sessionID string) TraceOption { @@ -97,6 +98,7 @@ type traceOptions struct { Name string UserID string Input string + inputSet bool SessionID string Release string Tags []string @@ -127,7 +129,8 @@ func initState(_ context.Context, cli langfuse.Langfuse, options *traceOptions) return nil, fmt.Errorf("create trace error: %v", err) } s := &langfuseState{ - traceID: traceID, + traceID: traceID, + traceInputSet: options.inputSet, } return s, nil } diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index a8da5a321..9f00840b7 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -374,6 +374,7 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . }() var waitList []*schema.Message streamCtx := &streamContext{} + var usage *schema.TokenUsage for stream.Next() { message, err_ := convStreamEvent(stream.Current(), streamCtx) if err_ != nil { @@ -396,6 +397,7 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . } waitList = []*schema.Message{} } + usage = applyClaudeStreamUsage(message, usage) closed := sw.Send(cm.getCallbackOutput(message), nil) if closed { @@ -409,6 +411,7 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . _ = sw.Send(nil, fmt.Errorf("concat empty message fail: %w", err_)) return } + usage = applyClaudeStreamUsage(message, usage) closed := sw.Send(cm.getCallbackOutput(message), nil) if closed { @@ -1356,11 +1359,22 @@ func convStreamEvent(event anthropic.MessageStreamEventUnion, streamCtx *streamC case anthropic.MessageStartEvent: return convOutputMessage(&e.Message) case anthropic.MessageDeltaEvent: + completionTokens := int(e.Usage.OutputTokens) + usage := &schema.TokenUsage{ + CompletionTokens: completionTokens, + } + if hasMessageDeltaPromptUsage(e.Usage) { + promptTokens := int(e.Usage.InputTokens + e.Usage.CacheReadInputTokens + e.Usage.CacheCreationInputTokens) + usage.PromptTokens = promptTokens + usage.PromptTokenDetails = schema.PromptTokenDetails{ + CachedTokens: int(e.Usage.CacheReadInputTokens), + } + usage.TotalTokens = promptTokens + completionTokens + } + result.ResponseMeta = &schema.ResponseMeta{ FinishReason: string(e.Delta.StopReason), - Usage: &schema.TokenUsage{ - CompletionTokens: int(e.Usage.OutputTokens), - }, + Usage: usage, } return result, nil @@ -1409,6 +1423,63 @@ func convStreamEvent(event anthropic.MessageStreamEventUnion, streamCtx *streamC } } +func hasMessageDeltaPromptUsage(usage anthropic.MessageDeltaUsage) bool { + return usage.InputTokens != 0 || + usage.CacheReadInputTokens != 0 || + usage.CacheCreationInputTokens != 0 || + usage.JSON.InputTokens.Valid() || + usage.JSON.CacheReadInputTokens.Valid() || + usage.JSON.CacheCreationInputTokens.Valid() +} + +func applyClaudeStreamUsage(message *schema.Message, usage *schema.TokenUsage) *schema.TokenUsage { + if message == nil || message.ResponseMeta == nil || message.ResponseMeta.Usage == nil { + return usage + } + + usage = mergeClaudeStreamTokenUsage(usage, message.ResponseMeta.Usage) + message.ResponseMeta.Usage = cloneClaudeStreamTokenUsage(usage) + return usage +} + +func mergeClaudeStreamTokenUsage(dst, src *schema.TokenUsage) *schema.TokenUsage { + if src == nil { + return dst + } + if dst == nil { + dst = &schema.TokenUsage{} + } + + if src.PromptTokens > dst.PromptTokens { + dst.PromptTokens = src.PromptTokens + } + if src.CompletionTokens > dst.CompletionTokens { + dst.CompletionTokens = src.CompletionTokens + } + if src.TotalTokens > dst.TotalTokens { + dst.TotalTokens = src.TotalTokens + } + if src.PromptTokenDetails.CachedTokens > dst.PromptTokenDetails.CachedTokens { + dst.PromptTokenDetails.CachedTokens = src.PromptTokenDetails.CachedTokens + } + if src.CompletionTokensDetails.ReasoningTokens > dst.CompletionTokensDetails.ReasoningTokens { + dst.CompletionTokensDetails.ReasoningTokens = src.CompletionTokensDetails.ReasoningTokens + } + + if total := dst.PromptTokens + dst.CompletionTokens; total > dst.TotalTokens { + dst.TotalTokens = total + } + return dst +} + +func cloneClaudeStreamTokenUsage(usage *schema.TokenUsage) *schema.TokenUsage { + if usage == nil { + return nil + } + cloned := *usage + return &cloned +} + func convImageBase64(data string) (string, string, error) { if !strings.HasPrefix(data, "data:") { return "", "", fmt.Errorf("invalid base64 image: %s", data) diff --git a/components/model/claude/claude_test.go b/components/model/claude/claude_test.go index 7e28cf776..442876791 100644 --- a/components/model/claude/claude_test.go +++ b/components/model/claude/claude_test.go @@ -354,6 +354,32 @@ func TestConvStreamEvent(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason) assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens) + assert.Equal(t, 0, message.ResponseMeta.Usage.PromptTokens) + assert.Equal(t, 0, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens) + assert.Equal(t, 0, message.ResponseMeta.Usage.TotalTokens) + }) + + mockey.PatchConvey("message delta event with prompt usage", t, func() { + event := anthropic.MessageStreamEventUnion{} + defer mockey.Mock(anthropic.MessageStreamEventUnion.AsAny).Return(anthropic.MessageDeltaEvent{ + Delta: anthropic.MessageDeltaEventDelta{ + StopReason: "end_turn", + }, + Usage: anthropic.MessageDeltaUsage{ + InputTokens: 5, + CacheReadInputTokens: 3, + CacheCreationInputTokens: 2, + OutputTokens: 10, + }, + }).Build().UnPatch() + + message, err := convStreamEvent(event, streamCtx) + assert.NoError(t, err) + assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason) + assert.Equal(t, 10, message.ResponseMeta.Usage.PromptTokens) + assert.Equal(t, 3, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens) + assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens) + assert.Equal(t, 20, message.ResponseMeta.Usage.TotalTokens) }) mockey.PatchConvey("content block start event", t, func() { @@ -376,6 +402,53 @@ func TestConvStreamEvent(t *testing.T) { }) } +func TestMergeClaudeStreamTokenUsage(t *testing.T) { + usage := mergeClaudeStreamTokenUsage(nil, &schema.TokenUsage{ + PromptTokens: 6900, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 3265, + }, + CompletionTokens: 1, + TotalTokens: 6901, + }) + usage = mergeClaudeStreamTokenUsage(usage, &schema.TokenUsage{ + CompletionTokens: 69, + }) + + assert.Equal(t, 6900, usage.PromptTokens) + assert.Equal(t, 3265, usage.PromptTokenDetails.CachedTokens) + assert.Equal(t, 69, usage.CompletionTokens) + assert.Equal(t, 6969, usage.TotalTokens) +} + +func TestApplyClaudeStreamUsageClonesCumulativeUsage(t *testing.T) { + var usage *schema.TokenUsage + first := &schema.Message{ + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 10, + CompletionTokens: 1, + TotalTokens: 11, + }, + }, + } + usage = applyClaudeStreamUsage(first, usage) + + second := &schema.Message{ + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + CompletionTokens: 5, + }, + }, + } + usage = applyClaudeStreamUsage(second, usage) + + assert.Equal(t, 11, first.ResponseMeta.Usage.TotalTokens) + assert.Equal(t, 10, second.ResponseMeta.Usage.PromptTokens) + assert.Equal(t, 5, second.ResponseMeta.Usage.CompletionTokens) + assert.Equal(t, 15, second.ResponseMeta.Usage.TotalTokens) +} + func TestPanicErr(t *testing.T) { err := newPanicErr("info", []byte("stack")) assert.Equal(t, "panic error: info, \nstack: stack", err.Error()) diff --git a/devops/go.mod b/devops/go.mod index 6b44d8376..2ea9eee26 100644 --- a/devops/go.mod +++ b/devops/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/bytedance/mockey v1.2.12 - github.com/cloudwego/eino v0.6.0 + github.com/cloudwego/eino v0.9.1 github.com/gorilla/mux v1.8.1 github.com/matoous/go-nanoid v1.5.1 github.com/stretchr/testify v1.10.0 @@ -15,12 +15,13 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.14.1 // indirect - github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/eino-contrib/jsonschema v1.0.2 // indirect + github.com/eino-contrib/jsonschema v1.0.3 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -42,6 +43,6 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/devops/go.sum b/devops/go.sum index 1d1626bb6..e8df98b76 100644 --- a/devops/go.sum +++ b/devops/go.sum @@ -11,28 +11,30 @@ github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4= github.com/bytedance/mockey v1.2.12/go.mod h1:3ZA4MQasmqC87Tw0w7Ygdy7eHIc2xgpZ8Pona5rsYIk= -github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= -github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= -github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= -github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= -github.com/cloudwego/eino v0.6.0 h1:pobGKMOfcQHVNhD9UT/HrvO0eYG6FC2ML/NKY2Eb9+Q= -github.com/cloudwego/eino v0.6.0/go.mod h1:JNapfU+QUrFFpboNDrNOFvmz0m9wjBFHHCr77RH6a50= +github.com/cloudwego/eino v0.9.1 h1:eSwgXfsaxmgTXsTgWi9OMBcm8hKvVhb1q0PPk58p6f8= +github.com/cloudwego/eino v0.9.1/go.mod h1:OBD1mrkfkt/pJa4rkg1P0VnaMeOVl7l8IAdEqY//3IQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/eino-contrib/jsonschema v1.0.2 h1:HaxruBMUdnXa7Lg/lX8g0Hk71ZIfdTZXmBQz0e3esr8= -github.com/eino-contrib/jsonschema v1.0.2/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= +github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= @@ -92,12 +94,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -121,9 +123,9 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/devops/internal/model/container.go b/devops/internal/model/container.go index 8445c51ec..c50935933 100644 --- a/devops/internal/model/container.go +++ b/devops/internal/model/container.go @@ -383,7 +383,7 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) } return g.AddEmbeddingNode(node, ins, newOpts...) - + case components.ComponentOfRetriever: ins, ok := gni.Instance.(retriever.Retriever) if !ok { @@ -405,6 +405,13 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddChatModelNode(node, ins, newOpts...) + case components.ComponentOfAgenticModel: + ins, ok := gni.Instance.(model.AgenticModel) + if !ok { + return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) + } + return g.AddAgenticModelNode(node, ins, newOpts...) + case components.ComponentOfPrompt: ins, ok := gni.Instance.(prompt.ChatTemplate) if !ok { @@ -412,6 +419,13 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddChatTemplateNode(node, ins, newOpts...) + case components.ComponentOfAgenticPrompt: + ins, ok := gni.Instance.(prompt.AgenticChatTemplate) + if !ok { + return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) + } + return g.AddAgenticChatTemplateNode(node, ins, newOpts...) + case compose.ComponentOfToolsNode: ins, ok := gni.Instance.(*compose.ToolsNode) if !ok { @@ -419,6 +433,13 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddToolsNode(node, ins, newOpts...) + case compose.ComponentOfAgenticToolsNode: + ins, ok := gni.Instance.(*compose.AgenticToolsNode) + if !ok { + return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) + } + return g.AddAgenticToolsNode(node, ins, newOpts...) + case compose.ComponentOfLambda: ins, ok := gni.Instance.(*compose.Lambda) if !ok { diff --git a/devops/internal/model/container_test.go b/devops/internal/model/container_test.go index 6db3dfed8..17a92a2df 100644 --- a/devops/internal/model/container_test.go +++ b/devops/internal/model/container_test.go @@ -30,6 +30,7 @@ import ( devmodel "github.com/cloudwego/eino-ext/devops/model" "github.com/cloudwego/eino/components" + componentmodel "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/compose" @@ -64,6 +65,21 @@ func (m mockContainerImplV2) Name() string { return m.NN } +type mockAgenticModel struct{} + +func (m *mockAgenticModel) Generate(ctx context.Context, input []*schema.AgenticMessage, + opts ...componentmodel.Option) (*schema.AgenticMessage, error) { + if len(input) == 0 { + return schema.UserAgenticMessage("mock"), nil + } + return input[0], nil +} + +func (m *mockAgenticModel) Stream(ctx context.Context, input []*schema.AgenticMessage, + opts ...componentmodel.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { + return nil, nil +} + type testCtxKey struct{} type testCallback struct { @@ -78,6 +94,41 @@ func (tt *testCallback) OnFinish(ctx context.Context, graphInfo *compose.GraphIn } func Test_GraphInfo_BuildDevGraph(t *testing.T) { + t.Run("graph: agentic prompt and model", func(t *testing.T) { + g := compose.NewGraph[map[string]any, *schema.AgenticMessage]() + err := g.AddAgenticChatTemplateNode("prompt", + prompt.FromAgenticMessages(schema.FString, schema.UserAgenticMessage("{query}"))) + assert.NoError(t, err) + err = g.AddAgenticModelNode("model", &mockAgenticModel{}) + assert.NoError(t, err) + err = g.AddEdge(compose.START, "prompt") + assert.NoError(t, err) + err = g.AddEdge("prompt", "model") + assert.NoError(t, err) + err = g.AddEdge("model", compose.END) + assert.NoError(t, err) + + tc := &testCallback{} + ctx := context.Background() + _, err = g.Compile(ctx, compose.WithGraphCompileCallbacks(tc)) + assert.NoError(t, err) + + ng, err := BuildDevGraph(tc.gi, compose.START) + assert.NoError(t, err) + + r, err := ng.Compile() + assert.NoError(t, err) + + input, err := UnmarshalJson([]byte(`{"query":{"_eino_go_type":"string","_value":"hello"}}`), ng.GraphInfo.InputType) + assert.NoError(t, err) + resp, err := r.Invoke(ctx, input) + assert.NoError(t, err) + + msg, ok := resp.(*schema.AgenticMessage) + assert.True(t, ok) + assert.Equal(t, schema.AgenticRoleTypeUser, msg.Role) + }) + t.Run("graph-chain: add chain, stateGraph,graph node", func(t *testing.T) { type mockInputType struct { Input string `json:"input"` @@ -1387,6 +1438,26 @@ func Test_Graph_addNode(t *testing.T) { assert.NoError(t, err) }) + t.Run("AgenticModel", func(t *testing.T) { + g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} + gni := compose.GraphNodeInfo{ + Component: components.ComponentOfAgenticModel, + Instance: &mockAgenticModel{}, + } + err := g.addNode("node_1", gni) + assert.NoError(t, err) + }) + + t.Run("AgenticPrompt", func(t *testing.T) { + g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} + gni := compose.GraphNodeInfo{ + Component: components.ComponentOfAgenticPrompt, + Instance: prompt.FromAgenticMessages(schema.FString, schema.UserAgenticMessage("hi")), + } + err := g.addNode("node_1", gni) + assert.NoError(t, err) + }) + t.Run("ToolsNode", func(t *testing.T) { g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} gni := compose.GraphNodeInfo{ @@ -1397,6 +1468,16 @@ func Test_Graph_addNode(t *testing.T) { assert.NoError(t, err) }) + t.Run("AgenticToolsNode", func(t *testing.T) { + g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} + gni := compose.GraphNodeInfo{ + Component: compose.ComponentOfAgenticToolsNode, + Instance: &compose.AgenticToolsNode{}, + } + err := g.addNode("node_1", gni) + assert.NoError(t, err) + }) + t.Run("Graph", func(t *testing.T) { g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} gni := compose.GraphNodeInfo{ diff --git a/devops/internal/model/types.go b/devops/internal/model/types.go index c72b0dd35..f85df038f 100644 --- a/devops/internal/model/types.go +++ b/devops/internal/model/types.go @@ -53,6 +53,9 @@ var registeredTypes = []RegisteredType{ {Identifier: "*schema.Message", Type: generic.TypeOf[*schema.Message]()}, {Identifier: "schema.Message", Type: generic.TypeOf[schema.Message]()}, {Identifier: "[]*schema.Message", Type: generic.TypeOf[[]*schema.Message]()}, + {Identifier: "*schema.AgenticMessage", Type: generic.TypeOf[*schema.AgenticMessage]()}, + {Identifier: "schema.AgenticMessage", Type: generic.TypeOf[schema.AgenticMessage]()}, + {Identifier: "[]*schema.AgenticMessage", Type: generic.TypeOf[[]*schema.AgenticMessage]()}, {Identifier: "*schema.Document", Type: generic.TypeOf[*schema.Document]()}, {Identifier: "schema.Document", Type: generic.TypeOf[schema.Document]()}, {Identifier: "[]*schema.Document", Type: generic.TypeOf[[]*schema.Document]()}, From 9b0150bb03494781915b0d9df1577cd5ca465fc8 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Mon, 1 Jun 2026 19:20:35 +0800 Subject: [PATCH 2/9] fix: synchronize langfuse stream trace io test --- callbacks/langfuse/langfuse_test.go | 33 ++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/callbacks/langfuse/langfuse_test.go b/callbacks/langfuse/langfuse_test.go index 37984c7fb..c9f8b6366 100644 --- a/callbacks/langfuse/langfuse_test.go +++ b/callbacks/langfuse/langfuse_test.go @@ -399,14 +399,17 @@ func TestTraceIOAutoPromotionStream(t *testing.T) { mockLangfuse.EXPECT().CreateSpan(gomock.Any()).Return("span-id", nil).Times(1) mockLangfuse.EXPECT().EndSpan(gomock.Any()).Return(nil).Times(2) - var traceInputs, traceOutputs []string + type traceIOEvent struct { + id string + input string + output string + } + traceEvents := make(chan traceIOEvent, 2) mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { - assert.Equal(t, "trace-id", body.ID) - if body.Input != "" { - traceInputs = append(traceInputs, body.Input) - } - if body.Output != "" { - traceOutputs = append(traceOutputs, body.Output) + traceEvents <- traceIOEvent{ + id: body.ID, + input: body.Input, + output: body.Output, } return nil }).Times(2) @@ -421,7 +424,21 @@ func TestTraceIOAutoPromotionStream(t *testing.T) { outsw.Close() cbh.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Name: "root"}, outsr) - time.Sleep(100 * time.Millisecond) + var traceInputs, traceOutputs []string + for i := 0; i < 2; i++ { + select { + case event := <-traceEvents: + assert.Equal(t, "trace-id", event.id) + if event.input != "" { + traceInputs = append(traceInputs, event.input) + } + if event.output != "" { + traceOutputs = append(traceOutputs, event.output) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for trace IO update") + } + } assert.Equal(t, []string{"[\"stream input\"]"}, traceInputs) assert.Equal(t, []string{"[\"stream output\"]"}, traceOutputs) From bdd0903086ed156b907b8a98c106e6ce9de43044 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 10:27:37 +0800 Subject: [PATCH 3/9] fix: remove automatic langfuse trace io updates --- callbacks/langfuse/README.md | 10 -- callbacks/langfuse/README_zh.md | 8 -- callbacks/langfuse/langfuse.go | 124 +++-------------------- callbacks/langfuse/langfuse_test.go | 150 ++-------------------------- callbacks/langfuse/trace.go | 5 +- 5 files changed, 20 insertions(+), 277 deletions(-) diff --git a/callbacks/langfuse/README.md b/callbacks/langfuse/README.md index 065293a14..d83a6ee16 100644 --- a/callbacks/langfuse/README.md +++ b/callbacks/langfuse/README.md @@ -129,10 +129,6 @@ type Config struct { // Public determines if traces are publicly accessible (Optional) Public bool - - // DisableTraceIO disables automatically writing root run input/output - // to trace input/output (Optional) - DisableTraceIO bool } ``` @@ -140,12 +136,6 @@ type Config struct { You can customize individual traces using the `SetTrace` function: -By default, the callback writes the root run input/output to the trace -input/output fields. Use `WithInput` when you want to provide the trace input -manually; use `UpdateTraceOutput` after the run when you want to override the -trace output manually. Set `DisableTraceIO` to `true` to turn off automatic -trace input/output updates. - ```go ctx = langfuse.SetTrace(ctx, langfuse.WithID("trace-id"), diff --git a/callbacks/langfuse/README_zh.md b/callbacks/langfuse/README_zh.md index 46503a8a2..6d762236f 100644 --- a/callbacks/langfuse/README_zh.md +++ b/callbacks/langfuse/README_zh.md @@ -129,9 +129,6 @@ type Config struct { // 是否公开可访问 (选填) Public bool - - // 禁用自动将根运行的输入/输出写入 trace input/output (选填) - DisableTraceIO bool } ``` @@ -139,11 +136,6 @@ type Config struct { 您可以使用 `SetTrace` 函数自定义单个追踪: -默认情况下,回调会将根运行的输入/输出写入 trace input/output 字段。使用 -`WithInput` 可以手动指定 trace input;运行结束后可调用 -`UpdateTraceOutput` 手动覆盖 trace output。将 `DisableTraceIO` 设为 `true` -可关闭自动 trace input/output 更新。 - ```go ctx = langfuse.SetTrace(ctx, langfuse.WithID("trace-id"), diff --git a/callbacks/langfuse/langfuse.go b/callbacks/langfuse/langfuse.go index ea56c3bdf..0b3edc940 100644 --- a/callbacks/langfuse/langfuse.go +++ b/callbacks/langfuse/langfuse.go @@ -118,11 +118,6 @@ type Config struct { // Default: false // Example: true Public bool - - // DisableTraceIO disables automatically writing root run input/output to the - // Langfuse trace input/output fields (Optional) - // Default: false - DisableTraceIO bool } func NewLangfuseHandler(cfg *Config) (handler *CallbackHandler, flusher func()) { @@ -171,8 +166,6 @@ func NewLangfuseHandler(cfg *Config) (handler *CallbackHandler, flusher func()) release: cfg.Release, tags: cfg.Tags, public: cfg.Public, - - disableTraceIO: cfg.DisableTraceIO, }, cli.Flush } @@ -185,16 +178,12 @@ type CallbackHandler struct { release string tags []string public bool - - disableTraceIO bool } type langfuseStateKey struct{} type langfuseState struct { traceID string observationID string - traceInputSet bool - isRoot bool } func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { @@ -206,7 +195,6 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, if state == nil { return ctx } - isRoot := state.observationID == "" if info.Component == components.ComponentOfChatModel { mcbi := model.ConvCallbackInput(input) @@ -231,21 +219,10 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, log.Printf("create generation error: %v, runinfo: %+v", err, info) return ctx } - nState := &langfuseState{ + return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ traceID: state.traceID, observationID: generationID, - traceInputSet: state.traceInputSet, - isRoot: isRoot, - } - if c.shouldUpdateTraceInput(nState) { - in, err_ := sonic.MarshalString(input) - if err_ != nil { - log.Printf("marshal trace input error: %v, runinfo: %+v", err_, info) - } else { - c.updateTraceInput(ctx, state.traceID, in) - } - } - return context.WithValue(ctx, langfuseStateKey{}, nState) + }) } in, err := sonic.MarshalString(input) @@ -268,16 +245,10 @@ func (c *CallbackHandler) OnStart(ctx context.Context, info *callbacks.RunInfo, log.Printf("create span error: %v", err) return ctx } - nState := &langfuseState{ + return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ traceID: state.traceID, observationID: spanID, - traceInputSet: state.traceInputSet, - isRoot: isRoot, - } - if c.shouldUpdateTraceInput(nState) { - c.updateTraceInput(ctx, state.traceID, in) - } - return context.WithValue(ctx, langfuseStateKey{}, nState) + }) } func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { @@ -312,14 +283,6 @@ func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou if err != nil { log.Printf("end generation error: %v, runinfo: %+v", err, info) } - if c.shouldUpdateTraceOutput(state) { - out, err_ := sonic.MarshalString(output) - if err_ != nil { - log.Printf("marshal trace output error: %v, runinfo: %+v", err_, info) - } else { - c.UpdateTraceOutput(ctx, state.traceID, out) - } - } return ctx } @@ -340,9 +303,6 @@ func (c *CallbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou if err != nil { log.Printf("end span fail: %v, runinfo: %+v", err, info) } - if c.shouldUpdateTraceOutput(state) { - c.UpdateTraceOutput(ctx, state.traceID, out) - } return ctx } @@ -374,9 +334,6 @@ func (c *CallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, if reportErr != nil { log.Printf("end generation fail: %v, runinfo: %+v, execute error: %v", reportErr, info, err) } - if c.shouldUpdateTraceOutput(state) { - c.UpdateTraceOutput(ctx, state.traceID, err.Error()) - } return ctx } @@ -393,9 +350,6 @@ func (c *CallbackHandler) OnError(ctx context.Context, info *callbacks.RunInfo, if reportErr != nil { log.Printf("end span fail: %v, runinfo: %+v, execute error: %v", reportErr, info, err) } - if c.shouldUpdateTraceOutput(state) { - c.UpdateTraceOutput(ctx, state.traceID, err.Error()) - } return ctx } @@ -408,7 +362,6 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if state == nil { return ctx } - isRoot := state.observationID == "" if info.Component == components.ComponentOfChatModel { generationID, err := c.cli.CreateGeneration(&langfuse.GenerationEventBody{ @@ -425,12 +378,6 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call log.Printf("create generation error: %v, runinfo: %+v", err, info) return ctx } - nState := &langfuseState{ - traceID: state.traceID, - observationID: generationID, - traceInputSet: state.traceInputSet, - isRoot: isRoot, - } go func() { defer func() { @@ -473,17 +420,12 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if err != nil { log.Printf("update stream generation fail: %v, runinfo: %+v", err, info) } - if c.shouldUpdateTraceInput(nState) { - in, err__ := sonic.MarshalString(ins) - if err__ != nil { - log.Printf("marshal trace stream input error: %v, runinfo: %+v", err__, info) - } else { - c.updateTraceInput(ctx, nState.traceID, in) - } - } }() - return context.WithValue(ctx, langfuseStateKey{}, nState) + return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ + traceID: state.traceID, + observationID: generationID, + }) } spanID, err := c.cli.CreateSpan(&langfuse.SpanEventBody{ @@ -500,12 +442,6 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call log.Printf("create span error: %v", err) return ctx } - nState := &langfuseState{ - traceID: state.traceID, - observationID: spanID, - traceInputSet: state.traceInputSet, - isRoot: isRoot, - } go func() { defer func() { @@ -544,12 +480,12 @@ func (c *CallbackHandler) OnStartWithStreamInput(ctx context.Context, info *call if err != nil { log.Printf("update stream span error: %v", err) } - if c.shouldUpdateTraceInput(nState) { - c.updateTraceInput(ctx, nState.traceID, in) - } }() - return context.WithValue(ctx, langfuseStateKey{}, nState) + return context.WithValue(ctx, langfuseStateKey{}, &langfuseState{ + traceID: state.traceID, + observationID: spanID, + }) } func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { @@ -609,14 +545,6 @@ func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callb if err != nil { log.Printf("end stream generation error: %v, runinfo: %+v", err, info) } - if c.shouldUpdateTraceOutput(state) { - out, err_ := sonic.MarshalString(outs) - if err_ != nil { - log.Printf("marshal trace stream output error: %v, runinfo: %+v", err_, info) - } else { - c.UpdateTraceOutput(ctx, state.traceID, out) - } - } }() return ctx } @@ -657,39 +585,11 @@ func (c *CallbackHandler) OnEndWithStreamOutput(ctx context.Context, info *callb if err != nil { log.Printf("end stream span fail: %v, runinfo: %+v", err, info) } - if c.shouldUpdateTraceOutput(state) { - c.UpdateTraceOutput(ctx, state.traceID, out) - } }() return ctx } -func (c *CallbackHandler) updateTraceInput(ctx context.Context, traceID string, input string) { - _ = ctx - err := c.cli.EndTrace(&langfuse.TraceEventBody{ - BaseEventBody: langfuse.BaseEventBody{ - ID: traceID, - }, - Input: input, - }) - if err != nil { - log.Printf("input end trace fail: %v, traceID: %s", err, traceID) - } -} - -func (c *CallbackHandler) shouldUpdateTraceInput(state *langfuseState) bool { - return c.shouldUpdateTraceIO(state) && !state.traceInputSet -} - -func (c *CallbackHandler) shouldUpdateTraceOutput(state *langfuseState) bool { - return c.shouldUpdateTraceIO(state) -} - -func (c *CallbackHandler) shouldUpdateTraceIO(state *langfuseState) bool { - return state != nil && state.isRoot && !c.disableTraceIO && len(state.traceID) > 0 -} - // UpdateTraceOutput pushes final trace output to Langfuse (via ACL EndTrace). // ctx is reserved for future cancellation / deadline propagation; callers may pass context.Background() for now. func (c *CallbackHandler) UpdateTraceOutput(ctx context.Context, traceID string, output string) { diff --git a/callbacks/langfuse/langfuse_test.go b/callbacks/langfuse/langfuse_test.go index c9f8b6366..b0c0a1675 100644 --- a/callbacks/langfuse/langfuse_test.go +++ b/callbacks/langfuse/langfuse_test.go @@ -42,13 +42,12 @@ func TestLangfuseCallback(t *testing.T) { mockLangfuse := mock.NewMockLangfuse(ctrl) defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() cbh, _ := NewLangfuseHandler(&Config{ - Name: "MyTrace", - UserID: "user id", - SessionID: "session", - Release: "release", - Tags: []string{"tag1", "tag2"}, - Public: true, - DisableTraceIO: true, + Name: "MyTrace", + UserID: "user id", + SessionID: "session", + Release: "release", + Tags: []string{"tag1", "tag2"}, + Public: true, }) callbacks.InitCallbackHandlers([]callbacks.Handler{cbh}) ctx := context.Background() @@ -304,146 +303,13 @@ func TestLangfuseCallback(t *testing.T) { WithPublic(true), WithEnvironment("development"), WithVersion("version"), - WithInput("manual input"), ) assert.Equal(t, "traceid", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).ID) assert.Equal(t, "development", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Environment) assert.Equal(t, "version", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Version) - assert.Equal(t, "manual input", ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).Input) - assert.True(t, ctx.Value(langfuseTraceOptionKey{}).(*traceOptions).inputSet) }) } -func TestTraceIOAutoPromotion(t *testing.T) { - ctrl := gomock.NewController(t) - mockLangfuse := mock.NewMockLangfuse(ctrl) - defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() - - cbh, _ := NewLangfuseHandler(&Config{ - Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", - Name: "trace", - }) - - mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) - mockLangfuse.EXPECT().CreateSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) (string, error) { - assert.Equal(t, "trace-id", body.TraceID) - assert.Empty(t, body.ParentObservationID) - assert.Equal(t, "\"input\"", body.Input) - return "span-id", nil - }).Times(1) - mockLangfuse.EXPECT().EndSpan(gomock.Any()).DoAndReturn(func(body *langfuse.SpanEventBody) error { - assert.Equal(t, "span-id", body.ID) - assert.Equal(t, "\"output\"", body.Output) - return nil - }).Times(1) - - var traceInputs, traceOutputs []string - mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { - assert.Equal(t, "trace-id", body.ID) - if body.Input != "" { - traceInputs = append(traceInputs, body.Input) - } - if body.Output != "" { - traceOutputs = append(traceOutputs, body.Output) - } - return nil - }).Times(2) - - ctx := cbh.OnStart(context.Background(), &callbacks.RunInfo{Name: "root"}, callbacks.CallbackInput("input")) - cbh.OnEnd(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackOutput("output")) - - assert.Equal(t, []string{"\"input\""}, traceInputs) - assert.Equal(t, []string{"\"output\""}, traceOutputs) -} - -func TestTraceIOManualInputNotOverwritten(t *testing.T) { - ctrl := gomock.NewController(t) - mockLangfuse := mock.NewMockLangfuse(ctrl) - defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() - - cbh, _ := NewLangfuseHandler(&Config{ - Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", - Name: "trace", - }) - - mockLangfuse.EXPECT().CreateTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) (string, error) { - assert.Equal(t, "trace-id", body.ID) - assert.Equal(t, "manual input", body.Input) - return "trace-id", nil - }).Times(1) - mockLangfuse.EXPECT().CreateSpan(gomock.Any()).Return("span-id", nil).Times(1) - mockLangfuse.EXPECT().EndSpan(gomock.Any()).Return(nil).Times(1) - mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { - assert.Equal(t, "trace-id", body.ID) - assert.Empty(t, body.Input) - assert.Equal(t, "\"output\"", body.Output) - return nil - }).Times(1) - - ctx := SetTrace(context.Background(), WithID("trace-id"), WithInput("manual input")) - ctx = cbh.OnStart(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackInput("auto input")) - cbh.OnEnd(ctx, &callbacks.RunInfo{Name: "root"}, callbacks.CallbackOutput("output")) -} - -func TestTraceIOAutoPromotionStream(t *testing.T) { - ctrl := gomock.NewController(t) - mockLangfuse := mock.NewMockLangfuse(ctrl) - defer mockey.Mock(langfuse.NewLangfuse).Return(mockLangfuse).Build().UnPatch() - - cbh, _ := NewLangfuseHandler(&Config{ - Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", - Name: "trace", - }) - - mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) - mockLangfuse.EXPECT().CreateSpan(gomock.Any()).Return("span-id", nil).Times(1) - mockLangfuse.EXPECT().EndSpan(gomock.Any()).Return(nil).Times(2) - - type traceIOEvent struct { - id string - input string - output string - } - traceEvents := make(chan traceIOEvent, 2) - mockLangfuse.EXPECT().EndTrace(gomock.Any()).DoAndReturn(func(body *langfuse.TraceEventBody) error { - traceEvents <- traceIOEvent{ - id: body.ID, - input: body.Input, - output: body.Output, - } - return nil - }).Times(2) - - insr, insw := schema.Pipe[callbacks.CallbackInput](1) - insw.Send(callbacks.CallbackInput("stream input"), nil) - insw.Close() - ctx := cbh.OnStartWithStreamInput(context.Background(), &callbacks.RunInfo{Name: "root"}, insr) - - outsr, outsw := schema.Pipe[callbacks.CallbackOutput](1) - outsw.Send(callbacks.CallbackOutput("stream output"), nil) - outsw.Close() - cbh.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Name: "root"}, outsr) - - var traceInputs, traceOutputs []string - for i := 0; i < 2; i++ { - select { - case event := <-traceEvents: - assert.Equal(t, "trace-id", event.id) - if event.input != "" { - traceInputs = append(traceInputs, event.input) - } - if event.output != "" { - traceOutputs = append(traceOutputs, event.output) - } - case <-time.After(time.Second): - t.Fatal("timed out waiting for trace IO update") - } - } - - assert.Equal(t, []string{"[\"stream input\"]"}, traceInputs) - assert.Equal(t, []string{"[\"stream output\"]"}, traceOutputs) -} - func TestAttack_NilMessageInOnEnd(t *testing.T) { ctrl := gomock.NewController(t) mockLangfuse := mock.NewMockLangfuse(ctrl) @@ -451,8 +317,7 @@ func TestAttack_NilMessageInOnEnd(t *testing.T) { cbh, _ := NewLangfuseHandler(&Config{ Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", - Name: "trace", - DisableTraceIO: true, + Name: "trace", }) mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) mockLangfuse.EXPECT().CreateGeneration(gomock.Any()).Return("generation-id", nil).Times(1) @@ -479,7 +344,6 @@ func TestAttack_ExtractModelOutputErrorIgnored(t *testing.T) { cbh, _ := NewLangfuseHandler(&Config{ Host: "http://localhost", PublicKey: "pk", SecretKey: "sk", - DisableTraceIO: true, }) mockLangfuse.EXPECT().CreateTrace(gomock.Any()).Return("trace-id", nil).Times(1) mockLangfuse.EXPECT().CreateGeneration(gomock.Any()).Return("generation-id", nil).Times(1) diff --git a/callbacks/langfuse/trace.go b/callbacks/langfuse/trace.go index 2100c434a..a1a9350ba 100644 --- a/callbacks/langfuse/trace.go +++ b/callbacks/langfuse/trace.go @@ -54,7 +54,6 @@ func WithUserID(userID string) TraceOption { func WithInput(input string) TraceOption { return func(o *traceOptions) { o.Input = input - o.inputSet = true } } func WithSessionID(sessionID string) TraceOption { @@ -98,7 +97,6 @@ type traceOptions struct { Name string UserID string Input string - inputSet bool SessionID string Release string Tags []string @@ -129,8 +127,7 @@ func initState(_ context.Context, cli langfuse.Langfuse, options *traceOptions) return nil, fmt.Errorf("create trace error: %v", err) } s := &langfuseState{ - traceID: traceID, - traceInputSet: options.inputSet, + traceID: traceID, } return s, nil } From 4e802c5f7789d04b5ce9736d8e8565f56de50f11 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 11:19:32 +0800 Subject: [PATCH 4/9] fix: keep token usage aggregation scoped --- callbacks/cozeloop/data_parser.go | 21 ------- callbacks/cozeloop/data_parser_test.go | 50 ----------------- components/model/claude/claude.go | 77 +------------------------- components/model/claude/claude_test.go | 73 ------------------------ 4 files changed, 3 insertions(+), 218 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index 411248e1c..13ca8ecd4 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -527,9 +527,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou } usage = mergeCumulativeTokenUsage(usage, cbOutput.TokenUsage) - if cbOutput.Message != nil && cbOutput.Message.ResponseMeta != nil { - usage = mergeCumulativeTokenUsage(usage, schemaTokenUsageToModelTokenUsage(cbOutput.Message.ResponseMeta.TokenUsage)) - } if cbOutput.Config != nil && !onceSet { onceSet = true @@ -610,24 +607,6 @@ func mergeCumulativeTokenUsage(dst, src *model.TokenUsage) *model.TokenUsage { return dst } -func schemaTokenUsageToModelTokenUsage(usage *schema.TokenUsage) *model.TokenUsage { - if usage == nil { - return nil - } - - return &model.TokenUsage{ - PromptTokens: usage.PromptTokens, - PromptTokenDetails: model.PromptTokenDetails{ - CachedTokens: usage.PromptTokenDetails.CachedTokens, - }, - CompletionTokens: usage.CompletionTokens, - CompletionTokensDetails: model.CompletionTokensDetails{ - ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, - }, - TotalTokens: usage.TotalTokens, - } -} - func (d defaultDataParser) ParseDefaultStreamInput(ctx context.Context, input *schema.StreamReader[callbacks.CallbackInput]) (chunks []any, err error) { for { item, recvErr := input.Recv() diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index d922b5965..3c679a6e3 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -436,56 +436,6 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage }) } -func Test_defaultDataParser_ParseAgenticModelStreamOutput_MergeMessageMetaTokenUsage(t *testing.T) { - mockey.PatchConvey("测试 AgenticModel 流式输出合并 message meta token usage", t, func() { - outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) - outsw.Send(&model.AgenticCallbackOutput{ - Message: &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeAssistant, - ContentBlocks: []*schema.ContentBlock{ - schema.NewContentBlock(&schema.AssistantGenText{Text: "assistant"}), - }, - ResponseMeta: &schema.AgenticResponseMeta{ - TokenUsage: &schema.TokenUsage{ - PromptTokens: 6900, - PromptTokenDetails: schema.PromptTokenDetails{ - CachedTokens: 3265, - }, - CompletionTokens: 1, - TotalTokens: 6901, - }, - }, - }, - }, nil) - outsw.Send(&model.AgenticCallbackOutput{ - Message: &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeAssistant, - ContentBlocks: []*schema.ContentBlock{ - schema.NewContentBlock(&schema.AssistantGenText{Text: " message"}), - }, - ResponseMeta: &schema.AgenticResponseMeta{ - TokenUsage: &schema.TokenUsage{ - CompletionTokens: 69, - CompletionTokensDetails: schema.CompletionTokensDetails{ - ReasoningTokens: 12, - }, - }, - }, - }, - }, nil) - outsw.Close() - - d := defaultDataParser{} - result := d.ParseAgenticModelStreamOutput(context.Background(), outsr) - - convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) - convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) - convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) - convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6969) - }) -} - // Test_defaultDataParser_tryConcatChunks 为 defaultDataParser 的 tryConcatChunks 方法编写单元测试 func Test_defaultDataParser_tryConcatChunks(t *testing.T) { mockey.PatchConvey("测试 defaultDataParser 的 tryConcatChunks 方法", t, func() { diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index 9f00840b7..a8da5a321 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -374,7 +374,6 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . }() var waitList []*schema.Message streamCtx := &streamContext{} - var usage *schema.TokenUsage for stream.Next() { message, err_ := convStreamEvent(stream.Current(), streamCtx) if err_ != nil { @@ -397,7 +396,6 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . } waitList = []*schema.Message{} } - usage = applyClaudeStreamUsage(message, usage) closed := sw.Send(cm.getCallbackOutput(message), nil) if closed { @@ -411,7 +409,6 @@ func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts . _ = sw.Send(nil, fmt.Errorf("concat empty message fail: %w", err_)) return } - usage = applyClaudeStreamUsage(message, usage) closed := sw.Send(cm.getCallbackOutput(message), nil) if closed { @@ -1359,22 +1356,11 @@ func convStreamEvent(event anthropic.MessageStreamEventUnion, streamCtx *streamC case anthropic.MessageStartEvent: return convOutputMessage(&e.Message) case anthropic.MessageDeltaEvent: - completionTokens := int(e.Usage.OutputTokens) - usage := &schema.TokenUsage{ - CompletionTokens: completionTokens, - } - if hasMessageDeltaPromptUsage(e.Usage) { - promptTokens := int(e.Usage.InputTokens + e.Usage.CacheReadInputTokens + e.Usage.CacheCreationInputTokens) - usage.PromptTokens = promptTokens - usage.PromptTokenDetails = schema.PromptTokenDetails{ - CachedTokens: int(e.Usage.CacheReadInputTokens), - } - usage.TotalTokens = promptTokens + completionTokens - } - result.ResponseMeta = &schema.ResponseMeta{ FinishReason: string(e.Delta.StopReason), - Usage: usage, + Usage: &schema.TokenUsage{ + CompletionTokens: int(e.Usage.OutputTokens), + }, } return result, nil @@ -1423,63 +1409,6 @@ func convStreamEvent(event anthropic.MessageStreamEventUnion, streamCtx *streamC } } -func hasMessageDeltaPromptUsage(usage anthropic.MessageDeltaUsage) bool { - return usage.InputTokens != 0 || - usage.CacheReadInputTokens != 0 || - usage.CacheCreationInputTokens != 0 || - usage.JSON.InputTokens.Valid() || - usage.JSON.CacheReadInputTokens.Valid() || - usage.JSON.CacheCreationInputTokens.Valid() -} - -func applyClaudeStreamUsage(message *schema.Message, usage *schema.TokenUsage) *schema.TokenUsage { - if message == nil || message.ResponseMeta == nil || message.ResponseMeta.Usage == nil { - return usage - } - - usage = mergeClaudeStreamTokenUsage(usage, message.ResponseMeta.Usage) - message.ResponseMeta.Usage = cloneClaudeStreamTokenUsage(usage) - return usage -} - -func mergeClaudeStreamTokenUsage(dst, src *schema.TokenUsage) *schema.TokenUsage { - if src == nil { - return dst - } - if dst == nil { - dst = &schema.TokenUsage{} - } - - if src.PromptTokens > dst.PromptTokens { - dst.PromptTokens = src.PromptTokens - } - if src.CompletionTokens > dst.CompletionTokens { - dst.CompletionTokens = src.CompletionTokens - } - if src.TotalTokens > dst.TotalTokens { - dst.TotalTokens = src.TotalTokens - } - if src.PromptTokenDetails.CachedTokens > dst.PromptTokenDetails.CachedTokens { - dst.PromptTokenDetails.CachedTokens = src.PromptTokenDetails.CachedTokens - } - if src.CompletionTokensDetails.ReasoningTokens > dst.CompletionTokensDetails.ReasoningTokens { - dst.CompletionTokensDetails.ReasoningTokens = src.CompletionTokensDetails.ReasoningTokens - } - - if total := dst.PromptTokens + dst.CompletionTokens; total > dst.TotalTokens { - dst.TotalTokens = total - } - return dst -} - -func cloneClaudeStreamTokenUsage(usage *schema.TokenUsage) *schema.TokenUsage { - if usage == nil { - return nil - } - cloned := *usage - return &cloned -} - func convImageBase64(data string) (string, string, error) { if !strings.HasPrefix(data, "data:") { return "", "", fmt.Errorf("invalid base64 image: %s", data) diff --git a/components/model/claude/claude_test.go b/components/model/claude/claude_test.go index 442876791..7e28cf776 100644 --- a/components/model/claude/claude_test.go +++ b/components/model/claude/claude_test.go @@ -354,32 +354,6 @@ func TestConvStreamEvent(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason) assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens) - assert.Equal(t, 0, message.ResponseMeta.Usage.PromptTokens) - assert.Equal(t, 0, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens) - assert.Equal(t, 0, message.ResponseMeta.Usage.TotalTokens) - }) - - mockey.PatchConvey("message delta event with prompt usage", t, func() { - event := anthropic.MessageStreamEventUnion{} - defer mockey.Mock(anthropic.MessageStreamEventUnion.AsAny).Return(anthropic.MessageDeltaEvent{ - Delta: anthropic.MessageDeltaEventDelta{ - StopReason: "end_turn", - }, - Usage: anthropic.MessageDeltaUsage{ - InputTokens: 5, - CacheReadInputTokens: 3, - CacheCreationInputTokens: 2, - OutputTokens: 10, - }, - }).Build().UnPatch() - - message, err := convStreamEvent(event, streamCtx) - assert.NoError(t, err) - assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason) - assert.Equal(t, 10, message.ResponseMeta.Usage.PromptTokens) - assert.Equal(t, 3, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens) - assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens) - assert.Equal(t, 20, message.ResponseMeta.Usage.TotalTokens) }) mockey.PatchConvey("content block start event", t, func() { @@ -402,53 +376,6 @@ func TestConvStreamEvent(t *testing.T) { }) } -func TestMergeClaudeStreamTokenUsage(t *testing.T) { - usage := mergeClaudeStreamTokenUsage(nil, &schema.TokenUsage{ - PromptTokens: 6900, - PromptTokenDetails: schema.PromptTokenDetails{ - CachedTokens: 3265, - }, - CompletionTokens: 1, - TotalTokens: 6901, - }) - usage = mergeClaudeStreamTokenUsage(usage, &schema.TokenUsage{ - CompletionTokens: 69, - }) - - assert.Equal(t, 6900, usage.PromptTokens) - assert.Equal(t, 3265, usage.PromptTokenDetails.CachedTokens) - assert.Equal(t, 69, usage.CompletionTokens) - assert.Equal(t, 6969, usage.TotalTokens) -} - -func TestApplyClaudeStreamUsageClonesCumulativeUsage(t *testing.T) { - var usage *schema.TokenUsage - first := &schema.Message{ - ResponseMeta: &schema.ResponseMeta{ - Usage: &schema.TokenUsage{ - PromptTokens: 10, - CompletionTokens: 1, - TotalTokens: 11, - }, - }, - } - usage = applyClaudeStreamUsage(first, usage) - - second := &schema.Message{ - ResponseMeta: &schema.ResponseMeta{ - Usage: &schema.TokenUsage{ - CompletionTokens: 5, - }, - }, - } - usage = applyClaudeStreamUsage(second, usage) - - assert.Equal(t, 11, first.ResponseMeta.Usage.TotalTokens) - assert.Equal(t, 10, second.ResponseMeta.Usage.PromptTokens) - assert.Equal(t, 5, second.ResponseMeta.Usage.CompletionTokens) - assert.Equal(t, 15, second.ResponseMeta.Usage.TotalTokens) -} - func TestPanicErr(t *testing.T) { err := newPanicErr("info", []byte("stack")) assert.Equal(t, "panic error: info, \nstack: stack", err.Error()) From d92096e47322af43d8f195ed20bbff35bd0a35d6 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 14:29:02 +0800 Subject: [PATCH 5/9] fix: align cozeloop token usage merge --- callbacks/cozeloop/data_parser.go | 4 ---- callbacks/cozeloop/data_parser_test.go | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index 13ca8ecd4..3e40d80b1 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -600,10 +600,6 @@ func mergeCumulativeTokenUsage(dst, src *model.TokenUsage) *model.TokenUsage { dst.CompletionTokensDetails.ReasoningTokens = src.CompletionTokensDetails.ReasoningTokens } - if total := dst.PromptTokens + dst.CompletionTokens; total > dst.TotalTokens { - dst.TotalTokens = total - } - return dst } diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index 3c679a6e3..cd7263f9b 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -404,7 +404,7 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6969) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) }) mockey.PatchConvey("测试 ChatModel 流式输出使用最终累计 token usage", t, func() { From c7711f44eef993dc5772b03f740dafaf8437e66f Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 14:46:50 +0800 Subject: [PATCH 6/9] fix: emit chat model reasoning token tags --- callbacks/cozeloop/data_parser.go | 6 ++++-- callbacks/cozeloop/data_parser_test.go | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index 3e40d80b1..107298d77 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -208,7 +208,8 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI tags.set(tracespec.Tokens, cbOutput.TokenUsage.TotalTokens). set(tracespec.InputTokens, cbOutput.TokenUsage.PromptTokens). set(tracespec.OutputTokens, cbOutput.TokenUsage.CompletionTokens). - set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens) + set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens). + set(tracespec.ReasoningTokens, cbOutput.TokenUsage.CompletionTokensDetails.ReasoningTokens) } } @@ -489,7 +490,8 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu tags.set(tracespec.Tokens, usage.TotalTokens). set(tracespec.InputTokens, usage.PromptTokens). set(tracespec.OutputTokens, usage.CompletionTokens). - set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens) + set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens). + set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens) } return tags diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index cd7263f9b..b77a03b96 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -120,6 +120,17 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) { Role: schema.Assistant, Content: "Hello, how can I assist you today?", }, + TokenUsage: &model.TokenUsage{ + PromptTokens: 1, + CompletionTokens: 2, + TotalTokens: 3, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: 4, + }, + CompletionTokensDetails: model.CompletionTokensDetails{ + ReasoningTokens: 5, + }, + }, } var outputs callbacks.CallbackOutput = []*schema.Message{ { @@ -148,6 +159,11 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) { convey.So(result, convey.ShouldNotBeNil) convey.So(result, convey.ShouldContainKey, tracespec.Output) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 3) + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 1) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 2) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 4) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 5) mockConvertModelOutput.UnPatch() mockGetTraceVariablesValue.UnPatch() @@ -394,6 +410,9 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage Message: &schema.Message{Role: schema.Assistant, Content: " message"}, TokenUsage: &model.TokenUsage{ CompletionTokens: 69, + CompletionTokensDetails: model.CompletionTokensDetails{ + ReasoningTokens: 12, + }, }, }, nil) outsw.Close() @@ -404,6 +423,7 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) }) From c95514e561aebab48a91464306b0bb600bf4c403 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 17:40:28 +0800 Subject: [PATCH 7/9] fix: prefer message meta token usage --- callbacks/cozeloop/data_parser.go | 78 ++++++++---- callbacks/cozeloop/data_parser_test.go | 163 ++++++++++++++++++++++++ devops/go.mod | 11 +- devops/go.sum | 26 ++-- devops/internal/model/container.go | 23 +--- devops/internal/model/container_test.go | 81 ------------ devops/internal/model/types.go | 3 - 7 files changed, 231 insertions(+), 154 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index 107298d77..f9b66df8c 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -204,13 +204,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI tags.set(tracespec.Output, finalOutput) tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra) - if cbOutput.TokenUsage != nil { - tags.set(tracespec.Tokens, cbOutput.TokenUsage.TotalTokens). - set(tracespec.InputTokens, cbOutput.TokenUsage.PromptTokens). - set(tracespec.OutputTokens, cbOutput.TokenUsage.CompletionTokens). - set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens). - set(tracespec.ReasoningTokens, cbOutput.TokenUsage.CompletionTokensDetails.ReasoningTokens) - } + setTokenUsageTags(tags, getMessageTokenUsage(cbOutput.Message, cbOutput.TokenUsage)) } tags.set(tracespec.Stream, false) @@ -231,13 +225,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI tags.set(tracespec.Output, finalOutput) tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra) - if cbOutput.TokenUsage != nil { - tags.set(tracespec.Tokens, cbOutput.TokenUsage.TotalTokens). - set(tracespec.InputTokens, cbOutput.TokenUsage.PromptTokens). - set(tracespec.OutputTokens, cbOutput.TokenUsage.CompletionTokens). - set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens). - set(tracespec.ReasoningTokens, cbOutput.TokenUsage.CompletionTokensDetails.ReasoningTokens) - } + setTokenUsageTags(tags, getAgenticMessageTokenUsage(cbOutput.Message, cbOutput.TokenUsage)) if cbOutput.Config != nil { if cbOutput.Config.Model != "" { @@ -470,6 +458,7 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu } } + finalUsage := usage if msg, concatErr := schema.ConcatMessages(chunks); concatErr != nil { // unexpected finalOutput := parseAny(ctx, chunks, true) tags.set(tracespec.Output, finalOutput) @@ -481,18 +470,13 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu } } else { tags.set(tracespec.Output, convertModelOutput(&model.CallbackOutput{Message: msg})) + finalUsage = getMessageTokenUsage(msg, usage) if level == 2 { collectOutput.addMessages(convertModelMessage(msg)) } } - if usage != nil { - tags.set(tracespec.Tokens, usage.TotalTokens). - set(tracespec.InputTokens, usage.PromptTokens). - set(tracespec.OutputTokens, usage.CompletionTokens). - set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens). - set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens) - } + setTokenUsageTags(tags, finalUsage) return tags } @@ -547,6 +531,7 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou tags.set(tracespec.ModelName, modelName) } + finalUsage := usage if msg, concatErr := schema.ConcatAgenticMessages(chunks); concatErr != nil { finalOutput := parseAny(ctx, chunks, true) tags.set(tracespec.Output, finalOutput) @@ -558,22 +543,59 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou } } else { tags.set(tracespec.Output, convertAgenticModelOutput(&model.AgenticCallbackOutput{Message: msg})) + finalUsage = getAgenticMessageTokenUsage(msg, usage) if level == 2 { collectOutput.addMessages(expandAgenticModelMessage(msg)...) } } - if usage != nil { - tags.set(tracespec.Tokens, usage.TotalTokens). - set(tracespec.InputTokens, usage.PromptTokens). - set(tracespec.OutputTokens, usage.CompletionTokens). - set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens). - set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens) - } + setTokenUsageTags(tags, finalUsage) return tags } +func getMessageTokenUsage(msg *schema.Message, fallback *model.TokenUsage) *model.TokenUsage { + if msg == nil || msg.ResponseMeta == nil { + return fallback + } + return mergeCumulativeTokenUsage(schemaTokenUsageToModel(msg.ResponseMeta.Usage), fallback) +} + +func getAgenticMessageTokenUsage(msg *schema.AgenticMessage, fallback *model.TokenUsage) *model.TokenUsage { + if msg == nil || msg.ResponseMeta == nil { + return fallback + } + return mergeCumulativeTokenUsage(schemaTokenUsageToModel(msg.ResponseMeta.TokenUsage), fallback) +} + +func schemaTokenUsageToModel(usage *schema.TokenUsage) *model.TokenUsage { + if usage == nil { + return nil + } + return &model.TokenUsage{ + PromptTokens: usage.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: usage.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + CompletionTokensDetails: model.CompletionTokensDetails{ + ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, + }, + } +} + +func setTokenUsageTags(tags spanTags, usage *model.TokenUsage) { + if usage == nil { + return + } + tags.set(tracespec.Tokens, usage.TotalTokens). + set(tracespec.InputTokens, usage.PromptTokens). + set(tracespec.OutputTokens, usage.CompletionTokens). + set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens). + set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens) +} + // mergeCumulativeTokenUsage keeps the final request-level token usage from a // stream. Streaming callbacks may carry partial cumulative snapshots, and // TokenUsage does not preserve field-presence metadata, so each monotonically diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index b77a03b96..972cc78ae 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -392,7 +392,118 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) { }) } +func Test_defaultDataParser_ParseOutput_MessageMetaTokenUsage(t *testing.T) { + mockey.PatchConvey("测试非流式模型输出优先使用 message meta usage", t, func() { + d := defaultDataParser{} + ctx := context.Background() + + mockey.PatchConvey("ChatModel 从 Message.ResponseMeta.Usage 读取 token usage", func() { + info := &callbacks.RunInfo{Component: components.ComponentOfChatModel} + output := &model.CallbackOutput{ + Message: &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 11, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 7, + }, + CompletionTokens: 13, + TotalTokens: 24, + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: 5, + }, + }, + }, + }, + } + + result := d.ParseOutput(ctx, info, output) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 11) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 7) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 13) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 5) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 24) + }) + + mockey.PatchConvey("AgenticModel 从 AgenticMessage.ResponseMeta.TokenUsage 读取 token usage", func() { + info := &callbacks.RunInfo{Component: components.ComponentOfAgenticModel} + output := &model.AgenticCallbackOutput{ + Message: &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ResponseMeta: &schema.AgenticResponseMeta{ + TokenUsage: &schema.TokenUsage{ + PromptTokens: 17, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 3, + }, + CompletionTokens: 19, + TotalTokens: 36, + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: 2, + }, + }, + }, + }, + } + + result := d.ParseOutput(ctx, info, output) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 17) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 19) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 2) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 36) + }) + }) +} + func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage(t *testing.T) { + mockey.PatchConvey("测试 ChatModel 流式输出优先使用 concat 后的 message meta usage", t, func() { + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{ + Role: schema.Assistant, + Content: "assistant", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 6900, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 3265, + }, + CompletionTokens: 1, + TotalTokens: 6901, + }, + }, + }, + }, nil) + outsw.Send(&model.CallbackOutput{ + Message: &schema.Message{ + Role: schema.Assistant, + Content: " message", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + CompletionTokens: 69, + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: 12, + }, + }, + }, + }, + }, nil) + outsw.Close() + + d := defaultDataParser{} + result := d.ParseChatModelStreamOutput(context.Background(), outsr) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) + }) + mockey.PatchConvey("测试 ChatModel 流式输出合并 token usage", t, func() { outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) outsw.Send(&model.CallbackOutput{ @@ -456,6 +567,58 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage }) } +func Test_defaultDataParser_ParseAgenticModelStreamOutput_MessageMetaTokenUsage(t *testing.T) { + mockey.PatchConvey("测试 AgenticModel 流式输出优先使用 concat 后的 message meta usage", t, func() { + outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) + outsw.Send(&model.AgenticCallbackOutput{ + Message: &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{{ + Type: schema.ContentBlockTypeAssistantGenText, + AssistantGenText: &schema.AssistantGenText{Text: "assistant"}, + }}, + ResponseMeta: &schema.AgenticResponseMeta{ + TokenUsage: &schema.TokenUsage{ + PromptTokens: 6900, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: 3265, + }, + CompletionTokens: 1, + TotalTokens: 6901, + }, + }, + }, + }, nil) + outsw.Send(&model.AgenticCallbackOutput{ + Message: &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeAssistant, + ContentBlocks: []*schema.ContentBlock{{ + Type: schema.ContentBlockTypeAssistantGenText, + AssistantGenText: &schema.AssistantGenText{Text: " message"}, + }}, + ResponseMeta: &schema.AgenticResponseMeta{ + TokenUsage: &schema.TokenUsage{ + CompletionTokens: 69, + CompletionTokensDetails: schema.CompletionTokensDetails{ + ReasoningTokens: 12, + }, + }, + }, + }, + }, nil) + outsw.Close() + + d := defaultDataParser{} + result := d.ParseAgenticModelStreamOutput(context.Background(), outsr) + + convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) + convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) + convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) + convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) + convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) + }) +} + // Test_defaultDataParser_tryConcatChunks 为 defaultDataParser 的 tryConcatChunks 方法编写单元测试 func Test_defaultDataParser_tryConcatChunks(t *testing.T) { mockey.PatchConvey("测试 defaultDataParser 的 tryConcatChunks 方法", t, func() { diff --git a/devops/go.mod b/devops/go.mod index 2ea9eee26..6b44d8376 100644 --- a/devops/go.mod +++ b/devops/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/bytedance/mockey v1.2.12 - github.com/cloudwego/eino v0.9.1 + github.com/cloudwego/eino v0.6.0 github.com/gorilla/mux v1.8.1 github.com/matoous/go-nanoid v1.5.1 github.com/stretchr/testify v1.10.0 @@ -15,13 +15,12 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.15.0 // indirect - github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/bytedance/sonic v1.14.1 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/eino-contrib/jsonschema v1.0.3 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/eino-contrib/jsonschema v1.0.2 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -43,6 +42,6 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.29.0 // indirect + golang.org/x/sys v0.26.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/devops/go.sum b/devops/go.sum index e8df98b76..1d1626bb6 100644 --- a/devops/go.sum +++ b/devops/go.sum @@ -11,30 +11,28 @@ github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4= github.com/bytedance/mockey v1.2.12/go.mod h1:3ZA4MQasmqC87Tw0w7Ygdy7eHIc2xgpZ8Pona5rsYIk= -github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= -github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= -github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= -github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/bytedance/sonic v1.14.1 h1:FBMC0zVz5XUmE4z9wF4Jey0An5FueFvOsTKKKtwIl7w= +github.com/bytedance/sonic v1.14.1/go.mod h1:gi6uhQLMbTdeP0muCnrjHLeCUPyb70ujhnNlhOylAFc= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= -github.com/cloudwego/eino v0.9.1 h1:eSwgXfsaxmgTXsTgWi9OMBcm8hKvVhb1q0PPk58p6f8= -github.com/cloudwego/eino v0.9.1/go.mod h1:OBD1mrkfkt/pJa4rkg1P0VnaMeOVl7l8IAdEqY//3IQ= +github.com/cloudwego/eino v0.6.0 h1:pobGKMOfcQHVNhD9UT/HrvO0eYG6FC2ML/NKY2Eb9+Q= +github.com/cloudwego/eino v0.6.0/go.mod h1:JNapfU+QUrFFpboNDrNOFvmz0m9wjBFHHCr77RH6a50= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0= -github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= +github.com/eino-contrib/jsonschema v1.0.2 h1:HaxruBMUdnXa7Lg/lX8g0Hk71ZIfdTZXmBQz0e3esr8= +github.com/eino-contrib/jsonschema v1.0.2/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= @@ -94,12 +92,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -123,9 +121,9 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/devops/internal/model/container.go b/devops/internal/model/container.go index c50935933..8445c51ec 100644 --- a/devops/internal/model/container.go +++ b/devops/internal/model/container.go @@ -383,7 +383,7 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) } return g.AddEmbeddingNode(node, ins, newOpts...) - + case components.ComponentOfRetriever: ins, ok := gni.Instance.(retriever.Retriever) if !ok { @@ -405,13 +405,6 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddChatModelNode(node, ins, newOpts...) - case components.ComponentOfAgenticModel: - ins, ok := gni.Instance.(model.AgenticModel) - if !ok { - return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) - } - return g.AddAgenticModelNode(node, ins, newOpts...) - case components.ComponentOfPrompt: ins, ok := gni.Instance.(prompt.ChatTemplate) if !ok { @@ -419,13 +412,6 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddChatTemplateNode(node, ins, newOpts...) - case components.ComponentOfAgenticPrompt: - ins, ok := gni.Instance.(prompt.AgenticChatTemplate) - if !ok { - return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) - } - return g.AddAgenticChatTemplateNode(node, ins, newOpts...) - case compose.ComponentOfToolsNode: ins, ok := gni.Instance.(*compose.ToolsNode) if !ok { @@ -433,13 +419,6 @@ func (g *Graph) addNode(node string, gni compose.GraphNodeInfo, opts ...compose. } return g.AddToolsNode(node, ins, newOpts...) - case compose.ComponentOfAgenticToolsNode: - ins, ok := gni.Instance.(*compose.AgenticToolsNode) - if !ok { - return fmt.Errorf("component is %s, but get unexpected instance=%v", gni.Component, reflect.TypeOf(gni.Instance)) - } - return g.AddAgenticToolsNode(node, ins, newOpts...) - case compose.ComponentOfLambda: ins, ok := gni.Instance.(*compose.Lambda) if !ok { diff --git a/devops/internal/model/container_test.go b/devops/internal/model/container_test.go index 17a92a2df..6db3dfed8 100644 --- a/devops/internal/model/container_test.go +++ b/devops/internal/model/container_test.go @@ -30,7 +30,6 @@ import ( devmodel "github.com/cloudwego/eino-ext/devops/model" "github.com/cloudwego/eino/components" - componentmodel "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/compose" @@ -65,21 +64,6 @@ func (m mockContainerImplV2) Name() string { return m.NN } -type mockAgenticModel struct{} - -func (m *mockAgenticModel) Generate(ctx context.Context, input []*schema.AgenticMessage, - opts ...componentmodel.Option) (*schema.AgenticMessage, error) { - if len(input) == 0 { - return schema.UserAgenticMessage("mock"), nil - } - return input[0], nil -} - -func (m *mockAgenticModel) Stream(ctx context.Context, input []*schema.AgenticMessage, - opts ...componentmodel.Option) (*schema.StreamReader[*schema.AgenticMessage], error) { - return nil, nil -} - type testCtxKey struct{} type testCallback struct { @@ -94,41 +78,6 @@ func (tt *testCallback) OnFinish(ctx context.Context, graphInfo *compose.GraphIn } func Test_GraphInfo_BuildDevGraph(t *testing.T) { - t.Run("graph: agentic prompt and model", func(t *testing.T) { - g := compose.NewGraph[map[string]any, *schema.AgenticMessage]() - err := g.AddAgenticChatTemplateNode("prompt", - prompt.FromAgenticMessages(schema.FString, schema.UserAgenticMessage("{query}"))) - assert.NoError(t, err) - err = g.AddAgenticModelNode("model", &mockAgenticModel{}) - assert.NoError(t, err) - err = g.AddEdge(compose.START, "prompt") - assert.NoError(t, err) - err = g.AddEdge("prompt", "model") - assert.NoError(t, err) - err = g.AddEdge("model", compose.END) - assert.NoError(t, err) - - tc := &testCallback{} - ctx := context.Background() - _, err = g.Compile(ctx, compose.WithGraphCompileCallbacks(tc)) - assert.NoError(t, err) - - ng, err := BuildDevGraph(tc.gi, compose.START) - assert.NoError(t, err) - - r, err := ng.Compile() - assert.NoError(t, err) - - input, err := UnmarshalJson([]byte(`{"query":{"_eino_go_type":"string","_value":"hello"}}`), ng.GraphInfo.InputType) - assert.NoError(t, err) - resp, err := r.Invoke(ctx, input) - assert.NoError(t, err) - - msg, ok := resp.(*schema.AgenticMessage) - assert.True(t, ok) - assert.Equal(t, schema.AgenticRoleTypeUser, msg.Role) - }) - t.Run("graph-chain: add chain, stateGraph,graph node", func(t *testing.T) { type mockInputType struct { Input string `json:"input"` @@ -1438,26 +1387,6 @@ func Test_Graph_addNode(t *testing.T) { assert.NoError(t, err) }) - t.Run("AgenticModel", func(t *testing.T) { - g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} - gni := compose.GraphNodeInfo{ - Component: components.ComponentOfAgenticModel, - Instance: &mockAgenticModel{}, - } - err := g.addNode("node_1", gni) - assert.NoError(t, err) - }) - - t.Run("AgenticPrompt", func(t *testing.T) { - g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} - gni := compose.GraphNodeInfo{ - Component: components.ComponentOfAgenticPrompt, - Instance: prompt.FromAgenticMessages(schema.FString, schema.UserAgenticMessage("hi")), - } - err := g.addNode("node_1", gni) - assert.NoError(t, err) - }) - t.Run("ToolsNode", func(t *testing.T) { g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} gni := compose.GraphNodeInfo{ @@ -1468,16 +1397,6 @@ func Test_Graph_addNode(t *testing.T) { assert.NoError(t, err) }) - t.Run("AgenticToolsNode", func(t *testing.T) { - g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} - gni := compose.GraphNodeInfo{ - Component: compose.ComponentOfAgenticToolsNode, - Instance: &compose.AgenticToolsNode{}, - } - err := g.addNode("node_1", gni) - assert.NoError(t, err) - }) - t.Run("Graph", func(t *testing.T) { g := &Graph{Graph: compose.NewGraph[any, any](compose.WithGenLocalState(genState))} gni := compose.GraphNodeInfo{ diff --git a/devops/internal/model/types.go b/devops/internal/model/types.go index f85df038f..c72b0dd35 100644 --- a/devops/internal/model/types.go +++ b/devops/internal/model/types.go @@ -53,9 +53,6 @@ var registeredTypes = []RegisteredType{ {Identifier: "*schema.Message", Type: generic.TypeOf[*schema.Message]()}, {Identifier: "schema.Message", Type: generic.TypeOf[schema.Message]()}, {Identifier: "[]*schema.Message", Type: generic.TypeOf[[]*schema.Message]()}, - {Identifier: "*schema.AgenticMessage", Type: generic.TypeOf[*schema.AgenticMessage]()}, - {Identifier: "schema.AgenticMessage", Type: generic.TypeOf[schema.AgenticMessage]()}, - {Identifier: "[]*schema.AgenticMessage", Type: generic.TypeOf[[]*schema.AgenticMessage]()}, {Identifier: "*schema.Document", Type: generic.TypeOf[*schema.Document]()}, {Identifier: "schema.Document", Type: generic.TypeOf[schema.Document]()}, {Identifier: "[]*schema.Document", Type: generic.TypeOf[[]*schema.Document]()}, From 5025bf82328a42fb7cc3d41b5c7bb39935ad6ce7 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Wed, 3 Jun 2026 17:59:16 +0800 Subject: [PATCH 8/9] fix: read cozeloop stream usage from message meta --- callbacks/cozeloop/data_parser.go | 63 +++----------- callbacks/cozeloop/data_parser_test.go | 109 ++----------------------- 2 files changed, 17 insertions(+), 155 deletions(-) diff --git a/callbacks/cozeloop/data_parser.go b/callbacks/cozeloop/data_parser.go index f9b66df8c..f4013070d 100644 --- a/callbacks/cozeloop/data_parser.go +++ b/callbacks/cozeloop/data_parser.go @@ -204,7 +204,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI tags.set(tracespec.Output, finalOutput) tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra) - setTokenUsageTags(tags, getMessageTokenUsage(cbOutput.Message, cbOutput.TokenUsage)) + setTokenUsageTags(tags, cbOutput.TokenUsage) } tags.set(tracespec.Stream, false) @@ -225,7 +225,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI tags.set(tracespec.Output, finalOutput) tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra) - setTokenUsageTags(tags, getAgenticMessageTokenUsage(cbOutput.Message, cbOutput.TokenUsage)) + setTokenUsageTags(tags, cbOutput.TokenUsage) if cbOutput.Config != nil { if cbOutput.Config.Model != "" { @@ -422,7 +422,6 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu chunks []*schema.Message onceSet bool tags = make(spanTags) - usage *model.TokenUsage ) level := getGraphNodeLevelFromCtx(ctx) @@ -447,8 +446,6 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu chunks = append(chunks, cbOutput.Message) } - usage = mergeCumulativeTokenUsage(usage, cbOutput.TokenUsage) - if cbOutput.Config != nil && !onceSet { onceSet = true @@ -458,7 +455,6 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu } } - finalUsage := usage if msg, concatErr := schema.ConcatMessages(chunks); concatErr != nil { // unexpected finalOutput := parseAny(ctx, chunks, true) tags.set(tracespec.Output, finalOutput) @@ -470,14 +466,12 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu } } else { tags.set(tracespec.Output, convertModelOutput(&model.CallbackOutput{Message: msg})) - finalUsage = getMessageTokenUsage(msg, usage) + setTokenUsageTags(tags, getMessageMetaTokenUsage(msg)) if level == 2 { collectOutput.addMessages(convertModelMessage(msg)) } } - setTokenUsageTags(tags, finalUsage) - return tags } @@ -486,7 +480,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou chunks []*schema.AgenticMessage onceSet bool tags = make(spanTags) - usage *model.TokenUsage modelName string ) @@ -512,8 +505,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou chunks = append(chunks, cbOutput.Message) } - usage = mergeCumulativeTokenUsage(usage, cbOutput.TokenUsage) - if cbOutput.Config != nil && !onceSet { onceSet = true @@ -531,7 +522,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou tags.set(tracespec.ModelName, modelName) } - finalUsage := usage if msg, concatErr := schema.ConcatAgenticMessages(chunks); concatErr != nil { finalOutput := parseAny(ctx, chunks, true) tags.set(tracespec.Output, finalOutput) @@ -543,29 +533,27 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou } } else { tags.set(tracespec.Output, convertAgenticModelOutput(&model.AgenticCallbackOutput{Message: msg})) - finalUsage = getAgenticMessageTokenUsage(msg, usage) + setTokenUsageTags(tags, getAgenticMessageMetaTokenUsage(msg)) if level == 2 { collectOutput.addMessages(expandAgenticModelMessage(msg)...) } } - setTokenUsageTags(tags, finalUsage) - return tags } -func getMessageTokenUsage(msg *schema.Message, fallback *model.TokenUsage) *model.TokenUsage { +func getMessageMetaTokenUsage(msg *schema.Message) *model.TokenUsage { if msg == nil || msg.ResponseMeta == nil { - return fallback + return nil } - return mergeCumulativeTokenUsage(schemaTokenUsageToModel(msg.ResponseMeta.Usage), fallback) + return schemaTokenUsageToModel(msg.ResponseMeta.Usage) } -func getAgenticMessageTokenUsage(msg *schema.AgenticMessage, fallback *model.TokenUsage) *model.TokenUsage { +func getAgenticMessageMetaTokenUsage(msg *schema.AgenticMessage) *model.TokenUsage { if msg == nil || msg.ResponseMeta == nil { - return fallback + return nil } - return mergeCumulativeTokenUsage(schemaTokenUsageToModel(msg.ResponseMeta.TokenUsage), fallback) + return schemaTokenUsageToModel(msg.ResponseMeta.TokenUsage) } func schemaTokenUsageToModel(usage *schema.TokenUsage) *model.TokenUsage { @@ -596,37 +584,6 @@ func setTokenUsageTags(tags spanTags, usage *model.TokenUsage) { set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens) } -// mergeCumulativeTokenUsage keeps the final request-level token usage from a -// stream. Streaming callbacks may carry partial cumulative snapshots, and -// TokenUsage does not preserve field-presence metadata, so each monotonically -// increasing counter is merged by its largest observed value. -func mergeCumulativeTokenUsage(dst, src *model.TokenUsage) *model.TokenUsage { - if src == nil { - return dst - } - if dst == nil { - dst = &model.TokenUsage{} - } - - if src.PromptTokens > dst.PromptTokens { - dst.PromptTokens = src.PromptTokens - } - if src.CompletionTokens > dst.CompletionTokens { - dst.CompletionTokens = src.CompletionTokens - } - if src.TotalTokens > dst.TotalTokens { - dst.TotalTokens = src.TotalTokens - } - if src.PromptTokenDetails.CachedTokens > dst.PromptTokenDetails.CachedTokens { - dst.PromptTokenDetails.CachedTokens = src.PromptTokenDetails.CachedTokens - } - if src.CompletionTokensDetails.ReasoningTokens > dst.CompletionTokensDetails.ReasoningTokens { - dst.CompletionTokensDetails.ReasoningTokens = src.CompletionTokensDetails.ReasoningTokens - } - - return dst -} - func (d defaultDataParser) ParseDefaultStreamInput(ctx context.Context, input *schema.StreamReader[callbacks.CallbackInput]) (chunks []any, err error) { for { item, recvErr := input.Recv() diff --git a/callbacks/cozeloop/data_parser_test.go b/callbacks/cozeloop/data_parser_test.go index 972cc78ae..f7a0563c2 100644 --- a/callbacks/cozeloop/data_parser_test.go +++ b/callbacks/cozeloop/data_parser_test.go @@ -392,74 +392,7 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) { }) } -func Test_defaultDataParser_ParseOutput_MessageMetaTokenUsage(t *testing.T) { - mockey.PatchConvey("测试非流式模型输出优先使用 message meta usage", t, func() { - d := defaultDataParser{} - ctx := context.Background() - - mockey.PatchConvey("ChatModel 从 Message.ResponseMeta.Usage 读取 token usage", func() { - info := &callbacks.RunInfo{Component: components.ComponentOfChatModel} - output := &model.CallbackOutput{ - Message: &schema.Message{ - Role: schema.Assistant, - ResponseMeta: &schema.ResponseMeta{ - Usage: &schema.TokenUsage{ - PromptTokens: 11, - PromptTokenDetails: schema.PromptTokenDetails{ - CachedTokens: 7, - }, - CompletionTokens: 13, - TotalTokens: 24, - CompletionTokensDetails: schema.CompletionTokensDetails{ - ReasoningTokens: 5, - }, - }, - }, - }, - } - - result := d.ParseOutput(ctx, info, output) - - convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 11) - convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 7) - convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 13) - convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 5) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 24) - }) - - mockey.PatchConvey("AgenticModel 从 AgenticMessage.ResponseMeta.TokenUsage 读取 token usage", func() { - info := &callbacks.RunInfo{Component: components.ComponentOfAgenticModel} - output := &model.AgenticCallbackOutput{ - Message: &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeAssistant, - ResponseMeta: &schema.AgenticResponseMeta{ - TokenUsage: &schema.TokenUsage{ - PromptTokens: 17, - PromptTokenDetails: schema.PromptTokenDetails{ - CachedTokens: 3, - }, - CompletionTokens: 19, - TotalTokens: 36, - CompletionTokensDetails: schema.CompletionTokensDetails{ - ReasoningTokens: 2, - }, - }, - }, - }, - } - - result := d.ParseOutput(ctx, info, output) - - convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 17) - convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3) - convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 19) - convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 2) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 36) - }) - }) -} - -func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage(t *testing.T) { +func Test_defaultDataParser_ParseChatModelStreamOutput_MessageMetaTokenUsage(t *testing.T) { mockey.PatchConvey("测试 ChatModel 流式输出优先使用 concat 后的 message meta usage", t, func() { outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) outsw.Send(&model.CallbackOutput{ @@ -504,7 +437,7 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) }) - mockey.PatchConvey("测试 ChatModel 流式输出合并 token usage", t, func() { + mockey.PatchConvey("测试 ChatModel 流式输出不读取 CallbackOutput.TokenUsage", t, func() { outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) outsw.Send(&model.CallbackOutput{ Message: &schema.Message{Role: schema.Assistant, Content: "assistant"}, @@ -531,39 +464,11 @@ func Test_defaultDataParser_ParseChatModelStreamOutput_MergeCumulativeTokenUsage d := defaultDataParser{} result := d.ParseChatModelStreamOutput(context.Background(), outsr) - convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900) - convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265) - convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69) - convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901) - }) - - mockey.PatchConvey("测试 ChatModel 流式输出使用最终累计 token usage", t, func() { - outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3) - outsw.Send(&model.CallbackOutput{ - Message: &schema.Message{Role: schema.Assistant, Content: "assistant"}, - TokenUsage: &model.TokenUsage{ - PromptTokens: 2679, - CompletionTokens: 3, - TotalTokens: 2682, - }, - }, nil) - outsw.Send(&model.CallbackOutput{ - Message: &schema.Message{Role: schema.Assistant, Content: " message"}, - TokenUsage: &model.TokenUsage{ - PromptTokens: 10682, - CompletionTokens: 510, - TotalTokens: 11192, - }, - }, nil) - outsw.Close() - - d := defaultDataParser{} - result := d.ParseChatModelStreamOutput(context.Background(), outsr) - - convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 10682) - convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 510) - convey.So(result[tracespec.Tokens], convey.ShouldEqual, 11192) + convey.So(result, convey.ShouldNotContainKey, tracespec.InputTokens) + convey.So(result, convey.ShouldNotContainKey, tracespec.InputCachedTokens) + convey.So(result, convey.ShouldNotContainKey, tracespec.OutputTokens) + convey.So(result, convey.ShouldNotContainKey, tracespec.ReasoningTokens) + convey.So(result, convey.ShouldNotContainKey, tracespec.Tokens) }) } From 409447a9159b7a4831dd800f61cba91e4de9dc91 Mon Sep 17 00:00:00 2001 From: tangchaojun Date: Thu, 4 Jun 2026 17:17:04 +0800 Subject: [PATCH 9/9] fix: map claude stream input token usage --- components/model/claude/claude.go | 18 +++++++++++++++--- components/model/claude/claude_test.go | 8 +++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index a8da5a321..5baf98f68 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -1270,6 +1270,20 @@ func convOutputMessage(resp *anthropic.Message) (*schema.Message, error) { return message, nil } +func convMessageDeltaUsage(usage anthropic.MessageDeltaUsage) *schema.TokenUsage { + promptTokens := int(usage.InputTokens + usage.CacheReadInputTokens + usage.CacheCreationInputTokens) + completionTokens := int(usage.OutputTokens) + + return &schema.TokenUsage{ + PromptTokens: promptTokens, + PromptTokenDetails: schema.PromptTokenDetails{ + CachedTokens: int(usage.CacheReadInputTokens), + }, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } +} + type streamContext struct { toolIndex *int } @@ -1358,9 +1372,7 @@ func convStreamEvent(event anthropic.MessageStreamEventUnion, streamCtx *streamC case anthropic.MessageDeltaEvent: result.ResponseMeta = &schema.ResponseMeta{ FinishReason: string(e.Delta.StopReason), - Usage: &schema.TokenUsage{ - CompletionTokens: int(e.Usage.OutputTokens), - }, + Usage: convMessageDeltaUsage(e.Usage), } return result, nil diff --git a/components/model/claude/claude_test.go b/components/model/claude/claude_test.go index 7e28cf776..f565b2e78 100644 --- a/components/model/claude/claude_test.go +++ b/components/model/claude/claude_test.go @@ -346,14 +346,20 @@ func TestConvStreamEvent(t *testing.T) { StopReason: "end_turn", }, Usage: anthropic.MessageDeltaUsage{ - OutputTokens: 10, + InputTokens: 8, + CacheReadInputTokens: 3, + CacheCreationInputTokens: 2, + OutputTokens: 10, }, }).Build().UnPatch() message, err := convStreamEvent(event, streamCtx) assert.NoError(t, err) assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason) + assert.Equal(t, 13, message.ResponseMeta.Usage.PromptTokens) + assert.Equal(t, 3, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens) assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens) + assert.Equal(t, 23, message.ResponseMeta.Usage.TotalTokens) }) mockey.PatchConvey("content block start event", t, func() {