Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions components/model/claude/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,15 @@ func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolUnionParam,
Properties: s.Properties,
Required: s.Required,
}
// Pass through additionalProperties from the JSON Schema if present.
// This is required for strict tool use mode on some providers (e.g. Bedrock).
if s.AdditionalProperties != nil {
apBytes, _ := json.Marshal(s.AdditionalProperties)
var ap any
if json.Unmarshal(apBytes, &ap) == nil {
inputSchema.ExtraFields = map[string]any{"additionalProperties": ap}
}
}
}

toolParam := &anthropic.ToolParam{
Expand All @@ -515,6 +524,13 @@ func toAnthropicToolParam(tools []*schema.ToolInfo) ([]anthropic.ToolUnionParam,
InputSchema: inputSchema,
}

// Support strict tool use via ToolInfo.Extra["strict"] = true.
// When enabled, the API guarantees that tool call inputs conform
// to the declared JSON Schema exactly (no type coercion).
if strict, ok := tool.Extra["strict"].(bool); ok && strict {
toolParam.Strict = param.NewOpt(true)
}

if isBreakpointTool(tool) {
toolParam.CacheControl = newCacheControlParam(getToolBreakpointCacheControl(tool))
}
Expand Down
23 changes: 23 additions & 0 deletions components/model/claude/claude_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,29 @@ func TestWithTools(t *testing.T) {
assert.Equal(t, "test tool name", ncm.(*ChatModel).origTools[0].Name)
}

func TestToAnthropicToolParam_Strict(t *testing.T) {
tools := []*schema.ToolInfo{
{
Name: "strict_tool",
Desc: "a tool with strict mode",
Extra: map[string]any{"strict": true},
},
{
Name: "normal_tool",
Desc: "a tool without strict mode",
},
}
result, err := toAnthropicToolParam(tools)
assert.NoError(t, err)
assert.Len(t, result, 2)

// First tool should have strict=true
assert.True(t, result[0].OfTool.Strict.Value)

// Second tool should not have strict set
assert.False(t, result[1].OfTool.Strict.Valid())
}

func TestPopulateContentBlockBreakPoint(t *testing.T) {
block := anthropic.NewTextBlock("input")
populateContentBlockBreakPoint(block, nil)
Expand Down
38 changes: 22 additions & 16 deletions libs/acl/openai/chat_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ func (c *Client) genRequest(ctx context.Context, in []*schema.Message, opts ...m
Function: &openai.FunctionDefinition{
Name: t.Function.Name,
Description: t.Function.Description,
Strict: dereferenceOrZero(t.Function.Strict),
Parameters: t.Function.Parameters,
},
}
Expand Down Expand Up @@ -1066,20 +1067,20 @@ func populateToolChoice(req *openai.ChatCompletionRequest, tc *schema.ToolChoice
}

var onlyOneToolName string
if len(allowedToolNames) == 1 {
onlyOneToolName = allowedToolNames[0]
} else if len(req.Tools) == 1 {
if len(allowedToolNames) == 0 && len(req.Tools) == 1 {
onlyOneToolName = req.Tools[0].Function.Name
}

if onlyOneToolName != "" {
req.ToolChoice = openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: onlyOneToolName,
},
}
} else if len(allowedToolNames) > 1 {
// Some OpenAI-compatible gateways reject the object form
// {"type":"function","function":{"name":"..."}}
// with "unknown parameter: tool_choice.function", while still
// supporting tool calling with the string form "required".
//
// In the single-tool case, "required" is semantically equivalent
// to forcing that one tool, so prefer the more compatible wire form.
req.ToolChoice = toolChoiceRequired
} else if len(allowedToolNames) > 0 {
req.ToolChoice = map[string]any{
"type": "allowed_tools",
"allowed_tools": allowedTools{
Expand Down Expand Up @@ -1314,13 +1315,18 @@ func toTools(tis []*schema.ToolInfo) ([]tool, error) {

sortArrayFields(paramsJSONSchema)

tools[i] = tool{
Function: &functionDefinition{
Name: ti.Name,
Description: ti.Desc,
Parameters: paramsJSONSchema,
},
fd := &functionDefinition{
Name: ti.Name,
Description: ti.Desc,
Parameters: paramsJSONSchema,
}
// Support strict tool use via ToolInfo.Extra["strict"] = true.
// When enabled, the API guarantees that tool call inputs conform
// to the declared JSON Schema exactly (OpenAI strict mode).
if strict, ok := ti.Extra["strict"].(bool); ok && strict {
fd.Strict = &strict
}
tools[i] = tool{Function: fd}
}

return tools, nil
Expand Down
50 changes: 37 additions & 13 deletions libs/acl/openai/chat_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,31 @@ func TestToTools(t *testing.T) {
})
}

func TestGenRequest_PreservesStrictTools(t *testing.T) {
c := &Client{config: &Config{Model: "test-model"}}

req, _, _, _, err := c.genRequest(t.Context(),
[]*schema.Message{{Role: schema.User, Content: "hello"}},
model.WithTools([]*schema.ToolInfo{
{
Name: "submit_memory",
Desc: "Save extracted people memory",
Extra: map[string]any{"strict": true},
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"people_memory": {
Type: schema.Object,
Required: true,
},
}),
},
}),
)

assert.NoError(t, err)
assert.Len(t, req.Tools, 1)
assert.True(t, req.Tools[0].Function.Strict)
}

func TestBuildMessages(t *testing.T) {
t.Run("buildMessageFromAssistantGenMultiContent", func(t *testing.T) {
t.Run("success with audio", func(t *testing.T) {
Expand Down Expand Up @@ -995,13 +1020,7 @@ func TestPopulateToolChoice(t *testing.T) {
}
err := populateToolChoice(req, options.ToolChoice, options.AllowedToolNames)
assert.NoError(t, err)
expected := openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "test-tool",
},
}
assert.Equal(t, expected, req.ToolChoice)
assert.Equal(t, "required", req.ToolChoice)
})

t.Run("tool choice forced with multiple tools", func(t *testing.T) {
Expand Down Expand Up @@ -1052,14 +1071,19 @@ func TestPopulateToolChoice(t *testing.T) {
}
err := populateToolChoice(req, options.ToolChoice, options.AllowedToolNames)
assert.NoError(t, err)

expected := openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "test-tool-1",
expected := allowedTools{
Mode: "required",
Tools: []openai.ToolChoice{
{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: "test-tool-1",
},
},
},
}
assert.Equal(t, expected, req.ToolChoice)
assert.Equal(t, "allowed_tools", req.ToolChoice.(map[string]any)["type"])
assert.Equal(t, expected, req.ToolChoice.(map[string]any)["allowed_tools"])
})

t.Run("tool choice forced with invalid allowed tool", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions libs/acl/openai/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ type functionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters *jsonschema.Schema `json:"parameters"`
Strict *bool `json:"strict,omitempty"`
}
Loading