From 2028195889f182dacf49f7ff3a0d5eaa439d59bb Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com> Date: Mon, 5 Jun 2023 20:48:39 +0800 Subject: [PATCH] add priority mechanism in `VariableFactory.register_from_value` and split variables (#110) --- .gitignore | 2 + docs/design/tracker-and-guard.md | 2 +- .../executor/opcode_executor.py | 2 +- .../opcode_translator/executor/variables.py | 1183 ----------------- .../executor/variables/__init__.py | 44 + .../executor/variables/base.py | 241 ++++ .../executor/variables/basic.py | 270 ++++ .../executor/variables/callable.py | 373 ++++++ .../executor/variables/container.py | 338 +++++ .../executor/variables/iter.py | 65 + 10 files changed, 1335 insertions(+), 1185 deletions(-) delete mode 100644 symbolic_trace/opcode_translator/executor/variables.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/__init__.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/base.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/basic.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/callable.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/container.py create mode 100644 symbolic_trace/opcode_translator/executor/variables/iter.py diff --git a/.gitignore b/.gitignore index 1d0744ae1..3919d0cb2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ __pycache__ *.sw* user_tag +build/ +symbolic_trace.egg-info/ # Editor config .vscode diff --git a/docs/design/tracker-and-guard.md b/docs/design/tracker-and-guard.md index 4be1d9267..029cb93d0 100644 --- a/docs/design/tracker-and-guard.md +++ b/docs/design/tracker-and-guard.md @@ -54,7 +54,7 @@ def foo(a: list[Tensor], b: int, c: int): - 在生成函数的字节码前,需要将输入 LOAD 到栈上,我们需要根据 Tracker 来生成 LOAD 这些输入的字节码 - 在调用 Guard 时,需要根据 Tracker 来索引到新的 Frame 里的相同变量的值,这样才能进行 Guard 的判断(`new_value == old_value`) -我们可以将这种索引机制成为 Source,而大多数中间结点是经过计算得到的,我们并不需要去还原这些中间结点,比如 `c = a + b`,`c` 是由 `BINARY_ADD` 构建得到的,我们的 Source 只需要分别索引 `a` 和 `b` 的来源,而我们的 Guard 也只需要分别 Guard 住 `a` 和 `b` 即可。 +我们可以将这种索引机制称为 Source,而大多数中间结点是经过计算得到的,我们并不需要去还原这些中间结点,比如 `c = a + b`,`c` 是由 `BINARY_ADD` 构建得到的,我们的 Source 只需要分别索引 `a` 和 `b` 的来源,而我们的 Guard 也只需要分别 Guard 住 `a` 和 `b` 即可。 因此对于这种中间结点,我们只需要知道它是由什么构建得到即可,即只需要知道 inputs 是什么,对于这些结点,我们使用 DummyTracker 来作为连接结点,DummyTracker 不会承担 Source 的索引功能,只会承担 DAG 的连接功能,以便 Guard 的收集。 diff --git a/symbolic_trace/opcode_translator/executor/opcode_executor.py b/symbolic_trace/opcode_translator/executor/opcode_executor.py index 3b42fb21e..ab301b12d 100644 --- a/symbolic_trace/opcode_translator/executor/opcode_executor.py +++ b/symbolic_trace/opcode_translator/executor/opcode_executor.py @@ -29,6 +29,7 @@ from .pycode_generator import PyCodeGen from .tracker import ( BuiltinTracker, + ConstTracker, DummyTracker, GetItemTracker, GetIterTracker, @@ -38,7 +39,6 @@ from .variables import ( CallableVariable, ConstantVariable, - ConstTracker, ContainerVariable, DictIterVariable, DictVariable, diff --git a/symbolic_trace/opcode_translator/executor/variables.py b/symbolic_trace/opcode_translator/executor/variables.py deleted file mode 100644 index 1cb939e71..000000000 --- a/symbolic_trace/opcode_translator/executor/variables.py +++ /dev/null @@ -1,1183 +0,0 @@ -from __future__ import annotations - -import collections -import inspect -import types -from queue import Queue -from typing import TYPE_CHECKING, Any, Callable - -import paddle - -from ...infer_meta import MetaInfo -from ...symbolic.statement_ir import Symbol -from ...utils import ( - ASSERT, - NameGenerator, - is_break_graph_api, - is_paddle_api, - log_do, - paddle_tensor_methods, -) -from ...utils.exceptions import BreakGraphError, FallbackErrorBase, InnerError -from .guard import StringifyExpression, union_free_vars -from .pycode_generator import PyCodeGen -from .tracker import ( - ConstTracker, - DummyTracker, - GetAttrTracker, - GetItemTracker, - Tracker, -) - -if TYPE_CHECKING: - from .function_graph import FunctionGraph - - -ConstTypes = (int, float, str, bool, type(None)) - - -def get_zero_degree_vars( - variables: set[VariableBase], visited_vars: list[VariableBase] -) -> list[VariableBase]: - return [ - var - for var in variables - if var not in visited_vars - and len(set(var.get_traceable_inputs()) - set(visited_vars)) == 0 - ] - - -def topo_sort_vars( - root_vars: list[VariableBase], -) -> list[VariableBase]: - unique_vars = set() - - for var in root_vars: - unique_vars.add(var) - unique_vars |= set(var.flatten_traceable_inputs()) - - topo_ordered_vars = [] - topo_queue = Queue() - for var in get_zero_degree_vars(unique_vars, topo_ordered_vars): - topo_queue.put(var) - - while not topo_queue.empty(): - var = topo_queue.get() - topo_ordered_vars.append(var) - for zero_degree_var in get_zero_degree_vars( - unique_vars, topo_ordered_vars - ): - if ( - zero_degree_var in topo_queue.queue - or zero_degree_var in topo_ordered_vars - ): - continue - topo_queue.put(zero_degree_var) - return topo_ordered_vars - - -def map_variables(map_func, variables): - def _map_variable(variable): - assert isinstance( - variable, VariableBase - ), f"variable must be VariableBase, got {variable}" - if isinstance(variable, ContainerVariable): - return paddle.utils.map_structure( - _map_variable, variable.get_wrapped_items() - ) - return map_func(variable) - - return paddle.utils.map_structure(_map_variable, variables) - - -class VariableFactory: - registered_funcs: list[Callable] = [] - - @staticmethod - def default_from_value(value, graph, tracker): - return ObjectVariable(value, graph, tracker) - - @staticmethod - def register_from_value(from_value_func: Callable): - VariableFactory.registered_funcs.append(from_value_func) - - @staticmethod - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - for func in VariableFactory.registered_funcs: - var = func(value, graph, tracker) - if var is not None: - return var - return VariableFactory.default_from_value(value, graph, tracker) - - -class VariableBase: - """ - VariableBase is a basic concept and each symbols in VM stack is regarded as - an Variable Object in symblic tracing process. - """ - - tracker: Tracker - name_generator = NameGenerator("object_") - - def __init__(self, tracker: Tracker): - self.tracker = tracker - self.id = VariableBase.name_generator.next() - - def __hash__(self): - return hash(self.id) - - def make_stringify_guard(self) -> StringifyExpression: - assert not isinstance( - self.tracker, DummyTracker - ), "Can not make guard from dummy tracker" - - frame_value_tracer = self.tracker.trace_value_from_frame() - log_do( - 4, - lambda: print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" - ), - ) - return StringifyExpression( - f"{frame_value_tracer.expr} == {self.get_value()}", - union_free_vars(frame_value_tracer.free_vars), - ) - - def get_value(self) -> Any: - raise NotImplementedError() - - def reconstruct(self, codegen: PyCodeGen): - """ - Contruct an opcode and append it into codegen.instructions. - """ - if ( - not isinstance(self.tracker, DummyTracker) - and self.tracker.is_traceable() - ): - self.tracker.gen_instructions(codegen) - else: - self._reconstruct(codegen) - - def _reconstruct(self, codegen: PyCodeGen): - raise NotImplementedError() - - def flatten_items(self) -> list[VariableBase]: - if not isinstance(self, ContainerVariable): - return [self] - flattened_items = [] - for item in self.get_items(): - flattened_items.extend(item.flatten_items()) - return flattened_items - - def get_inputs(self) -> list[VariableBase]: - return self.tracker.inputs - - def get_traceable_inputs(self) -> list[VariableBase]: - if self.tracker.is_traceable(): - return [] - - return list( - filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) - ) - - def flatten_traceable_inputs(self) -> list[VariableBase]: - flattened_traceable_inputs: list[VariableBase] = [self] - if self.tracker.is_traceable(): - return flattened_traceable_inputs - - for input in self.get_inputs(): - flattened_traceable_inputs.extend(input.flatten_traceable_inputs()) - return flattened_traceable_inputs - - def call_function(self, *args, **kwargs): - pass - - def __getattr__(self, name: str): - if not hasattr(self.value, name): - raise InnerError( - f"{self.__class__.__name__} {self} has no attribute {name}" - ) - attr = getattr(self.value, name) - if inspect.ismethod(attr): - return UserDefinedMethodVariable( - self, - attr.__func__, - graph=self.graph, - tracker=GetAttrTracker(self, name), - ) - return VariableFactory.from_value( - attr, self.graph, tracker=GetAttrTracker(self, name) - ) - - def getitem(self, *args, **kwargs): - pass - - @VariableFactory.register_from_value - def from_value( - value: Any, - graph: FunctionGraph | None, - tracker: Tracker, - ): - if isinstance(value, VariableBase): - return value - return None - - -class ConstantVariable(VariableBase): - def __init__( - self, - value: Any, - tracker: Tracker, - ): - super().__init__(tracker) - self.value = value - - def get_value(self): - return self.value - - def _reconstruct(self, codegen: PyCodeGen): - codegen.gen_load_const(self.value) - - def __repr__(self) -> str: - return f"ConstantVariable({self.value})" - - def __bool__(self) -> bool: - return bool(self.value) - - def apply_unary_operator(self, magic_name): - operator = getattr(self.value, magic_name) - var = VariableFactory.from_value( - operator(), - None, - tracker=DummyTracker( - [ - self, - ] - ), - ) - return var - - def apply_binary_operator(self, other, magic_name): - if not isinstance(other, ConstantVariable): - return NotImplemented - operator = getattr(self.value, magic_name) - var = VariableFactory.from_value( - operator(other.value), None, tracker=DummyTracker([self, other]) - ) - return var - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, ConstTypes): - return ConstantVariable(value, tracker) - return None - - @staticmethod - def wrap_literal(value: Any) -> ConstantVariable: - if isinstance(value, ConstantVariable): - return value - assert isinstance( - value, ConstTypes - ), f"value: {value},type: {type(value)}" - return ConstantVariable(value, ConstTracker(value)) - - -class TensorVariable(VariableBase): - var_name_generator = NameGenerator("var_") - - def __init__( - self, - tensor: paddle.Tensor | MetaInfo, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - if isinstance(tensor, paddle.Tensor): - self.value = tensor - self.meta = MetaInfo.from_tensor(tensor) - elif isinstance(tensor, MetaInfo): - self.value = None - self.meta = tensor - else: - raise InnerError( - "Required type(tensor) is paddle.Tensor or ProxyTensor, but received {}.".format( - type(tensor).__name__ - ) - ) - self.var_name = TensorVariable.var_name_generator.next() - self.graph = graph - - def get_value(self): - if self.value is None: - raise InnerError("Can not get value from a inner tensor variable.") - return self.value - - def get_symbol(self) -> Symbol: - return Symbol(self.var_name) - - @property - def out_var_name(self): - return f"{self.graph.out_var_prefix}{self.var_name}" - - def _reconstruct(self, codegen: PyCodeGen): - codegen.gen_load_fast(self.out_var_name) - - def make_stringify_guard(self) -> StringifyExpression: - assert not isinstance( - self.tracker, DummyTracker - ), "Can not make guard from dummy tracker" - - frame_value_tracer = self.tracker.trace_value_from_frame() - log_do( - 4, - lambda: print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" - ), - ) - return StringifyExpression( - f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.meta}'", - union_free_vars( - {"MetaInfo": MetaInfo}, - frame_value_tracer.free_vars, - ), - ) - - def __repr__(self) -> str: - return f"TensorVariable{self.meta}" - - def __getitem__(self, key): - return self.graph.call_tensor_method( - '__getitem__', - self, - VariableFactory.from_value( - key, self.graph, tracker=ConstTracker(key) - ), - ) - - @property - def T(self): - perm = list(range(len(self.meta.shape) - 1, -1, -1)) - perm_var = VariableFactory.from_value( - perm, self.graph, tracker=ConstTracker(perm) - ) - out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) - return out - - @property - def ndim(self): - return ConstantVariable.wrap_literal(len(self.meta.shape)) - - def __getattr__(self, name: str): - if name in paddle_tensor_methods: - return TensorMethodVariable( - self, name, self.graph, tracker=GetAttrTracker(self, name) - ) - elif name in ["shape", "dtype", "stop_gradient"]: - return VariableFactory.from_value( - getattr(self.meta, name), - self.graph, - tracker=GetAttrTracker(self, name), - ) - elif name in ["T", "ndim"]: - return getattr(self, name) - else: - raise InnerError(f"Unknown Tensor attribute: {name}") - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (paddle.Tensor, MetaInfo)): - assert graph is not None - return TensorVariable(value, graph, tracker) - return None - - -class ContainerVariable(VariableBase): - def get_items(self) -> list[VariableBase]: - raise NotImplementedError() - - def __len__(self): - raise NotImplementedError() - - def __bool__(self): - return len(self) > 0 - - -class ListVariable(ContainerVariable): - def __init__( - self, - val_list: list[VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - # everything in stack is VariableBase, so just accept the input list is ok - self.value = val_list - - def get_value(self): - return [self[i].get_value() for i in range(len(self))] - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for idx in range(size): - self[idx].reconstruct(codegen) - codegen.gen_build_list(size) - - def get_items(self): - size = len(self) - return [self[idx] for idx in range(size)] - - def get_wrapped_items(self): - return self.get_items() - - def __repr__(self) -> str: - return f"ListVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - ''' - we need to make sure that: - before an inplace change happens to ListVariable, - the related items should already be wrapped as VariableBase - - if not, tracker might be set to a wrong elem - ''' - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - retval = self.value[key] - - # if list is an input of funciton, we need make sure __getitem__ returns a VariableBase - retval = VariableFactory.from_value( - retval, self.graph, tracker=GetItemTracker(self, key) - ) - - return retval - - def __setitem__(self, key, value): - ''' - why __setitem__ is ok: - - case: - def f(x = [t0, t1]) - ... - x[0] = 0 - ... - - 1. if setitem happens after get t0: t0 is a VariableBase (transformed at getitem), so it is ok - 2. if setitem happens before get t0: t0 will not be used - ''' - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." - ) - - if not isinstance(value, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {value} to set value." - ) - self.value[key] = value - - def __delitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key to delete." - ) - del self.value[key] - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, list): - assert graph is not None - return ListVariable(value, graph=graph, tracker=tracker) - return None - - -class TupleVariable(ContainerVariable): - def __init__( - self, - val_tuple: list[VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - # exactly it is a list (need replace item with VariableBase) - self.value = list(val_tuple) - - def get_value(self): - return tuple(self[i].get_value() for i in range(len(self))) - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for idx in range(size): - self[idx].reconstruct(codegen) - codegen.gen_build_tuple(size) - - def get_items(self): - size = len(self) - return [self[idx] for idx in range(size)] - - def get_wrapped_items(self): - return self.get_items() - - def __repr__(self) -> str: - return f"TupleVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - retval = self.value[key] - - return VariableFactory.from_value( - retval, graph=self.graph, tracker=GetItemTracker(self, key) - ) - - def __setitem__(self, key, value): - raise InnerError( - f"[{self.__class__.__name__}]: setitem is not allowed." - ) - - def __delitem__(self, key): - raise InnerError( - f"[{self.__class__.__name__}]: delitem is not allowed." - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, tuple): - return TupleVariable(value, graph, tracker) - return None - - -class DictVariable(ContainerVariable): - def __init__( - self, - val_dict: dict[object, VariableBase], - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tracker) - self.graph = graph - self.value = val_dict - - def get_value(self): - return {key: self[key].get_value() for key in self.value} - - def _reconstruct(self, codegen: PyCodeGen): - size = len(self) - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - key_var = ConstantVariable.wrap_literal(key) - value_var = self[key] - key_var.reconstruct(codegen) - value_var.reconstruct(codegen) - codegen.gen_build_map(size) - - def get_items(self): - items = [] - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - key_var = VariableFactory.from_value( - key, self.graph, tracker=ConstTracker(key) - ) - value_var = self[key] - items.extend([key_var, value_var]) - return items - - def get_wrapped_items(self): - items = {} - for key in self.value.keys(): - if not isinstance(key, ConstTypes): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - items[key] = self[key] - return items - - def __repr__(self) -> str: - return f"DictVariable(len={len(self)})" - - def __len__(self): - return len(self.value) - - def __getitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - retval = self.value[key] - - return VariableFactory.from_value( - retval, self.graph, tracker=GetItemTracker(self, key) - ) - - def __setitem__(self, key, value): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." - ) - - if not isinstance(value, ConstantVariable): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {value} to set value." - ) - - self.value[key] = value - - def __delitem__(self, key): - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key to delete." - ) - del self.value[key] - - def override_method_keys(self): - raw_list = [ - ConstantVariable(x, ConstTracker(x)) for x in self.value.keys() - ] - key_list = VariableFactory.from_value( - raw_list, self.graph, ConstTracker(raw_list) - ) - return SequenceIterVariable( - key_list, self.graph, DummyTracker([key_list]) - ) - - def override_method_values(self): - raw_list = list(self.get_wrapped_items().values()) - value_list = VariableFactory.from_value( - raw_list, self.graph, DummyTracker([self]) - ) - return SequenceIterVariable( - value_list, self.graph, DummyTracker([value_list]) - ) - - def override_method_items(self): - keys = [ConstantVariable(x, ConstTracker(x)) for x in self.value.keys()] - values = list(self.get_wrapped_items().values()) - raw_list = list(zip(keys, values)) - item_list = VariableFactory.from_value( - raw_list, self.graph, DummyTracker([self]) - ) - return SequenceIterVariable( - item_list, self.graph, DummyTracker([item_list]) - ) - - def __getattr__(self, name): - name_ = "override_method_" + name - if hasattr(self, name_): - method = getattr(self, name_) - return DirectlyCallMethodVariable( - self, - method.__func__, - self.graph, - GetAttrTracker(self, name), - ) - else: - raise NotImplementedError( - f"attribute {name} for dict is not implemented" - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, dict): - assert graph is not None - return DictVariable(value, graph=graph, tracker=tracker) - - -class CallableVariable(VariableBase): - def __init__(self, graph: FunctionGraph, tracker: Tracker): - super().__init__(tracker) - self.graph = graph - - def __call__(self, *args, **kwargs) -> VariableBase: - return self.call_function(*args, **kwargs) - - def call_function(self, *args, **kwargs): - raise NotImplementedError("call_function is not implemented.") - - -class FunctionVariable(CallableVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = fn - - def get_value(self): - return self.value - - def get_code(self) -> types.CodeType: - return self.value.__code__ - - -class PaddleApiVariable(FunctionVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(fn, graph, tracker) - - def call_function(self, *args, **kwargs): - return self.graph.call_paddle_api(self.value, *args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - # This should be front of FunctionVariable to avoid conflict. - if callable(value) and is_paddle_api(value): - return PaddleApiVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"PaddleApiVariable({self.value.__name__})" - - -class UserDefinedGeneratorVariable(FunctionVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(fn, graph, tracker) - - def call_function(self, *args, **kwargs) -> VariableBase: - - iter_ = self.value() - return VariableFactory.from_value( - iter_, self.graph, DummyTracker([self]) - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.isgeneratorfunction(value): - return UserDefinedGeneratorVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"UserDefinedGeneratorVariable({self.value.__name__})" - - -class UserDefinedFunctionVariable(FunctionVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(fn, graph, tracker) - - def call_function(self, *args, **kwargs) -> VariableBase: - from .opcode_inline_executor import OpcodeInlineExecutor - - if self.value is ASSERT: - return self.value(args[0].value) - - checkpoint = self.graph.save_memo() - try: - inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) - output = inline_executor.inline_call() - except FallbackErrorBase as e: - self.graph.restore_memo(checkpoint) - raise BreakGraphError( - f"{self.value} is raise a inline call error. {e}" - ) - return output - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (types.FunctionType)): - return UserDefinedFunctionVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"UserDefinedFunctionVariable({self.value.__name__})" - - -class MethodVariable(CallableVariable): - def __init__( - self, - bound_instance: VariableBase, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(graph, tracker) - self.bound_instance = bound_instance - - -class TensorMethodVariable(MethodVariable): - def __init__( - self, - tensor: TensorVariable, - method_name: str, - graph: FunctionGraph, - tracker: Tracker, - ): - super().__init__(tensor, graph, tracker) - self.tensor = tensor - self.method_name = method_name - - def get_value(self): - return getattr(self.tensor, self.method_name) - - def call_function(self, *args, **kwargs): - return self.graph.call_tensor_method( - self.method_name, self.tensor, *args, **kwargs - ) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.ismethod(value) and isinstance( - value.__self__, paddle.Tensor - ): - # NOTE(SigureMo): Since the method_self need method_var as the obj - # of the tracker, we need to temporarily set the tracker of method_self - # to DummyTracker, and set it to GetAttrTracker after method_var is created. - method_self = TensorVariable( - value.__self__, graph, DummyTracker([]) - ) - method_var = TensorMethodVariable( - method_self, - value.__name__, - graph, - tracker, - ) - method_self.tracker = GetAttrTracker(method_var, "__self__") - return method_var - return None - - def __repr__(self) -> str: - return f"TensorMethodVariable({self.method_name})" - - -class UserDefinedMethodVariable(MethodVariable): - def __init__( - self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(bound_instance, graph, tracker) - self.bound_instance = bound_instance - self.fn = fn - - def get_value(self): - return self.fn.__get__( - self.bound_instance, self.bound_instance.__class__ - ) - - def call_function(self, *args, **kwargs): - fn_var = UserDefinedFunctionVariable( - self.fn, self.graph, GetAttrTracker(self, "__func__") - ) - - return fn_var(*(self.bound_instance, *args), **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.ismethod(value): - method_self = VariableFactory.from_value( - value.__self__, graph, DummyTracker([]) - ) - method_var = UserDefinedMethodVariable( - method_self, - value.__func__, - graph, - tracker, - ) - method_self.tracker = GetAttrTracker(method_var, "__self__") - return method_var - return None - - def __repr__(self) -> str: - return f"UserDefinedMethodVariable({self.fn.__name__})" - - -class DirectlyCallMethodVariable(MethodVariable): - def __init__( - self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(bound_instance, graph, tracker) - self.bound_instance = bound_instance - self.fn = fn - - def get_value(self): - return self.fn.__get__( - self.bound_instance, self.bound_instance.__class__ - ) - - def call_function(self, *args, **kwargs): - return self.fn(*(self.bound_instance, *args), **kwargs) - - -class LayerVariable(CallableVariable): - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = layer - - def get_value(self): - return self.value - - def make_stringify_guard(self) -> StringifyExpression: - assert not isinstance( - self.tracker, DummyTracker - ), "Can not make guard from dummy tracker" - - frame_value_tracer = self.tracker.trace_value_from_frame() - log_do( - 4, - lambda: print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" - ), - ) - return StringifyExpression( - f"id({frame_value_tracer.expr}) == {id(self.get_value())}", - union_free_vars(frame_value_tracer.free_vars), - ) & StringifyExpression( - f"{frame_value_tracer.expr}.training == {self.get_value().training}", - union_free_vars(frame_value_tracer.free_vars), - ) - - -class PaddleLayerVariable(LayerVariable): - layer_name_generator = NameGenerator("layer_") - - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(layer, graph, tracker) - self.name = self.layer_name_generator.next() - - def get_symbol(self) -> Symbol: - return Symbol(self.name) - - def call_function(self, *args, **kwargs): - # TODO: Remove this trick after we support for-loop. - if isinstance(self.value, paddle.nn.Sequential): - assert len(args) == 1, "Sequential only accept one input" - input = args[0] - for i, layer in enumerate(self.value._sub_layers.values()): - layer_var = VariableFactory.from_value( - layer, self.graph, tracker=GetItemTracker(self, i) - ) - assert isinstance(layer_var, LayerVariable) - input = layer_var(input) - return input - return self.graph.call_layer(self, *args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. - if isinstance(value, paddle.nn.Layer) and value.__module__.startswith( - "paddle.nn." - ): - return PaddleLayerVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"PaddleLayerVariable({self.value.__class__.__name__})" - - -class UserDefinedLayerVariable(LayerVariable): - def __init__( - self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker - ): - super().__init__(layer, graph, tracker) - - def call_function(self, *args, **kwargs): - fn_var = UserDefinedFunctionVariable( - self.value.__class__.__call__, - self.graph, - GetAttrTracker(self, "__call__"), - ) - - return fn_var(*(self, *args), **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance( - value, paddle.nn.Layer - ) and not value.__module__.startswith("paddle.nn."): - return UserDefinedLayerVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"UserDefinedLayerVariable({self.value.__class__.__name__})" - - -class BuiltinVariable(CallableVariable): - def __init__( - self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker - ): - super().__init__(graph, tracker) - self.value = fn - - def call_function(self, *args, **kwargs): - # TODO(0x45f): For builtin functions, may have 3 different ways to process as below: - # 1. Simulation execution: ensure correct simulation execution and handle trackers with care - # 2. Trigger the paddle api call - # 3. Trigger fallback - if is_break_graph_api(self.value): - raise BreakGraphError() - args = [ - arg.value if isinstance(arg, ConstantVariable) else arg - for arg in args - ] - kwargs = { - k: (v.value if isinstance(v, ConstantVariable) else v) - for k, v in kwargs.items() - } - return self.value(*args, **kwargs) - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (types.BuiltinFunctionType)): - return BuiltinVariable(value, graph, tracker) - return None - - def __repr__(self) -> str: - return f"BuiltinVariable({self.value.__name__})" - - -class SliceVariable(VariableBase): - def __init__(self, slice_, graph, tracker): - super().__init__(tracker) - self.value = slice_ - self.graph = graph - - def __repr__(self) -> str: - return f"SliceVariable({self.value})" - - def get_value(self): - return self.value - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, slice): - return SliceVariable(value, graph, tracker) - return None - - -class ModuleVariable(VariableBase): - def __init__(self, func, graph, tracker): - super().__init__(tracker) - self.value = func - self.graph = graph - - def get_value(self): - return self.value - - def __repr__(self) -> str: - return f"ModuleVariable({self.value})" - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, types.ModuleType): - return ModuleVariable(value, graph, tracker) - return None - - -class DygraphTracerVariable(VariableBase): - # TODO(SigureMo): Remove this trick after we add CompareTracker - def __init__(self, value, graph, tracker): - super().__init__(tracker) - self.value = value - self.graph = graph - - def get_value(self): - return self.value - - def make_stringify_guard(self) -> StringifyExpression: - assert not isinstance( - self.tracker, DummyTracker - ), "Can not make guard from dummy tracker" - - frame_value_tracer = self.tracker.trace_value_from_frame() - log_do( - 4, - lambda: print( - f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" - ), - ) - return StringifyExpression("True", {}) - - def __repr__(self) -> str: - return f"DygraphTracerVariable(is_none={self.value is None})" - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, paddle.fluid.dygraph.tracer.Tracer): - return DygraphTracerVariable(value, graph, tracker) - return None - - -class ObjectVariable(VariableBase): - def __init__(self, obj, graph, tracker): - super().__init__(tracker) - self.value = obj - self.graph = graph - - def __repr__(self) -> str: - return f"ObjectVariable({self.value})" - - -class IterVariable(VariableBase): - def __init__(self, obj, graph, tracker): - super().__init__(tracker) - self.hold = obj - self.graph = graph - - @VariableFactory.register_from_value - def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, collections.abc.Iterable): - return UserDefinedIterVariable(value, graph, tracker) - return None - - -class SequenceIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - self.idx = 0 - - def next(self): - if self.idx < len(self.hold): - val = self.hold[self.idx] - self.idx += 1 - return val - else: - raise StopIteration() - - -class DictIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - self.key_list = [ - ConstantVariable(x, ConstTracker(x)) for x in self.hold - ] - self.idx = 0 - - def next(self): - if self.idx < len(self.key_list): - val = self.key_list[self.idx] - return val - else: - raise StopIteration() - - -class TensorIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) - - -# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph -class UserDefinedIterVariable(IterVariable): - def __init__(self, obj, graph, tracker): - super().__init__(obj, graph, tracker) diff --git a/symbolic_trace/opcode_translator/executor/variables/__init__.py b/symbolic_trace/opcode_translator/executor/variables/__init__.py new file mode 100644 index 000000000..e71fd0b31 --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/__init__.py @@ -0,0 +1,44 @@ +from .base import ( # noqa: F401 + ConstTypes, + VariableBase, + VariableFactory, + get_zero_degree_vars, + map_variables, + topo_sort_vars, +) +from .basic import ( # noqa: F401 + ConstantVariable, + DygraphTracerVariable, + ModuleVariable, + ObjectVariable, + SliceVariable, + TensorVariable, +) +from .callable import ( # noqa: F401 + BuiltinVariable, + CallableVariable, + DirectlyCallMethodVariable, + FunctionVariable, + LayerVariable, + MethodVariable, + PaddleApiVariable, + PaddleLayerVariable, + TensorMethodVariable, + UserDefinedFunctionVariable, + UserDefinedGeneratorVariable, + UserDefinedLayerVariable, + UserDefinedMethodVariable, +) +from .container import ( # noqa: F401 + ContainerVariable, + DictVariable, + ListVariable, + TupleVariable, +) +from .iter import ( # noqa: F401 + DictIterVariable, + IterVariable, + SequenceIterVariable, + TensorIterVariable, + UserDefinedIterVariable, +) diff --git a/symbolic_trace/opcode_translator/executor/variables/base.py b/symbolic_trace/opcode_translator/executor/variables/base.py new file mode 100644 index 000000000..8856d3482 --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/base.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import inspect +from queue import Queue +from typing import TYPE_CHECKING, Any, Callable + +import paddle + +from ....utils import NameGenerator, log, log_do +from ....utils.exceptions import InnerError +from ..guard import StringifyExpression, union_free_vars +from ..pycode_generator import PyCodeGen +from ..tracker import DummyTracker, GetAttrTracker, Tracker + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +ConstTypes = (int, float, str, bool, type(None)) + + +def get_zero_degree_vars( + variables: set[VariableBase], visited_vars: list[VariableBase] +) -> list[VariableBase]: + return [ + var + for var in variables + if var not in visited_vars + and len(set(var.get_traceable_inputs()) - set(visited_vars)) == 0 + ] + + +def topo_sort_vars( + root_vars: list[VariableBase], +) -> list[VariableBase]: + unique_vars = set() + + for var in root_vars: + unique_vars.add(var) + unique_vars |= set(var.flatten_traceable_inputs()) + + topo_ordered_vars = [] + topo_queue = Queue() + for var in get_zero_degree_vars(unique_vars, topo_ordered_vars): + topo_queue.put(var) + + while not topo_queue.empty(): + var = topo_queue.get() + topo_ordered_vars.append(var) + for zero_degree_var in get_zero_degree_vars( + unique_vars, topo_ordered_vars + ): + if ( + zero_degree_var in topo_queue.queue + or zero_degree_var in topo_ordered_vars + ): + continue + topo_queue.put(zero_degree_var) + return topo_ordered_vars + + +def map_variables(map_func, variables): + def _map_variable(variable): + assert isinstance( + variable, VariableBase + ), f"variable must be VariableBase, got {variable}" + from .container import ContainerVariable + + if isinstance(variable, ContainerVariable): + return paddle.utils.map_structure( + _map_variable, variable.get_wrapped_items() + ) + return map_func(variable) + + return paddle.utils.map_structure(_map_variable, variables) + + +class VariableFactory: + registered_funcs: dict[str, list[str]] = {"default": []} + mapping_str_func: dict[str, Callable] = {} + + @staticmethod + def default_from_value(value, graph, tracker): + from .basic import ObjectVariable + + return ObjectVariable(value, graph, tracker) + + @staticmethod + def register_from_value(*, successor: str | None = None): + registered_funcs = VariableFactory.registered_funcs + mapping_str_func = VariableFactory.mapping_str_func + + def _register_from_value(func: Callable): + name = func.__qualname__.split(".")[0] + mapping_str_func[name] = func + if successor is None: + registered_funcs["default"].append(name) + elif successor not in registered_funcs.keys(): + registered_funcs[successor] = [name] + else: + registered_funcs[successor].append(name) + + log(4, VariableFactory.registered_funcs) + return _register_from_value + + @staticmethod + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + registered_funcs = VariableFactory.registered_funcs + + def _find_var(key: str = "default"): + for name in registered_funcs[key]: + if name in registered_funcs.keys(): + var = _find_var(name) + if var is not None: + return var + func = VariableFactory.mapping_str_func[name] + var = func(value, graph, tracker) + if var is not None: + return var + + var = _find_var() + if var is not None: + return var + return VariableFactory.default_from_value(value, graph, tracker) + + +class VariableBase: + """ + VariableBase is a basic concept and each symbols in VM stack is regarded as + an Variable Object in symblic tracing process. + """ + + tracker: Tracker + name_generator = NameGenerator("object_") + + def __init__(self, tracker: Tracker): + self.tracker = tracker + self.id = VariableBase.name_generator.next() + + def __hash__(self): + return hash(self.id) + + def make_stringify_guard(self) -> StringifyExpression: + assert not isinstance( + self.tracker, DummyTracker + ), "Can not make guard from dummy tracker" + + frame_value_tracer = self.tracker.trace_value_from_frame() + log_do( + 4, + lambda: print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ), + ) + return StringifyExpression( + f"{frame_value_tracer.expr} == {self.get_value()}", + union_free_vars(frame_value_tracer.free_vars), + ) + + def get_value(self) -> Any: + raise NotImplementedError() + + def reconstruct(self, codegen: PyCodeGen): + """ + Contruct an opcode and append it into codegen.instructions. + """ + if ( + not isinstance(self.tracker, DummyTracker) + and self.tracker.is_traceable() + ): + self.tracker.gen_instructions(codegen) + else: + self._reconstruct(codegen) + + def _reconstruct(self, codegen: PyCodeGen): + raise NotImplementedError() + + def flatten_items(self) -> list[VariableBase]: + from .container import ContainerVariable + + if not isinstance(self, ContainerVariable): + return [self] + flattened_items = [] + for item in self.get_items(): + flattened_items.extend(item.flatten_items()) + return flattened_items + + def get_inputs(self) -> list[VariableBase]: + return self.tracker.inputs + + def get_traceable_inputs(self) -> list[VariableBase]: + if self.tracker.is_traceable(): + return [] + + return list( + filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) + ) + + def flatten_traceable_inputs(self) -> list[VariableBase]: + flattened_traceable_inputs: list[VariableBase] = [self] + if self.tracker.is_traceable(): + return flattened_traceable_inputs + + for input in self.get_inputs(): + flattened_traceable_inputs.extend(input.flatten_traceable_inputs()) + return flattened_traceable_inputs + + def call_function(self, *args, **kwargs): + pass + + def __getattr__(self, name: str): + if not hasattr(self.value, name): + raise InnerError( + f"{self.__class__.__name__} {self} has no attribute {name}" + ) + attr = getattr(self.value, name) + if inspect.ismethod(attr): + from .callable import UserDefinedMethodVariable + + return UserDefinedMethodVariable( + self, + attr.__func__, + graph=self.graph, + tracker=GetAttrTracker(self, name), + ) + return VariableFactory.from_value( + attr, self.graph, tracker=GetAttrTracker(self, name) + ) + + def getitem(self, *args, **kwargs): + pass + + @VariableFactory.register_from_value() + def from_value( + value: Any, + graph: FunctionGraph | None, + tracker: Tracker, + ): + if isinstance(value, VariableBase): + return value + return None diff --git a/symbolic_trace/opcode_translator/executor/variables/basic.py b/symbolic_trace/opcode_translator/executor/variables/basic.py new file mode 100644 index 000000000..e09a3a454 --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/basic.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +import types +from typing import TYPE_CHECKING, Any + +import paddle + +from ....infer_meta import MetaInfo +from ....symbolic.statement_ir import Symbol +from ....utils import NameGenerator, log_do, paddle_tensor_methods +from ....utils.exceptions import InnerError +from ..guard import StringifyExpression, union_free_vars +from ..pycode_generator import PyCodeGen +from ..tracker import ConstTracker, DummyTracker, GetAttrTracker, Tracker +from .base import ConstTypes, VariableBase, VariableFactory + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class ConstantVariable(VariableBase): + def __init__( + self, + value: Any, + tracker: Tracker, + ): + super().__init__(tracker) + self.value = value + + def get_value(self): + return self.value + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_const(self.value) + + def __repr__(self) -> str: + return f"ConstantVariable({self.value})" + + def __bool__(self) -> bool: + return bool(self.value) + + def apply_unary_operator(self, magic_name): + operator = getattr(self.value, magic_name) + var = VariableFactory.from_value( + operator(), + None, + tracker=DummyTracker( + [ + self, + ] + ), + ) + return var + + def apply_binary_operator(self, other, magic_name): + if not isinstance(other, ConstantVariable): + return NotImplemented + operator = getattr(self.value, magic_name) + var = VariableFactory.from_value( + operator(other.value), None, tracker=DummyTracker([self, other]) + ) + return var + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, ConstTypes): + return ConstantVariable(value, tracker) + return None + + @staticmethod + def wrap_literal(value: Any) -> ConstantVariable: + if isinstance(value, ConstantVariable): + return value + assert isinstance( + value, ConstTypes + ), f"value: {value},type: {type(value)}" + return ConstantVariable(value, ConstTracker(value)) + + +class TensorVariable(VariableBase): + var_name_generator = NameGenerator("var_") + + def __init__( + self, + tensor: paddle.Tensor | MetaInfo, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(tracker) + if isinstance(tensor, paddle.Tensor): + self.value = tensor + self.meta = MetaInfo.from_tensor(tensor) + elif isinstance(tensor, MetaInfo): + self.value = None + self.meta = tensor + else: + raise InnerError( + "Required type(tensor) is paddle.Tensor or ProxyTensor, but received {}.".format( + type(tensor).__name__ + ) + ) + self.var_name = TensorVariable.var_name_generator.next() + self.graph = graph + + def get_value(self): + if self.value is None: + raise InnerError("Can not get value from a inner tensor variable.") + return self.value + + def get_symbol(self) -> Symbol: + return Symbol(self.var_name) + + @property + def out_var_name(self): + return f"{self.graph.out_var_prefix}{self.var_name}" + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_fast(self.out_var_name) + + def make_stringify_guard(self) -> StringifyExpression: + assert not isinstance( + self.tracker, DummyTracker + ), "Can not make guard from dummy tracker" + + frame_value_tracer = self.tracker.trace_value_from_frame() + log_do( + 4, + lambda: print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ), + ) + return StringifyExpression( + f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.meta}'", + union_free_vars( + {"MetaInfo": MetaInfo}, + frame_value_tracer.free_vars, + ), + ) + + def __repr__(self) -> str: + return f"TensorVariable{self.meta}" + + def __getitem__(self, key): + return self.graph.call_tensor_method( + '__getitem__', + self, + VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ), + ) + + @property + def T(self): + perm = list(range(len(self.meta.shape) - 1, -1, -1)) + perm_var = VariableFactory.from_value( + perm, self.graph, tracker=ConstTracker(perm) + ) + out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) + return out + + @property + def ndim(self): + return ConstantVariable.wrap_literal(len(self.meta.shape)) + + def __getattr__(self, name: str): + if name in paddle_tensor_methods: + from .callable import TensorMethodVariable + + return TensorMethodVariable( + self, name, self.graph, tracker=GetAttrTracker(self, name) + ) + elif name in ["shape", "dtype", "stop_gradient"]: + return VariableFactory.from_value( + getattr(self.meta, name), + self.graph, + tracker=GetAttrTracker(self, name), + ) + elif name in ["T", "ndim"]: + return getattr(self, name) + else: + raise InnerError(f"Unknown Tensor attribute: {name}") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, (paddle.Tensor, MetaInfo)): + assert graph is not None + return TensorVariable(value, graph, tracker) + return None + + +class ObjectVariable(VariableBase): + def __init__(self, obj, graph, tracker): + super().__init__(tracker) + self.value = obj + self.graph = graph + + def __repr__(self) -> str: + return f"ObjectVariable({self.value})" + + +class SliceVariable(VariableBase): + def __init__(self, slice_, graph, tracker): + super().__init__(tracker) + self.value = slice_ + self.graph = graph + + def __repr__(self) -> str: + return f"SliceVariable({self.value})" + + def get_value(self): + return self.value + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, slice): + return SliceVariable(value, graph, tracker) + return None + + +class ModuleVariable(VariableBase): + def __init__(self, func, graph, tracker): + super().__init__(tracker) + self.value = func + self.graph = graph + + def get_value(self): + return self.value + + def __repr__(self) -> str: + return f"ModuleVariable({self.value})" + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, types.ModuleType): + return ModuleVariable(value, graph, tracker) + return None + + +class DygraphTracerVariable(VariableBase): + # TODO(SigureMo): Remove this trick after we add CompareTracker + def __init__(self, value, graph, tracker): + super().__init__(tracker) + self.value = value + self.graph = graph + + def get_value(self): + return self.value + + def make_stringify_guard(self) -> StringifyExpression: + assert not isinstance( + self.tracker, DummyTracker + ), "Can not make guard from dummy tracker" + + frame_value_tracer = self.tracker.trace_value_from_frame() + log_do( + 4, + lambda: print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ), + ) + return StringifyExpression("True", {}) + + def __repr__(self) -> str: + return f"DygraphTracerVariable(is_none={self.value is None})" + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, paddle.fluid.dygraph.tracer.Tracer): + return DygraphTracerVariable(value, graph, tracker) + return None diff --git a/symbolic_trace/opcode_translator/executor/variables/callable.py b/symbolic_trace/opcode_translator/executor/variables/callable.py new file mode 100644 index 000000000..becf12d1d --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/callable.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import inspect +import types +from typing import TYPE_CHECKING, Any, Callable + +import paddle + +from ....symbolic.statement_ir import Symbol +from ....utils import ( + ASSERT, + NameGenerator, + is_break_graph_api, + is_paddle_api, + log_do, +) +from ....utils.exceptions import BreakGraphError, FallbackErrorBase +from ..guard import StringifyExpression, union_free_vars +from ..tracker import DummyTracker, GetAttrTracker, GetItemTracker, Tracker +from .base import VariableBase, VariableFactory +from .basic import ConstantVariable, TensorVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class CallableVariable(VariableBase): + def __init__(self, graph: FunctionGraph, tracker: Tracker): + super().__init__(tracker) + self.graph = graph + + def __call__(self, *args, **kwargs) -> VariableBase: + return self.call_function(*args, **kwargs) + + def call_function(self, *args, **kwargs): + raise NotImplementedError("call_function is not implemented.") + + +class FunctionVariable(CallableVariable): + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = fn + + def get_value(self): + return self.value + + def get_code(self) -> types.CodeType: + return self.value.__code__ + + +class UserDefinedFunctionVariable(FunctionVariable): + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, *args, **kwargs) -> VariableBase: + from ..opcode_inline_executor import OpcodeInlineExecutor + + if self.value is ASSERT: + return self.value(args[0].value) + + checkpoint = self.graph.save_memo() + try: + inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) + output = inline_executor.inline_call() + except FallbackErrorBase as e: + self.graph.restore_memo(checkpoint) + raise BreakGraphError( + f"{self.value} is raise a inline call error. {e}" + ) + return output + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, (types.FunctionType)): + return UserDefinedFunctionVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"UserDefinedFunctionVariable({self.value.__name__})" + + +class PaddleApiVariable(FunctionVariable): + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, *args, **kwargs): + return self.graph.call_paddle_api(self.value, *args, **kwargs) + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if callable(value) and is_paddle_api(value): + return PaddleApiVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"PaddleApiVariable({self.value.__name__})" + + +class MethodVariable(CallableVariable): + def __init__( + self, + bound_instance: VariableBase, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.bound_instance = bound_instance + + +class UserDefinedMethodVariable(MethodVariable): + def __init__( + self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(bound_instance, graph, tracker) + self.bound_instance = bound_instance + self.fn = fn + + def get_value(self): + return self.fn.__get__( + self.bound_instance, self.bound_instance.__class__ + ) + + def call_function(self, *args, **kwargs): + fn_var = UserDefinedFunctionVariable( + self.fn, self.graph, GetAttrTracker(self, "__func__") + ) + + return fn_var(*(self.bound_instance, *args), **kwargs) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if inspect.ismethod(value): + method_self = VariableFactory.from_value( + value.__self__, graph, DummyTracker([]) + ) + method_var = UserDefinedMethodVariable( + method_self, + value.__func__, + graph, + tracker, + ) + method_self.tracker = GetAttrTracker(method_var, "__self__") + return method_var + return None + + def __repr__(self) -> str: + return f"UserDefinedMethodVariable({self.fn.__name__})" + + +class TensorMethodVariable(MethodVariable): + def __init__( + self, + tensor: TensorVariable, + method_name: str, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(tensor, graph, tracker) + self.tensor = tensor + self.method_name = method_name + + def get_value(self): + return getattr(self.tensor, self.method_name) + + def call_function(self, *args, **kwargs): + return self.graph.call_tensor_method( + self.method_name, self.tensor, *args, **kwargs + ) + + @VariableFactory.register_from_value(successor="UserDefinedMethodVariable") + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if inspect.ismethod(value) and isinstance( + value.__self__, paddle.Tensor + ): + # NOTE(SigureMo): Since the method_self need method_var as the obj + # of the tracker, we need to temporarily set the tracker of method_self + # to DummyTracker, and set it to GetAttrTracker after method_var is created. + method_self = TensorVariable( + value.__self__, graph, DummyTracker([]) + ) + method_var = TensorMethodVariable( + method_self, + value.__name__, + graph, + tracker, + ) + method_self.tracker = GetAttrTracker(method_var, "__self__") + return method_var + return None + + def __repr__(self) -> str: + return f"TensorMethodVariable({self.method_name})" + + +class DirectlyCallMethodVariable(MethodVariable): + def __init__( + self, bound_instance, fn, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(bound_instance, graph, tracker) + self.bound_instance = bound_instance + self.fn = fn + + def get_value(self): + return self.fn.__get__( + self.bound_instance, self.bound_instance.__class__ + ) + + def call_function(self, *args, **kwargs): + return self.fn(*(self.bound_instance, *args), **kwargs) + + +class LayerVariable(CallableVariable): + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = layer + + def get_value(self): + return self.value + + def make_stringify_guard(self) -> StringifyExpression: + assert not isinstance( + self.tracker, DummyTracker + ), "Can not make guard from dummy tracker" + + frame_value_tracer = self.tracker.trace_value_from_frame() + log_do( + 4, + lambda: print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ), + ) + return StringifyExpression( + f"id({frame_value_tracer.expr}) == {id(self.get_value())}", + union_free_vars(frame_value_tracer.free_vars), + ) & StringifyExpression( + f"{frame_value_tracer.expr}.training == {self.get_value().training}", + union_free_vars(frame_value_tracer.free_vars), + ) + + +class UserDefinedLayerVariable(LayerVariable): + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + def call_function(self, *args, **kwargs): + fn_var = UserDefinedFunctionVariable( + self.value.__class__.__call__, + self.graph, + GetAttrTracker(self, "__call__"), + ) + + return fn_var(*(self, *args), **kwargs) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance( + value, paddle.nn.Layer + ) and not value.__module__.startswith("paddle.nn."): + return UserDefinedLayerVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"UserDefinedLayerVariable({self.value.__class__.__name__})" + + +class BuiltinVariable(CallableVariable): + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = fn + + def call_function(self, *args, **kwargs): + # TODO(0x45f): For builtin functions, may have 3 different ways to process as below: + # 1. Simulation execution: ensure correct simulation execution and handle trackers with care + # 2. Trigger the paddle api call + # 3. Trigger fallback + if is_break_graph_api(self.value): + raise BreakGraphError() + args = [ + arg.value if isinstance(arg, ConstantVariable) else arg + for arg in args + ] + kwargs = { + k: (v.value if isinstance(v, ConstantVariable) else v) + for k, v in kwargs.items() + } + return self.value(*args, **kwargs) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, (types.BuiltinFunctionType)): + return BuiltinVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"BuiltinVariable({self.value.__name__})" + + +class UserDefinedGeneratorVariable(FunctionVariable): + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, *args, **kwargs) -> VariableBase: + + iter_ = self.value() + return VariableFactory.from_value( + iter_, self.graph, DummyTracker([self]) + ) + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if inspect.isgeneratorfunction(value): + return UserDefinedGeneratorVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"UserDefinedGeneratorVariable({self.value.__name__})" + + +class PaddleLayerVariable(LayerVariable): + layer_name_generator = NameGenerator("layer_") + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + self.name = self.layer_name_generator.next() + + def get_symbol(self) -> Symbol: + return Symbol(self.name) + + def call_function(self, *args, **kwargs): + # TODO: Remove this trick after we support for-loop. + if isinstance(self.value, paddle.nn.Sequential): + assert len(args) == 1, "Sequential only accept one input" + input = args[0] + for i, layer in enumerate(self.value._sub_layers.values()): + layer_var = VariableFactory.from_value( + layer, self.graph, tracker=GetItemTracker(self, i) + ) + assert isinstance(layer_var, LayerVariable) + input = layer_var(input) + return input + return self.graph.call_layer(self, *args, **kwargs) + + @VariableFactory.register_from_value(successor="UserDefinedLayerVariable") + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. + if isinstance(value, paddle.nn.Layer) and value.__module__.startswith( + "paddle.nn." + ): + return PaddleLayerVariable(value, graph, tracker) + return None + + def __repr__(self) -> str: + return f"PaddleLayerVariable({self.value.__class__.__name__})" diff --git a/symbolic_trace/opcode_translator/executor/variables/container.py b/symbolic_trace/opcode_translator/executor/variables/container.py new file mode 100644 index 000000000..5c4a7b063 --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/container.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ....utils.exceptions import InnerError +from ..pycode_generator import PyCodeGen +from ..tracker import ( + ConstTracker, + DummyTracker, + GetAttrTracker, + GetItemTracker, + Tracker, +) +from .base import ConstTypes, VariableBase, VariableFactory +from .basic import ConstantVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class ContainerVariable(VariableBase): + def get_items(self) -> list[VariableBase]: + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + def __bool__(self): + return len(self) > 0 + + +class ListVariable(ContainerVariable): + def __init__( + self, + val_list: list[VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(tracker) + self.graph = graph + # everything in stack is VariableBase, so just accept the input list is ok + self.value = val_list + + def get_value(self): + return [self[i].get_value() for i in range(len(self))] + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + self[idx].reconstruct(codegen) + codegen.gen_build_list(size) + + def get_items(self): + size = len(self) + return [self[idx] for idx in range(size)] + + def get_wrapped_items(self): + return self.get_items() + + def __repr__(self) -> str: + return f"ListVariable(len={len(self)})" + + def __len__(self): + return len(self.value) + + def __getitem__(self, key): + ''' + we need to make sure that: + before an inplace change happens to ListVariable, + the related items should already be wrapped as VariableBase + + if not, tracker might be set to a wrong elem + ''' + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + + retval = self.value[key] + + # if list is an input of funciton, we need make sure __getitem__ returns a VariableBase + retval = VariableFactory.from_value( + retval, self.graph, tracker=GetItemTracker(self, key) + ) + + return retval + + def __setitem__(self, key, value): + ''' + why __setitem__ is ok: + + case: + def f(x = [t0, t1]) + ... + x[0] = 0 + ... + + 1. if setitem happens after get t0: t0 is a VariableBase (transformed at getitem), so it is ok + 2. if setitem happens before get t0: t0 will not be used + ''' + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {key} as key." + ) + + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {value} to set value." + ) + self.value[key] = value + + def __delitem__(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {key} as key to delete." + ) + del self.value[key] + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, list): + assert graph is not None + return ListVariable(value, graph=graph, tracker=tracker) + return None + + +class TupleVariable(ContainerVariable): + def __init__( + self, + val_tuple: list[VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(tracker) + self.graph = graph + # exactly it is a list (need replace item with VariableBase) + self.value = list(val_tuple) + + def get_value(self): + return tuple(self[i].get_value() for i in range(len(self))) + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + self[idx].reconstruct(codegen) + codegen.gen_build_tuple(size) + + def get_items(self): + size = len(self) + return [self[idx] for idx in range(size)] + + def get_wrapped_items(self): + return self.get_items() + + def __repr__(self) -> str: + return f"TupleVariable(len={len(self)})" + + def __len__(self): + return len(self.value) + + def __getitem__(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + retval = self.value[key] + + return VariableFactory.from_value( + retval, graph=self.graph, tracker=GetItemTracker(self, key) + ) + + def __setitem__(self, key, value): + raise InnerError( + f"[{self.__class__.__name__}]: setitem is not allowed." + ) + + def __delitem__(self, key): + raise InnerError( + f"[{self.__class__.__name__}]: delitem is not allowed." + ) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, tuple): + return TupleVariable(value, graph, tracker) + return None + + +class DictVariable(ContainerVariable): + def __init__( + self, + val_dict: dict[object, VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(tracker) + self.graph = graph + self.value = val_dict + + def get_value(self): + return {key: self[key].get_value() for key in self.value} + + def _reconstruct(self, codegen: PyCodeGen): + from .basic import ConstantVariable + + size = len(self) + for key in self.value.keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = ConstantVariable.wrap_literal(key) + value_var = self[key] + key_var.reconstruct(codegen) + value_var.reconstruct(codegen) + codegen.gen_build_map(size) + + def get_items(self): + items = [] + for key in self.value.keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + value_var = self[key] + items.extend([key_var, value_var]) + return items + + def get_wrapped_items(self): + items = {} + for key in self.value.keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + items[key] = self[key] + return items + + def __repr__(self) -> str: + return f"DictVariable(len={len(self)})" + + def __len__(self): + return len(self.value) + + def __getitem__(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + + retval = self.value[key] + + return VariableFactory.from_value( + retval, self.graph, tracker=GetItemTracker(self, key) + ) + + def __setitem__(self, key, value): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + + if not isinstance(value, ConstantVariable): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {value} to set value." + ) + + self.value[key] = value + + def __delitem__(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key to delete." + ) + del self.value[key] + + def override_method_keys(self): + from .iter import SequenceIterVariable + + raw_list = [ + ConstantVariable(x, ConstTracker(x)) for x in self.value.keys() + ] + key_list = VariableFactory.from_value( + raw_list, self.graph, ConstTracker(raw_list) + ) + return SequenceIterVariable( + key_list, self.graph, DummyTracker([key_list]) + ) + + def override_method_values(self): + from .iter import SequenceIterVariable + + raw_list = list(self.get_wrapped_items().values()) + value_list = VariableFactory.from_value( + raw_list, self.graph, DummyTracker([self]) + ) + return SequenceIterVariable( + value_list, self.graph, DummyTracker([value_list]) + ) + + def override_method_items(self): + from .iter import SequenceIterVariable + + keys = [ConstantVariable(x, ConstTracker(x)) for x in self.value.keys()] + values = list(self.get_wrapped_items().values()) + raw_list = list(zip(keys, values)) + item_list = VariableFactory.from_value( + raw_list, self.graph, DummyTracker([self]) + ) + return SequenceIterVariable( + item_list, self.graph, DummyTracker([item_list]) + ) + + def __getattr__(self, name): + from .callable import DirectlyCallMethodVariable + + name_ = "override_method_" + name + if hasattr(self, name_): + method = getattr(self, name_) + return DirectlyCallMethodVariable( + self, + method.__func__, + self.graph, + GetAttrTracker(self, name), + ) + else: + raise NotImplementedError( + f"attribute {name} for dict is not implemented" + ) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, dict): + assert graph is not None + return DictVariable(value, graph=graph, tracker=tracker) diff --git a/symbolic_trace/opcode_translator/executor/variables/iter.py b/symbolic_trace/opcode_translator/executor/variables/iter.py new file mode 100644 index 000000000..81006ce19 --- /dev/null +++ b/symbolic_trace/opcode_translator/executor/variables/iter.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import collections +from typing import TYPE_CHECKING, Any + +from ..tracker import ConstTracker, Tracker +from .base import VariableBase, VariableFactory +from .basic import ConstantVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class IterVariable(VariableBase): + def __init__(self, obj, graph, tracker): + super().__init__(tracker) + self.hold = obj + self.graph = graph + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): + if isinstance(value, collections.abc.Iterable): + return UserDefinedIterVariable(value, graph, tracker) + return None + + +class SequenceIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker) + self.idx = 0 + + def next(self): + if self.idx < len(self.hold): + val = self.hold[self.idx] + self.idx += 1 + return val + else: + raise StopIteration() + + +class DictIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker) + self.key_list = [ + ConstantVariable(x, ConstTracker(x)) for x in self.hold + ] + self.idx = 0 + + def next(self): + if self.idx < len(self.key_list): + val = self.key_list[self.idx] + return val + else: + raise StopIteration() + + +class TensorIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker) + + +# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph +class UserDefinedIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker)