Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 19, 2023
1 parent 1b47093 commit 35caaf2
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,8 @@ def test_custom_export_official_model(self):
assert "decoder_attentions.0" in output_names
assert "cross_attentions.0" in output_names

@parameterized.expand(
grid_parameters({"fn_get_submodels": (None, fn_get_submodels_custom), "legacy": (True, False)})
)
def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy):
@parameterized.expand([(None,), (fn_get_submodels_custom,)])
def test_custom_export_trust_remote(self, fn_get_submodels):
model_id = "fxmarty/tiny-mpt-random-remote-code"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
onnx_config = CustomMPTOnnxConfig(
Expand All @@ -583,28 +581,29 @@ def test_custom_export_trust_remote(self, test_name, fn_get_submodels, legacy):
)
onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True)

if legacy:
custom_onnx_configs = {
"decoder_model": onnx_config,
"decoder_with_past_model": onnx_config_with_past,
}
else:
custom_onnx_configs = {
"model": onnx_config_with_past,
}
for legacy in (True, False):
if legacy:
custom_onnx_configs = {
"decoder_model": onnx_config,
"decoder_with_past_model": onnx_config_with_past,
}
else:
custom_onnx_configs = {
"model": onnx_config_with_past,
}

with TemporaryDirectory() as tmpdirname:
main_export(
model_id,
output=tmpdirname,
task="text-generation-with-past",
trust_remote_code=True,
custom_onnx_configs=custom_onnx_configs,
no_post_process=True,
fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None,
legacy=legacy,
opset=14,
)
with TemporaryDirectory() as tmpdirname:
main_export(
model_id,
output=tmpdirname,
task="text-generation-with-past",
trust_remote_code=True,
custom_onnx_configs=custom_onnx_configs,
no_post_process=True,
fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None,
legacy=legacy,
opset=14,
)

def test_custom_export_trust_remote_error(self):
model_id = "mohitsha/tiny-ernie-random-remote-code"
Expand Down

0 comments on commit 35caaf2

Please sign in to comment.