diff --git a/sot/__init__.py b/sot/__init__.py index 383cb3017..49c814b5b 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,5 +1,11 @@ +from .opcode_translator.breakpoint import BM, add_breakpoint, add_event +from .opcode_translator.skip_files import skip_function from .translate import symbolic_translate __all__ = [ "symbolic_translate", + "add_breakpoint", + "add_event", + "BM", + "skip_function", ] diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py new file mode 100644 index 000000000..ebc25d002 --- /dev/null +++ b/sot/opcode_translator/breakpoint.py @@ -0,0 +1,161 @@ +import inspect +import traceback +from dataclasses import dataclass + +from ..opcode_translator.instruction_utils import instrs_info +from ..utils import Singleton, log +from .executor.opcode_executor import OpcodeExecutorBase + +# this file is a debug utils files for quick debug +# >>> sot.add_breakpoint(file, line) +# >>> sot.remove_breakpoint(file, line) + + +@dataclass +class Breakpoint: + file: str + line: int + + def __hash__(self): + return hash((self.file, self.line)) + + +@Singleton +class BreakpointManager: + def __init__(self): + self.breakpoints = set() + self.executors = OpcodeExecutorBase.call_stack + self.active = 0 + self.record_event = [] + + def clear_event(self, event): + self.record_event.clear() + + def add_event(self, event): + """ + event in ['All' ,'NotImplementException', 'BreakGraphError', 'InnerError'] + """ + self.record_event.append(event) + + def add(self, file, line): + log(1, f"add breakpoint at {file}:{line}\n") + self.breakpoints.add(Breakpoint(file, line)) + + def addn(self, *lines): + """ + called inside a executor. add a list of line number in current file. + """ + if not isinstance(lines, (list, tuple)): + lines = [lines] + for line in lines: + file = self.cur_exe._code.co_filename + self.add(file, line) + + def clear(self): + self.breakpoints.clear() + + def hit(self, file, line): + _breakpoint = Breakpoint(file, line) + if _breakpoint in self.breakpoints: + return True + return False + + def locate(self, exe): + for i, _e in enumerate(self.executors): + if _e is exe: + self.activate = i + return + raise RuntimeError("Not found executor.") + + def up(self): + if self.activate == 0: + return + self.activate -= 1 + print("current function is: ", self.cur_exe._code.co_name) + + def down(self): + if self.activate >= len(self.executors) - 1: + return + self.activate += 1 + print("current function is: ", self.cur_exe._code.co_name) + + def opcode(self, cur_exe=None): + if cur_exe is None: + cur_exe = self.cur_exe + instr = cur_exe._instructions[cur_exe._lasti - 1] + message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" + return message + + def bt(self): + """ + display all inline calls: backtrace. + """ + for exe in self.executors: + lines, _ = inspect.getsourcelines(exe._code) + print( + " " + + exe._code.co_filename + + f"({exe._current_line})" + + f"{exe._code.co_name}()" + ) + print(f"-> {lines[0].strip()}") + print(f"-> {self.opcode(exe)}") + pass + + def on_event(self, event): + if "All" in self.record_event or event in self.record_event: + print("event captured.") + self.activate = len(self.executors) - 1 + breakpoint() + + def _dis_source_code(self): + cur_exe = self.executors[self.activate] + lines, start_line = inspect.getsourcelines(cur_exe._code) + cur_line = cur_exe._current_line + lines[ + cur_line - start_line + 1 : cur_line - start_line + 1 + ] = " ^^^^^ HERE \n" + print("\033[31mSource Code is: \033[0m") + print("".join(lines)) + + def dis(self, range=5): + """ + display all instruction code and source code. + """ + print("displaying debug info...") + cur_exe = self.cur_exe + print(f"{cur_exe._code}") + lasti = cur_exe._lasti + lines = instrs_info(cur_exe._instructions, lasti - 1, range) + print("\n".join(lines)) + print(self._dis_source_code()) + + @property + def cur_exe(self): + exe = self.executors[self.activate] + return exe + + def sir(self): + """ + display sir in a page. + """ + print("displaying sir...") + self.cur_exe.print_sir() + + def pe(self, e): + """ + print exception. + """ + lines = traceback.format_tb(e.__traceback__) + print("".join(lines)) + + +def add_breakpoint(file, line): + BM.add(file, line) + + +def add_event(event): + BM.add_event(event) + + +BM = BreakpointManager() diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index bc82f051c..f9d1d70ab 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -18,12 +18,7 @@ log, log_do, ) -from ..instruction_utils import ( - Instruction, - analysis_inputs, - get_instructions, - instrs_info, -) +from ..instruction_utils import Instruction, analysis_inputs, get_instructions from .function_graph import FunctionGraph from .guard import Guard from .instr_flag import FORMAT_VALUE_FLAG as FV @@ -428,14 +423,6 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph): self._name = "Executor" self._prepare_virtual_env() - def print_instrs(self): - """ - Prints the instructions in the executor. - - """ - print(self._code.co_name) - print(instrs_info(self._instructions, mark=self._lasti)) - def print_sir(self): """ Prints the Static Instruction Representation (SIR) in the executor. @@ -497,6 +484,8 @@ def get_var(self, name: str): return self._globals[name] elif name in self._builtins.keys(): return self._builtins[name] + elif name in self._cells.keys(): # in closure + return self._cells[name].get_value() else: raise InnerError(f'Can not get var: {name}') @@ -580,10 +569,16 @@ def step(self, instr: Instruction): raise NotImplementException( f"opcode: {instr.opname} is not supported." ) - log( - 3, - f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n", - ) + log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n" + log(3, log_message) + code_file = self._code.co_filename + code_line = self._current_line + from ..breakpoint import BreakpointManager + + if BreakpointManager().hit(code_file, code_line): + BreakpointManager().locate(self) + print(log_message) + breakpoint() # breakpoint for debug return getattr(self, instr.opname)(instr) # run single step. def indexof(self, instr: Instruction): @@ -810,9 +805,6 @@ def STORE_FAST(self, instr: Instruction): """ var = self.pop() var.debug_name = instr.argval - if instr.argval == "__breakpoint__": - print(var.value) - breakpoint() self._locals[instr.argval] = var def STORE_GLOBAL(self, instr: Instruction): @@ -1760,7 +1752,6 @@ def FOR_ITER(self, instr): ) self._graph.add_global_guarded_variable(iterator) - # TODO need support TensorIterVariable.next try: if not isinstance( diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index e1aa77b86..bd8074765 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -257,9 +257,16 @@ def replace_instr(instructions, instr, new_instr): instructions[idx, idx + 1] = new_instr -def instrs_info(instrs, mark=None): +def instrs_info(instrs, mark=None, range=None): ret = [] + start = -1 + end = 1000000 + if mark is not None and range is not None: + start = mark - range + end = mark + range + 1 for idx, instr in enumerate(instrs): + if idx < start or idx >= end: + continue if instr.starts_line is not None: ret.append("") ret.append( @@ -272,7 +279,9 @@ def instrs_info(instrs, mark=None): opname=instr.opname, arg=str(instr.arg) if instr.arg is not None else "", argval=f"({instr.argval})" if instr.argval else "", - mark=" <--- HERE" if mark == idx else "", + mark="", ) ) - return "\n".join(ret) + if idx == mark: + ret[-1] = "\033[31m" + ret[-1] + "\033[0m" + return ret diff --git a/sot/opcode_translator/skip_files.py b/sot/opcode_translator/skip_files.py index 6f2d7cce8..ecbac3514 100644 --- a/sot/opcode_translator/skip_files.py +++ b/sot/opcode_translator/skip_files.py @@ -38,6 +38,8 @@ import google.protobuf import numpy +from ..utils import log + def _strip_init_py(s): return re.sub(r"__init__.py$", "", s) @@ -105,6 +107,8 @@ def _module_dir(m: types.ModuleType): f"^({'|'.join(map(re.escape, skip_file_names))})" ) +customed_skip_code = set() + def need_skip_path(filepath: str) -> bool: """ @@ -119,3 +123,14 @@ def need_skip_path(filepath: str) -> bool: if not filepath.startswith("<"): filepath = os.path.abspath(filepath) return bool(skip_file_name_re.match(filepath)) + + +def skip_function(function): + customed_skip_code.add(function.__code__) + + +def need_skip(pycode): + if pycode in customed_skip_code: + log(3, f"Skip frame by code: {pycode}") + return True + return need_skip_path(pycode.co_filename) diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 7f3776453..65a7aff07 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -2,7 +2,7 @@ from ..utils import log, log_do from .executor.opcode_executor import InstructionTranslatorCache -from .skip_files import need_skip_path +from .skip_files import need_skip def eval_frame_callback(frame, **kwargs): @@ -10,7 +10,7 @@ def eval_frame_callback(frame, **kwargs): if frame.f_code.co_flags & 0x20 > 0: return None - if not need_skip_path(frame.f_code.co_filename): + if not need_skip(frame.f_code): log( 2, "[eval_frame_callback] start to translate: " diff --git a/sot/symbolic/compile_cache.py b/sot/symbolic/compile_cache.py index 51569c5b1..9b2897925 100644 --- a/sot/symbolic/compile_cache.py +++ b/sot/symbolic/compile_cache.py @@ -10,10 +10,11 @@ def clear_eager_tensor_name(output_tensors): class FallbackWrapper: - def __init__(self, compile_sir): - self.compile_sir = compile_sir + def __init__(self, compiled_fn, SIR): + self.compiled_fn = compiled_fn self.partial_program = None self.concrete_program = None + self.SIR = SIR # for debug def __call__(self, *args, **kwargs): """TODO: we disable partial_program cache here because some bugs in ast to_static. @@ -29,11 +30,11 @@ def __call__(self, *args, **kwargs): # TODO(xiongkun): or True is on purpose, we should remove it later after # dy2static bug is fixed. if self.partial_program is None or True: - outputs = self.compile_sir(*args, **kwargs) + outputs = self.compiled_fn(*args, **kwargs) ( self.concrete_program, self.partial_program, - ) = self.compile_sir.get_concrete_program(*args, **kwargs) + ) = self.compiled_fn.get_concrete_program(*args, **kwargs) else: # Speed up Resnet from 0.0068 --> 0.0057 outputs = self.partial_program(*args, **kwargs) @@ -64,5 +65,6 @@ def value_fn(self, context, sir_name, build_strategy): compile_sir(context, sir_name), build_strategy=build_strategy, enable_fallback=False, - ) + ), + context.get_sir(sir_name), ) diff --git a/sot/utils/exceptions.py b/sot/utils/exceptions.py index f5d55bd6a..20f660fe3 100644 --- a/sot/utils/exceptions.py +++ b/sot/utils/exceptions.py @@ -1,5 +1,16 @@ +import traceback + + class FallbackErrorBase(Exception): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from ..opcode_translator.breakpoint import BreakpointManager + + BreakpointManager().on_event(f"{self.__class__.__name__}") + + def print(self): + lines = traceback.format_tb(self.__traceback__) + print("".join(lines)) class InnerError(FallbackErrorBase): diff --git a/tests/test_segment_linear.py b/tests/test_segment_linear.py new file mode 100644 index 000000000..79034c434 --- /dev/null +++ b/tests/test_segment_linear.py @@ -0,0 +1,55 @@ +import unittest + +from test_case_base import TestCaseBase + +import paddle +import sot +from paddle import nn + + +class Head(nn.Layer): + def __init__(self): + super().__init__() + self.head = nn.Linear(10, 150) + + def forward(self, x, patch_embed_size): + masks = self.head(x) + # [b, (h w), c] -> [b, c, h, w] + h, w = patch_embed_size[0], patch_embed_size[1] + masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) + masks = masks.transpose((0, 3, 1, 2)) + return masks + + +class SimpleNet(nn.Layer): + def __init__(self): + super().__init__() + self.tmp = nn.Linear(1, 1024 * 10) + self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) + self.head = Head() + + def getshape(self, x): + x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) + x = paddle.shape(x) + return x + + def forward(self, x): + shape = self.getshape(x) + feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) + logits = self.head(feat, shape[2:]) + return logits + + +class TestExecutor(TestCaseBase): + def test_simple(self): + sot.skip_function(SimpleNet.forward) + x = paddle.randn((1, 8, 8)) + net = SimpleNet() + net = paddle.jit.to_static(net) + loss = net(x) + loss = loss.sum() + loss.backward() + + +if __name__ == "__main__": + unittest.main()