diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 4152febb2d7..9f5178c46c0 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -101,8 +101,8 @@ class ORTConfigManager: "albert": "bert", "bart": "bart", "bert": "bert", - "big_bird": "bert", - # "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py` + "big-bird": "bert", + # "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py` "blenderbot": "bert", "bloom": "gpt2", "camembert": "bert", @@ -112,9 +112,9 @@ class ORTConfigManager: "distilbert": "bert", "electra": "bert", "gpt2": "gpt2", - "gpt_bigcode": "gpt2", - "gpt_neo": "gpt2", - "gpt_neox": "gpt2", + "gpt-bigcode": "gpt2", + "gpt-neo": "gpt2", + "gpt-neox": "gpt2", "gptj": "gpt2", # longt5 with O4 results in segmentation fault "longt5": "bert", @@ -122,7 +122,7 @@ class ORTConfigManager: "marian": "bart", "mbart": "bart", "mt5": "bart", - "m2m_100": "bart", + "m2m-100": "bart", "nystromformer": "bert", "pegasus": "bert", "roberta": "bert", @@ -134,6 +134,7 @@ class ORTConfigManager: @classmethod def get_model_ort_type(cls, model_type: str) -> str: + model_type = model_type.replace("_", "-") cls.check_supported_model(model_type) return cls._conf[model_type] @@ -161,7 +162,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config "vit", "swin", ] - + model_type = model_type.replace("_", "-") if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization): raise NotImplementedError( f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. "