Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Feature] add breakpoint machanism on SOT #240

Merged
merged 16 commits into from
Jul 5, 2023
6 changes: 6 additions & 0 deletions sot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .opcode_translator.breakpoint import BM, add_breakpoint, add_event
from .opcode_translator.skip_files import skip_function
from .translate import symbolic_translate

__all__ = [
"symbolic_translate",
"add_breakpoint",
"add_event",
"BM",
"skip_function",
]
161 changes: 161 additions & 0 deletions sot/opcode_translator/breakpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import inspect
import traceback
from dataclasses import dataclass

from ..opcode_translator.instruction_utils import instrs_info
from ..utils import Singleton, log
from .executor.opcode_executor import OpcodeExecutorBase

# this file is a debug utils files for quick debug
# >>> sot.add_breakpoint(file, line)
# >>> sot.remove_breakpoint(file, line)


@dataclass
class Breakpoint:
file: str
line: int

def __hash__(self):
return hash((self.file, self.line))


@Singleton
class BreakpointManager:
def __init__(self):
self.breakpoints = set()
self.executors = OpcodeExecutorBase.call_stack
self.active = 0
self.record_event = []

def clear_event(self, event):
self.record_event.clear()

def add_event(self, event):
"""
event in ['All' ,'NotImplementException', 'BreakGraphError', 'InnerError']
"""
self.record_event.append(event)

def add(self, file, line):
log(1, f"add breakpoint at {file}:{line}\n")
self.breakpoints.add(Breakpoint(file, line))

def addn(self, *lines):
"""
called inside a executor. add a list of line number in current file.
"""
if not isinstance(lines, (list, tuple)):
lines = [lines]
for line in lines:
file = self.cur_exe._code.co_filename
self.add(file, line)

def clear(self):
self.breakpoints.clear()

def hit(self, file, line):
_breakpoint = Breakpoint(file, line)
if _breakpoint in self.breakpoints:
return True
return False

def locate(self, exe):
for i, _e in enumerate(self.executors):
if _e is exe:
self.activate = i
return
raise RuntimeError("Not found executor.")

def up(self):
if self.activate == 0:
return
self.activate -= 1
print("current function is: ", self.cur_exe._code.co_name)

def down(self):
if self.activate >= len(self.executors) - 1:
return
self.activate += 1
print("current function is: ", self.cur_exe._code.co_name)

def opcode(self, cur_exe=None):
if cur_exe is None:
cur_exe = self.cur_exe
instr = cur_exe._instructions[cur_exe._lasti - 1]
message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n"
return message

def bt(self):
"""
display all inline calls: backtrace.
"""
for exe in self.executors:
lines, _ = inspect.getsourcelines(exe._code)
print(
" "
+ exe._code.co_filename
+ f"({exe._current_line})"
+ f"{exe._code.co_name}()"
)
print(f"-> {lines[0].strip()}")
print(f"-> {self.opcode(exe)}")
pass

def on_event(self, event):
if "All" in self.record_event or event in self.record_event:
print("event captured.")
self.activate = len(self.executors) - 1
breakpoint()

def _dis_source_code(self):
cur_exe = self.executors[self.activate]
lines, start_line = inspect.getsourcelines(cur_exe._code)
cur_line = cur_exe._current_line
lines[
cur_line - start_line + 1 : cur_line - start_line + 1
] = " ^^^^^ HERE \n"
print("\033[31mSource Code is: \033[0m")
print("".join(lines))

def dis(self, range=5):
"""
display all instruction code and source code.
"""
print("displaying debug info...")
cur_exe = self.cur_exe
print(f"{cur_exe._code}")
lasti = cur_exe._lasti
lines = instrs_info(cur_exe._instructions, lasti - 1, range)
print("\n".join(lines))
print(self._dis_source_code())

@property
def cur_exe(self):
exe = self.executors[self.activate]
return exe

def sir(self):
"""
display sir in a page.
"""
print("displaying sir...")
self.cur_exe.print_sir()

def pe(self, e):
"""
print exception.
"""
lines = traceback.format_tb(e.__traceback__)
print("".join(lines))


def add_breakpoint(file, line):
BM.add(file, line)


def add_event(event):
BM.add_event(event)


BM = BreakpointManager()
35 changes: 13 additions & 22 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
log,
log_do,
)
from ..instruction_utils import (
Instruction,
analysis_inputs,
get_instructions,
instrs_info,
)
from ..instruction_utils import Instruction, analysis_inputs, get_instructions
from .function_graph import FunctionGraph
from .guard import Guard
from .instr_flag import FORMAT_VALUE_FLAG as FV
Expand Down Expand Up @@ -428,14 +423,6 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph):
self._name = "Executor"
self._prepare_virtual_env()

def print_instrs(self):
"""
Prints the instructions in the executor.

"""
print(self._code.co_name)
print(instrs_info(self._instructions, mark=self._lasti))

def print_sir(self):
"""
Prints the Static Instruction Representation (SIR) in the executor.
Expand Down Expand Up @@ -497,6 +484,8 @@ def get_var(self, name: str):
return self._globals[name]
elif name in self._builtins.keys():
return self._builtins[name]
elif name in self._cells.keys(): # in closure
return self._cells[name].get_value()
else:
raise InnerError(f'Can not get var: {name}')

Expand Down Expand Up @@ -580,10 +569,16 @@ def step(self, instr: Instruction):
raise NotImplementException(
f"opcode: {instr.opname} is not supported."
)
log(
3,
f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n",
)
log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n"
log(3, log_message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里无条件打印和下面的打印是不是重复了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log3下会重复的,但是这个BM不一定在log3下使用。所以就默认都打印了一行。

code_file = self._code.co_filename
code_line = self._current_line
from ..breakpoint import BreakpointManager

if BreakpointManager().hit(code_file, code_line):
BreakpointManager().locate(self)
print(log_message)
breakpoint() # breakpoint for debug
return getattr(self, instr.opname)(instr) # run single step.

def indexof(self, instr: Instruction):
Expand Down Expand Up @@ -810,9 +805,6 @@ def STORE_FAST(self, instr: Instruction):
"""
var = self.pop()
var.debug_name = instr.argval
if instr.argval == "__breakpoint__":
print(var.value)
breakpoint()
self._locals[instr.argval] = var

def STORE_GLOBAL(self, instr: Instruction):
Expand Down Expand Up @@ -1760,7 +1752,6 @@ def FOR_ITER(self, instr):
)

self._graph.add_global_guarded_variable(iterator)

# TODO need support TensorIterVariable.next
try:
if not isinstance(
Expand Down
15 changes: 12 additions & 3 deletions sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,16 @@ def replace_instr(instructions, instr, new_instr):
instructions[idx, idx + 1] = new_instr


def instrs_info(instrs, mark=None):
def instrs_info(instrs, mark=None, range=None):
ret = []
start = -1
end = 1000000
if mark is not None and range is not None:
start = mark - range
end = mark + range + 1
for idx, instr in enumerate(instrs):
if idx < start or idx >= end:
continue
if instr.starts_line is not None:
ret.append("")
ret.append(
Expand All @@ -272,7 +279,9 @@ def instrs_info(instrs, mark=None):
opname=instr.opname,
arg=str(instr.arg) if instr.arg is not None else "",
argval=f"({instr.argval})" if instr.argval else "",
mark=" <--- HERE" if mark == idx else "",
mark="",
)
)
return "\n".join(ret)
if idx == mark:
ret[-1] = "\033[31m" + ret[-1] + "\033[0m"
return ret
15 changes: 15 additions & 0 deletions sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import google.protobuf
import numpy

from ..utils import log


def _strip_init_py(s):
return re.sub(r"__init__.py$", "", s)
Expand Down Expand Up @@ -105,6 +107,8 @@ def _module_dir(m: types.ModuleType):
f"^({'|'.join(map(re.escape, skip_file_names))})"
)

customed_skip_code = set()


def need_skip_path(filepath: str) -> bool:
"""
Expand All @@ -119,3 +123,14 @@ def need_skip_path(filepath: str) -> bool:
if not filepath.startswith("<"):
filepath = os.path.abspath(filepath)
return bool(skip_file_name_re.match(filepath))


def skip_function(function):
customed_skip_code.add(function.__code__)


def need_skip(pycode):
if pycode in customed_skip_code:
log(3, f"Skip frame by code: {pycode}")
return True
return need_skip_path(pycode.co_filename)
4 changes: 2 additions & 2 deletions sot/opcode_translator/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from ..utils import log, log_do
from .executor.opcode_executor import InstructionTranslatorCache
from .skip_files import need_skip_path
from .skip_files import need_skip


def eval_frame_callback(frame, **kwargs):
# is generator
if frame.f_code.co_flags & 0x20 > 0:
return None

if not need_skip_path(frame.f_code.co_filename):
if not need_skip(frame.f_code):
log(
2,
"[eval_frame_callback] start to translate: "
Expand Down
12 changes: 7 additions & 5 deletions sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ def clear_eager_tensor_name(output_tensors):


class FallbackWrapper:
def __init__(self, compile_sir):
self.compile_sir = compile_sir
def __init__(self, compiled_fn, SIR):
self.compiled_fn = compiled_fn
self.partial_program = None
self.concrete_program = None
self.SIR = SIR # for debug

def __call__(self, *args, **kwargs):
"""TODO: we disable partial_program cache here because some bugs in ast to_static.
Expand All @@ -29,11 +30,11 @@ def __call__(self, *args, **kwargs):
# TODO(xiongkun): or True is on purpose, we should remove it later after
# dy2static bug is fixed.
if self.partial_program is None or True:
outputs = self.compile_sir(*args, **kwargs)
outputs = self.compiled_fn(*args, **kwargs)
(
self.concrete_program,
self.partial_program,
) = self.compile_sir.get_concrete_program(*args, **kwargs)
) = self.compiled_fn.get_concrete_program(*args, **kwargs)
else:
# Speed up Resnet from 0.0068 --> 0.0057
outputs = self.partial_program(*args, **kwargs)
Expand Down Expand Up @@ -64,5 +65,6 @@ def value_fn(self, context, sir_name, build_strategy):
compile_sir(context, sir_name),
build_strategy=build_strategy,
enable_fallback=False,
)
),
context.get_sir(sir_name),
)
13 changes: 12 additions & 1 deletion sot/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import traceback


class FallbackErrorBase(Exception):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from ..opcode_translator.breakpoint import BreakpointManager

BreakpointManager().on_event(f"{self.__class__.__name__}")

def print(self):
lines = traceback.format_tb(self.__traceback__)
print("".join(lines))


class InnerError(FallbackErrorBase):
Expand Down
Loading