Skip to content

Commit 7a0221b

Browse files
feat: Support generic categories and MMLR-Pro mapping (#192)
1 parent b7888be commit 7a0221b

File tree

7 files changed

+368
-53
lines changed

7 files changed

+368
-53
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Example: Using generic categories with MMLU-Pro mapping
2+
# This file demonstrates how to declare free-style categories and map them to
3+
# MMLU-Pro categories expected by the classifier model.
4+
5+
bert_model:
6+
model_id: sentence-transformers/all-MiniLM-L12-v2
7+
threshold: 0.6
8+
use_cpu: true
9+
10+
classifier:
11+
category_model:
12+
model_id: "models/category_classifier_modernbert-base_model"
13+
use_modernbert: true
14+
threshold: 0.6
15+
use_cpu: true
16+
category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json"
17+
18+
# Define your generic categories and map them to MMLU-Pro categories.
19+
# The classifier will translate predicted MMLU categories into these generic names.
20+
categories:
21+
- name: tech
22+
mmlu_categories: ["computer science", "engineering"]
23+
model_scores:
24+
- model: phi4
25+
score: 0.9
26+
- model: mistral-small3.1
27+
score: 0.7
28+
- name: finance
29+
mmlu_categories: ["economics"]
30+
model_scores:
31+
- model: gemma3:27b
32+
score: 0.8
33+
- name: politics
34+
# If omitted, identity mapping applies when this name matches MMLU
35+
model_scores:
36+
- model: gemma3:27b
37+
score: 0.6
38+
39+
# A default model is recommended for fallback
40+
default_model: mistral-small3.1

src/semantic-router/pkg/config/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ type Category struct {
272272
ReasoningDescription string `yaml:"reasoning_description,omitempty"`
273273
ReasoningEffort string `yaml:"reasoning_effort,omitempty"` // Configurable reasoning effort level (low, medium, high)
274274
ModelScores []ModelScore `yaml:"model_scores"`
275+
// MMLUCategories optionally maps this generic category to one or more MMLU-Pro categories
276+
// used by the classifier model. When provided, classifier outputs will be translated
277+
// from these MMLU categories to this generic category name.
278+
MMLUCategories []string `yaml:"mmlu_categories,omitempty"`
275279
}
276280

277281
// Legacy types - can be removed once migration is complete
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package config_test
2+
3+
import (
4+
. "github.com/onsi/ginkgo/v2"
5+
. "github.com/onsi/gomega"
6+
7+
"gopkg.in/yaml.v3"
8+
9+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
10+
)
11+
12+
var _ = Describe("MMLU categories in config YAML", func() {
13+
It("should unmarshal mmlu_categories into Category struct", func() {
14+
yamlContent := `
15+
categories:
16+
- name: "tech"
17+
mmlu_categories: ["computer science", "engineering"]
18+
model_scores:
19+
- model: "phi4"
20+
score: 0.9
21+
use_reasoning: false
22+
- name: "finance"
23+
mmlu_categories: ["economics"]
24+
model_scores:
25+
- model: "gemma3:27b"
26+
score: 0.8
27+
use_reasoning: true
28+
- name: "politics"
29+
model_scores:
30+
- model: "gemma3:27b"
31+
score: 0.6
32+
use_reasoning: false
33+
`
34+
35+
var cfg config.RouterConfig
36+
Expect(yaml.Unmarshal([]byte(yamlContent), &cfg)).To(Succeed())
37+
38+
Expect(cfg.Categories).To(HaveLen(3))
39+
40+
Expect(cfg.Categories[0].Name).To(Equal("tech"))
41+
Expect(cfg.Categories[0].MMLUCategories).To(ConsistOf("computer science", "engineering"))
42+
Expect(cfg.Categories[0].ModelScores).ToNot(BeEmpty())
43+
44+
Expect(cfg.Categories[1].Name).To(Equal("finance"))
45+
Expect(cfg.Categories[1].MMLUCategories).To(ConsistOf("economics"))
46+
47+
Expect(cfg.Categories[2].Name).To(Equal("politics"))
48+
Expect(cfg.Categories[2].MMLUCategories).To(BeEmpty())
49+
})
50+
})

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ type Classifier struct {
209209
CategoryMapping *CategoryMapping
210210
PIIMapping *PIIMapping
211211
JailbreakMapping *JailbreakMapping
212+
213+
// Category name mapping layer to support generic categories in config
214+
// Maps MMLU-Pro category names -> generic category names (as defined in config.Categories)
215+
MMLUToGeneric map[string]string
216+
// Maps generic category names -> MMLU-Pro category names
217+
GenericToMMLU map[string][]string
212218
}
213219

214220
type option func(*Classifier)
@@ -272,6 +278,9 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
272278
option(classifier)
273279
}
274280

281+
// Build category name mappings to support generic categories in config
282+
classifier.buildCategoryNameMappings()
283+
275284
return initModels(classifier)
276285
}
277286

@@ -331,18 +340,21 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
331340
return "", float64(result.Confidence), nil
332341
}
333342

334-
// Convert class index to category name
343+
// Convert class index to category name (MMLU-Pro)
335344
categoryName, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class)
336345
if !ok {
337346
observability.Warnf("Class index %d not found in category mapping", result.Class)
338347
return "", float64(result.Confidence), nil
339348
}
340349

341-
// Record the category classification metric
342-
metrics.RecordCategoryClassification(categoryName)
350+
// Translate to generic category if mapping is configured
351+
genericCategory := c.translateMMLUToGeneric(categoryName)
343352

344-
observability.Infof("Classified as category: %s", categoryName)
345-
return categoryName, float64(result.Confidence), nil
353+
// Record the category classification metric using generic name when available
354+
metrics.RecordCategoryClassification(genericCategory)
355+
356+
observability.Infof("Classified as category: %s (mmlu=%s)", genericCategory, categoryName)
357+
return genericCategory, float64(result.Confidence), nil
346358
}
347359

348360
// IsJailbreakEnabled checks if jailbreak detection is enabled and properly configured
@@ -485,11 +497,11 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
485497
observability.Infof("Classification result: class=%d, confidence=%.4f, entropy_available=%t",
486498
result.Class, result.Confidence, len(result.Probabilities) > 0)
487499

488-
// Get category names for all classes
500+
// Get category names for all classes and translate to generic names when configured
489501
categoryNames := make([]string, len(result.Probabilities))
490502
for i := range result.Probabilities {
491503
if name, ok := c.CategoryMapping.GetCategoryFromIndex(i); ok {
492-
categoryNames[i] = name
504+
categoryNames[i] = c.translateMMLUToGeneric(name)
493505
} else {
494506
categoryNames[i] = fmt.Sprintf("unknown_%d", i)
495507
}
@@ -580,20 +592,21 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
580592
return "", float64(result.Confidence), reasoningDecision, nil
581593
}
582594

583-
// Convert class index to category name
595+
// Convert class index to category name and translate to generic
584596
categoryName, ok := c.CategoryMapping.GetCategoryFromIndex(result.Class)
585597
if !ok {
586598
observability.Warnf("Class index %d not found in category mapping", result.Class)
587599
return "", float64(result.Confidence), reasoningDecision, nil
588600
}
601+
genericCategory := c.translateMMLUToGeneric(categoryName)
589602

590603
// Record the category classification metric
591-
metrics.RecordCategoryClassification(categoryName)
604+
metrics.RecordCategoryClassification(genericCategory)
592605

593-
observability.Infof("Classified as category: %s, reasoning_decision: use=%t, confidence=%.3f, reason=%s",
594-
categoryName, reasoningDecision.UseReasoning, reasoningDecision.Confidence, reasoningDecision.DecisionReason)
606+
observability.Infof("Classified as category: %s (mmlu=%s), reasoning_decision: use=%t, confidence=%.3f, reason=%s",
607+
genericCategory, categoryName, reasoningDecision.UseReasoning, reasoningDecision.Confidence, reasoningDecision.DecisionReason)
595608

596-
return categoryName, float64(result.Confidence), reasoningDecision, nil
609+
return genericCategory, float64(result.Confidence), reasoningDecision, nil
597610
}
598611

599612
// ClassifyPII performs PII token classification on the given text and returns detected PII types
@@ -772,6 +785,51 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
772785
return nil
773786
}
774787

788+
// buildCategoryNameMappings builds translation maps between MMLU-Pro and generic categories
789+
func (c *Classifier) buildCategoryNameMappings() {
790+
c.MMLUToGeneric = make(map[string]string)
791+
c.GenericToMMLU = make(map[string][]string)
792+
793+
// Build set of known MMLU-Pro categories from the model mapping (if available)
794+
knownMMLU := make(map[string]bool)
795+
if c.CategoryMapping != nil {
796+
for _, label := range c.CategoryMapping.IdxToCategory {
797+
knownMMLU[strings.ToLower(label)] = true
798+
}
799+
}
800+
801+
for _, cat := range c.Config.Categories {
802+
if len(cat.MMLUCategories) > 0 {
803+
for _, mmlu := range cat.MMLUCategories {
804+
key := strings.ToLower(mmlu)
805+
c.MMLUToGeneric[key] = cat.Name
806+
c.GenericToMMLU[cat.Name] = append(c.GenericToMMLU[cat.Name], mmlu)
807+
}
808+
} else {
809+
// Fallback: identity mapping when the generic name matches an MMLU category
810+
nameLower := strings.ToLower(cat.Name)
811+
if knownMMLU[nameLower] {
812+
c.MMLUToGeneric[nameLower] = cat.Name
813+
c.GenericToMMLU[cat.Name] = append(c.GenericToMMLU[cat.Name], cat.Name)
814+
}
815+
}
816+
}
817+
}
818+
819+
// translateMMLUToGeneric translates an MMLU-Pro category to a generic category if mapping exists
820+
func (c *Classifier) translateMMLUToGeneric(mmluCategory string) string {
821+
if mmluCategory == "" {
822+
return ""
823+
}
824+
if c.MMLUToGeneric == nil {
825+
return mmluCategory
826+
}
827+
if generic, ok := c.MMLUToGeneric[strings.ToLower(mmluCategory)]; ok {
828+
return generic
829+
}
830+
return mmluCategory
831+
}
832+
775833
// selectBestModelInternal performs the core model selection logic
776834
//
777835
// modelFilter is optional - if provided, only models passing the filter will be considered
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package classification
2+
3+
import (
4+
. "github.com/onsi/ginkgo/v2"
5+
. "github.com/onsi/gomega"
6+
7+
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
8+
"github.com/vllm-project/semantic-router/src/semantic-router/pkg/config"
9+
)
10+
11+
var _ = Describe("generic category mapping (MMLU-Pro -> generic)", func() {
12+
var (
13+
classifier *Classifier
14+
mockCategoryInitializer *MockCategoryInitializer
15+
mockCategoryModel *MockCategoryInference
16+
)
17+
18+
BeforeEach(func() {
19+
mockCategoryInitializer = &MockCategoryInitializer{InitError: nil}
20+
mockCategoryModel = &MockCategoryInference{}
21+
22+
cfg := &config.RouterConfig{}
23+
cfg.Classifier.CategoryModel.ModelID = "model-id"
24+
cfg.Classifier.CategoryModel.CategoryMappingPath = "category-mapping-path"
25+
cfg.Classifier.CategoryModel.Threshold = 0.5
26+
27+
// Define generic categories with MMLU-Pro mappings
28+
cfg.Categories = []config.Category{
29+
{
30+
Name: "tech",
31+
MMLUCategories: []string{"computer science", "engineering"},
32+
ModelScores: []config.ModelScore{{Model: "phi4", Score: 0.9, UseReasoning: config.BoolPtr(false)}},
33+
ReasoningEffort: "low",
34+
},
35+
{
36+
Name: "finance",
37+
MMLUCategories: []string{"economics"},
38+
ModelScores: []config.ModelScore{{Model: "gemma3:27b", Score: 0.8, UseReasoning: config.BoolPtr(true)}},
39+
},
40+
{
41+
Name: "politics",
42+
// No explicit mmlu_categories -> identity fallback when label exists in mapping
43+
ModelScores: []config.ModelScore{{Model: "gemma3:27b", Score: 0.6, UseReasoning: config.BoolPtr(false)}},
44+
},
45+
}
46+
47+
// Category mapping represents labels coming from the MMLU-Pro model
48+
categoryMapping := &CategoryMapping{
49+
CategoryToIdx: map[string]int{
50+
"computer science": 0,
51+
"economics": 1,
52+
"politics": 2,
53+
},
54+
IdxToCategory: map[string]string{
55+
"0": "Computer Science", // different case to assert case-insensitive mapping
56+
"1": "economics",
57+
"2": "politics",
58+
},
59+
}
60+
61+
var err error
62+
classifier, err = newClassifierWithOptions(
63+
cfg,
64+
withCategory(categoryMapping, mockCategoryInitializer, mockCategoryModel),
65+
)
66+
Expect(err).ToNot(HaveOccurred())
67+
})
68+
69+
It("builds expected MMLU<->generic maps", func() {
70+
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("computer science", "tech"))
71+
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("engineering", "tech"))
72+
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("economics", "finance"))
73+
// identity fallback for a generic name that exists as an MMLU label
74+
Expect(classifier.MMLUToGeneric).To(HaveKeyWithValue("politics", "politics"))
75+
76+
Expect(classifier.GenericToMMLU).To(HaveKey("tech"))
77+
Expect(classifier.GenericToMMLU["tech"]).To(ConsistOf("computer science", "engineering"))
78+
Expect(classifier.GenericToMMLU).To(HaveKeyWithValue("finance", ConsistOf("economics")))
79+
Expect(classifier.GenericToMMLU).To(HaveKeyWithValue("politics", ConsistOf("politics")))
80+
})
81+
82+
It("translates ClassifyCategory result to generic category", func() {
83+
// Model returns class index 0 -> "Computer Science" (MMLU) which maps to generic "tech"
84+
mockCategoryModel.classifyResult = candle_binding.ClassResult{Class: 0, Confidence: 0.92}
85+
86+
category, score, err := classifier.ClassifyCategory("This text is about GPUs and compilers")
87+
Expect(err).ToNot(HaveOccurred())
88+
Expect(category).To(Equal("tech"))
89+
Expect(score).To(BeNumerically("~", 0.92, 0.001))
90+
})
91+
92+
It("translates names in entropy flow and returns generic top category", func() {
93+
// Probabilities favor index 0 -> generic should be "tech"
94+
mockCategoryModel.classifyWithProbsResult = candle_binding.ClassResultWithProbs{
95+
Class: 0,
96+
Confidence: 0.88,
97+
Probabilities: []float32{0.7, 0.2, 0.1},
98+
NumClasses: 3,
99+
}
100+
101+
category, confidence, decision, err := classifier.ClassifyCategoryWithEntropy("Economic policies in computer science education")
102+
Expect(err).ToNot(HaveOccurred())
103+
Expect(category).To(Equal("tech"))
104+
Expect(confidence).To(BeNumerically("~", 0.88, 0.001))
105+
Expect(decision.TopCategories).ToNot(BeEmpty())
106+
Expect(decision.TopCategories[0].Category).To(Equal("tech"))
107+
})
108+
109+
It("falls back to identity when no mapping exists for an MMLU label", func() {
110+
// index 2 -> "politics" (no explicit mapping provided, but present in MMLU set)
111+
mockCategoryModel.classifyResult = candle_binding.ClassResult{Class: 2, Confidence: 0.91}
112+
113+
category, score, err := classifier.ClassifyCategory("This is a political debate")
114+
Expect(err).ToNot(HaveOccurred())
115+
Expect(category).To(Equal("politics"))
116+
Expect(score).To(BeNumerically("~", 0.91, 0.001))
117+
})
118+
})

0 commit comments

Comments
 (0)