From 0ffafce5968c9f56d705b5bc2edcbb15e35867f0 Mon Sep 17 00:00:00 2001 From: Masao-Someki Date: Sun, 18 Sep 2022 19:38:27 +0900 Subject: [PATCH] Add test script for RelPosAttention --- tests/test_optimize.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 75a8f66..db153b5 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -12,17 +12,18 @@ from .op_test_utils import check_op_type_count test_cases = [ - ['encoder', 'transformer', 4, 256, 3, 0, False], + ['encoder', 'transformer', 4, 256, 'Attention', 3, False], # ['encoder', 'contextual_block_transformer', 4, 256, 3, 0, False], - ['encoder', 'transformer', 4, 256, 3, 0, True], + ['encoder', 'transformer', 4, 256, 'Attention', 3, True], + ['encoder', 'conformer_rpe_latest', 4, 256, 'RelPosAttention', 3, True], # ['encoder', 'contextual_block_transformer', 4, 256, 3, 0, True], - ['decoder', 'transformer', 4, 256, 0, 6, True], - ['lm', 'transformer_pe', 4, 256, 0, 3, True], - ['lm', 'transformer', 4, 256, 0, 2, True], + ['decoder', 'transformer', 4, 256, 'CrossAttention', 6, True], + ['lm', 'transformer_pe', 4, 256, 'CrossAttention', 3, True], + ['lm', 'transformer', 4, 256, 'CrossAttention', 2, True], ] -@pytest.mark.parametrize('model_type, model_name, n_head, h_size, n_att, n_cross_att, use_custom_ort', test_cases) -def test_optimize(model_type, model_name, n_head, h_size, n_att, n_cross_att, use_custom_ort, model_export): +@pytest.mark.parametrize('model_type, model_name, n_head, h_size, node_name, node_num, use_custom_ort', test_cases) +def test_optimize(model_type, model_name, n_head, h_size, node_name, node_num, use_custom_ort, model_export): export_dir = model_export.cache_dir / 'test' / \ model_type / f'cache_{model_name}' output_dir = model_export.cache_dir / 'test' / \ @@ -45,11 +46,6 @@ def test_optimize(model_type, model_name, n_head, h_size, n_att, n_cross_att, us ) # load the optimized model and check if the number of fused nodes is correct. - nodes = {} - if n_att > 0: - nodes['Attention'] = n_att - if n_cross_att > 0: - nodes['CrossAttention'] = n_cross_att - + nodes = {node_name : node_num} check_op_type_count(str(output_dir / model_name), **nodes) \ No newline at end of file