Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Dispatcher] Support more ListVariable operations #255

Merged
merged 10 commits into from
Jul 11, 2023
18 changes: 18 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,24 @@
{},
lambda var: var.reverse(),
)
Dispatcher.register(
list.copy,
("ListVariable",),
{},
lambda var: var.copy(),
)
Dispatcher.register(
list.count,
("ListVariable", "VariableBase"),
{},
lambda var, obj: var.count(obj),
)
Dispatcher.register(
list.index,
("ListVariable", "VariableBase"),
{},
lambda var, obj: var.index(obj),
)
Dispatcher.register(
operator.add,
("ListVariable", "ListVariable"),
Expand Down
50 changes: 49 additions & 1 deletion sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from .base import ConstTypes, VariableBase, VariableFactory
from .basic import ConstantVariable
from .callable import BuiltinVariable

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
Expand Down Expand Up @@ -289,6 +290,51 @@ def reverse(self):
self.graph.side_effects.record_variable(self)
return ConstantVariable.wrap_literal(None, self.graph)

def count(self, value: VariableBase):
count: int = 0
for i in self:
if i.id == value.id:
count += 1
continue
eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())(
i, value
)
eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq)
assert isinstance(
eq_bool, ConstantVariable
), "bool should return ConstantVariable"
if eq.get_value() is True:
count += 1
continue

return VariableFactory.from_value(
count, self.graph, DummyTracker([self, value])
)

def index(self, value: VariableBase):
res = 0
for i in self:
if i.id == value.id:
return VariableFactory.from_value(
res, self.graph, DummyTracker([self, value])
)
eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())(
i, value
)
eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq)
assert isinstance(
eq_bool, ConstantVariable
), "bool should return ConstantVariable"
if eq.get_value() is True:
return VariableFactory.from_value(
res, self.graph, DummyTracker([self, value])
)
res += 1

return VariableFactory.from_value(
-1, self.graph, DummyTracker([self, value])
)

def getattr(self, name):
from .callable import BuiltinVariable

Expand All @@ -302,6 +348,8 @@ def getattr(self, name):
"remove": list.remove,
"sort": list.sort,
"reverse": list.reverse,
"count": list.count,
"index": list.index,
}

if name in method_name_to_builtin_fn:
Expand All @@ -311,7 +359,7 @@ def getattr(self, name):
).bind(self, name)
else:
raise NotImplementException(
f"attribute {name} for dict is not implemented"
f"attribute {name} for list is not implemented"
)

@VariableFactory.register_from_value()
Expand Down
158 changes: 156 additions & 2 deletions tests/test_04_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# BINARY_SUBSCR
# DELETE_SUBSCR

from __future__ import annotations

import unittest

Expand Down Expand Up @@ -45,14 +46,167 @@ def list_delitem_tensor(x: int, y: paddle.Tensor):
return z


def list_append_int(x: int, y: paddle.Tensor):
z = [x, y]
z.append(3)
return z


def list_append_tensor(x: int, y: paddle.Tensor):
z = [x, y]
z.append(y)
return z


def list_clear(x: int, y: paddle.Tensor):
z = [x, y]
z.clear()
return z


def list_copy(x: int, y: paddle.Tensor):
z = [x, y]
a = z.copy()
z[0] = 3
z[1] = y + 1
return (a, z)


def list_count_int(x: int, y: paddle.Tensor):
z = [x, x, 2, 3, 1]
return z.count(x)


def list_count_tensor(x: paddle.Tensor, y: list[paddle.Tensor]):
return y.count(x)


def list_extend(x: int, y: paddle.Tensor):
z = [x, y]
a = [y, x]
b = (x, y)
z.extend(a)
z.extend(b)
return z


def list_index_int(x: int, y: paddle.Tensor):
z = [x, x, 1, 2]
return z.index(x)


def list_index_tensor(x: paddle.Tensor, y: list[paddle.Tensor]):
return y.index(x)


def list_insert(x: int, y: paddle.Tensor):
z = [x, y]
z.insert(0, x)
z.insert(3, y)
return z


def list_pop(x: int, y: paddle.Tensor):
z = [x, y]
a = z.pop()
b = z.pop()
return (z, a, b)


def list_remove(x: int, y: paddle.Tensor):
z = [x, x, y, y]
z.remove(x)
z.remove(y)
return z


def list_reverse(x: int, y: paddle.Tensor):
z = [x, x, y, y]
z.reverse()
return z


def list_default_sort(x: int, y: paddle.Tensor):
z = [x + 2, x, x + 1]
z.sort()
return z


def list_key_sort(x: int, y: paddle.Tensor):
z = [x + 2, x, x + 1]
z.sort(lambda x: x)
return z


def list_reverse_sort(x: int, y: paddle.Tensor):
z = [x + 2, x, x + 1]
z.sort(reverse=True)
return z


def list_tensor_sort(x: int, y: paddle.Tensor):
z = [y + 2, y, y + 1]
z.sort()
return z


class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(list_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2))
self.assert_results(list_setitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2))
self.assert_results(list_delitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2))
self.assert_results(list_count_int, 1, paddle.to_tensor(2))
self.assert_results(list_index_int, 1, paddle.to_tensor(2))
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
self.assert_results(list_count_tensor, a, [a, b, a, b, a, b])
self.assert_results(list_index_tensor, b, [a, b, a, b, a, b])
self.assert_results_with_side_effects(
list_delitem_int, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_delitem_tensor, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_append_int, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_append_tensor, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_clear, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(list_copy, 1, paddle.to_tensor(2))
self.assert_results_with_side_effects(
list_extend, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_insert, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(list_pop, 1, paddle.to_tensor(2))
self.assert_results_with_side_effects(
list_remove, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_reverse, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_reverse, 1, paddle.to_tensor(2)
)
self.assert_results_with_side_effects(
list_default_sort, 1, paddle.to_tensor(2)
)
# TODO: Not currently supported
# self.assert_results_with_side_effects(
# list_tensor_sort, 1, paddle.to_tensor(2)
# )
# self.assert_results_with_side_effects(
# list_key_sort, 1, paddle.to_tensor(2)
# )
# self.assert_results_with_side_effects(
# list_reverse_sort, 1, paddle.to_tensor(2)
# )


if __name__ == "__main__":
Expand Down