diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 5b172e1f0..da1828901 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -189,6 +189,24 @@ {}, lambda var: var.reverse(), ) +Dispatcher.register( + list.copy, + ("ListVariable",), + {}, + lambda var: var.copy(), +) +Dispatcher.register( + list.count, + ("ListVariable", "VariableBase"), + {}, + lambda var, obj: var.count(obj), +) +Dispatcher.register( + list.index, + ("ListVariable", "VariableBase"), + {}, + lambda var, obj: var.index(obj), +) Dispatcher.register( operator.add, ("ListVariable", "ListVariable"), diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index e48ab8b20..accb62849 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -18,6 +18,7 @@ ) from .base import ConstTypes, VariableBase, VariableFactory from .basic import ConstantVariable +from .callable import BuiltinVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -289,6 +290,51 @@ def reverse(self): self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) + def count(self, value: VariableBase): + count: int = 0 + for i in self: + if i.id == value.id: + count += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + i, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_value() is True: + count += 1 + continue + + return VariableFactory.from_value( + count, self.graph, DummyTracker([self, value]) + ) + + def index(self, value: VariableBase): + res = 0 + for i in self: + if i.id == value.id: + return VariableFactory.from_value( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + i, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_value() is True: + return VariableFactory.from_value( + res, self.graph, DummyTracker([self, value]) + ) + res += 1 + + return VariableFactory.from_value( + -1, self.graph, DummyTracker([self, value]) + ) + def getattr(self, name): from .callable import BuiltinVariable @@ -302,6 +348,8 @@ def getattr(self, name): "remove": list.remove, "sort": list.sort, "reverse": list.reverse, + "count": list.count, + "index": list.index, } if name in method_name_to_builtin_fn: @@ -311,7 +359,7 @@ def getattr(self, name): ).bind(self, name) else: raise NotImplementException( - f"attribute {name} for dict is not implemented" + f"attribute {name} for list is not implemented" ) @VariableFactory.register_from_value() diff --git a/tests/test_04_list.py b/tests/test_04_list.py index ad31ad431..c1157b0a7 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -3,6 +3,7 @@ # BINARY_SUBSCR # DELETE_SUBSCR +from __future__ import annotations import unittest @@ -45,14 +46,167 @@ def list_delitem_tensor(x: int, y: paddle.Tensor): return z +def list_append_int(x: int, y: paddle.Tensor): + z = [x, y] + z.append(3) + return z + + +def list_append_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z.append(y) + return z + + +def list_clear(x: int, y: paddle.Tensor): + z = [x, y] + z.clear() + return z + + +def list_copy(x: int, y: paddle.Tensor): + z = [x, y] + a = z.copy() + z[0] = 3 + z[1] = y + 1 + return (a, z) + + +def list_count_int(x: int, y: paddle.Tensor): + z = [x, x, 2, 3, 1] + return z.count(x) + + +def list_count_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.count(x) + + +def list_extend(x: int, y: paddle.Tensor): + z = [x, y] + a = [y, x] + b = (x, y) + z.extend(a) + z.extend(b) + return z + + +def list_index_int(x: int, y: paddle.Tensor): + z = [x, x, 1, 2] + return z.index(x) + + +def list_index_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.index(x) + + +def list_insert(x: int, y: paddle.Tensor): + z = [x, y] + z.insert(0, x) + z.insert(3, y) + return z + + +def list_pop(x: int, y: paddle.Tensor): + z = [x, y] + a = z.pop() + b = z.pop() + return (z, a, b) + + +def list_remove(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.remove(x) + z.remove(y) + return z + + +def list_reverse(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.reverse() + return z + + +def list_default_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort() + return z + + +def list_key_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(lambda x: x) + return z + + +def list_reverse_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(reverse=True) + return z + + +def list_tensor_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort() + return z + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) - self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) - self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(list_count_int, 1, paddle.to_tensor(2)) + self.assert_results(list_index_int, 1, paddle.to_tensor(2)) + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(list_count_tensor, a, [a, b, a, b, a, b]) + self.assert_results(list_index_tensor, b, [a, b, a, b, a, b]) + self.assert_results_with_side_effects( + list_delitem_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_delitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_clear, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects(list_copy, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_extend, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_insert, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects(list_pop, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_remove, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_default_sort, 1, paddle.to_tensor(2) + ) + # TODO: Not currently supported + # self.assert_results_with_side_effects( + # list_tensor_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_key_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_reverse_sort, 1, paddle.to_tensor(2) + # ) if __name__ == "__main__":