Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Jul 4, 2023
1 parent 71bc3ba commit 2bce56a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 1 addition & 2 deletions sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,14 @@ def infer_meta_for_layer(layer, *args, **kwargs):
layer, paddle.nn.Layer
), f"Expect a Layer, but got {layer}."
layer = paddle.jit.to_static(layer, enable_fallback=False)

args_, kwargs_ = convert_meta_to_input_spec(
args
), convert_meta_to_input_spec(kwargs)

(
concrete_program,
partial_program_layer,
) = layer.forward.get_concrete_program(*args_, **kwargs_)

out = partial_program_layer._restore_out(
convert_variable_to_meta_info(concrete_program.outputs)
)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_instruction_translator_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def fake_inner_fn_5():


def mock_start_translate(frame: types.FrameType, **kwargs):
true_lambda = lambda frame: True
true_lambda.expr = "lambda frame: True"
false_lambda = lambda frame: False
false_lambda.expr = "lambda frame: False"
translate_map = {
FRAME_1: (FRAME_2.f_code, lambda frame: True),
FRAME_3: (FRAME_4.f_code, lambda frame: False), # Always re-compile
FRAME_1: (FRAME_2.f_code, true_lambda),
FRAME_3: (FRAME_4.f_code, false_lambda), # Always re-compile
FRAME_5: None,
}
return translate_map[frame]
Expand Down

0 comments on commit 2bce56a

Please sign in to comment.