From 315b1906413dc6f9f9f318450b1f8cac8a89a4f3 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 18 Apr 2023 20:36:39 +0800 Subject: [PATCH] Fix SIR cache when call func many times (#21) * Fix SIR cache when call func many times * Rename ut * Polish code * Remove bind * Rename var and fun --- symbolic_trace/trace_cache_entrance.py | 44 ++++++++++++++----- tests/error_test_sir_call.py | 8 ++-- ...est_trace_cache.py => test_trace_cache.py} | 0 3 files changed, 36 insertions(+), 16 deletions(-) rename tests/{error_test_trace_cache.py => test_trace_cache.py} (100%) diff --git a/symbolic_trace/trace_cache_entrance.py b/symbolic_trace/trace_cache_entrance.py index 359d3524c..c1648014b 100644 --- a/symbolic_trace/trace_cache_entrance.py +++ b/symbolic_trace/trace_cache_entrance.py @@ -9,16 +9,42 @@ def trace_cache(func): + @no_eval_frame def call_with_cache(*args, **kwargs): - args, kwargs = convert_arguments(args), convert_arguments(kwargs) + args, kwargs = convert_arguments(args), convert_arguments(kwargs) + args, kwargs, outter_names = construct_inner_proxy_tensor(func.__name__, *args, **kwargs) + if frame_enter(func.__name__, args): - return cache_and_return(func.__name__, args) + return cache_and_return(func.__name__, outter_names) ret = func(*args) - frame_leave(func.__name__, ret) + frame_leave(func.__name__, outter_names, ret) return ret return call_with_cache +def construct_inner_proxy_tensor(func_name, *args, **kwargs): + flat_args = paddle.utils.flatten(args) + flat_kwargs = paddle.utils.flatten(kwargs) + outter_names = [] + name_i = 0 + for i, v in enumerate(flat_args): + if isinstance(v, ProxyTensor): + name = '{}_input_{}'.format(func_name, name_i) + outter_names.append(v.name) + flat_args[i] = ProxyTensor(name, v.meta) + name_i = name_i + 1 + for i, v in enumerate(flat_kwargs): + if isinstance(v, ProxyTensor): + name = '{}_input_{}'.format(func_name, name_i) + outter_names.append(v.name) + flat_kwargs[i] = ProxyTensor(name, v.meta) + name_i = name_i + 1 + + args = paddle.utils.pack_sequence_as(args, flat_args) + kwargs = paddle.utils.pack_sequence_as(kwargs, flat_kwargs) + + return args, kwargs, outter_names + @no_eval_frame # should generate a unique name for every function def frame_enter(name, inputs): @@ -52,7 +78,7 @@ def frame_enter(name, inputs): @no_eval_frame -def frame_leave(name, outputs): +def frame_leave(name, outter_names, outputs): key_name = SymbolicTraceContext().sir_key_stack[-1] SymbolicTraceContext().sir_key_stack.pop() @@ -88,13 +114,13 @@ def frame_leave(name, outputs): return # at the first time, the inputs and outputs need not change - SymbolicTraceContext().call_SIR(cur_sir.name, cur_sir.inputs, cur_sir.outputs) + SymbolicTraceContext().call_SIR(cur_sir.name, [Symbol(name) for name in outter_names], cur_sir.outputs) log(1, cur_sir, "\n") return @no_eval_frame -def cache_and_return(name, inputs): +def cache_and_return(name, outter_names): key_name = SymbolicTraceContext().sir_key_stack[-1] SymbolicTraceContext().sir_key_stack.pop() @@ -102,10 +128,6 @@ def cache_and_return(name, inputs): cached_sir = SymbolicTraceContext().statement_factory[key_name] origin_outputs = SIRRuntimeCache().get_origin_outputs(key_name) - # gen call_SIR inputs - flat_inputs = paddle.utils.flatten(inputs) - symbol_inputs = [Symbol(x.name) for x in flat_inputs if isinstance(x, ProxyTensor)] - # create return value outputs = gen_new_proxy_tensor_output(origin_outputs) @@ -114,7 +136,7 @@ def cache_and_return(name, inputs): symbol_outputs = [Symbol(x.name) for x in flat_outputs if isinstance(x, ProxyTensor)] # add call_SIR - SymbolicTraceContext().call_SIR(cached_sir.name, symbol_inputs, symbol_outputs) + SymbolicTraceContext().call_SIR(cached_sir.name, [Symbol(name) for name in outter_names], symbol_outputs) return outputs diff --git a/tests/error_test_sir_call.py b/tests/error_test_sir_call.py index 42ffea9a6..6402fa6a5 100644 --- a/tests/error_test_sir_call.py +++ b/tests/error_test_sir_call.py @@ -1,14 +1,12 @@ import unittest import paddle from symbolic_trace import symbolic_trace -from symbolic_trace.trace_cache_entrance import frame_enter, frame_leave, cache_and_return +from symbolic_trace.trace_cache_entrance import trace_cache +@trace_cache def sum(x, y): - if frame_enter("sum", (x, y)): - return cache_and_return("sum", (x, y)) ret = x + y - frame_leave("sum", (ret)) return ret def main(x, y): @@ -21,7 +19,7 @@ def test_return_callable(self): x = paddle.to_tensor([1.0]) y = paddle.to_tensor([2.0]) ret = symbolic_trace(main)(x, y) - assert (ret.item() == 3.0), "Should be 4.0" + assert (ret.item() == 3.0), "Should be 3.0" if __name__ == "__main__": unittest.main() diff --git a/tests/error_test_trace_cache.py b/tests/test_trace_cache.py similarity index 100% rename from tests/error_test_trace_cache.py rename to tests/test_trace_cache.py