Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 45 additions & 47 deletions callbacks/cozeloop/data_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI
tags.set(tracespec.Output, finalOutput)
tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra)

if cbOutput.TokenUsage != nil {
tags.set(tracespec.Tokens, cbOutput.TokenUsage.TotalTokens).
set(tracespec.InputTokens, cbOutput.TokenUsage.PromptTokens).
set(tracespec.OutputTokens, cbOutput.TokenUsage.CompletionTokens).
set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens)
}
setTokenUsageTags(tags, cbOutput.TokenUsage)
}

tags.set(tracespec.Stream, false)
Expand All @@ -230,13 +225,7 @@ func (d defaultDataParser) ParseOutput(ctx context.Context, info *callbacks.RunI
tags.set(tracespec.Output, finalOutput)
tags.set(consts.CustomSpanTagKeyExtra, cbOutput.Extra)

if cbOutput.TokenUsage != nil {
tags.set(tracespec.Tokens, cbOutput.TokenUsage.TotalTokens).
set(tracespec.InputTokens, cbOutput.TokenUsage.PromptTokens).
set(tracespec.OutputTokens, cbOutput.TokenUsage.CompletionTokens).
set(tracespec.InputCachedTokens, cbOutput.TokenUsage.PromptTokenDetails.CachedTokens).
set(tracespec.ReasoningTokens, cbOutput.TokenUsage.CompletionTokensDetails.ReasoningTokens)
}
setTokenUsageTags(tags, cbOutput.TokenUsage)

if cbOutput.Config != nil {
if cbOutput.Config.Model != "" {
Expand Down Expand Up @@ -433,7 +422,6 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu
chunks []*schema.Message
onceSet bool
tags = make(spanTags)
usage *model.TokenUsage
)

level := getGraphNodeLevelFromCtx(ctx)
Expand All @@ -458,14 +446,6 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu
chunks = append(chunks, cbOutput.Message)
}

if cbOutput.TokenUsage != nil {
usage = &model.TokenUsage{
PromptTokens: cbOutput.TokenUsage.PromptTokens,
CompletionTokens: cbOutput.TokenUsage.CompletionTokens,
TotalTokens: cbOutput.TokenUsage.TotalTokens,
}
}

if cbOutput.Config != nil && !onceSet {
onceSet = true

Expand All @@ -486,18 +466,12 @@ func (d defaultDataParser) ParseChatModelStreamOutput(ctx context.Context, outpu
}
} else {
tags.set(tracespec.Output, convertModelOutput(&model.CallbackOutput{Message: msg}))
setTokenUsageTags(tags, getMessageMetaTokenUsage(msg))
if level == 2 {
collectOutput.addMessages(convertModelMessage(msg))
}
}

if usage != nil {
tags.set(tracespec.Tokens, usage.TotalTokens).
set(tracespec.InputTokens, usage.PromptTokens).
set(tracespec.OutputTokens, usage.CompletionTokens).
set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens)
}

return tags
}

Expand All @@ -506,7 +480,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou
chunks []*schema.AgenticMessage
onceSet bool
tags = make(spanTags)
usage *model.TokenUsage
modelName string
)

Expand All @@ -532,16 +505,6 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou
chunks = append(chunks, cbOutput.Message)
}

if cbOutput.TokenUsage != nil {
usage = &model.TokenUsage{
PromptTokens: cbOutput.TokenUsage.PromptTokens,
CompletionTokens: cbOutput.TokenUsage.CompletionTokens,
TotalTokens: cbOutput.TokenUsage.TotalTokens,
PromptTokenDetails: cbOutput.TokenUsage.PromptTokenDetails,
CompletionTokensDetails: cbOutput.TokenUsage.CompletionTokensDetails,
}
}

if cbOutput.Config != nil && !onceSet {
onceSet = true

Expand Down Expand Up @@ -570,20 +533,55 @@ func (d defaultDataParser) ParseAgenticModelStreamOutput(ctx context.Context, ou
}
} else {
tags.set(tracespec.Output, convertAgenticModelOutput(&model.AgenticCallbackOutput{Message: msg}))
setTokenUsageTags(tags, getAgenticMessageMetaTokenUsage(msg))
if level == 2 {
collectOutput.addMessages(expandAgenticModelMessage(msg)...)
}
}

if usage != nil {
tags.set(tracespec.Tokens, usage.TotalTokens).
set(tracespec.InputTokens, usage.PromptTokens).
set(tracespec.OutputTokens, usage.CompletionTokens).
set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens).
set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens)
return tags
}

func getMessageMetaTokenUsage(msg *schema.Message) *model.TokenUsage {
if msg == nil || msg.ResponseMeta == nil {
return nil
}
return schemaTokenUsageToModel(msg.ResponseMeta.Usage)
}

return tags
func getAgenticMessageMetaTokenUsage(msg *schema.AgenticMessage) *model.TokenUsage {
if msg == nil || msg.ResponseMeta == nil {
return nil
}
return schemaTokenUsageToModel(msg.ResponseMeta.TokenUsage)
}

func schemaTokenUsageToModel(usage *schema.TokenUsage) *model.TokenUsage {
if usage == nil {
return nil
}
return &model.TokenUsage{
PromptTokens: usage.PromptTokens,
PromptTokenDetails: model.PromptTokenDetails{
CachedTokens: usage.PromptTokenDetails.CachedTokens,
},
CompletionTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
CompletionTokensDetails: model.CompletionTokensDetails{
ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens,
},
}
}

func setTokenUsageTags(tags spanTags, usage *model.TokenUsage) {
if usage == nil {
return
}
tags.set(tracespec.Tokens, usage.TotalTokens).
set(tracespec.InputTokens, usage.PromptTokens).
set(tracespec.OutputTokens, usage.CompletionTokens).
set(tracespec.InputCachedTokens, usage.PromptTokenDetails.CachedTokens).
set(tracespec.ReasoningTokens, usage.CompletionTokensDetails.ReasoningTokens)
}

func (d defaultDataParser) ParseDefaultStreamInput(ctx context.Context, input *schema.StreamReader[callbacks.CallbackInput]) (chunks []any, err error) {
Expand Down
148 changes: 148 additions & 0 deletions callbacks/cozeloop/data_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) {
Role: schema.Assistant,
Content: "Hello, how can I assist you today?",
},
TokenUsage: &model.TokenUsage{
PromptTokens: 1,
CompletionTokens: 2,
TotalTokens: 3,
PromptTokenDetails: model.PromptTokenDetails{
CachedTokens: 4,
},
CompletionTokensDetails: model.CompletionTokensDetails{
ReasoningTokens: 5,
},
},
}
var outputs callbacks.CallbackOutput = []*schema.Message{
{
Expand Down Expand Up @@ -148,6 +159,11 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) {

convey.So(result, convey.ShouldNotBeNil)
convey.So(result, convey.ShouldContainKey, tracespec.Output)
convey.So(result[tracespec.Tokens], convey.ShouldEqual, 3)
convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 1)
convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 2)
convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 4)
convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 5)

mockConvertModelOutput.UnPatch()
mockGetTraceVariablesValue.UnPatch()
Expand Down Expand Up @@ -376,6 +392,138 @@ func Test_defaultDataParser_ParseOutput(t *testing.T) {
})
}

func Test_defaultDataParser_ParseChatModelStreamOutput_MessageMetaTokenUsage(t *testing.T) {
mockey.PatchConvey("测试 ChatModel 流式输出优先使用 concat 后的 message meta usage", t, func() {
outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3)
outsw.Send(&model.CallbackOutput{
Message: &schema.Message{
Role: schema.Assistant,
Content: "assistant",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 6900,
PromptTokenDetails: schema.PromptTokenDetails{
CachedTokens: 3265,
},
CompletionTokens: 1,
TotalTokens: 6901,
},
},
},
}, nil)
outsw.Send(&model.CallbackOutput{
Message: &schema.Message{
Role: schema.Assistant,
Content: " message",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
CompletionTokens: 69,
CompletionTokensDetails: schema.CompletionTokensDetails{
ReasoningTokens: 12,
},
},
},
},
}, nil)
outsw.Close()

d := defaultDataParser{}
result := d.ParseChatModelStreamOutput(context.Background(), outsr)

convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900)
convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265)
convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69)
convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12)
convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901)
})

mockey.PatchConvey("测试 ChatModel 流式输出不读取 CallbackOutput.TokenUsage", t, func() {
outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3)
outsw.Send(&model.CallbackOutput{
Message: &schema.Message{Role: schema.Assistant, Content: "assistant"},
TokenUsage: &model.TokenUsage{
PromptTokens: 6900,
PromptTokenDetails: model.PromptTokenDetails{
CachedTokens: 3265,
},
CompletionTokens: 1,
TotalTokens: 6901,
},
}, nil)
outsw.Send(&model.CallbackOutput{
Message: &schema.Message{Role: schema.Assistant, Content: " message"},
TokenUsage: &model.TokenUsage{
CompletionTokens: 69,
CompletionTokensDetails: model.CompletionTokensDetails{
ReasoningTokens: 12,
},
},
}, nil)
outsw.Close()

d := defaultDataParser{}
result := d.ParseChatModelStreamOutput(context.Background(), outsr)

convey.So(result, convey.ShouldNotContainKey, tracespec.InputTokens)
convey.So(result, convey.ShouldNotContainKey, tracespec.InputCachedTokens)
convey.So(result, convey.ShouldNotContainKey, tracespec.OutputTokens)
convey.So(result, convey.ShouldNotContainKey, tracespec.ReasoningTokens)
convey.So(result, convey.ShouldNotContainKey, tracespec.Tokens)
})
}

func Test_defaultDataParser_ParseAgenticModelStreamOutput_MessageMetaTokenUsage(t *testing.T) {
mockey.PatchConvey("测试 AgenticModel 流式输出优先使用 concat 后的 message meta usage", t, func() {
outsr, outsw := schema.Pipe[callbacks.CallbackOutput](3)
outsw.Send(&model.AgenticCallbackOutput{
Message: &schema.AgenticMessage{
Role: schema.AgenticRoleTypeAssistant,
ContentBlocks: []*schema.ContentBlock{{
Type: schema.ContentBlockTypeAssistantGenText,
AssistantGenText: &schema.AssistantGenText{Text: "assistant"},
}},
ResponseMeta: &schema.AgenticResponseMeta{
TokenUsage: &schema.TokenUsage{
PromptTokens: 6900,
PromptTokenDetails: schema.PromptTokenDetails{
CachedTokens: 3265,
},
CompletionTokens: 1,
TotalTokens: 6901,
},
},
},
}, nil)
outsw.Send(&model.AgenticCallbackOutput{
Message: &schema.AgenticMessage{
Role: schema.AgenticRoleTypeAssistant,
ContentBlocks: []*schema.ContentBlock{{
Type: schema.ContentBlockTypeAssistantGenText,
AssistantGenText: &schema.AssistantGenText{Text: " message"},
}},
ResponseMeta: &schema.AgenticResponseMeta{
TokenUsage: &schema.TokenUsage{
CompletionTokens: 69,
CompletionTokensDetails: schema.CompletionTokensDetails{
ReasoningTokens: 12,
},
},
},
},
}, nil)
outsw.Close()

d := defaultDataParser{}
result := d.ParseAgenticModelStreamOutput(context.Background(), outsr)

convey.So(result[tracespec.InputTokens], convey.ShouldEqual, 6900)
convey.So(result[tracespec.InputCachedTokens], convey.ShouldEqual, 3265)
convey.So(result[tracespec.OutputTokens], convey.ShouldEqual, 69)
convey.So(result[tracespec.ReasoningTokens], convey.ShouldEqual, 12)
convey.So(result[tracespec.Tokens], convey.ShouldEqual, 6901)
})
}

// Test_defaultDataParser_tryConcatChunks 为 defaultDataParser 的 tryConcatChunks 方法编写单元测试
func Test_defaultDataParser_tryConcatChunks(t *testing.T) {
mockey.PatchConvey("测试 defaultDataParser 的 tryConcatChunks 方法", t, func() {
Expand Down
8 changes: 7 additions & 1 deletion components/model/claude/claude_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,20 @@ func TestConvStreamEvent(t *testing.T) {
StopReason: "end_turn",
},
Usage: anthropic.MessageDeltaUsage{
OutputTokens: 10,
InputTokens: 8,
CacheReadInputTokens: 3,
CacheCreationInputTokens: 2,
OutputTokens: 10,
},
}).Build().UnPatch()

message, err := convStreamEvent(event, streamCtx)
assert.NoError(t, err)
assert.Equal(t, "end_turn", message.ResponseMeta.FinishReason)
assert.Equal(t, 13, message.ResponseMeta.Usage.PromptTokens)
assert.Equal(t, 3, message.ResponseMeta.Usage.PromptTokenDetails.CachedTokens)
assert.Equal(t, 10, message.ResponseMeta.Usage.CompletionTokens)
assert.Equal(t, 23, message.ResponseMeta.Usage.TotalTokens)
})

mockey.PatchConvey("content block start event", t, func() {
Expand Down
Loading