Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 106 additions & 14 deletions pkg/mcp/tool_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,36 @@ func (c *toolMiddlewareConfig) getToolListOverride(toolName string) (*toolOverri
// middleware.
type ToolMiddlewareOption func(*toolMiddlewareConfig) error

// SimpleTool represents a minimal tool with name and description.
// This is used by ApplyToolFiltering to work with tools in a generic way.
type SimpleTool struct {
Name string
Description string
}

// ApplyToolFiltering applies filtering and overriding to a list of tools.
// This is the core logic used by both the HTTP middleware and other components
// that need to apply the same filtering/overriding behavior.
//
// Returns the filtered and overridden tools.
func ApplyToolFiltering(opts []ToolMiddlewareOption, tools []SimpleTool) ([]SimpleTool, error) {
config := &toolMiddlewareConfig{
filterTools: make(map[string]struct{}),
actualToUserOverride: make(map[string]toolOverrideEntry),
userToActualOverride: make(map[string]toolOverrideEntry),
}

// Apply options to build config
for _, opt := range opts {
if err := opt(config); err != nil {
return nil, err
}
}

// Use the shared core logic
return applyFilteringAndOverrides(config, tools), nil
}

// WithToolsFilter is a function that can be used to configure the tool
// middleware to use a filter list of tools.
func WithToolsFilter(toolsFilter ...string) ToolMiddlewareOption {
Expand Down Expand Up @@ -448,7 +478,10 @@ func processToolsListResponse(
toolsListResponse toolsListResponse,
w io.Writer,
) error {
filteredTools := []map[string]any{}
// Convert to SimpleTool format for shared processing
simpleTools := make([]SimpleTool, 0, len(*toolsListResponse.Result.Tools))
toolMaps := make([]map[string]any, 0, len(*toolsListResponse.Result.Tools))

for _, tool := range *toolsListResponse.Result.Tools {
// NOTE: the spec does not allow for name to be missing.
toolName, ok := tool["name"].(string)
Expand All @@ -461,31 +494,90 @@ func processToolsListResponse(
return errToolNameNotFound
}

// Get description if present (optional in MCP spec)
description, _ := tool["description"].(string)

simpleTools = append(simpleTools, SimpleTool{
Name: toolName,
Description: description,
})
toolMaps = append(toolMaps, tool)
}

// Apply the shared filtering/override logic
processedTools := applyFilteringAndOverrides(config, simpleTools)

// Build the filtered response by matching processed tools with their original maps
// Note: This is O(n²) complexity, but acceptable because:
// - Tool lists are typically small (< 100 tools per backend)
// - Only runs once during tool list retrieval (not in hot path)
// - Inner loop breaks early on match
filteredTools := make([]map[string]any, 0, len(processedTools))
for _, processed := range processedTools {
// Find the original tool map by matching names
for i, simple := range simpleTools {
if simple.Name == processed.Name || simple.Name == findOriginalName(config, processed.Name) {
// Clone the original map and update name/description
toolCopy := make(map[string]any, len(toolMaps[i]))
for k, v := range toolMaps[i] {
toolCopy[k] = v
}
toolCopy["name"] = processed.Name
if processed.Description != "" {
toolCopy["description"] = processed.Description
}
filteredTools = append(filteredTools, toolCopy)
break
}
}
}

toolsListResponse.Result.Tools = &filteredTools
if err := json.NewEncoder(w).Encode(toolsListResponse); err != nil {
return fmt.Errorf("%w: %v", errBug, err)
}

return nil
}

// applyFilteringAndOverrides is the core logic for filtering and overriding tools.
// This implements the exact same logic as before but is now extracted for reuse.
func applyFilteringAndOverrides(config *toolMiddlewareConfig, tools []SimpleTool) []SimpleTool {
result := make([]SimpleTool, 0, len(tools))
for _, tool := range tools {
description := tool.Description

// If the tool is overridden, we need to use the override name and description.
if entry, ok := config.getToolListOverride(toolName); ok {
if entry, ok := config.getToolListOverride(tool.Name); ok {
if entry.OverrideName != "" {
tool["name"] = entry.OverrideName
tool.Name = entry.OverrideName
}
if entry.OverrideDescription != "" {
tool["description"] = entry.OverrideDescription
description = entry.OverrideDescription
}
toolName = entry.OverrideName
}

// If the tool is in the filter, we add it to the filtered tools list.
// Note that lookup is done using the user-known name, which might be
// different from the actual tool name.
if config.isToolInFilter(toolName) {
filteredTools = append(filteredTools, tool)
// Note that lookup is done using the user-known name (tool.Name after override).
if config.isToolInFilter(tool.Name) {
result = append(result, SimpleTool{
Name: tool.Name,
Description: description,
})
}
}
return result
}

toolsListResponse.Result.Tools = &filteredTools
if err := json.NewEncoder(w).Encode(toolsListResponse); err != nil {
return fmt.Errorf("%w: %v", errBug, err)
// findOriginalName attempts to find the original tool name before override.
func findOriginalName(config *toolMiddlewareConfig, overriddenName string) string {
// Iterate through overrides to find reverse mapping
for actualName, entry := range config.actualToUserOverride {
if entry.OverrideName == overriddenName {
return actualName
}
}

return nil
return overriddenName
}

// toolCallFix mimics a sum type in Go. The actual types represent the
Expand Down
3 changes: 0 additions & 3 deletions pkg/vmcp/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ type AggregationMetadata struct {
// PromptCount is the total number of prompts.
PromptCount int

// ConflictsResolved is the number of conflicts that were resolved.
ConflictsResolved int

// ConflictStrategy is the strategy used for conflict resolution.
ConflictStrategy vmcp.ConflictResolutionStrategy
}
Expand Down
71 changes: 71 additions & 0 deletions pkg/vmcp/aggregator/conflict_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Package aggregator provides capability aggregation for Virtual MCP Server.
//
// This file contains the factory function for creating conflict resolvers
// and shared helper functions used by multiple resolver implementations.
package aggregator

import (
"fmt"

"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/config"
)

// NewConflictResolver creates the appropriate conflict resolver based on configuration.
func NewConflictResolver(aggregationConfig *config.AggregationConfig) (ConflictResolver, error) {
if aggregationConfig == nil {
// Default to prefix strategy with default format
logger.Infof("No aggregation config provided, using default prefix strategy")
return NewPrefixConflictResolver("{workload}_"), nil
}

switch aggregationConfig.ConflictResolution {
case vmcp.ConflictStrategyPrefix:
prefixFormat := "{workload}_" // Default
if aggregationConfig.ConflictResolutionConfig != nil &&
aggregationConfig.ConflictResolutionConfig.PrefixFormat != "" {
prefixFormat = aggregationConfig.ConflictResolutionConfig.PrefixFormat
}
logger.Infof("Using prefix conflict resolution strategy (format: %s)", prefixFormat)
return NewPrefixConflictResolver(prefixFormat), nil

case vmcp.ConflictStrategyPriority:
if aggregationConfig.ConflictResolutionConfig == nil ||
len(aggregationConfig.ConflictResolutionConfig.PriorityOrder) == 0 {
return nil, fmt.Errorf("priority strategy requires priority_order in conflict_resolution_config")
}
logger.Infof("Using priority conflict resolution strategy (order: %v)",
aggregationConfig.ConflictResolutionConfig.PriorityOrder)
return NewPriorityConflictResolver(aggregationConfig.ConflictResolutionConfig.PriorityOrder)

case vmcp.ConflictStrategyManual:
logger.Infof("Using manual conflict resolution strategy")
return NewManualConflictResolver(aggregationConfig.Tools)

default:
return nil, fmt.Errorf("%w: %s", ErrInvalidConflictStrategy, aggregationConfig.ConflictResolution)
}
}

// toolWithBackend is a helper struct to track which backend a tool comes from.
// This is shared by multiple conflict resolution strategies.
type toolWithBackend struct {
Tool vmcp.Tool
BackendID string
}

// groupToolsByName groups tools by their names to detect conflicts.
// This is shared by multiple conflict resolution strategies.
func groupToolsByName(toolsByBackend map[string][]vmcp.Tool) map[string][]toolWithBackend {
toolsByName := make(map[string][]toolWithBackend)
for backendID, tools := range toolsByBackend {
for _, tool := range tools {
toolsByName[tool.Name] = append(toolsByName[tool.Name], toolWithBackend{
Tool: tool,
BackendID: backendID,
})
}
}
return toolsByName
}
Loading
Loading