From 9e308f95cd461d3370f55d097655594fb8b553d6 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Tue, 10 Oct 2023 08:38:43 +0000 Subject: [PATCH 1/4] update --- .../executor/executor_cache.py | 113 +++++++++++------- 1 file changed, 73 insertions(+), 40 deletions(-) diff --git a/sot/opcode_translator/executor/executor_cache.py b/sot/opcode_translator/executor/executor_cache.py index 79c66b78..008fe251 100644 --- a/sot/opcode_translator/executor/executor_cache.py +++ b/sot/opcode_translator/executor/executor_cache.py @@ -4,6 +4,9 @@ import types from typing import List, Tuple +import paddle + +from ...infer_meta import MetaInfo from ...profiler import EventGuard, event_register from ...psdb import NO_FALLBACK_CODES from ...utils import ( @@ -27,6 +30,8 @@ dummy_guard.expr = "lambda frame: True" dummy_guard.lambda_expr = "lambda frame: True" +ConstTypes = (int, float, str, bool, type(None)) + @Singleton class OpcodeExecutorCache: @@ -39,7 +44,7 @@ class OpcodeExecutorCache: translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits. """ - MAX_CACHE_SIZE = 20 + MAX_CACHE_SIZE = 10 cache: dict[types.CodeType, GuardedFunctions] translate_count: int @@ -47,6 +52,11 @@ def __init__(self): self.cache = {} self.translate_count = 0 + class _PlaceHolder: + pass + + self.place_holder = _PlaceHolder() + def clear(self): """ Clears the cache and resets the translate count. @@ -56,14 +66,33 @@ def clear(self): def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: code: types.CodeType = frame.f_code - if code not in self.cache: - log(2, f"[Cache]: Firstly call {code}\n") + code_key = self.get_key(frame) + if code not in self.cache or code_key not in self.cache[code]: + log(2, f"[Cache]: Firstly call {code} with code key {code_key}\n") new_custom_code, guard_fn = self.translate(frame, **kwargs) - self.cache[code] = [(new_custom_code, guard_fn)] + self.cache[code] = {code_key: [(new_custom_code, guard_fn)]} return new_custom_code - guarded_fns = self.cache[code] + guarded_fns = self.cache[code][code_key] return self.lookup(frame, guarded_fns, **kwargs) + def get_key(self, frame): + def get_code_key(name): + var = frame.f_locals[name] + if isinstance(var, ConstTypes): + return var + elif isinstance(var, paddle.Tensor): + return str(MetaInfo.from_tensor(var)) + elif isinstance(var, paddle.nn.Layer): + return id(var) + else: + return self.place_holder + + code = frame.f_code + n_args = code.co_argcount + input_names = code.co_varnames[0:n_args] + code_key = tuple(map(get_code_key, input_names)) + return code_key + @event_register("lookup") def lookup( self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs @@ -96,7 +125,7 @@ def lookup( else: log_do( 4, - self.analyse_guard_global_object(guard_fn), + analyse_guard_global_object(guard_fn), ) log( 2, @@ -104,7 +133,7 @@ def lookup( ) log_do( 2, - self.analyse_guard_error(guard_fn, frame), + analyse_guard_error(guard_fn, frame), ) except Exception as e: log(2, f"[Cache]: Guard function error: {e}\n") @@ -127,43 +156,10 @@ def translate( Returns: tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object. """ - code: types.CodeType = frame.f_code self.translate_count += 1 custom_new_code, guard_fn = start_translate(frame, **kwargs) return custom_new_code, guard_fn - def analyse_guard_global_object(self, guard_fn): - def inner(): - for key in guard_fn.__globals__.keys(): - if key.startswith("__object"): - print( - f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", - ) - - return inner - - def analyse_guard_error(self, guard_fn, frame): - def inner(): - guard_expr = guard_fn.lambda_expr - lambda_head = "lambda frame: " - guard_expr = guard_expr.replace(lambda_head, "") - guards = guard_expr.split(" and ") - for guard_str in guards: - guard = eval(lambda_head + guard_str, guard_fn.__globals__) - result = False - try: - result = guard(frame) - except Exception as e: - print( - f"[Cache]: skip checking {guard_str}\n because error occured {e}" - ) - if result is False: - print(f"[Cache]: missed at {guard_str}") - return - print("[Cache]: missed guard not found.") - - return inner - def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: """ @@ -214,3 +210,40 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e finally: simulator.cleanup() + + +# log utils + + +def analyse_guard_global_object(guard_fn): + def inner(): + for key in guard_fn.__globals__.keys(): + if key.startswith("__object"): + print( + f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", + ) + + return inner + + +def analyse_guard_error(guard_fn, frame): + def inner(): + guard_expr = guard_fn.lambda_expr + lambda_head = "lambda frame: " + guard_expr = guard_expr.replace(lambda_head, "") + guards = guard_expr.split(" and ") + for guard_str in guards: + guard = eval(lambda_head + guard_str, guard_fn.__globals__) + result = False + try: + result = guard(frame) + except Exception as e: + print( + f"[Cache]: skip checking {guard_str}\n because error occured {e}" + ) + if result is False: + print(f"[Cache]: missed at {guard_str}") + return + print("[Cache]: missed guard not found.") + + return inner From 62c745006bdba711782faa56f31893d6c32c6d6b Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Tue, 10 Oct 2023 09:18:33 +0000 Subject: [PATCH 2/4] update --- .../executor/executor_cache.py | 13 ++-- tests/test_code_status.py | 75 +++++++++---------- tests/test_trace_list_arg.py | 2 +- 3 files changed, 43 insertions(+), 47 deletions(-) diff --git a/sot/opcode_translator/executor/executor_cache.py b/sot/opcode_translator/executor/executor_cache.py index 008fe251..30d9edc4 100644 --- a/sot/opcode_translator/executor/executor_cache.py +++ b/sot/opcode_translator/executor/executor_cache.py @@ -2,7 +2,7 @@ import traceback import types -from typing import List, Tuple +from typing import Any, List, Tuple import paddle @@ -44,19 +44,18 @@ class OpcodeExecutorCache: translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits. """ + class _PlaceHolder: + pass + MAX_CACHE_SIZE = 10 - cache: dict[types.CodeType, GuardedFunctions] + cache: dict[types.CodeType, dict[tuple[Any], GuardedFunctions]] translate_count: int + place_holder = _PlaceHolder() def __init__(self): self.cache = {} self.translate_count = 0 - class _PlaceHolder: - pass - - self.place_holder = _PlaceHolder() - def clear(self): """ Clears the cache and resets the translate count. diff --git a/tests/test_code_status.py b/tests/test_code_status.py index 4a24f305..4304cb78 100644 --- a/tests/test_code_status.py +++ b/tests/test_code_status.py @@ -1,6 +1,6 @@ import unittest -from test_case_base import TestCaseBase, strict_mode_guard +from test_case_base import TestCaseBase import paddle import sot @@ -12,7 +12,7 @@ class SimpleNet1(paddle.nn.Layer): def __init__(self): super().__init__() self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] + [paddle.nn.Linear(10, 10) for _ in range(20)] ) def forward(self, x): @@ -20,8 +20,6 @@ def forward(self, x): sot.psdb.breakgraph() x = self.layers[i](x) x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) return x @@ -29,7 +27,7 @@ class SimpleNet2(paddle.nn.Layer): def __init__(self): super().__init__() self.layers = paddle.nn.LayerList( - [paddle.nn.Linear(10, 10) for _ in range(30)] + [paddle.nn.Linear(10, 10) for _ in range(20)] ) def forward(self, x): @@ -37,8 +35,6 @@ def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) x = self.layers[i](x) - x = self.layers[i](x) - x = self.layers[i](x) return x @@ -64,31 +60,32 @@ def test_case_1(self): else: assert v.state == CodeState.WITHOUT_GRAPH # run_net, forward, loop body, resumed part2 in loop body + breakpoint() assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 # resumed part1 in loop body assert ( len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 ) - def test_case_2(self): - with strict_mode_guard(0): - CodeStatus().clear() - net = SimpleNet2() - inp = paddle.rand((10, 10)) - self.assert_results(run_net, net, inp) - code_map = CodeStatus().code_map - states = [] - for k, v in code_map.items(): - if k.co_name.startswith("#") or k.co_name.startswith("$"): - states.append(v) - elif k in CodeStatus().WITH_GRAPH_API: - assert v.state == CodeState.WITH_GRAPH - else: - assert v.state == CodeState.WITHOUT_GRAPH - # no graph found because fallback (paddle api will not enter simulate) - assert ( - len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 - ) + # def test_case_2(self): + # with strict_mode_guard(0): + # CodeStatus().clear() + # net = SimpleNet2() + # inp = paddle.rand((10, 10)) + # self.assert_results(run_net, net, inp) + # code_map = CodeStatus().code_map + # states = [] + # for k, v in code_map.items(): + # if k.co_name.startswith("#") or k.co_name.startswith("$"): + # states.append(v) + # elif k in CodeStatus().WITH_GRAPH_API: + # assert v.state == CodeState.WITH_GRAPH + # else: + # assert v.state == CodeState.WITHOUT_GRAPH + # # no graph found because fallback (paddle api will not enter simulate) + # assert ( + # len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 + # ) def no_skip_func_0(x): @@ -121,19 +118,19 @@ def call_skipped_func_0(x): skip_function(call_skipped_func_0) -class TestDisableSkippedFrame(TestCaseBase): - def test_case_0(self): - CodeStatus().clear() - x = paddle.to_tensor([1]) - self.assert_results(call_skipped_func_0, x) - code_map = CodeStatus().code_map - assert ( - code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH - ) - assert ( - code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH - ) - assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH +# class TestDisableSkippedFrame(TestCaseBase): +# def test_case_0(self): +# CodeStatus().clear() +# x = paddle.to_tensor([1]) +# self.assert_results(call_skipped_func_0, x) +# code_map = CodeStatus().code_map +# assert ( +# code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH +# ) +# assert ( +# code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH +# ) +# assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH if __name__ == "__main__": diff --git a/tests/test_trace_list_arg.py b/tests/test_trace_list_arg.py index 278dfdce..d8cbc5f1 100644 --- a/tests/test_trace_list_arg.py +++ b/tests/test_trace_list_arg.py @@ -42,7 +42,7 @@ def test_bar(self): self.assert_results(bar, a, 2, 0) # Cache miss self.assertEqual(cache.translate_count, 2) self.assert_results(bar, b, 1, 1) # Cache hit - self.assertEqual(cache.translate_count, 2) + self.assertEqual(cache.translate_count, 3) if __name__ == "__main__": From 114d2fbd6553eff95a726ae080ea1cd77c259e90 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Tue, 10 Oct 2023 12:07:15 +0000 Subject: [PATCH 3/4] update --- .../executor/executor_cache.py | 17 +++- .../instruction_utils/opcode_analysis.py | 6 +- sot/opcode_translator/transform.py | 4 +- tests/test_code_status.py | 83 ++++++++++--------- 4 files changed, 66 insertions(+), 44 deletions(-) diff --git a/sot/opcode_translator/executor/executor_cache.py b/sot/opcode_translator/executor/executor_cache.py index 30d9edc4..44b41bae 100644 --- a/sot/opcode_translator/executor/executor_cache.py +++ b/sot/opcode_translator/executor/executor_cache.py @@ -48,6 +48,7 @@ class _PlaceHolder: pass MAX_CACHE_SIZE = 10 + MAX_BUCKET_SIZE = 20 cache: dict[types.CodeType, dict[tuple[Any], GuardedFunctions]] translate_count: int place_holder = _PlaceHolder() @@ -66,11 +67,23 @@ def clear(self): def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: code: types.CodeType = frame.f_code code_key = self.get_key(frame) - if code not in self.cache or code_key not in self.cache[code]: - log(2, f"[Cache]: Firstly call {code} with code key {code_key}\n") + + if code not in self.cache: + log( + 2, f"[Cache]: First time call {code} with code_key {code_key}\n" + ) new_custom_code, guard_fn = self.translate(frame, **kwargs) self.cache[code] = {code_key: [(new_custom_code, guard_fn)]} return new_custom_code + elif code_key not in self.cache[code]: + if len(self.cache[code]) >= self.MAX_BUCKET_SIZE: + log(2, "[Cache]: Exceed max bucket size, skip it\n") + return CustomCode(None, False) + log(2, f"[Cache]: Firstly call {code} with code_key {code_key}\n") + new_custom_code, guard_fn = self.translate(frame, **kwargs) + self.cache[code][code_key] = [(new_custom_code, guard_fn)] + return new_custom_code + guarded_fns = self.cache[code][code_key] return self.lookup(frame, guarded_fns, **kwargs) diff --git a/sot/opcode_translator/instruction_utils/opcode_analysis.py b/sot/opcode_translator/instruction_utils/opcode_analysis.py index e4e635ba..3f36a6c3 100644 --- a/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -185,7 +185,11 @@ def walk(state: SpaceState, start: int) -> SpaceState: assert instr.jump_to is not None target_idx = instructions.index(instr.jump_to) # Fork to two branches, jump or not - jump_branch = fork(state, i, True, target_idx) + jump_branch = ( + fork(state, i, True, target_idx) + if target_idx >= start_instr_idx and target_idx < end + else state + ) not_jump_branch = ( fork(state, i, False, target_idx) if instr.opname not in UNCONDITIONAL_JUMP diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 0c4710d7..e91b18f4 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -86,8 +86,8 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode: # just check those codes which need open eval_frame if ( - custom_code.disable_eval_frame is False - and CodeStatus().is_code_without_graph(new_code) + CodeStatus().is_code_without_graph(new_code) + and custom_code.disable_eval_frame is False ): log( 3, diff --git a/tests/test_code_status.py b/tests/test_code_status.py index 4304cb78..065901a4 100644 --- a/tests/test_code_status.py +++ b/tests/test_code_status.py @@ -1,6 +1,6 @@ import unittest -from test_case_base import TestCaseBase +from test_case_base import TestCaseBase, strict_mode_guard import paddle import sot @@ -49,6 +49,7 @@ def test_case_1(self): CodeStatus().clear() net = SimpleNet1() inp = paddle.rand((10, 10)) + inp.stop_gradient = False self.assert_results(run_net, net, inp) code_map = CodeStatus().code_map states = [] @@ -59,33 +60,37 @@ def test_case_1(self): assert v.state == CodeState.WITH_GRAPH else: assert v.state == CodeState.WITHOUT_GRAPH - # run_net, forward, loop body, resumed part2 in loop body - breakpoint() - assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 - # resumed part1 in loop body + # run_net, loop_body in run_net, forward => 3 + # (forward loop_body + resumed part in loop_body) * 20 => 40 + assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 43 + # part after loop in forward + # resumed part in loop_body of run_net assert ( - len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 + len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 2 ) - - # def test_case_2(self): - # with strict_mode_guard(0): - # CodeStatus().clear() - # net = SimpleNet2() - # inp = paddle.rand((10, 10)) - # self.assert_results(run_net, net, inp) - # code_map = CodeStatus().code_map - # states = [] - # for k, v in code_map.items(): - # if k.co_name.startswith("#") or k.co_name.startswith("$"): - # states.append(v) - # elif k in CodeStatus().WITH_GRAPH_API: - # assert v.state == CodeState.WITH_GRAPH - # else: - # assert v.state == CodeState.WITHOUT_GRAPH - # # no graph found because fallback (paddle api will not enter simulate) - # assert ( - # len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 - # ) + # part after loop in run_net, it is only called once, so UNKNOW + assert len([v for v in states if v.state == CodeState.UNKNOW]) == 1 + + def test_case_2(self): + with strict_mode_guard(0): + CodeStatus().clear() + net = SimpleNet2() + inp = paddle.rand((10, 10)) + inp.stop_gradient = False + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # no graph found because fallback (paddle api will not enter simulate) + assert ( + len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 + ) def no_skip_func_0(x): @@ -118,19 +123,19 @@ def call_skipped_func_0(x): skip_function(call_skipped_func_0) -# class TestDisableSkippedFrame(TestCaseBase): -# def test_case_0(self): -# CodeStatus().clear() -# x = paddle.to_tensor([1]) -# self.assert_results(call_skipped_func_0, x) -# code_map = CodeStatus().code_map -# assert ( -# code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH -# ) -# assert ( -# code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH -# ) -# assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH +class TestDisableSkippedFrame(TestCaseBase): + def test_case_0(self): + CodeStatus().clear() + x = paddle.to_tensor([1]) + self.assert_results(call_skipped_func_0, x) + code_map = CodeStatus().code_map + assert ( + code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert ( + code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH if __name__ == "__main__": From fd6829c1bb79507c1245318cf910861531a12ac2 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Tue, 10 Oct 2023 12:30:41 +0000 Subject: [PATCH 4/4] update --- sot/opcode_translator/executor/executor_cache.py | 3 ++- tests/test_trace_list_arg.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/executor_cache.py b/sot/opcode_translator/executor/executor_cache.py index 44b41bae..12220ab6 100644 --- a/sot/opcode_translator/executor/executor_cache.py +++ b/sot/opcode_translator/executor/executor_cache.py @@ -45,7 +45,8 @@ class OpcodeExecutorCache: """ class _PlaceHolder: - pass + def __str__(self): + return "PlaceHolder" MAX_CACHE_SIZE = 10 MAX_BUCKET_SIZE = 20 diff --git a/tests/test_trace_list_arg.py b/tests/test_trace_list_arg.py index d8cbc5f1..278dfdce 100644 --- a/tests/test_trace_list_arg.py +++ b/tests/test_trace_list_arg.py @@ -42,7 +42,7 @@ def test_bar(self): self.assert_results(bar, a, 2, 0) # Cache miss self.assertEqual(cache.translate_count, 2) self.assert_results(bar, b, 1, 1) # Cache hit - self.assertEqual(cache.translate_count, 3) + self.assertEqual(cache.translate_count, 2) if __name__ == "__main__":