From ad003f7449b8dce579c4e2a5ff3e40dda7f4305f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sat, 8 Jul 2023 08:21:08 +0000 Subject: [PATCH] add assert --- .../executor/mutable_data.py | 6 ++-- .../executor/variables/base.py | 1 + .../executor/variables/basic.py | 23 ++++++------- .../executor/variables/callable.py | 33 ++++++++++++------- .../executor/variables/container.py | 18 +++++++--- .../executor/variables/iter.py | 4 ++- 6 files changed, 55 insertions(+), 30 deletions(-) diff --git a/sot/opcode_translator/executor/mutable_data.py b/sot/opcode_translator/executor/mutable_data.py index 7ec385b31..cfc6c82cf 100644 --- a/sot/opcode_translator/executor/mutable_data.py +++ b/sot/opcode_translator/executor/mutable_data.py @@ -237,8 +237,10 @@ def get(self, key): write_cache = self.reproduce(self.version) return write_cache[key] - def get_all(self): - return self.reproduce(self.version) + def get_all(self) -> list[Any]: + items = self.reproduce(self.version) + assert isinstance(items, list) + return items @record_mutation def set(self, key: int, value: Any): diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index 2f01896dc..55cc6736f 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -532,6 +532,7 @@ def __call__(self, *args, **kwargs): self.graph, GetAttrTracker(self, '__class__'), ) + assert class_var is not None # if __call__ is a method, we should add self to arguments. if inspect.ismethod(self.get_value().__call__): args = (self,) + args diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 37bce98dc..c453e6335 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -122,7 +122,7 @@ def bool_not(self): @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, ConstTypes): + if isinstance(value, ConstTypes) and graph is not None: return ConstantVariable(value, graph, tracker) return None @@ -200,7 +200,7 @@ def get_value(self): @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (paddle.dtype)): + if isinstance(value, (paddle.dtype)) and graph is not None: return DataVariable(value, graph, tracker) @@ -291,21 +291,21 @@ def main_info(self) -> dict[str, Any]: } def getitem(self, key): - return self.graph.call_tensor_method( - '__getitem__', - self, - VariableFactory.from_value( - key, self.graph, tracker=ConstTracker(key) - ), + var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) ) + assert var is not None + return self.graph.call_tensor_method('__getitem__', self, var) def setitem(self, key, value): + var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + assert var is not None return self.graph.call_tensor_method( '__setitem__', self, - VariableFactory.from_value( - key, self.graph, tracker=ConstTracker(key) - ), + var, value, ) @@ -318,6 +318,7 @@ def T(self): perm_var = VariableFactory.from_value( perm, self.graph, tracker=ConstTracker(perm) ) + assert perm_var is not None out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) return out diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index 954faaf7c..952cb2613 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -29,7 +29,7 @@ Tracker, ) from .base import VariableBase, VariableFactory -from .basic import ConstantVariable, ObjectVariable, PrintStmtVariable +from .basic import ConstantVariable, PrintStmtVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -72,6 +72,7 @@ def bind(self, instance: VariableBase, name: str): graph=self.graph, tracker=GetAttrTracker(instance, "__class__"), ) + assert class_var is not None self.tracker = GetAttrTracker(class_var, name) return method_var @@ -110,7 +111,7 @@ def call_function(self, *args, **kwargs) -> VariableBase: @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, (types.FunctionType)): + if isinstance(value, (types.FunctionType)) and graph is not None: return UserDefinedFunctionVariable(value, graph, tracker) return None @@ -138,7 +139,7 @@ def call_function(self, *args, **kwargs): successor="UserDefinedFunctionVariable" ) def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if callable(value) and is_paddle_api(value): + if callable(value) and is_paddle_api(value) and graph is not None: return PaddleApiVariable(value, graph, tracker) return None @@ -224,7 +225,8 @@ def wrap_method( value.__func__, graph, DanglingTracker() ) assert isinstance(instance_var, VariableBase) - assert isinstance(fn_var, (FunctionVariable, ObjectVariable)) + assert isinstance(fn_var, FunctionVariable) + assert isinstance(graph, FunctionGraph) method_var = MethodVariable( instance_var, fn_var, @@ -301,9 +303,11 @@ def call_function(self, *args, **kwargs): @VariableFactory.register_from_value(successor="PaddleApiVariable") def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance( - value, paddle.nn.Layer - ) and not value.__module__.startswith("paddle.nn."): + if ( + isinstance(value, paddle.nn.Layer) + and not value.__module__.startswith("paddle.nn.") + and graph is not None + ): return UserDefinedLayerVariable(value, graph, tracker) return None @@ -341,6 +345,7 @@ def call_function(self, *args, **kwargs): self.graph, GetAttrTracker(args[0], "__class__"), ) + assert isinstance(class_var, VariableBase) fn_var = VariableFactory.from_value( class_fn, self.graph, @@ -356,7 +361,7 @@ def call_function(self, *args, **kwargs): @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if is_builtin_fn(value): + if is_builtin_fn(value) and graph is not None: return BuiltinVariable(value, graph, tracker) return None @@ -375,15 +380,17 @@ def __init__( def call_function(self, *args, **kwargs) -> VariableBase: iter_ = self.value() - return VariableFactory.from_value( + var = VariableFactory.from_value( iter_, self.graph, DummyTracker([self]) ) + assert var is not None + return var @VariableFactory.register_from_value( successor="UserDefinedFunctionVariable" ) def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if inspect.isgeneratorfunction(value): + if inspect.isgeneratorfunction(value) and graph is not None: return UserDefinedGeneratorVariable(value, graph, tracker) return None @@ -421,8 +428,10 @@ def call_function(self, *args, **kwargs): @VariableFactory.register_from_value(successor="UserDefinedLayerVariable") def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. - if isinstance(value, paddle.nn.Layer) and value.__module__.startswith( - "paddle.nn." + if ( + isinstance(value, paddle.nn.Layer) + and value.__module__.startswith("paddle.nn.") + and graph is not None ): return PaddleLayerVariable(value, graph, tracker) return None diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index e48ab8b20..8ee5cdb76 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -107,7 +107,8 @@ def proxy_getter(self, data, key): ) def get_value(self): - return [item.get_value() for item in self.proxy.get_all()] + items = self.proxy.get_all() + return [item.get_value() for item in items] def get_type(self): return list @@ -142,8 +143,9 @@ def getitem(self, key): raise InnerError(f"List {self} out of range (index={key})") return res elif isinstance(key, slice): + items = self.proxy.get_all() return VariableFactory.from_value( - self.proxy.get_all()[key], + items[key], self.graph, tracker=GetItemTracker(self, key), ) @@ -270,6 +272,7 @@ def sort(self, key=None, reverse=None): key = VariableFactory.from_value( lambda x: x, self.graph, DanglingTracker() ) + assert key is not None if reverse is None: reverse = ConstantVariable.wrap_literal(False, self.graph) @@ -325,7 +328,7 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): class TupleVariable(ContainerVariable): def __init__( self, - val_tuple: tuple[VariableBase], + val_tuple: tuple[VariableBase, ...], graph: FunctionGraph, tracker: Tracker, ): @@ -422,7 +425,7 @@ def repeat(self, length): @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): - if isinstance(value, tuple): + if isinstance(value, tuple) and graph is not None: return TupleVariable(value, graph, tracker) return None @@ -519,6 +522,7 @@ def get(self, key, default=None): if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): if isinstance(default, VariableBase): return default + # TODO: VariableFactory.from_value maybe need 3 args? return VariableFactory.from_value(default) return self.getitem(key) @@ -576,6 +580,7 @@ def keys(self): key_list = VariableFactory.from_value( raw_list, self.graph, ConstTracker(raw_list) ) + assert key_list is not None return SequenceIterVariable( key_list, self.graph, DummyTracker([key_list]) ) @@ -587,6 +592,7 @@ def values(self): value_list = VariableFactory.from_value( raw_list, self.graph, DummyTracker([self]) ) + assert value_list is not None return SequenceIterVariable( value_list, self.graph, DummyTracker([value_list]) ) @@ -603,6 +609,7 @@ def items(self): item_list = VariableFactory.from_value( raw_list, self.graph, DummyTracker([self]) ) + assert item_list is not None return SequenceIterVariable( item_list, self.graph, DummyTracker([item_list]) ) @@ -633,6 +640,7 @@ def pop(self, key, default=None): if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): if isinstance(default, VariableBase): return default + # TODO: VariableFactory.from_value maybe need 3 args? return VariableFactory.from_value(default) # default is not None, or key is in dict @@ -643,6 +651,8 @@ def pop(self, key, default=None): def popitem(self): key = self.keys().hold.get_value()[-1] value = self.getitem(key) + assert isinstance(key, VariableBase) + assert isinstance(value, VariableBase) new_tuple_variable = TupleVariable( (key, value), self.graph, DummyTracker([self]) ) diff --git a/sot/opcode_translator/executor/variables/iter.py b/sot/opcode_translator/executor/variables/iter.py index f2b7ac0fc..75467e9cf 100644 --- a/sot/opcode_translator/executor/variables/iter.py +++ b/sot/opcode_translator/executor/variables/iter.py @@ -13,8 +13,10 @@ class IterVariable(VariableBase): """ def __init__(self, obj, graph, tracker): + from .container import ContainerVariable + super().__init__(tracker) - assert isinstance(obj, VariableBase) + assert isinstance(obj, ContainerVariable) self.hold = obj self.graph = graph