Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Fix guard errs (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Jul 4, 2023
1 parent 6f03436 commit 605fa46
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 24 deletions.
32 changes: 21 additions & 11 deletions sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def infer_meta(self, func, *args, **kwargs):
with paddle.fluid.framework._dygraph_guard(None), UniqueNameGuard(
self.var_name_generator
):
args, kwargs = convert_to_variable(args), convert_to_variable(
kwargs
)
args, kwargs = convert_meta_to_variable(
args
), convert_meta_to_variable(kwargs)

with paddle.static.program_guard(
self.main_program, self.startup_program
Expand All @@ -125,10 +125,10 @@ def infer_meta(self, func, *args, **kwargs):
else:
out = func(*args, **kwargs)

return variable_to_meta_info(out)
return convert_variable_to_meta_info(out)


def convert_to_variable(args):
def convert_meta_to_variable(args):
return map_if(
args,
pred=lambda x: isinstance(x, MetaInfo),
Expand All @@ -137,7 +137,7 @@ def convert_to_variable(args):
)


def convert_to_input_spec(args):
def convert_meta_to_input_spec(args):
return map_if(
args,
pred=lambda x: isinstance(x, MetaInfo),
Expand All @@ -146,7 +146,7 @@ def convert_to_input_spec(args):
)


def variable_to_meta_info(args):
def convert_variable_to_meta_info(args):
return map_if(
args,
pred=lambda x: isinstance(x, paddle.static.Variable),
Expand All @@ -168,10 +168,20 @@ def infer_meta_for_layer(layer, *args, **kwargs):
), f"Expect a Layer, but got {layer}."
layer = paddle.jit.to_static(layer, enable_fallback=False)

args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs)
concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0]
out = concrete_program.outputs[0]
out = MetaInfo.from_tensor(out)
args_, kwargs_ = convert_meta_to_input_spec(
args
), convert_meta_to_input_spec(kwargs)

(
concrete_program,
partial_program_layer,
) = layer.forward.get_concrete_program(*args_, **kwargs_)

out = partial_program_layer._restore_out(
paddle.utils.flatten(
convert_variable_to_meta_info(concrete_program.outputs)
)
)
layer.forward.rollback()
return out

Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs):
) # symbolic only contain symbols.
self._put_inner(outputs)
return VariableFactory.from_value(
outputs, self, DummyTracker(outputs)
outputs, self, DummyTracker(list(args) + list(kwargs.values()))
)
else:
return None
Expand Down
6 changes: 5 additions & 1 deletion sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ def union_free_vars(*free_vars: dict[str, Any]):
def make_guard(stringify_guards: list[StringifyExpression]) -> Guard:
num_guards = len(stringify_guards)
if not num_guards:
return lambda frame: True
guard = lambda frame: True
guard.expr = "lambda frame: True"
return guard

union_guard_expr = reduce(lambda x, y: x & y, stringify_guards)
guard_string = f"lambda frame: {union_guard_expr.expr}"
guard = eval(
guard_string,
union_guard_expr.free_vars,
)
log(3, f"[Guard]: {guard_string}\n")
guard.expr = guard_string
assert callable(guard), "guard must be callable."

return guard
Expand Down
15 changes: 12 additions & 3 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def impl(
for code, guard_fn in guarded_fns:
try:
if guard_fn(frame):
log(3, "[Cache]: Cache hit\n")
log(
3,
f"[Cache]: Cache hit, Guard is {guard_fn.expr if hasattr(guard_fn, 'expr') else 'None'}\n",
)
return CustomCode(code, False)
except Exception as e:
log(3, f"[Cache]: Guard function error: {e}\n")
Expand Down Expand Up @@ -430,7 +433,8 @@ def print_instrs(self):
Prints the instructions in the executor.
"""
print(instrs_info(self._instructions))
print(self._code.co_name)
print(instrs_info(self._instructions, mark=self._lasti))

def print_sir(self):
"""
Expand Down Expand Up @@ -806,6 +810,9 @@ def STORE_FAST(self, instr: Instruction):
"""
var = self.pop()
var.debug_name = instr.argval
if instr.argval == "__breakpoint__":
print(var.value)
breakpoint()
self._locals[instr.argval] = var

def STORE_GLOBAL(self, instr: Instruction):
Expand Down Expand Up @@ -1752,6 +1759,8 @@ def FOR_ITER(self, instr):
"Found RETURN_VALUE in for loop body."
)

self._graph.add_global_guarded_variable(iterator)

# TODO need support TensorIterVariable.next
try:
if not isinstance(
Expand All @@ -1761,7 +1770,7 @@ def FOR_ITER(self, instr):
backup_iter_idx = iterator.idx
self._inline_call_for_loop(iterator, instr)
self._lasti = self.indexof(instr.jump_to)
except BreakGraphError:
except BreakGraphError as e:
if backup_iter_idx:
iterator.idx = backup_iter_idx
self._break_graph_in_for_loop(iterator, instr)
Expand Down
2 changes: 2 additions & 0 deletions sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def FOR_ITER(self, instr):
iterator = self.peek()
assert isinstance(iterator, IterVariable)

self._graph.add_global_guarded_variable(iterator)

# simplely get next
if isinstance(iterator, (SequenceIterVariable, DictIterVariable)):
try:
Expand Down
12 changes: 6 additions & 6 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def operator_BAD(left, right):
{},
lambda var, other: VariableFactory.from_value(
var.get_symbol() == other.get_symbol(),
None,
var.graph,
tracker=DummyTracker([var, other]),
),
)
Expand All @@ -338,7 +338,7 @@ def operator_BAD(left, right):
{},
lambda var, other: VariableFactory.from_value(
False,
None,
var.graph,
tracker=DummyTracker([var, other]),
),
)
Expand All @@ -349,7 +349,7 @@ def operator_BAD(left, right):
{},
lambda var, other: VariableFactory.from_value(
False,
None,
var.graph,
tracker=DummyTracker([var, other]),
),
)
Expand All @@ -361,7 +361,7 @@ def operator_BAD(left, right):
{},
lambda var, other: VariableFactory.from_value(
var.get_value() is other.get_value(),
None,
var.graph,
tracker=DummyTracker([var, other]),
),
)
Expand Down Expand Up @@ -403,7 +403,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
{},
partial(
lambda fn, var: VariableFactory.from_value(
fn(var.get_value()), None, tracker=DummyTracker([var])
fn(var.get_value()), var.graph, tracker=DummyTracker([var])
),
unary_fn,
),
Expand All @@ -417,7 +417,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
partial(
lambda fn, var, other: VariableFactory.from_value(
fn(var.get_value(), other.get_value()),
None,
var.graph,
tracker=DummyTracker([var, other]),
),
binary_fn,
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def reconstruct(self, codegen: PyCodeGen):
):
self.tracker.gen_instructions(codegen)
else:
self.graph.add_global_guarded_variable(self)
self._reconstruct(codegen)

def _reconstruct(self, codegen: PyCodeGen):
Expand Down
8 changes: 8 additions & 0 deletions sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@


class IterVariable(VariableBase):
"""
This Variable (include subclasses) should be generated only when simulate GET_ITER opcode
"""

def __init__(self, obj, graph, tracker):
super().__init__(tracker)
assert isinstance(obj, VariableBase)
self.hold = obj
self.graph = graph

def make_stringify_guard(self):
return self.hold.make_stringify_guard()


class SequenceIterVariable(IterVariable):
def __init__(self, obj, graph, tracker):
Expand Down
5 changes: 3 additions & 2 deletions sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ def replace_instr(instructions, instr, new_instr):
instructions[idx, idx + 1] = new_instr


def instrs_info(instrs):
def instrs_info(instrs, mark=None):
ret = []
for idx, instr in enumerate(instrs):
if instr.starts_line is not None:
ret.append("")
ret.append(
"{line:<8s}{is_jump_target:>2s}{offset:>4d} {opname:<30s}{arg:<4s}{argval}".format(
"{line:<8s}{is_jump_target:>2s}{offset:>4d} {opname:<30s}{arg:<4s}{argval:<40s}{mark}".format(
line=str(instr.starts_line) if instr.starts_line else "",
is_jump_target=">>" if instr.is_jump_target else " ",
offset=instr.offset
Expand All @@ -272,6 +272,7 @@ def instrs_info(instrs):
opname=instr.opname,
arg=str(instr.arg) if instr.arg is not None else "",
argval=f"({instr.argval})" if instr.argval else "",
mark=" <--- HERE" if mark == idx else "",
)
)
return "\n".join(ret)

0 comments on commit 605fa46

Please sign in to comment.