diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index a8da5a321..adb70ab29 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -507,6 +507,15 @@ func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolUnionParam, Properties: s.Properties, Required: s.Required, } + // Pass through additionalProperties from the JSON Schema if present. + // This is required for strict tool use mode on some providers (e.g. Bedrock). + if s.AdditionalProperties != nil { + apBytes, _ := json.Marshal(s.AdditionalProperties) + var ap any + if json.Unmarshal(apBytes, &ap) == nil { + inputSchema.ExtraFields = map[string]any{"additionalProperties": ap} + } + } } toolParam := &anthropic.ToolParam{ @@ -515,6 +524,13 @@ func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolUnionParam, InputSchema: inputSchema, } + // Support strict tool use via ToolInfo.Extra["strict"] = true. + // When enabled, the API guarantees that tool call inputs conform + // to the declared JSON Schema exactly (no type coercion). + if strict, ok := tool.Extra["strict"].(bool); ok && strict { + toolParam.Strict = param.NewOpt(true) + } + if isBreakpointTool(tool) { toolParam.CacheControl = newCacheControlParam(getToolBreakpointCacheControl(tool)) } diff --git a/components/model/claude/claude_test.go b/components/model/claude/claude_test.go index 7e28cf776..3c93f8a6a 100644 --- a/components/model/claude/claude_test.go +++ b/components/model/claude/claude_test.go @@ -389,6 +389,29 @@ func TestWithTools(t *testing.T) { assert.Equal(t, "test tool name", ncm.(*ChatModel).origTools[0].Name) } +func TestToAnthropicToolParam_Strict(t *testing.T) { + tools := []*schema.ToolInfo{ + { + Name: "strict_tool", + Desc: "a tool with strict mode", + Extra: map[string]any{"strict": true}, + }, + { + Name: "normal_tool", + Desc: "a tool without strict mode", + }, + } + result, err := toAnthropicToolParam(tools) + assert.NoError(t, err) + assert.Len(t, result, 2) + + // First tool should have strict=true + assert.True(t, result[0].OfTool.Strict.Value) + + // Second tool should not have strict set + assert.False(t, result[1].OfTool.Strict.Valid()) +} + func TestPopulateContentBlockBreakPoint(t *testing.T) { block := anthropic.NewTextBlock("input") populateContentBlockBreakPoint(block, nil) diff --git a/libs/acl/openai/chat_model.go b/libs/acl/openai/chat_model.go index 4b3f98f25..c2f930ddb 100644 --- a/libs/acl/openai/chat_model.go +++ b/libs/acl/openai/chat_model.go @@ -666,6 +666,7 @@ func (c *Client) genRequest(ctx context.Context, in []*schema.Message, opts ...m Function: &openai.FunctionDefinition{ Name: t.Function.Name, Description: t.Function.Description, + Strict: dereferenceOrZero(t.Function.Strict), Parameters: t.Function.Parameters, }, } @@ -1066,20 +1067,20 @@ func populateToolChoice(req *openai.ChatCompletionRequest, tc *schema.ToolChoice } var onlyOneToolName string - if len(allowedToolNames) == 1 { - onlyOneToolName = allowedToolNames[0] - } else if len(req.Tools) == 1 { + if len(allowedToolNames) == 0 && len(req.Tools) == 1 { onlyOneToolName = req.Tools[0].Function.Name } if onlyOneToolName != "" { - req.ToolChoice = openai.ToolChoice{ - Type: openai.ToolTypeFunction, - Function: openai.ToolFunction{ - Name: onlyOneToolName, - }, - } - } else if len(allowedToolNames) > 1 { + // Some OpenAI-compatible gateways reject the object form + // {"type":"function","function":{"name":"..."}} + // with "unknown parameter: tool_choice.function", while still + // supporting tool calling with the string form "required". + // + // In the single-tool case, "required" is semantically equivalent + // to forcing that one tool, so prefer the more compatible wire form. + req.ToolChoice = toolChoiceRequired + } else if len(allowedToolNames) > 0 { req.ToolChoice = map[string]any{ "type": "allowed_tools", "allowed_tools": allowedTools{ @@ -1314,13 +1315,18 @@ func toTools(tis []*schema.ToolInfo) ([]tool, error) { sortArrayFields(paramsJSONSchema) - tools[i] = tool{ - Function: &functionDefinition{ - Name: ti.Name, - Description: ti.Desc, - Parameters: paramsJSONSchema, - }, + fd := &functionDefinition{ + Name: ti.Name, + Description: ti.Desc, + Parameters: paramsJSONSchema, + } + // Support strict tool use via ToolInfo.Extra["strict"] = true. + // When enabled, the API guarantees that tool call inputs conform + // to the declared JSON Schema exactly (OpenAI strict mode). + if strict, ok := ti.Extra["strict"].(bool); ok && strict { + fd.Strict = &strict } + tools[i] = tool{Function: fd} } return tools, nil diff --git a/libs/acl/openai/chat_model_test.go b/libs/acl/openai/chat_model_test.go index 098956c49..91a771a4a 100644 --- a/libs/acl/openai/chat_model_test.go +++ b/libs/acl/openai/chat_model_test.go @@ -373,6 +373,31 @@ func TestToTools(t *testing.T) { }) } +func TestGenRequest_PreservesStrictTools(t *testing.T) { + c := &Client{config: &Config{Model: "test-model"}} + + req, _, _, _, err := c.genRequest(t.Context(), + []*schema.Message{{Role: schema.User, Content: "hello"}}, + model.WithTools([]*schema.ToolInfo{ + { + Name: "submit_memory", + Desc: "Save extracted people memory", + Extra: map[string]any{"strict": true}, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "people_memory": { + Type: schema.Object, + Required: true, + }, + }), + }, + }), + ) + + assert.NoError(t, err) + assert.Len(t, req.Tools, 1) + assert.True(t, req.Tools[0].Function.Strict) +} + func TestBuildMessages(t *testing.T) { t.Run("buildMessageFromAssistantGenMultiContent", func(t *testing.T) { t.Run("success with audio", func(t *testing.T) { @@ -995,13 +1020,7 @@ func TestPopulateToolChoice(t *testing.T) { } err := populateToolChoice(req, options.ToolChoice, options.AllowedToolNames) assert.NoError(t, err) - expected := openai.ToolChoice{ - Type: openai.ToolTypeFunction, - Function: openai.ToolFunction{ - Name: "test-tool", - }, - } - assert.Equal(t, expected, req.ToolChoice) + assert.Equal(t, "required", req.ToolChoice) }) t.Run("tool choice forced with multiple tools", func(t *testing.T) { @@ -1052,14 +1071,19 @@ func TestPopulateToolChoice(t *testing.T) { } err := populateToolChoice(req, options.ToolChoice, options.AllowedToolNames) assert.NoError(t, err) - - expected := openai.ToolChoice{ - Type: openai.ToolTypeFunction, - Function: openai.ToolFunction{ - Name: "test-tool-1", + expected := allowedTools{ + Mode: "required", + Tools: []openai.ToolChoice{ + { + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "test-tool-1", + }, + }, }, } - assert.Equal(t, expected, req.ToolChoice) + assert.Equal(t, "allowed_tools", req.ToolChoice.(map[string]any)["type"]) + assert.Equal(t, expected, req.ToolChoice.(map[string]any)["allowed_tools"]) }) t.Run("tool choice forced with invalid allowed tool", func(t *testing.T) { diff --git a/libs/acl/openai/tool.go b/libs/acl/openai/tool.go index 38b9dc9b2..58561a00c 100644 --- a/libs/acl/openai/tool.go +++ b/libs/acl/openai/tool.go @@ -28,4 +28,5 @@ type functionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` Parameters *jsonschema.Schema `json:"parameters"` + Strict *bool `json:"strict,omitempty"` }