Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'upstream/develop' into fix-enumerate
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Jul 6, 2023
2 parents 55d1e68 + 3fcab2c commit 2cf46c2
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 45 deletions.
1 change: 0 additions & 1 deletion .github/workflows/codestyle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@ jobs:
- name: Install dependencies
run: |
pip install pre-commit
- name: Precommit Check
run : pre-commit run --all-files
3 changes: 0 additions & 3 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,14 @@ jobs:
run: |
python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
pip install -e ".[dev]"
- name: Run unit tests
working-directory: ./tests/
run: |
bash run_all.sh
- name: Run examples
working-directory: ./examples/
run: |
bash run_all.sh
- name: Run xdoctest
working-directory: ./
run: |
Expand Down
2 changes: 2 additions & 0 deletions sot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .opcode_translator.breakpoint import BM, add_breakpoint, add_event
from .opcode_translator.skip_files import skip_function
from .translate import symbolic_translate
from .utils import psdb_print

__all__ = [
"symbolic_translate",
"add_breakpoint",
"add_event",
"BM",
"skip_function",
"psdb_print",
]
24 changes: 14 additions & 10 deletions sot/opcode_translator/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
class Breakpoint:
file: str
line: int
co_name: str
offset: int

def __hash__(self):
return hash((self.file, self.line))
return hash((self.file, self.line, self.co_name, self.offset))


@Singleton
Expand All @@ -37,9 +39,9 @@ def add_event(self, event):
"""
self.record_event.append(event)

def add(self, file, line):
def add(self, file, line, coname=None, offset=None):
log(1, f"add breakpoint at {file}:{line}\n")
self.breakpoints.add(Breakpoint(file, line))
self.breakpoints.add(Breakpoint(file, line, coname, offset))

def addn(self, *lines):
"""
Expand All @@ -54,9 +56,10 @@ def addn(self, *lines):
def clear(self):
self.breakpoints.clear()

def hit(self, file, line):
_breakpoint = Breakpoint(file, line)
if _breakpoint in self.breakpoints:
def hit(self, file, line, co_name, offset):
if Breakpoint(file, line, None, None) in self.breakpoints:
return True
if Breakpoint(file, line, co_name, offset) in self.breakpoints:
return True
return False

Expand Down Expand Up @@ -124,11 +127,12 @@ def dis(self, range=5):
"""
print("displaying debug info...")
cur_exe = self.cur_exe
print(f"{cur_exe._code}")
print(self._dis_source_code())

print(f"\n{cur_exe._code}")
lasti = cur_exe._lasti
lines = instrs_info(cur_exe._instructions, lasti - 1, range)
print("\n".join(lines))
print(self._dis_source_code())

@property
def cur_exe(self):
Expand All @@ -150,8 +154,8 @@ def pe(self, e):
print("".join(lines))


def add_breakpoint(file, line):
BM.add(file, line)
def add_breakpoint(file, line, co_name=None, offset=None):
BM.add(file, line, co_name, offset)


def add_event(event):
Expand Down
32 changes: 32 additions & 0 deletions sot/opcode_translator/executor/dispatch_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# This file stores the customed function that will be called by the dispatch mechanism.

from ...utils import BreakGraphError, NotImplementException


def raise_break_graph_fn(*args, **kwarg):
raise BreakGraphError("raise by raise_break_graph_fn.")


def raise_not_implement_fn(*args, **kwarg):
raise NotImplementException("raise by raise_break_graph_fn.")


# just a function for operator.in
def operator_in(left, right):
return left in right


def operator_not_in(left, right):
return left not in right


def operator_exception_match(left, right):
pass


def operator_BAD(left, right):
pass


def tensor_numel(x):
pass
19 changes: 19 additions & 0 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class FunctionGraph:
"stmt_ir",
"global_guards",
"side_effects_state",
"print_variables",
],
)

Expand All @@ -79,8 +80,12 @@ def __init__(self, frame, **kwargs):
self.side_effects = SideEffects()
self.py_frame = frame
self._global_guarded_variables: list[VariableBase] = []
self._print_variables = []
self.build_strategy = kwargs.get('build_strategy', None)

def add_print_variables(self, variable):
self._print_variables.append(variable)

def need_add_input(self, var):
if var.id in self.inner_out:
return False
Expand All @@ -102,6 +107,7 @@ def save_memo(self):
stmt_ir=saved_stmt_ir,
global_guards=list(self._global_guarded_variables),
side_effects_state=self.side_effects.get_state(),
print_variables=list(self._print_variables),
)

def restore_memo(self, memo):
Expand All @@ -110,6 +116,7 @@ def restore_memo(self, memo):
self.sir_ctx.replace_TOS(memo.stmt_ir)
self._global_guarded_variables = memo.global_guards
self.side_effects.restore_state(memo.side_effects_state)
self._print_variables = memo.print_variables

def collect_input_variables(self, inputs: list[VariableBase]):
for inp in inputs:
Expand Down Expand Up @@ -176,6 +183,7 @@ def start_compile(self, *ret_vars: VariableBase):

# deal side effect
self.restore_side_effects(self.side_effects.variables)
self.restore_print_stmts(self._print_variables)

tracker_output_path = show_trackers()
if tracker_output_path:
Expand Down Expand Up @@ -305,8 +313,19 @@ def _find_tensor_outputs(
var, TensorVariable
):
output_tensors.append(var)
# Find Tensor in print_stmts
for print_stmt in self._print_variables:
for var in print_stmt.flatten_items():
if isinstance(var.tracker, DummyTracker) and isinstance(
var, TensorVariable
):
output_tensors.append(var)
return output_tensors

def restore_print_stmts(self, variables: list[VariableBase]):
for var in variables:
var._reconstruct(self.pycode_gen)

def restore_side_effects(self, variables: list[VariableBase]):
if not variables:
return
Expand Down
18 changes: 11 additions & 7 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
log_do,
)
from ..instruction_utils import Instruction, analysis_inputs, get_instructions
from .dispatch_functions import (
operator_BAD,
operator_exception_match,
operator_in,
operator_not_in,
)
from .function_graph import FunctionGraph
from .guard import Guard
from .instr_flag import FORMAT_VALUE_FLAG as FV
Expand All @@ -35,12 +41,6 @@
GlobalTracker,
LocalTracker,
)
from .variable_dispatch import (
operator_BAD,
operator_exception_match,
operator_in,
operator_not_in,
)
from .variables import (
BuiltinVariable,
CellVariable,
Expand Down Expand Up @@ -576,9 +576,13 @@ def step(self, instr: Instruction):
log(3, log_message)
code_file = self._code.co_filename
code_line = self._current_line
code_name = self._code.co_name
code_offset = instr.offset
from ..breakpoint import BreakpointManager

if BreakpointManager().hit(code_file, code_line):
if BreakpointManager().hit(
code_file, code_line, code_name, code_offset
):
BreakpointManager().locate(self)
print(log_message)
breakpoint() # breakpoint for debug
Expand Down
44 changes: 25 additions & 19 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
UNARY_OPS,
magic_method_builtin_dispatch,
)
from .dispatch_functions import (
operator_in,
operator_not_in,
raise_break_graph_fn,
tensor_numel,
)
from .dispatcher import Dispatcher
from .tracker import DummyTracker
from .variables import (
Expand All @@ -24,24 +30,6 @@
if TYPE_CHECKING:
from .variables import DataVariable, NumpyVariable, TensorVariable


# just a function for operator.in
def operator_in(left, right):
return left in right


def operator_not_in(left, right):
return left not in right


def operator_exception_match(left, right):
pass


def operator_BAD(left, right):
pass


# dict
Dispatcher.register(
operator_in,
Expand Down Expand Up @@ -567,11 +555,29 @@ def is_not_func(var: VariableBase, other: VariableBase):
),
)
# Tensor
fallback_tensor_unary_method = {
int,
bool,
operator.truth,
}

Dispatcher.register(tensor_numel, ("TensorVariable",), {}, lambda x: x.numel())

for unary_fn in UNARY_OPS:
# Tensor doesn't support unary +, skip it
# TODO(SigureMo): deal len and bool
if unary_fn in {operator.pos, len, bool, operator.truth}:
if unary_fn in {len}:
continue

if unary_fn in fallback_tensor_unary_method:
Dispatcher.register(
unary_fn,
("TensorVariable",),
{},
raise_break_graph_fn,
)
continue

for magic_method in magic_method_builtin_dispatch(unary_fn):
Dispatcher.register(
unary_fn,
Expand Down
34 changes: 33 additions & 1 deletion sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
paddle_tensor_methods,
)
from ....utils.exceptions import InnerError
from ..dispatch_functions import tensor_numel
from ..guard import (
StringifyExpression,
object_equal_stringify_guard,
Expand Down Expand Up @@ -149,6 +150,29 @@ def wrap_literal(value: Any, graph: FunctionGraph) -> ConstantVariable:
return ConstantVariable(value, graph, ConstTracker(value))


class PrintStmtVariable(VariableBase):
def __init__(self, value: Any, graph: FunctionGraph):
super().__init__(DanglingTracker())
self.args, self.kwargs = value
self.graph = graph

def _reconstruct(self, codegen: PyCodeGen):
# do we need ? may be too strict.
for var in self.args:
self.graph.add_global_guarded_variable(var)
for var in self.kwargs.values():
self.graph.add_global_guarded_variable(var)
# currently dont' consider kwargs
codegen.gen_load_global("print")
for var in self.args:
var.reconstruct(codegen)
codegen.gen_call_function(len(self.args))
codegen.gen_pop_top()

def flatten_items(self):
return self.args


IMPLEMENTED_TENSOR_PROPERTIES = set()


Expand Down Expand Up @@ -206,7 +230,10 @@ def __init__(
super().__init__(tracker)
if isinstance(tensor, paddle.Tensor):
self.value = tensor
self.meta = MetaInfo.from_tensor(tensor)
try:
self.meta = MetaInfo.from_tensor(tensor)
except:
breakpoint()
elif isinstance(tensor, MetaInfo):
self.value = None
self.meta = tensor
Expand Down Expand Up @@ -329,6 +356,7 @@ def size(self):
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
elements = reduce(operator.mul, self.meta.shape, 1)
self.graph.add_global_guarded_variable(self)
return ConstantVariable.wrap_literal(elements, self.graph)

@tensor_property
Expand All @@ -342,6 +370,9 @@ def shape(self):
self.meta.shape, self.graph, tracker=ConstTracker(self.meta.shape)
)

def numel(self):
return self.size

def is_tensor(self):
return ConstantVariable.wrap_literal(True, self.graph)

Expand All @@ -363,6 +394,7 @@ def is_floating_point(self):
def getattr(self, name: str):
method_name_to_builtin_fn = {
"dim": paddle.rank,
"numel": tensor_numel,
"ndimension": paddle.rank,
"is_tensor": paddle.is_tensor,
"is_complex": paddle.is_complex,
Expand Down
Loading

0 comments on commit 2cf46c2

Please sign in to comment.