diff --git a/pkg/mcp/tool_filter.go b/pkg/mcp/tool_filter.go index a5bad1b83..deea68a57 100644 --- a/pkg/mcp/tool_filter.go +++ b/pkg/mcp/tool_filter.go @@ -72,6 +72,36 @@ func (c *toolMiddlewareConfig) getToolListOverride(toolName string) (*toolOverri // middleware. type ToolMiddlewareOption func(*toolMiddlewareConfig) error +// SimpleTool represents a minimal tool with name and description. +// This is used by ApplyToolFiltering to work with tools in a generic way. +type SimpleTool struct { + Name string + Description string +} + +// ApplyToolFiltering applies filtering and overriding to a list of tools. +// This is the core logic used by both the HTTP middleware and other components +// that need to apply the same filtering/overriding behavior. +// +// Returns the filtered and overridden tools. +func ApplyToolFiltering(opts []ToolMiddlewareOption, tools []SimpleTool) ([]SimpleTool, error) { + config := &toolMiddlewareConfig{ + filterTools: make(map[string]struct{}), + actualToUserOverride: make(map[string]toolOverrideEntry), + userToActualOverride: make(map[string]toolOverrideEntry), + } + + // Apply options to build config + for _, opt := range opts { + if err := opt(config); err != nil { + return nil, err + } + } + + // Use the shared core logic + return applyFilteringAndOverrides(config, tools), nil +} + // WithToolsFilter is a function that can be used to configure the tool // middleware to use a filter list of tools. func WithToolsFilter(toolsFilter ...string) ToolMiddlewareOption { @@ -448,7 +478,10 @@ func processToolsListResponse( toolsListResponse toolsListResponse, w io.Writer, ) error { - filteredTools := []map[string]any{} + // Convert to SimpleTool format for shared processing + simpleTools := make([]SimpleTool, 0, len(*toolsListResponse.Result.Tools)) + toolMaps := make([]map[string]any, 0, len(*toolsListResponse.Result.Tools)) + for _, tool := range *toolsListResponse.Result.Tools { // NOTE: the spec does not allow for name to be missing. toolName, ok := tool["name"].(string) @@ -461,31 +494,90 @@ func processToolsListResponse( return errToolNameNotFound } + // Get description if present (optional in MCP spec) + description, _ := tool["description"].(string) + + simpleTools = append(simpleTools, SimpleTool{ + Name: toolName, + Description: description, + }) + toolMaps = append(toolMaps, tool) + } + + // Apply the shared filtering/override logic + processedTools := applyFilteringAndOverrides(config, simpleTools) + + // Build the filtered response by matching processed tools with their original maps + // Note: This is O(n²) complexity, but acceptable because: + // - Tool lists are typically small (< 100 tools per backend) + // - Only runs once during tool list retrieval (not in hot path) + // - Inner loop breaks early on match + filteredTools := make([]map[string]any, 0, len(processedTools)) + for _, processed := range processedTools { + // Find the original tool map by matching names + for i, simple := range simpleTools { + if simple.Name == processed.Name || simple.Name == findOriginalName(config, processed.Name) { + // Clone the original map and update name/description + toolCopy := make(map[string]any, len(toolMaps[i])) + for k, v := range toolMaps[i] { + toolCopy[k] = v + } + toolCopy["name"] = processed.Name + if processed.Description != "" { + toolCopy["description"] = processed.Description + } + filteredTools = append(filteredTools, toolCopy) + break + } + } + } + + toolsListResponse.Result.Tools = &filteredTools + if err := json.NewEncoder(w).Encode(toolsListResponse); err != nil { + return fmt.Errorf("%w: %v", errBug, err) + } + + return nil +} + +// applyFilteringAndOverrides is the core logic for filtering and overriding tools. +// This implements the exact same logic as before but is now extracted for reuse. +func applyFilteringAndOverrides(config *toolMiddlewareConfig, tools []SimpleTool) []SimpleTool { + result := make([]SimpleTool, 0, len(tools)) + for _, tool := range tools { + description := tool.Description + // If the tool is overridden, we need to use the override name and description. - if entry, ok := config.getToolListOverride(toolName); ok { + if entry, ok := config.getToolListOverride(tool.Name); ok { if entry.OverrideName != "" { - tool["name"] = entry.OverrideName + tool.Name = entry.OverrideName } if entry.OverrideDescription != "" { - tool["description"] = entry.OverrideDescription + description = entry.OverrideDescription } - toolName = entry.OverrideName } // If the tool is in the filter, we add it to the filtered tools list. - // Note that lookup is done using the user-known name, which might be - // different from the actual tool name. - if config.isToolInFilter(toolName) { - filteredTools = append(filteredTools, tool) + // Note that lookup is done using the user-known name (tool.Name after override). + if config.isToolInFilter(tool.Name) { + result = append(result, SimpleTool{ + Name: tool.Name, + Description: description, + }) } } + return result +} - toolsListResponse.Result.Tools = &filteredTools - if err := json.NewEncoder(w).Encode(toolsListResponse); err != nil { - return fmt.Errorf("%w: %v", errBug, err) +// findOriginalName attempts to find the original tool name before override. +func findOriginalName(config *toolMiddlewareConfig, overriddenName string) string { + // Iterate through overrides to find reverse mapping + for actualName, entry := range config.actualToUserOverride { + if entry.OverrideName == overriddenName { + return actualName + } } - - return nil + return overriddenName } // toolCallFix mimics a sum type in Go. The actual types represent the diff --git a/pkg/vmcp/aggregator/aggregator.go b/pkg/vmcp/aggregator/aggregator.go index b919e291b..61b143e8d 100644 --- a/pkg/vmcp/aggregator/aggregator.go +++ b/pkg/vmcp/aggregator/aggregator.go @@ -160,9 +160,6 @@ type AggregationMetadata struct { // PromptCount is the total number of prompts. PromptCount int - // ConflictsResolved is the number of conflicts that were resolved. - ConflictsResolved int - // ConflictStrategy is the strategy used for conflict resolution. ConflictStrategy vmcp.ConflictResolutionStrategy } diff --git a/pkg/vmcp/aggregator/conflict_resolver.go b/pkg/vmcp/aggregator/conflict_resolver.go new file mode 100644 index 000000000..332941963 --- /dev/null +++ b/pkg/vmcp/aggregator/conflict_resolver.go @@ -0,0 +1,71 @@ +// Package aggregator provides capability aggregation for Virtual MCP Server. +// +// This file contains the factory function for creating conflict resolvers +// and shared helper functions used by multiple resolver implementations. +package aggregator + +import ( + "fmt" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// NewConflictResolver creates the appropriate conflict resolver based on configuration. +func NewConflictResolver(aggregationConfig *config.AggregationConfig) (ConflictResolver, error) { + if aggregationConfig == nil { + // Default to prefix strategy with default format + logger.Infof("No aggregation config provided, using default prefix strategy") + return NewPrefixConflictResolver("{workload}_"), nil + } + + switch aggregationConfig.ConflictResolution { + case vmcp.ConflictStrategyPrefix: + prefixFormat := "{workload}_" // Default + if aggregationConfig.ConflictResolutionConfig != nil && + aggregationConfig.ConflictResolutionConfig.PrefixFormat != "" { + prefixFormat = aggregationConfig.ConflictResolutionConfig.PrefixFormat + } + logger.Infof("Using prefix conflict resolution strategy (format: %s)", prefixFormat) + return NewPrefixConflictResolver(prefixFormat), nil + + case vmcp.ConflictStrategyPriority: + if aggregationConfig.ConflictResolutionConfig == nil || + len(aggregationConfig.ConflictResolutionConfig.PriorityOrder) == 0 { + return nil, fmt.Errorf("priority strategy requires priority_order in conflict_resolution_config") + } + logger.Infof("Using priority conflict resolution strategy (order: %v)", + aggregationConfig.ConflictResolutionConfig.PriorityOrder) + return NewPriorityConflictResolver(aggregationConfig.ConflictResolutionConfig.PriorityOrder) + + case vmcp.ConflictStrategyManual: + logger.Infof("Using manual conflict resolution strategy") + return NewManualConflictResolver(aggregationConfig.Tools) + + default: + return nil, fmt.Errorf("%w: %s", ErrInvalidConflictStrategy, aggregationConfig.ConflictResolution) + } +} + +// toolWithBackend is a helper struct to track which backend a tool comes from. +// This is shared by multiple conflict resolution strategies. +type toolWithBackend struct { + Tool vmcp.Tool + BackendID string +} + +// groupToolsByName groups tools by their names to detect conflicts. +// This is shared by multiple conflict resolution strategies. +func groupToolsByName(toolsByBackend map[string][]vmcp.Tool) map[string][]toolWithBackend { + toolsByName := make(map[string][]toolWithBackend) + for backendID, tools := range toolsByBackend { + for _, tool := range tools { + toolsByName[tool.Name] = append(toolsByName[tool.Name], toolWithBackend{ + Tool: tool, + BackendID: backendID, + }) + } + } + return toolsByName +} diff --git a/pkg/vmcp/aggregator/conflict_resolver_test.go b/pkg/vmcp/aggregator/conflict_resolver_test.go new file mode 100644 index 000000000..b5e283e79 --- /dev/null +++ b/pkg/vmcp/aggregator/conflict_resolver_test.go @@ -0,0 +1,466 @@ +package aggregator + +import ( + "context" + "strings" + "testing" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +func TestPrefixConflictResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefixFormat string + toolsByBackend map[string][]vmcp.Tool + wantCount int + checkNames map[string]string // resolved name -> expected backend ID + }{ + { + name: "default prefix format with conflicts", + prefixFormat: "{workload}_", + toolsByBackend: map[string][]vmcp.Tool{ + "github": { + {Name: "create_issue", Description: "Create GitHub issue"}, + {Name: "list_issues", Description: "List GitHub issues"}, + }, + "jira": { + {Name: "create_issue", Description: "Create Jira issue"}, + {Name: "list_projects", Description: "List Jira projects"}, + }, + }, + wantCount: 4, + checkNames: map[string]string{ + "github_create_issue": "github", + "github_list_issues": "github", + "jira_create_issue": "jira", + "jira_list_projects": "jira", + }, + }, + { + name: "dot separator prefix", + prefixFormat: "{workload}.", + toolsByBackend: map[string][]vmcp.Tool{ + "backend1": { + {Name: "tool1", Description: "Tool 1"}, + }, + "backend2": { + {Name: "tool1", Description: "Tool 1 from backend2"}, + }, + }, + wantCount: 2, + checkNames: map[string]string{ + "backend1.tool1": "backend1", + "backend2.tool1": "backend2", + }, + }, + { + name: "no conflicts", + prefixFormat: "{workload}_", + toolsByBackend: map[string][]vmcp.Tool{ + "github": { + {Name: "create_pr", Description: "Create PR"}, + }, + "jira": { + {Name: "create_ticket", Description: "Create ticket"}, + }, + }, + wantCount: 2, + checkNames: map[string]string{ + "github_create_pr": "github", + "jira_create_ticket": "jira", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resolver := NewPrefixConflictResolver(tt.prefixFormat) + resolved, err := resolver.ResolveToolConflicts(context.Background(), tt.toolsByBackend) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(resolved) != tt.wantCount { + t.Errorf("got %d resolved tools, want %d", len(resolved), tt.wantCount) + } + + for resolvedName, expectedBackendID := range tt.checkNames { + tool, exists := resolved[resolvedName] + if !exists { + t.Errorf("expected tool %q not found in resolved tools", resolvedName) + continue + } + + if tool.BackendID != expectedBackendID { + t.Errorf("tool %q has backend %q, want %q", resolvedName, tool.BackendID, expectedBackendID) + } + + if tool.ConflictResolutionApplied != vmcp.ConflictStrategyPrefix { + t.Errorf("tool %q has wrong strategy %q, want %q", resolvedName, tool.ConflictResolutionApplied, vmcp.ConflictStrategyPrefix) + } + } + }) + } +} + +func TestPriorityConflictResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + priorityOrder []string + toolsByBackend map[string][]vmcp.Tool + wantCount int + wantWinners map[string]string // tool name -> expected backend ID + wantStrategies map[string]vmcp.ConflictResolutionStrategy // tool name -> expected strategy (optional) + wantErr bool + }{ + { + name: "basic priority resolution", + priorityOrder: []string{"github", "jira"}, + toolsByBackend: map[string][]vmcp.Tool{ + "github": { + {Name: "create_issue", Description: "GitHub issue"}, + {Name: "list_repos", Description: "List repos"}, + }, + "jira": { + {Name: "create_issue", Description: "Jira issue"}, + {Name: "list_projects", Description: "List projects"}, + }, + }, + wantCount: 3, + wantWinners: map[string]string{ + "create_issue": "github", // github wins + "list_repos": "github", + "list_projects": "jira", + }, + }, + { + name: "three-way conflict", + priorityOrder: []string{"primary", "secondary", "tertiary"}, + toolsByBackend: map[string][]vmcp.Tool{ + "primary": { + {Name: "shared_tool", Description: "Primary version"}, + }, + "secondary": { + {Name: "shared_tool", Description: "Secondary version"}, + }, + "tertiary": { + {Name: "shared_tool", Description: "Tertiary version"}, + }, + }, + wantCount: 1, + wantWinners: map[string]string{ + "shared_tool": "primary", + }, + }, + { + name: "backends not in priority list are skipped", + priorityOrder: []string{"github"}, + toolsByBackend: map[string][]vmcp.Tool{ + "github": { + {Name: "tool1", Description: "GitHub tool"}, + }, + "unknown_backend": { + {Name: "tool2", Description: "Unknown tool"}, + }, + }, + wantCount: 2, // Both tools included (no conflict) + wantWinners: map[string]string{ + "tool1": "github", + "tool2": "unknown_backend", + }, + }, + { + name: "backends not in priority with conflict use prefix fallback", + priorityOrder: []string{"github"}, + toolsByBackend: map[string][]vmcp.Tool{ + "github": { + {Name: "create_issue", Description: "GitHub issue"}, + }, + "slack": { + {Name: "send_message", Description: "Slack message"}, + }, + "teams": { + {Name: "send_message", Description: "Teams message"}, + }, + }, + wantCount: 3, // All tools included, conflicting ones prefixed + wantWinners: map[string]string{ + "create_issue": "github", // In priority list + "slack_send_message": "slack", // Not in priority, prefixed + "teams_send_message": "teams", // Not in priority, prefixed + }, + wantStrategies: map[string]vmcp.ConflictResolutionStrategy{ + "create_issue": vmcp.ConflictStrategyPriority, // Priority strategy used + "slack_send_message": vmcp.ConflictStrategyPrefix, // Prefix fallback used + "teams_send_message": vmcp.ConflictStrategyPrefix, // Prefix fallback used + }, + }, + { + name: "empty priority order", + priorityOrder: []string{}, + toolsByBackend: map[string][]vmcp.Tool{ + "github": {{Name: "tool1"}}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resolver, err := NewPriorityConflictResolver(tt.priorityOrder) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error creating resolver: %v", err) + } + + resolved, err := resolver.ResolveToolConflicts(context.Background(), tt.toolsByBackend) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(resolved) != tt.wantCount { + t.Errorf("got %d resolved tools, want %d", len(resolved), tt.wantCount) + } + + for toolName, expectedBackendID := range tt.wantWinners { + tool, exists := resolved[toolName] + if !exists { + t.Errorf("expected tool %q not found", toolName) + continue + } + + if tool.BackendID != expectedBackendID { + t.Errorf("tool %q from %q, want %q", toolName, tool.BackendID, expectedBackendID) + } + + // Check strategy if specified + if tt.wantStrategies != nil { + if expectedStrategy, hasExpectedStrategy := tt.wantStrategies[toolName]; hasExpectedStrategy { + if tool.ConflictResolutionApplied != expectedStrategy { + t.Errorf("tool %q has strategy %q, want %q", toolName, tool.ConflictResolutionApplied, expectedStrategy) + } + } + } else { + // Default: expect priority strategy + if tool.ConflictResolutionApplied != vmcp.ConflictStrategyPriority { + t.Errorf("tool %q has wrong strategy %q, want %q", toolName, tool.ConflictResolutionApplied, vmcp.ConflictStrategyPriority) + } + } + } + }) + } +} + +func TestManualConflictResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadConfigs []*config.WorkloadToolConfig + toolsByBackend map[string][]vmcp.Tool + wantCount int + wantNames []string // Expected resolved names + wantErr bool + errContains string + }{ + { + name: "all conflicts resolved with overrides", + workloadConfigs: []*config.WorkloadToolConfig{ + { + Workload: "github", + Overrides: map[string]*config.ToolOverride{ + "create_issue": {Name: "gh_create_issue"}, + }, + }, + { + Workload: "jira", + Overrides: map[string]*config.ToolOverride{ + "create_issue": {Name: "jira_create_issue"}, + }, + }, + }, + toolsByBackend: map[string][]vmcp.Tool{ + "github": {{Name: "create_issue", Description: "GitHub"}}, + "jira": {{Name: "create_issue", Description: "Jira"}}, + }, + wantCount: 2, + wantNames: []string{"gh_create_issue", "jira_create_issue"}, + }, + { + name: "unresolved conflict fails validation", + workloadConfigs: []*config.WorkloadToolConfig{ + { + Workload: "github", + Overrides: map[string]*config.ToolOverride{ + "create_issue": {Name: "gh_create_issue"}, + }, + }, + // jira has no override for create_issue + { + Workload: "jira", + }, + }, + toolsByBackend: map[string][]vmcp.Tool{ + "github": {{Name: "create_issue"}}, + "jira": {{Name: "create_issue"}}, + }, + wantErr: true, + errContains: "unresolved tool name conflicts", + }, + { + name: "no conflicts - no overrides needed", + workloadConfigs: []*config.WorkloadToolConfig{ + {Workload: "github"}, + {Workload: "jira"}, + }, + toolsByBackend: map[string][]vmcp.Tool{ + "github": {{Name: "create_pr"}}, + "jira": {{Name: "create_ticket"}}, + }, + wantCount: 2, + wantNames: []string{"create_pr", "create_ticket"}, + }, + { + name: "override description only", + workloadConfigs: []*config.WorkloadToolConfig{ + { + Workload: "github", + Overrides: map[string]*config.ToolOverride{ + "create_pr": {Description: "Updated description"}, + }, + }, + }, + toolsByBackend: map[string][]vmcp.Tool{ + "github": {{Name: "create_pr", Description: "Original"}}, + }, + wantCount: 1, + wantNames: []string{"create_pr"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resolver, err := NewManualConflictResolver(tt.workloadConfigs) + if err != nil { + t.Fatalf("unexpected error creating resolver: %v", err) + } + + resolved, err := resolver.ResolveToolConflicts(context.Background(), tt.toolsByBackend) + + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(resolved) != tt.wantCount { + t.Errorf("got %d resolved tools, want %d", len(resolved), tt.wantCount) + } + + for _, name := range tt.wantNames { + if _, exists := resolved[name]; !exists { + t.Errorf("expected tool %q not found", name) + } + } + }) + } +} + +func TestNewConflictResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *config.AggregationConfig + wantErr bool + }{ + { + name: "prefix strategy", + config: &config.AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPrefix, + ConflictResolutionConfig: &config.ConflictResolutionConfig{ + PrefixFormat: "{workload}_", + }, + }, + }, + { + name: "priority strategy", + config: &config.AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPriority, + ConflictResolutionConfig: &config.ConflictResolutionConfig{ + PriorityOrder: []string{"backend1", "backend2"}, + }, + }, + }, + { + name: "manual strategy", + config: &config.AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyManual, + Tools: []*config.WorkloadToolConfig{ + {Workload: "github"}, + }, + }, + }, + { + name: "priority without priority order fails", + config: &config.AggregationConfig{ + ConflictResolution: vmcp.ConflictStrategyPriority, + }, + wantErr: true, + }, + { + name: "nil config defaults to prefix", + config: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resolver, err := NewConflictResolver(tt.config) + + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resolver == nil { + t.Fatal("got nil resolver") + } + }) + } +} diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index c00c8e885..a39340d0f 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -9,19 +9,37 @@ import ( "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" ) // defaultAggregator implements the Aggregator interface for capability aggregation. // It queries backends in parallel, handles failures gracefully, and merges capabilities. type defaultAggregator struct { - backendClient vmcp.BackendClient - // TODO: Add conflict resolver, tool filter, tool override + backendClient vmcp.BackendClient + conflictResolver ConflictResolver + toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config } // NewDefaultAggregator creates a new default aggregator implementation. -func NewDefaultAggregator(backendClient vmcp.BackendClient) Aggregator { +// conflictResolver handles tool name conflicts across backends. +// workloadConfigs specifies per-backend tool filtering and overrides. +func NewDefaultAggregator( + backendClient vmcp.BackendClient, + conflictResolver ConflictResolver, + workloadConfigs []*config.WorkloadToolConfig, +) Aggregator { + // Build tool config map for quick lookup by backend ID + toolConfigMap := make(map[string]*config.WorkloadToolConfig) + for _, wlConfig := range workloadConfigs { + if wlConfig != nil { + toolConfigMap[wlConfig.Workload] = wlConfig + } + } + return &defaultAggregator{ - backendClient: backendClient, + backendClient: backendClient, + conflictResolver: conflictResolver, + toolConfigMap: toolConfigMap, } } @@ -46,17 +64,20 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. return nil, fmt.Errorf("%w: %s: %v", ErrBackendQueryFailed, backend.ID, err) } + // Apply per-backend tool filtering and overrides (before conflict resolution) + processedTools := processBackendTools(ctx, backend.ID, capabilities.Tools, a.toolConfigMap[backend.ID]) + // Convert to BackendCapabilities result := &BackendCapabilities{ BackendID: backend.ID, - Tools: capabilities.Tools, + Tools: processedTools, Resources: capabilities.Resources, Prompts: capabilities.Prompts, SupportsLogging: capabilities.SupportsLogging, SupportsSampling: capabilities.SupportsSampling, } - logger.Debugf("Backend %s: %d tools, %d resources, %d prompts", + logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts", backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) return result, nil @@ -113,51 +134,59 @@ func (a *defaultAggregator) QueryAllCapabilities( // ResolveConflicts applies conflict resolution strategy to handle // duplicate capability names across backends. -func (*defaultAggregator) ResolveConflicts( - _ context.Context, +func (a *defaultAggregator) ResolveConflicts( + ctx context.Context, capabilities map[string]*BackendCapabilities, ) (*ResolvedCapabilities, error) { logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) - // For Phase 1 (Issue #148), we'll implement basic conflict resolution - // Just collect all capabilities without resolving conflicts yet - // Conflict resolution will be implemented in a future phase - - resolved := &ResolvedCapabilities{ - Tools: make(map[string]*ResolvedTool), - Resources: []vmcp.Resource{}, - Prompts: []vmcp.Prompt{}, + // Group tools by backend for conflict resolution + toolsByBackend := make(map[string][]vmcp.Tool) + for backendID, caps := range capabilities { + toolsByBackend[backendID] = caps.Tools } - // Collect all tools (for now, without conflict resolution) - // Later, we'll add prefix/priority/manual strategies - for backendID, caps := range capabilities { - for _, tool := range caps.Tools { - // For now, just use the tool name as-is - // In future phases, we'll apply prefixing or priority rules - resolvedName := tool.Name - - // If there's a conflict, log a warning (but don't fail) - if existing, exists := resolved.Tools[resolvedName]; exists { - logger.Warnf("Tool name conflict: %s exists in both %s and %s (keeping first)", - resolvedName, existing.BackendID, backendID) - continue - } + // Use the configured conflict resolver to resolve tool conflicts + var resolvedTools map[string]*ResolvedTool + var err error - resolved.Tools[resolvedName] = &ResolvedTool{ - ResolvedName: resolvedName, - OriginalName: tool.Name, - Description: tool.Description, - InputSchema: tool.InputSchema, - BackendID: tool.BackendID, - // ConflictResolutionApplied will be set in future phases + if a.conflictResolver != nil { + resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend) + if err != nil { + return nil, fmt.Errorf("conflict resolution failed: %w", err) + } + } else { + // Fallback: no conflict resolution (first wins, log warnings) + logger.Warnf("No conflict resolver configured, using fallback (first wins)") + resolvedTools = make(map[string]*ResolvedTool) + for backendID, tools := range toolsByBackend { + for _, tool := range tools { + if existing, exists := resolvedTools[tool.Name]; exists { + logger.Warnf("Tool name conflict: %s exists in both %s and %s (keeping first)", + tool.Name, existing.BackendID, backendID) + continue + } + resolvedTools[tool.Name] = &ResolvedTool{ + ResolvedName: tool.Name, + OriginalName: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + BackendID: backendID, + } } } + } - // Collect resources (URIs should be globally unique) - resolved.Resources = append(resolved.Resources, caps.Resources...) + // Build resolved capabilities + resolved := &ResolvedCapabilities{ + Tools: resolvedTools, + Resources: []vmcp.Resource{}, + Prompts: []vmcp.Prompt{}, + } - // Collect prompts + // Collect resources and prompts (no conflict resolution for these yet) + for _, caps := range capabilities { + resolved.Resources = append(resolved.Resources, caps.Resources...) resolved.Prompts = append(resolved.Prompts, caps.Prompts...) // Aggregate logging/sampling support (OR logic - enabled if any backend supports) @@ -239,6 +268,16 @@ func (*defaultAggregator) MergeCapabilities( } } + // Determine conflict strategy used + conflictStrategy := vmcp.ConflictStrategyPrefix // Default + if len(resolved.Tools) > 0 { + // Get strategy from first tool (all tools use same strategy) + for _, tool := range resolved.Tools { + conflictStrategy = tool.ConflictResolutionApplied + break + } + } + // Create final aggregated view aggregated := &AggregatedCapabilities{ Tools: tools, @@ -248,11 +287,11 @@ func (*defaultAggregator) MergeCapabilities( SupportsSampling: resolved.SupportsSampling, RoutingTable: routingTable, Metadata: &AggregationMetadata{ - BackendCount: 0, // Will be set by caller - ToolCount: len(tools), - ResourceCount: len(resolved.Resources), - PromptCount: len(resolved.Prompts), - ConflictsResolved: 0, // Will be tracked in future phases + BackendCount: 0, // Will be set by caller + ToolCount: len(tools), + ResourceCount: len(resolved.Resources), + PromptCount: len(resolved.Prompts), + ConflictStrategy: conflictStrategy, }, } diff --git a/pkg/vmcp/aggregator/default_aggregator_test.go b/pkg/vmcp/aggregator/default_aggregator_test.go index 94045df1e..590c81de6 100644 --- a/pkg/vmcp/aggregator/default_aggregator_test.go +++ b/pkg/vmcp/aggregator/default_aggregator_test.go @@ -32,7 +32,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(expectedCaps, nil) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.NoError(t, err) @@ -56,7 +56,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.Error(t, err) @@ -87,7 +87,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -117,7 +117,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { return nil, errors.New("connection timeout") }).Times(2) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -137,7 +137,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.Error(t, err) @@ -168,7 +168,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil) + agg := NewDefaultAggregator(nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -201,7 +201,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil) + agg := NewDefaultAggregator(nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -260,7 +260,7 @@ func TestDefaultAggregator_MergeCapabilities(t *testing.T) { } registry := vmcp.NewImmutableRegistry(backends) - agg := NewDefaultAggregator(nil) + agg := NewDefaultAggregator(nil, nil, nil) aggregated, err := agg.MergeCapabilities(context.Background(), resolved, registry) require.NoError(t, err) @@ -328,7 +328,7 @@ func TestDefaultAggregator_AggregateCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient) + agg := NewDefaultAggregator(mockClient, nil, nil) result, err := agg.AggregateCapabilities(context.Background(), backends) require.NoError(t, err) diff --git a/pkg/vmcp/aggregator/manual_resolver.go b/pkg/vmcp/aggregator/manual_resolver.go new file mode 100644 index 000000000..47898d5a9 --- /dev/null +++ b/pkg/vmcp/aggregator/manual_resolver.go @@ -0,0 +1,160 @@ +package aggregator + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// ManualConflictResolver implements manual conflict resolution. +// It requires explicit overrides for ALL conflicts and fails startup if any are unresolved. +type ManualConflictResolver struct { + // Overrides maps (backendID, originalToolName) to the resolved configuration. + // Key format: "backendID:toolName" + Overrides map[string]*config.ToolOverride +} + +// NewManualConflictResolver creates a new manual conflict resolver. +// Note: This resolver validates that overrides don't create NEW conflicts. +// If two tools are both overridden to the same name, ResolveToolConflicts +// will return an error ("collision after override"). +func NewManualConflictResolver(workloadConfigs []*config.WorkloadToolConfig) (*ManualConflictResolver, error) { + overrides := make(map[string]*config.ToolOverride) + + // Build override map from configuration + for _, wlConfig := range workloadConfigs { + for toolName, override := range wlConfig.Overrides { + if override == nil { + continue + } + key := fmt.Sprintf("%s:%s", wlConfig.Workload, toolName) + overrides[key] = override + } + } + + return &ManualConflictResolver{ + Overrides: overrides, + }, nil +} + +// ResolveToolConflicts applies manual conflict resolution with validation. +// Returns an error if any conflicts exist without explicit overrides. +func (r *ManualConflictResolver) ResolveToolConflicts( + _ context.Context, + toolsByBackend map[string][]vmcp.Tool, +) (map[string]*ResolvedTool, error) { + logger.Debugf("Resolving conflicts using manual strategy with %d overrides", len(r.Overrides)) + + // Group tools by name to detect conflicts + toolsByName := groupToolsByName(toolsByBackend) + + // Check for unresolved conflicts + if unresolvedConflicts := r.findUnresolvedConflicts(toolsByName); len(unresolvedConflicts) > 0 { + return nil, r.formatConflictError(unresolvedConflicts) + } + + // Apply overrides and build resolved map + resolved, err := r.applyOverridesAndResolve(toolsByBackend) + if err != nil { + return nil, err + } + + logger.Infof("Manual strategy: %d unique tools after applying overrides", len(resolved)) + return resolved, nil +} + +// findUnresolvedConflicts checks for conflicts without explicit overrides. +func (r *ManualConflictResolver) findUnresolvedConflicts(toolsByName map[string][]toolWithBackend) map[string][]string { + unresolvedConflicts := make(map[string][]string) + for toolName, candidates := range toolsByName { + if len(candidates) <= 1 { + continue // No conflict + } + + // Check if all conflicting tools have overrides + if !r.allCandidatesHaveOverrides(toolName, candidates) { + backendIDs := make([]string, len(candidates)) + for i, candidate := range candidates { + backendIDs[i] = candidate.BackendID + } + unresolvedConflicts[toolName] = backendIDs + } + } + return unresolvedConflicts +} + +// allCandidatesHaveOverrides checks if all candidates for a tool have overrides configured. +func (r *ManualConflictResolver) allCandidatesHaveOverrides(toolName string, candidates []toolWithBackend) bool { + for _, candidate := range candidates { + key := fmt.Sprintf("%s:%s", candidate.BackendID, toolName) + if _, hasOverride := r.Overrides[key]; !hasOverride { + return false + } + } + return true +} + +// applyOverridesAndResolve applies overrides and builds the resolved tool map. +func (r *ManualConflictResolver) applyOverridesAndResolve( + toolsByBackend map[string][]vmcp.Tool, +) (map[string]*ResolvedTool, error) { + resolved := make(map[string]*ResolvedTool) + for backendID, tools := range toolsByBackend { + for _, tool := range tools { + resolvedTool := r.resolveToolWithOverride(backendID, tool) + + // Check for collision after override + if existing, exists := resolved[resolvedTool.ResolvedName]; exists { + return nil, fmt.Errorf("collision after override: tool %s from backend %s conflicts with tool from backend %s", + resolvedTool.ResolvedName, backendID, existing.BackendID) + } + + resolved[resolvedTool.ResolvedName] = resolvedTool + } + } + return resolved, nil +} + +// resolveToolWithOverride applies overrides to a single tool. +func (r *ManualConflictResolver) resolveToolWithOverride(backendID string, tool vmcp.Tool) *ResolvedTool { + resolvedName := tool.Name + description := tool.Description + + // Check if there's an override for this tool + key := fmt.Sprintf("%s:%s", backendID, tool.Name) + if override, exists := r.Overrides[key]; exists { + if override.Name != "" { + resolvedName = override.Name + } + if override.Description != "" { + description = override.Description + } + } + + return &ResolvedTool{ + ResolvedName: resolvedName, + OriginalName: tool.Name, + Description: description, + InputSchema: tool.InputSchema, + BackendID: backendID, + ConflictResolutionApplied: vmcp.ConflictStrategyManual, + } +} + +// formatConflictError creates a detailed error message for unresolved conflicts. +func (*ManualConflictResolver) formatConflictError(conflicts map[string][]string) error { + var sb strings.Builder + sb.WriteString("unresolved tool name conflicts detected:\n") + + for toolName, backendIDs := range conflicts { + sb.WriteString(fmt.Sprintf(" - %s: [%s]\n", toolName, strings.Join(backendIDs, ", "))) + } + + sb.WriteString("\nUse 'overrides' in aggregation config to resolve these conflicts when using conflict_resolution: manual") + + return fmt.Errorf("%w: %s", ErrUnresolvedConflicts, sb.String()) +} diff --git a/pkg/vmcp/aggregator/prefix_resolver.go b/pkg/vmcp/aggregator/prefix_resolver.go new file mode 100644 index 000000000..719cae746 --- /dev/null +++ b/pkg/vmcp/aggregator/prefix_resolver.go @@ -0,0 +1,80 @@ +package aggregator + +import ( + "context" + "strings" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// PrefixConflictResolver implements automatic tool name prefixing to resolve conflicts. +// All tools are prefixed with their workload identifier according to a configurable format. +type PrefixConflictResolver struct { + // PrefixFormat defines how to format the prefix. + // Supported placeholders: + // {workload} - just the workload name + // {workload}_ - workload with underscore + // {workload}. - workload with dot + // Can also be a custom static prefix like "backend_" + PrefixFormat string +} + +// NewPrefixConflictResolver creates a new prefix-based conflict resolver. +func NewPrefixConflictResolver(prefixFormat string) *PrefixConflictResolver { + if prefixFormat == "" { + prefixFormat = "{workload}_" // Default format + } + return &PrefixConflictResolver{ + PrefixFormat: prefixFormat, + } +} + +// ResolveToolConflicts applies prefix strategy to all tools. +// Returns a map of resolved tool names to ResolvedTool structs. +func (r *PrefixConflictResolver) ResolveToolConflicts( + _ context.Context, + toolsByBackend map[string][]vmcp.Tool, +) (map[string]*ResolvedTool, error) { + logger.Debugf("Resolving conflicts using prefix strategy (format: %s)", r.PrefixFormat) + + resolved := make(map[string]*ResolvedTool) + + for backendID, tools := range toolsByBackend { + for _, tool := range tools { + // Apply prefix to create resolved name + resolvedName := r.applyPrefix(backendID, tool.Name) + + // Check if this resolved name is unique + if existing, exists := resolved[resolvedName]; exists { + // This should be extremely rare with prefixing, but handle it + logger.Warnf("Collision after prefixing: %s from %s conflicts with %s from %s", + resolvedName, backendID, existing.ResolvedName, existing.BackendID) + continue + } + + resolved[resolvedName] = &ResolvedTool{ + ResolvedName: resolvedName, + OriginalName: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + BackendID: backendID, + ConflictResolutionApplied: vmcp.ConflictStrategyPrefix, + } + } + } + + logger.Infof("Prefix strategy created %d unique tools", len(resolved)) + + return resolved, nil +} + +// applyPrefix applies the configured prefix format to a tool name. +func (r *PrefixConflictResolver) applyPrefix(backendID, toolName string) string { + prefix := r.PrefixFormat + + // Replace {workload} placeholder with actual backend ID + prefix = strings.ReplaceAll(prefix, "{workload}", backendID) + + return prefix + toolName +} diff --git a/pkg/vmcp/aggregator/priority_resolver.go b/pkg/vmcp/aggregator/priority_resolver.go new file mode 100644 index 000000000..d78bdd1ee --- /dev/null +++ b/pkg/vmcp/aggregator/priority_resolver.go @@ -0,0 +1,158 @@ +package aggregator + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// PriorityConflictResolver implements priority-based conflict resolution. +// The first backend in the priority order wins; conflicting tools from +// lower-priority backends are dropped. +// +// For backends not in the priority list, conflicts are resolved using +// prefix strategy as a fallback (prevents data loss). +type PriorityConflictResolver struct { + // PriorityOrder defines the priority of backends (first has highest priority). + PriorityOrder []string + + // priorityMap is a map from backend ID to its priority index. + priorityMap map[string]int + + // prefixResolver is used as fallback for backends not in priority list. + prefixResolver *PrefixConflictResolver +} + +// NewPriorityConflictResolver creates a new priority-based conflict resolver. +func NewPriorityConflictResolver(priorityOrder []string) (*PriorityConflictResolver, error) { + if len(priorityOrder) == 0 { + return nil, fmt.Errorf("priority order cannot be empty") + } + + // Build priority map for O(1) lookups + priorityMap := make(map[string]int, len(priorityOrder)) + for i, backendID := range priorityOrder { + if backendID == "" { + return nil, fmt.Errorf("priority order contains empty backend ID at index %d", i) + } + priorityMap[backendID] = i + } + + return &PriorityConflictResolver{ + PriorityOrder: priorityOrder, + priorityMap: priorityMap, + prefixResolver: NewPrefixConflictResolver("{workload}_"), // Fallback for unmapped backends + }, nil +} + +// ResolveToolConflicts applies priority strategy to resolve conflicts. +// Returns a map of resolved tool names to ResolvedTool structs. +func (r *PriorityConflictResolver) ResolveToolConflicts( + _ context.Context, + toolsByBackend map[string][]vmcp.Tool, +) (map[string]*ResolvedTool, error) { + logger.Debugf("Resolving conflicts using priority strategy (order: %v)", r.PriorityOrder) + + resolved := make(map[string]*ResolvedTool) + droppedTools := 0 + + // First pass: collect all tools grouped by name + toolsByName := groupToolsByName(toolsByBackend) + + // Second pass: resolve conflicts using priority + for toolName, candidates := range toolsByName { + if len(candidates) == 1 { + // No conflict - include the tool as-is + candidate := candidates[0] + resolved[toolName] = &ResolvedTool{ + ResolvedName: toolName, + OriginalName: toolName, + Description: candidate.Tool.Description, + InputSchema: candidate.Tool.InputSchema, + BackendID: candidate.BackendID, + ConflictResolutionApplied: vmcp.ConflictStrategyPriority, + } + continue + } + + // Conflict detected - choose the highest priority backend + winner := r.selectWinner(candidates) + if winner == nil { + // All candidates are from backends not in priority list + // Use prefix strategy as fallback to avoid data loss + backendIDs := make([]string, len(candidates)) + for i, c := range candidates { + backendIDs[i] = c.BackendID + } + logger.Debugf("Tool %s exists in backends %v not in priority order, using prefix fallback", + toolName, backendIDs) + + // Apply prefix strategy to these unmapped backends + for _, candidate := range candidates { + prefixedName := r.prefixResolver.applyPrefix(candidate.BackendID, toolName) + resolved[prefixedName] = &ResolvedTool{ + ResolvedName: prefixedName, + OriginalName: toolName, + Description: candidate.Tool.Description, + InputSchema: candidate.Tool.InputSchema, + BackendID: candidate.BackendID, + ConflictResolutionApplied: vmcp.ConflictStrategyPrefix, // Fallback used prefix + } + } + continue + } + + resolved[toolName] = &ResolvedTool{ + ResolvedName: toolName, + OriginalName: toolName, + Description: winner.Tool.Description, + InputSchema: winner.Tool.InputSchema, + BackendID: winner.BackendID, + ConflictResolutionApplied: vmcp.ConflictStrategyPriority, + } + + // Log dropped tools + for _, candidate := range candidates { + if candidate.BackendID != winner.BackendID { + logger.Warnf("Dropped tool %s from backend %s (lower priority than %s)", + toolName, candidate.BackendID, winner.BackendID) + droppedTools++ + } + } + } + + if droppedTools > 0 { + logger.Infof("Priority strategy: %d unique tools, %d conflicting tools dropped", + len(resolved), droppedTools) + } else { + logger.Infof("Priority strategy: %d unique tools", len(resolved)) + } + + return resolved, nil +} + +// selectWinner chooses the tool from the highest-priority backend. +// Returns nil if none of the candidates are in the priority list. +func (r *PriorityConflictResolver) selectWinner(candidates []toolWithBackend) *toolWithBackend { + var winner *toolWithBackend + winnerPriority := -1 + + for i := range candidates { + candidate := &candidates[i] + priority, exists := r.priorityMap[candidate.BackendID] + if !exists { + // Backend not in priority list - skip + continue + } + + // Lower index = higher priority + if winnerPriority == -1 || priority < winnerPriority { + winner = candidate + winnerPriority = priority + } + } + + return winner +} diff --git a/pkg/vmcp/aggregator/tool_adapter.go b/pkg/vmcp/aggregator/tool_adapter.go new file mode 100644 index 000000000..1ecc39891 --- /dev/null +++ b/pkg/vmcp/aggregator/tool_adapter.go @@ -0,0 +1,104 @@ +// Package aggregator provides capability aggregation for Virtual MCP Server. +package aggregator + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// processBackendTools applies per-backend filtering and overrides to tools. +// This is called during capability discovery, before conflict resolution. +// +// This function reuses the battle-tested logic from pkg/mcp/tool_filter.go +// by converting vmcp.Tool to mcp.SimpleTool, applying the middleware logic, +// and converting back. +func processBackendTools( + _ context.Context, + backendID string, + tools []vmcp.Tool, + workloadConfig *config.WorkloadToolConfig, +) []vmcp.Tool { + if workloadConfig == nil { + return tools // No configuration for this backend + } + + // If no filter or overrides configured, return tools as-is + if len(workloadConfig.Filter) == 0 && len(workloadConfig.Overrides) == 0 { + return tools + } + + // Build middleware options from workload config + var opts []mcp.ToolMiddlewareOption + + // Add filter if configured + if len(workloadConfig.Filter) > 0 { + opts = append(opts, mcp.WithToolsFilter(workloadConfig.Filter...)) + } + + // Build reverse map: overridden name -> original name (for lookup after processing) + reverseOverrideMap := make(map[string]string) + + // Add overrides if configured + if len(workloadConfig.Overrides) > 0 { + for originalName, override := range workloadConfig.Overrides { + if override != nil { + opts = append(opts, mcp.WithToolsOverride(originalName, override.Name, override.Description)) + // Track the mapping from overridden name back to original name + if override.Name != "" { + reverseOverrideMap[override.Name] = originalName + } + } + } + } + + // Convert vmcp.Tool to mcp.SimpleTool + simpleTools := make([]mcp.SimpleTool, len(tools)) + originalToolsByName := make(map[string]vmcp.Tool, len(tools)) + for i, tool := range tools { + simpleTools[i] = mcp.SimpleTool{ + Name: tool.Name, + Description: tool.Description, + } + originalToolsByName[tool.Name] = tool + } + + // Apply the shared filtering/override logic from pkg/mcp + processed, err := mcp.ApplyToolFiltering(opts, simpleTools) + if err != nil { + logger.Warnf("Failed to apply tool filtering for backend %s: %v", backendID, err) + return tools // Return original tools if processing fails + } + + // Convert back to vmcp.Tool, preserving InputSchema and BackendID + result := make([]vmcp.Tool, 0, len(processed)) + for _, simpleTool := range processed { + // Find the original tool name (before any override) + originalName := simpleTool.Name + if revName, wasOverridden := reverseOverrideMap[simpleTool.Name]; wasOverridden { + originalName = revName + } + + // Look up the original tool to preserve InputSchema and BackendID + originalTool, exists := originalToolsByName[originalName] + if !exists { + // This should not happen unless there's a bug in the filtering logic, + // but skip the tool rather than panicking + logger.Warnf("Tool %s not found in original tools map for backend %s, skipping", originalName, backendID) + continue + } + + // Construct the result tool with processed name/description but original schema + result = append(result, vmcp.Tool{ + Name: simpleTool.Name, // Use the processed (potentially overridden) name + Description: simpleTool.Description, // Use the processed (potentially overridden) description + InputSchema: originalTool.InputSchema, + BackendID: backendID, // Use the backendID parameter (source of truth) + }) + } + + return result +} diff --git a/pkg/vmcp/aggregator/tool_adapter_test.go b/pkg/vmcp/aggregator/tool_adapter_test.go new file mode 100644 index 000000000..a1ff4ac79 --- /dev/null +++ b/pkg/vmcp/aggregator/tool_adapter_test.go @@ -0,0 +1,162 @@ +package aggregator + +import ( + "context" + "testing" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +func TestProcessBackendTools(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + backendID string + tools []vmcp.Tool + workloadConfig *config.WorkloadToolConfig + wantCount int + wantNames []string + }{ + { + name: "no configuration - all tools pass through", + backendID: "github", + tools: []vmcp.Tool{ + {Name: "create_pr", Description: "Create PR", InputSchema: map[string]any{"type": "object"}, BackendID: "github"}, + {Name: "merge_pr", Description: "Merge PR", InputSchema: map[string]any{"type": "object"}, BackendID: "github"}, + }, + workloadConfig: nil, + wantCount: 2, + wantNames: []string{"create_pr", "merge_pr"}, + }, + { + name: "filter only specific tools", + backendID: "github", + tools: []vmcp.Tool{ + {Name: "create_pr", Description: "Create PR", BackendID: "github"}, + {Name: "merge_pr", Description: "Merge PR", BackendID: "github"}, + {Name: "list_prs", Description: "List PRs", BackendID: "github"}, + }, + workloadConfig: &config.WorkloadToolConfig{ + Workload: "github", + Filter: []string{"create_pr", "merge_pr"}, + }, + wantCount: 2, + wantNames: []string{"create_pr", "merge_pr"}, + }, + { + name: "override tool names", + backendID: "github", + tools: []vmcp.Tool{ + {Name: "create_issue", Description: "Create issue", InputSchema: map[string]any{"type": "object"}, BackendID: "github"}, + {Name: "list_repos", Description: "List repos", BackendID: "github"}, + }, + workloadConfig: &config.WorkloadToolConfig{ + Workload: "github", + Overrides: map[string]*config.ToolOverride{ + "create_issue": {Name: "gh_create_issue", Description: "Create GitHub issue"}, + }, + }, + wantCount: 2, + wantNames: []string{"gh_create_issue", "list_repos"}, + }, + { + name: "filter and override combined", + backendID: "github", + tools: []vmcp.Tool{ + {Name: "create_pr", Description: "Create PR", BackendID: "github"}, + {Name: "merge_pr", Description: "Merge PR", BackendID: "github"}, + {Name: "delete_pr", Description: "Delete PR", BackendID: "github"}, + }, + workloadConfig: &config.WorkloadToolConfig{ + Workload: "github", + // Filter uses user-facing names (after override) + Filter: []string{"gh_create_pr", "merge_pr"}, + Overrides: map[string]*config.ToolOverride{ + "create_pr": {Name: "gh_create_pr"}, + }, + }, + wantCount: 2, + wantNames: []string{"gh_create_pr", "merge_pr"}, + }, + { + name: "description override only", + backendID: "github", + tools: []vmcp.Tool{ + {Name: "create_pr", Description: "Original description", BackendID: "github"}, + }, + workloadConfig: &config.WorkloadToolConfig{ + Workload: "github", + Overrides: map[string]*config.ToolOverride{ + "create_pr": {Description: "Updated description"}, + }, + }, + wantCount: 1, + wantNames: []string{"create_pr"}, + }, + { + name: "preserves InputSchema and BackendID", + backendID: "backend1", + tools: []vmcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{"param": map[string]any{"type": "string"}}}, + BackendID: "backend1", + }, + }, + workloadConfig: &config.WorkloadToolConfig{ + Workload: "backend1", + Overrides: map[string]*config.ToolOverride{ + "tool1": {Name: "renamed_tool1"}, + }, + }, + wantCount: 1, + wantNames: []string{"renamed_tool1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := processBackendTools(context.Background(), tt.backendID, tt.tools, tt.workloadConfig) + + if len(result) != tt.wantCount { + t.Errorf("got %d tools, want %d", len(result), tt.wantCount) + } + + // Check expected tool names are present + resultNames := make(map[string]bool) + for _, tool := range result { + resultNames[tool.Name] = true + } + + for _, wantName := range tt.wantNames { + if !resultNames[wantName] { + t.Errorf("expected tool %q not found in results", wantName) + } + } + + // Verify InputSchema and BackendID are preserved + for i, resultTool := range result { + if resultTool.InputSchema != nil { + // Find original tool to verify schema preservation + for _, origTool := range tt.tools { + if origTool.InputSchema != nil { + // Schema should be preserved (same reference) + if len(resultTool.InputSchema) == 0 && len(origTool.InputSchema) > 0 { + t.Errorf("tool %d lost InputSchema", i) + } + } + } + } + + if resultTool.BackendID != tt.backendID { + t.Errorf("tool %d has BackendID %q, want %q", i, resultTool.BackendID, tt.backendID) + } + } + }) + } +}