diff --git a/sot/infer_meta.py b/sot/infer_meta.py index 599c2f6df..21523d0f7 100644 --- a/sot/infer_meta.py +++ b/sot/infer_meta.py @@ -176,15 +176,14 @@ def infer_meta_for_layer(layer, *args, **kwargs): layer, paddle.nn.Layer ), f"Expect a Layer, but got {layer}." layer = paddle.jit.to_static(layer, enable_fallback=False) + args_, kwargs_ = convert_meta_to_input_spec( args ), convert_meta_to_input_spec(kwargs) - ( concrete_program, partial_program_layer, ) = layer.forward.get_concrete_program(*args_, **kwargs_) - out = partial_program_layer._restore_out( convert_variable_to_meta_info(concrete_program.outputs) ) diff --git a/tests/test_instruction_translator_cache.py b/tests/test_instruction_translator_cache.py index ccc53f2f8..019a6b0e3 100644 --- a/tests/test_instruction_translator_cache.py +++ b/tests/test_instruction_translator_cache.py @@ -65,9 +65,13 @@ def fake_inner_fn_5(): def mock_start_translate(frame: types.FrameType, **kwargs): + true_lambda = lambda frame: True + true_lambda.expr = "lambda frame: True" + false_lambda = lambda frame: False + false_lambda.expr = "lambda frame: False" translate_map = { - FRAME_1: (FRAME_2.f_code, lambda frame: True), - FRAME_3: (FRAME_4.f_code, lambda frame: False), # Always re-compile + FRAME_1: (FRAME_2.f_code, true_lambda), + FRAME_3: (FRAME_4.f_code, false_lambda), # Always re-compile FRAME_5: None, } return translate_map[frame]