Skip to content

Commit

Permalink
uniformize
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 2, 2023
1 parent dc657a2 commit 4ad92aa
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class ORTConfigManager:
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"big-bird": "bert",
# "bigbird-pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
Expand All @@ -112,17 +112,17 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt_bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gpt-bigcode": "gpt2",
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"m2m-100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
Expand All @@ -134,6 +134,7 @@ class ORTConfigManager:

@classmethod
def get_model_ort_type(cls, model_type: str) -> str:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]

Expand Down Expand Up @@ -161,7 +162,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"vit",
"swin",
]

model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise NotImplementedError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {list(cls._conf.keys())} are supported. "
Expand Down

0 comments on commit 4ad92aa

Please sign in to comment.