Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add assert
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Jul 8, 2023
1 parent a28875e commit ad003f7
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 30 deletions.
6 changes: 4 additions & 2 deletions sot/opcode_translator/executor/mutable_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ def get(self, key):
write_cache = self.reproduce(self.version)
return write_cache[key]

def get_all(self):
return self.reproduce(self.version)
def get_all(self) -> list[Any]:
items = self.reproduce(self.version)
assert isinstance(items, list)
return items

@record_mutation
def set(self, key: int, value: Any):
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def __call__(self, *args, **kwargs):
self.graph,
GetAttrTracker(self, '__class__'),
)
assert class_var is not None
# if __call__ is a method, we should add self to arguments.
if inspect.ismethod(self.get_value().__call__):
args = (self,) + args
Expand Down
23 changes: 12 additions & 11 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def bool_not(self):

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, ConstTypes):
if isinstance(value, ConstTypes) and graph is not None:
return ConstantVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -200,7 +200,7 @@ def get_value(self):

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, (paddle.dtype)):
if isinstance(value, (paddle.dtype)) and graph is not None:
return DataVariable(value, graph, tracker)


Expand Down Expand Up @@ -291,21 +291,21 @@ def main_info(self) -> dict[str, Any]:
}

def getitem(self, key):
return self.graph.call_tensor_method(
'__getitem__',
self,
VariableFactory.from_value(
key, self.graph, tracker=ConstTracker(key)
),
var = VariableFactory.from_value(
key, self.graph, tracker=ConstTracker(key)
)
assert var is not None
return self.graph.call_tensor_method('__getitem__', self, var)

def setitem(self, key, value):
var = VariableFactory.from_value(
key, self.graph, tracker=ConstTracker(key)
)
assert var is not None
return self.graph.call_tensor_method(
'__setitem__',
self,
VariableFactory.from_value(
key, self.graph, tracker=ConstTracker(key)
),
var,
value,
)

Expand All @@ -318,6 +318,7 @@ def T(self):
perm_var = VariableFactory.from_value(
perm, self.graph, tracker=ConstTracker(perm)
)
assert perm_var is not None
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
return out

Expand Down
33 changes: 21 additions & 12 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Tracker,
)
from .base import VariableBase, VariableFactory
from .basic import ConstantVariable, ObjectVariable, PrintStmtVariable
from .basic import ConstantVariable, PrintStmtVariable

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
Expand Down Expand Up @@ -72,6 +72,7 @@ def bind(self, instance: VariableBase, name: str):
graph=self.graph,
tracker=GetAttrTracker(instance, "__class__"),
)
assert class_var is not None
self.tracker = GetAttrTracker(class_var, name)
return method_var

Expand Down Expand Up @@ -110,7 +111,7 @@ def call_function(self, *args, **kwargs) -> VariableBase:

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, (types.FunctionType)):
if isinstance(value, (types.FunctionType)) and graph is not None:
return UserDefinedFunctionVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -138,7 +139,7 @@ def call_function(self, *args, **kwargs):
successor="UserDefinedFunctionVariable"
)
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if callable(value) and is_paddle_api(value):
if callable(value) and is_paddle_api(value) and graph is not None:
return PaddleApiVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -224,7 +225,8 @@ def wrap_method(
value.__func__, graph, DanglingTracker()
)
assert isinstance(instance_var, VariableBase)
assert isinstance(fn_var, (FunctionVariable, ObjectVariable))
assert isinstance(fn_var, FunctionVariable)
assert isinstance(graph, FunctionGraph)
method_var = MethodVariable(
instance_var,
fn_var,
Expand Down Expand Up @@ -301,9 +303,11 @@ def call_function(self, *args, **kwargs):

@VariableFactory.register_from_value(successor="PaddleApiVariable")
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(
value, paddle.nn.Layer
) and not value.__module__.startswith("paddle.nn."):
if (
isinstance(value, paddle.nn.Layer)
and not value.__module__.startswith("paddle.nn.")
and graph is not None
):
return UserDefinedLayerVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -341,6 +345,7 @@ def call_function(self, *args, **kwargs):
self.graph,
GetAttrTracker(args[0], "__class__"),
)
assert isinstance(class_var, VariableBase)
fn_var = VariableFactory.from_value(
class_fn,
self.graph,
Expand All @@ -356,7 +361,7 @@ def call_function(self, *args, **kwargs):

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if is_builtin_fn(value):
if is_builtin_fn(value) and graph is not None:
return BuiltinVariable(value, graph, tracker)
return None

Expand All @@ -375,15 +380,17 @@ def __init__(

def call_function(self, *args, **kwargs) -> VariableBase:
iter_ = self.value()
return VariableFactory.from_value(
var = VariableFactory.from_value(
iter_, self.graph, DummyTracker([self])
)
assert var is not None
return var

@VariableFactory.register_from_value(
successor="UserDefinedFunctionVariable"
)
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if inspect.isgeneratorfunction(value):
if inspect.isgeneratorfunction(value) and graph is not None:
return UserDefinedGeneratorVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -421,8 +428,10 @@ def call_function(self, *args, **kwargs):
@VariableFactory.register_from_value(successor="UserDefinedLayerVariable")
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
# TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer.
if isinstance(value, paddle.nn.Layer) and value.__module__.startswith(
"paddle.nn."
if (
isinstance(value, paddle.nn.Layer)
and value.__module__.startswith("paddle.nn.")
and graph is not None
):
return PaddleLayerVariable(value, graph, tracker)
return None
Expand Down
18 changes: 14 additions & 4 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def proxy_getter(self, data, key):
)

def get_value(self):
return [item.get_value() for item in self.proxy.get_all()]
items = self.proxy.get_all()
return [item.get_value() for item in items]

def get_type(self):
return list
Expand Down Expand Up @@ -142,8 +143,9 @@ def getitem(self, key):
raise InnerError(f"List {self} out of range (index={key})")
return res
elif isinstance(key, slice):
items = self.proxy.get_all()
return VariableFactory.from_value(
self.proxy.get_all()[key],
items[key],
self.graph,
tracker=GetItemTracker(self, key),
)
Expand Down Expand Up @@ -270,6 +272,7 @@ def sort(self, key=None, reverse=None):
key = VariableFactory.from_value(
lambda x: x, self.graph, DanglingTracker()
)
assert key is not None
if reverse is None:
reverse = ConstantVariable.wrap_literal(False, self.graph)

Expand Down Expand Up @@ -325,7 +328,7 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
class TupleVariable(ContainerVariable):
def __init__(
self,
val_tuple: tuple[VariableBase],
val_tuple: tuple[VariableBase, ...],
graph: FunctionGraph,
tracker: Tracker,
):
Expand Down Expand Up @@ -422,7 +425,7 @@ def repeat(self, length):

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, tuple):
if isinstance(value, tuple) and graph is not None:
return TupleVariable(value, graph, tracker)
return None

Expand Down Expand Up @@ -519,6 +522,7 @@ def get(self, key, default=None):
if isinstance(self.proxy.get(key), MutableDictLikeData.Empty):
if isinstance(default, VariableBase):
return default
# TODO: VariableFactory.from_value maybe need 3 args?
return VariableFactory.from_value(default)

return self.getitem(key)
Expand Down Expand Up @@ -576,6 +580,7 @@ def keys(self):
key_list = VariableFactory.from_value(
raw_list, self.graph, ConstTracker(raw_list)
)
assert key_list is not None
return SequenceIterVariable(
key_list, self.graph, DummyTracker([key_list])
)
Expand All @@ -587,6 +592,7 @@ def values(self):
value_list = VariableFactory.from_value(
raw_list, self.graph, DummyTracker([self])
)
assert value_list is not None
return SequenceIterVariable(
value_list, self.graph, DummyTracker([value_list])
)
Expand All @@ -603,6 +609,7 @@ def items(self):
item_list = VariableFactory.from_value(
raw_list, self.graph, DummyTracker([self])
)
assert item_list is not None
return SequenceIterVariable(
item_list, self.graph, DummyTracker([item_list])
)
Expand Down Expand Up @@ -633,6 +640,7 @@ def pop(self, key, default=None):
if isinstance(self.proxy.get(key), MutableDictLikeData.Empty):
if isinstance(default, VariableBase):
return default
# TODO: VariableFactory.from_value maybe need 3 args?
return VariableFactory.from_value(default)

# default is not None, or key is in dict
Expand All @@ -643,6 +651,8 @@ def pop(self, key, default=None):
def popitem(self):
key = self.keys().hold.get_value()[-1]
value = self.getitem(key)
assert isinstance(key, VariableBase)
assert isinstance(value, VariableBase)
new_tuple_variable = TupleVariable(
(key, value), self.graph, DummyTracker([self])
)
Expand Down
4 changes: 3 additions & 1 deletion sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class IterVariable(VariableBase):
"""

def __init__(self, obj, graph, tracker):
from .container import ContainerVariable

super().__init__(tracker)
assert isinstance(obj, VariableBase)
assert isinstance(obj, ContainerVariable)
self.hold = obj
self.graph = graph

Expand Down

0 comments on commit ad003f7

Please sign in to comment.