@@ -209,6 +209,12 @@ type Classifier struct {
209
209
CategoryMapping * CategoryMapping
210
210
PIIMapping * PIIMapping
211
211
JailbreakMapping * JailbreakMapping
212
+
213
+ // Category name mapping layer to support generic categories in config
214
+ // Maps MMLU-Pro category names -> generic category names (as defined in config.Categories)
215
+ MMLUToGeneric map [string ]string
216
+ // Maps generic category names -> MMLU-Pro category names
217
+ GenericToMMLU map [string ][]string
212
218
}
213
219
214
220
type option func (* Classifier )
@@ -272,6 +278,9 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
272
278
option (classifier )
273
279
}
274
280
281
+ // Build category name mappings to support generic categories in config
282
+ classifier .buildCategoryNameMappings ()
283
+
275
284
return initModels (classifier )
276
285
}
277
286
@@ -331,18 +340,21 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
331
340
return "" , float64 (result .Confidence ), nil
332
341
}
333
342
334
- // Convert class index to category name
343
+ // Convert class index to category name (MMLU-Pro)
335
344
categoryName , ok := c .CategoryMapping .GetCategoryFromIndex (result .Class )
336
345
if ! ok {
337
346
observability .Warnf ("Class index %d not found in category mapping" , result .Class )
338
347
return "" , float64 (result .Confidence ), nil
339
348
}
340
349
341
- // Record the category classification metric
342
- metrics . RecordCategoryClassification (categoryName )
350
+ // Translate to generic category if mapping is configured
351
+ genericCategory := c . translateMMLUToGeneric (categoryName )
343
352
344
- observability .Infof ("Classified as category: %s" , categoryName )
345
- return categoryName , float64 (result .Confidence ), nil
353
+ // Record the category classification metric using generic name when available
354
+ metrics .RecordCategoryClassification (genericCategory )
355
+
356
+ observability .Infof ("Classified as category: %s (mmlu=%s)" , genericCategory , categoryName )
357
+ return genericCategory , float64 (result .Confidence ), nil
346
358
}
347
359
348
360
// IsJailbreakEnabled checks if jailbreak detection is enabled and properly configured
@@ -485,11 +497,11 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
485
497
observability .Infof ("Classification result: class=%d, confidence=%.4f, entropy_available=%t" ,
486
498
result .Class , result .Confidence , len (result .Probabilities ) > 0 )
487
499
488
- // Get category names for all classes
500
+ // Get category names for all classes and translate to generic names when configured
489
501
categoryNames := make ([]string , len (result .Probabilities ))
490
502
for i := range result .Probabilities {
491
503
if name , ok := c .CategoryMapping .GetCategoryFromIndex (i ); ok {
492
- categoryNames [i ] = name
504
+ categoryNames [i ] = c . translateMMLUToGeneric ( name )
493
505
} else {
494
506
categoryNames [i ] = fmt .Sprintf ("unknown_%d" , i )
495
507
}
@@ -580,20 +592,21 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
580
592
return "" , float64 (result .Confidence ), reasoningDecision , nil
581
593
}
582
594
583
- // Convert class index to category name
595
+ // Convert class index to category name and translate to generic
584
596
categoryName , ok := c .CategoryMapping .GetCategoryFromIndex (result .Class )
585
597
if ! ok {
586
598
observability .Warnf ("Class index %d not found in category mapping" , result .Class )
587
599
return "" , float64 (result .Confidence ), reasoningDecision , nil
588
600
}
601
+ genericCategory := c .translateMMLUToGeneric (categoryName )
589
602
590
603
// Record the category classification metric
591
- metrics .RecordCategoryClassification (categoryName )
604
+ metrics .RecordCategoryClassification (genericCategory )
592
605
593
- observability .Infof ("Classified as category: %s, reasoning_decision: use=%t, confidence=%.3f, reason=%s" ,
594
- categoryName , reasoningDecision .UseReasoning , reasoningDecision .Confidence , reasoningDecision .DecisionReason )
606
+ observability .Infof ("Classified as category: %s (mmlu=%s) , reasoning_decision: use=%t, confidence=%.3f, reason=%s" ,
607
+ genericCategory , categoryName , reasoningDecision .UseReasoning , reasoningDecision .Confidence , reasoningDecision .DecisionReason )
595
608
596
- return categoryName , float64 (result .Confidence ), reasoningDecision , nil
609
+ return genericCategory , float64 (result .Confidence ), reasoningDecision , nil
597
610
}
598
611
599
612
// ClassifyPII performs PII token classification on the given text and returns detected PII types
@@ -772,6 +785,51 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
772
785
return nil
773
786
}
774
787
788
+ // buildCategoryNameMappings builds translation maps between MMLU-Pro and generic categories
789
+ func (c * Classifier ) buildCategoryNameMappings () {
790
+ c .MMLUToGeneric = make (map [string ]string )
791
+ c .GenericToMMLU = make (map [string ][]string )
792
+
793
+ // Build set of known MMLU-Pro categories from the model mapping (if available)
794
+ knownMMLU := make (map [string ]bool )
795
+ if c .CategoryMapping != nil {
796
+ for _ , label := range c .CategoryMapping .IdxToCategory {
797
+ knownMMLU [strings .ToLower (label )] = true
798
+ }
799
+ }
800
+
801
+ for _ , cat := range c .Config .Categories {
802
+ if len (cat .MMLUCategories ) > 0 {
803
+ for _ , mmlu := range cat .MMLUCategories {
804
+ key := strings .ToLower (mmlu )
805
+ c .MMLUToGeneric [key ] = cat .Name
806
+ c .GenericToMMLU [cat .Name ] = append (c .GenericToMMLU [cat .Name ], mmlu )
807
+ }
808
+ } else {
809
+ // Fallback: identity mapping when the generic name matches an MMLU category
810
+ nameLower := strings .ToLower (cat .Name )
811
+ if knownMMLU [nameLower ] {
812
+ c .MMLUToGeneric [nameLower ] = cat .Name
813
+ c .GenericToMMLU [cat .Name ] = append (c .GenericToMMLU [cat .Name ], cat .Name )
814
+ }
815
+ }
816
+ }
817
+ }
818
+
819
+ // translateMMLUToGeneric translates an MMLU-Pro category to a generic category if mapping exists
820
+ func (c * Classifier ) translateMMLUToGeneric (mmluCategory string ) string {
821
+ if mmluCategory == "" {
822
+ return ""
823
+ }
824
+ if c .MMLUToGeneric == nil {
825
+ return mmluCategory
826
+ }
827
+ if generic , ok := c .MMLUToGeneric [strings .ToLower (mmluCategory )]; ok {
828
+ return generic
829
+ }
830
+ return mmluCategory
831
+ }
832
+
775
833
// selectBestModelInternal performs the core model selection logic
776
834
//
777
835
// modelFilter is optional - if provided, only models passing the filter will be considered
0 commit comments