From fff84cf665941b1cf52a03fcc25c98bc631c0123 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 01:20:32 +0800 Subject: [PATCH 01/10] [executor] add copy; [tests] add list append, clear, copy --- .../executor/variable_dispatch.py | 6 ++++ tests/test_04_list.py | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 5b172e1f0..cc182215d 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -189,6 +189,12 @@ {}, lambda var: var.reverse(), ) +Dispatcher.register( + list.copy, + ("ListVariable",), + {}, + lambda var: var.copy(), +) Dispatcher.register( operator.add, ("ListVariable", "ListVariable"), diff --git a/tests/test_04_list.py b/tests/test_04_list.py index ad31ad431..88327112b 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -45,6 +45,32 @@ def list_delitem_tensor(x: int, y: paddle.Tensor): return z +def list_append_int(x: int, y: paddle.Tensor): + z = [x, y] + z.append(3) + return z + + +def list_append_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z.append(y) + return z + + +def list_clear(x: int, y: paddle.Tensor): + z = [x, y] + z.clear() + return z + + +def list_copy(x: int, y: paddle.Tensor): + z = [x, y] + a = z.copy() + z[0] = 3 + z[1] = y + 1 + return (a, z) + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -53,6 +79,10 @@ def test_simple(self): self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(list_append_int, 1, paddle.to_tensor(2)) + self.assert_results(list_append_tensor, 1, paddle.to_tensor(2)) + self.assert_results(list_clear, 1, paddle.to_tensor(2)) + self.assert_results(list_copy, 1, paddle.to_tensor(2)) if __name__ == "__main__": From f75b7685eb04c39b9b975a9224617d3b0d23aabe Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 16:13:06 +0800 Subject: [PATCH 02/10] support ListVariable count --- sot/opcode_translator/executor/variable_dispatch.py | 6 ++++++ sot/opcode_translator/executor/variables/container.py | 11 ++++++++++- tests/test_04_list.py | 6 ++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index cc182215d..420ac2b62 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -195,6 +195,12 @@ {}, lambda var: var.copy(), ) +Dispatcher.register( + list.count, + ("ListVariable", "VariableBase"), + {}, + lambda var, obj: var.count(obj), +) Dispatcher.register( operator.add, ("ListVariable", "ListVariable"), diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index e48ab8b20..51161ee38 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -289,6 +289,14 @@ def reverse(self): self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) + def count(self, obj: VariableBase): + res = 0 + for i in self: + if i == obj: + res += 1 + + return ConstantVariable.wrap_literal(res, self.graph) + def getattr(self, name): from .callable import BuiltinVariable @@ -302,6 +310,7 @@ def getattr(self, name): "remove": list.remove, "sort": list.sort, "reverse": list.reverse, + "count": list.count, } if name in method_name_to_builtin_fn: @@ -311,7 +320,7 @@ def getattr(self, name): ).bind(self, name) else: raise NotImplementException( - f"attribute {name} for dict is not implemented" + f"attribute {name} for list is not implemented" ) @VariableFactory.register_from_value() diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 88327112b..a62e5704f 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -71,6 +71,11 @@ def list_copy(x: int, y: paddle.Tensor): return (a, z) +def list_count(x: int, y: paddle.Tensor): + z = [x, x, y, y, y] + return (z.count(x), z.count(y)) + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -83,6 +88,7 @@ def test_simple(self): self.assert_results(list_append_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_clear, 1, paddle.to_tensor(2)) self.assert_results(list_copy, 1, paddle.to_tensor(2)) + self.assert_results(list_count, 1, paddle.to_tensor(2)) if __name__ == "__main__": From 156ce8a90c8fe51d90d335d93188b62de44448ad Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 16:22:12 +0800 Subject: [PATCH 03/10] add extend tests --- tests/test_04_list.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_04_list.py b/tests/test_04_list.py index a62e5704f..7ab89ee5d 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -76,6 +76,15 @@ def list_count(x: int, y: paddle.Tensor): return (z.count(x), z.count(y)) +def list_extend(x: int, y: paddle.Tensor): + z = [x, y] + a = [y, x] + b = (x, y) + z.extend(a) + z.extend(b) + return z + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -89,6 +98,7 @@ def test_simple(self): self.assert_results(list_clear, 1, paddle.to_tensor(2)) self.assert_results(list_copy, 1, paddle.to_tensor(2)) self.assert_results(list_count, 1, paddle.to_tensor(2)) + self.assert_results(list_extend, 1, paddle.to_tensor(2)) if __name__ == "__main__": From 80815ab6a682fde7d13177540131926d0a9ec19b Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 16:33:05 +0800 Subject: [PATCH 04/10] support ListVariable index --- sot/opcode_translator/executor/variable_dispatch.py | 6 ++++++ sot/opcode_translator/executor/variables/container.py | 10 ++++++++++ tests/test_04_list.py | 6 ++++++ 3 files changed, 22 insertions(+) diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 420ac2b62..da1828901 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -201,6 +201,12 @@ {}, lambda var, obj: var.count(obj), ) +Dispatcher.register( + list.index, + ("ListVariable", "VariableBase"), + {}, + lambda var, obj: var.index(obj), +) Dispatcher.register( operator.add, ("ListVariable", "ListVariable"), diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index 51161ee38..15287fb26 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -297,6 +297,15 @@ def count(self, obj: VariableBase): return ConstantVariable.wrap_literal(res, self.graph) + def index(self, obj: VariableBase): + res = 0 + for i in self: + if i == obj: + return ConstantVariable.wrap_literal(res, self.graph) + res += 1 + + return ConstantVariable.wrap_literal(-1, self.graph) + def getattr(self, name): from .callable import BuiltinVariable @@ -311,6 +320,7 @@ def getattr(self, name): "sort": list.sort, "reverse": list.reverse, "count": list.count, + "index": list.index, } if name in method_name_to_builtin_fn: diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 7ab89ee5d..e2fa18abb 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -85,6 +85,11 @@ def list_extend(x: int, y: paddle.Tensor): return z +def list_index(x: int, y: paddle.Tensor): + z = [y, x, y, y] + return (z.index(x), z.index(y)) + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -99,6 +104,7 @@ def test_simple(self): self.assert_results(list_copy, 1, paddle.to_tensor(2)) self.assert_results(list_count, 1, paddle.to_tensor(2)) self.assert_results(list_extend, 1, paddle.to_tensor(2)) + self.assert_results(list_index, 1, paddle.to_tensor(2)) if __name__ == "__main__": From 9b41e650450278441afa5915dec2bcc994325c55 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 17:00:28 +0800 Subject: [PATCH 05/10] add ListVariable insert tests;switch assert_results_with_side_effects --- tests/test_04_list.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/tests/test_04_list.py b/tests/test_04_list.py index e2fa18abb..25bb264f3 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -90,21 +90,43 @@ def list_index(x: int, y: paddle.Tensor): return (z.index(x), z.index(y)) +def list_insert(x: int, y: paddle.Tensor): + z = [x, y] + z.insert(0, x) + z.insert(3, y) + return z + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) - self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) - self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) - self.assert_results(list_append_int, 1, paddle.to_tensor(2)) - self.assert_results(list_append_tensor, 1, paddle.to_tensor(2)) - self.assert_results(list_clear, 1, paddle.to_tensor(2)) - self.assert_results(list_copy, 1, paddle.to_tensor(2)) self.assert_results(list_count, 1, paddle.to_tensor(2)) - self.assert_results(list_extend, 1, paddle.to_tensor(2)) self.assert_results(list_index, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_delitem_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_delitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_clear, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects(list_copy, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_extend, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_insert, 1, paddle.to_tensor(2) + ) if __name__ == "__main__": From e2c98d5c304f7dc2825632d285525a3744491a70 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 18:01:53 +0800 Subject: [PATCH 06/10] [tests] add pop remove reverse --- tests/test_04_list.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 25bb264f3..14744fe0a 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -97,6 +97,44 @@ def list_insert(x: int, y: paddle.Tensor): return z +def list_pop(x: int, y: paddle.Tensor): + z = [x, y] + a = z.pop() + b = z.pop() + return (z, a, b) + + +def list_remove(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.remove(x) + z.remove(y) + return z + + +def list_reverse(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.reverse() + return z + + +def list_default_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort() + return z + + +def list_key_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort(key=len) + return z + + +def list_reverse_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort(reverse=True) + return z + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -127,6 +165,22 @@ def test_simple(self): self.assert_results_with_side_effects( list_insert, 1, paddle.to_tensor(2) ) + self.assert_results_with_side_effects(list_pop, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_remove, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_default_sort, 1, paddle.to_tensor(2) + ) + # self.assert_results_with_side_effects( + # list_default_sort, 1, paddle.to_tensor(2) + # ) if __name__ == "__main__": From 2e6c7eddd076132ed7fe2aab7932564322855b67 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 7 Jul 2023 23:29:42 +0800 Subject: [PATCH 07/10] [tests] fix ListVariable sort --- tests/test_04_list.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 14744fe0a..be4b70858 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -118,23 +118,29 @@ def list_reverse(x: int, y: paddle.Tensor): def list_default_sort(x: int, y: paddle.Tensor): - z = [y + 2, y, y + 1] + z = [x + 2, x, x + 1] z.sort() return z def list_key_sort(x: int, y: paddle.Tensor): - z = [y + 2, y, y + 1] - z.sort(key=len) + z = [x + 2, x, x + 1] + z.sort(lambda x: x) return z def list_reverse_sort(x: int, y: paddle.Tensor): - z = [y + 2, y, y + 1] + z = [x + 2, x, x + 1] z.sort(reverse=True) return z +def list_tensor_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort() + return z + + class TestExecutor(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) @@ -178,8 +184,15 @@ def test_simple(self): self.assert_results_with_side_effects( list_default_sort, 1, paddle.to_tensor(2) ) + # TODO: Not currently supported + # self.assert_results_with_side_effects( + # list_tensor_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_key_sort, 1, paddle.to_tensor(2) + # ) # self.assert_results_with_side_effects( - # list_default_sort, 1, paddle.to_tensor(2) + # list_reverse_sort, 1, paddle.to_tensor(2) # ) From dc0eba121bdbbc8fecd29faec97101b233d999b3 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 11 Jul 2023 14:37:10 +0800 Subject: [PATCH 08/10] fix index and count --- .../executor/variables/container.py | 48 +++++++++++++++---- tests/test_04_list.py | 32 +++++++++---- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index 15287fb26..d247be792 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -5,7 +5,11 @@ from typing import TYPE_CHECKING, Any from ....utils import log_do -from ....utils.exceptions import InnerError, NotImplementException +from ....utils.exceptions import ( + BreakGraphError, + InnerError, + NotImplementException, +) from ..guard import StringifyExpression from ..mutable_data import MutableDictLikeData, MutableListLikeData from ..pycode_generator import PyCodeGen @@ -17,7 +21,8 @@ Tracker, ) from .base import ConstTypes, VariableBase, VariableFactory -from .basic import ConstantVariable +from .basic import ConstantVariable, TensorVariable +from .callable import BuiltinVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -289,22 +294,49 @@ def reverse(self): self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) - def count(self, obj: VariableBase): + def count(self, value: VariableBase): res = 0 for i in self: - if i == obj: + if i.id == value.id: + res += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + i, value + ) + if isinstance(eq, ConstantVariable) and eq.get_value(): res += 1 + continue + + if isinstance(eq, TensorVariable): + raise BreakGraphError( + "TensorVariable Not currently supported bool" + ) return ConstantVariable.wrap_literal(res, self.graph) - def index(self, obj: VariableBase): + def index(self, value: VariableBase): res = 0 for i in self: - if i == obj: - return ConstantVariable.wrap_literal(res, self.graph) + if i.id == value.id: + return VariableFactory.from_value( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + i, value + ) + if isinstance(eq, ConstantVariable) and eq.get_value() is True: + return VariableFactory.from_value( + res, self.graph, DummyTracker([self, value]) + ) + if isinstance(eq, TensorVariable): + raise BreakGraphError( + "TensorVariable Not currently supported bool" + ) res += 1 - return ConstantVariable.wrap_literal(-1, self.graph) + return VariableFactory.from_value( + -1, self.graph, DummyTracker([self, value]) + ) def getattr(self, name): from .callable import BuiltinVariable diff --git a/tests/test_04_list.py b/tests/test_04_list.py index be4b70858..7d501baaf 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -71,9 +71,16 @@ def list_copy(x: int, y: paddle.Tensor): return (a, z) -def list_count(x: int, y: paddle.Tensor): - z = [x, x, y, y, y] - return (z.count(x), z.count(y)) +def list_count_int(x: int, y: paddle.Tensor): + z = [x, x, 2, 3, 1] + return z.count(x) + + +def list_count_tensor(x: int, y: paddle.Tensor): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + z = [a, b, y, y, y] + return z.count(y) def list_extend(x: int, y: paddle.Tensor): @@ -85,9 +92,16 @@ def list_extend(x: int, y: paddle.Tensor): return z -def list_index(x: int, y: paddle.Tensor): - z = [y, x, y, y] - return (z.index(x), z.index(y)) +def list_index_int(x: int, y: paddle.Tensor): + z = [x, x, 1, 2] + return z.index(x) + + +def list_index_tensor(x: int, y: paddle.Tensor): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + z = [a, b, y, y] + return z.index(y) def list_insert(x: int, y: paddle.Tensor): @@ -147,8 +161,8 @@ def test_simple(self): self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) - self.assert_results(list_count, 1, paddle.to_tensor(2)) - self.assert_results(list_index, 1, paddle.to_tensor(2)) + self.assert_results(list_count_int, 1, paddle.to_tensor(2)) + self.assert_results(list_index_int, 1, paddle.to_tensor(2)) self.assert_results_with_side_effects( list_delitem_int, 1, paddle.to_tensor(2) ) @@ -185,6 +199,8 @@ def test_simple(self): list_default_sort, 1, paddle.to_tensor(2) ) # TODO: Not currently supported + # self.assert_results(list_count_tensor, 1, paddle.to_tensor(2)) + # self.assert_results(list_index_tensor, 1, paddle.to_tensor(2)) # self.assert_results_with_side_effects( # list_tensor_sort, 1, paddle.to_tensor(2) # ) From a88478042248e691c98f59bc9e1bdaf90c9828c4 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 11 Jul 2023 15:57:17 +0800 Subject: [PATCH 09/10] fix BreakGraphError --- .../executor/variables/container.py | 39 +++++++++---------- tests/test_04_list.py | 20 ++++------ 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index d247be792..accb62849 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -5,11 +5,7 @@ from typing import TYPE_CHECKING, Any from ....utils import log_do -from ....utils.exceptions import ( - BreakGraphError, - InnerError, - NotImplementException, -) +from ....utils.exceptions import InnerError, NotImplementException from ..guard import StringifyExpression from ..mutable_data import MutableDictLikeData, MutableListLikeData from ..pycode_generator import PyCodeGen @@ -21,7 +17,7 @@ Tracker, ) from .base import ConstTypes, VariableBase, VariableFactory -from .basic import ConstantVariable, TensorVariable +from .basic import ConstantVariable from .callable import BuiltinVariable if TYPE_CHECKING: @@ -295,24 +291,25 @@ def reverse(self): return ConstantVariable.wrap_literal(None, self.graph) def count(self, value: VariableBase): - res = 0 + count: int = 0 for i in self: if i.id == value.id: - res += 1 + count += 1 continue eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( i, value ) - if isinstance(eq, ConstantVariable) and eq.get_value(): - res += 1 + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_value() is True: + count += 1 continue - if isinstance(eq, TensorVariable): - raise BreakGraphError( - "TensorVariable Not currently supported bool" - ) - - return ConstantVariable.wrap_literal(res, self.graph) + return VariableFactory.from_value( + count, self.graph, DummyTracker([self, value]) + ) def index(self, value: VariableBase): res = 0 @@ -324,14 +321,14 @@ def index(self, value: VariableBase): eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( i, value ) - if isinstance(eq, ConstantVariable) and eq.get_value() is True: + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_value() is True: return VariableFactory.from_value( res, self.graph, DummyTracker([self, value]) ) - if isinstance(eq, TensorVariable): - raise BreakGraphError( - "TensorVariable Not currently supported bool" - ) res += 1 return VariableFactory.from_value( diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 7d501baaf..c92f7f1a7 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -76,11 +76,8 @@ def list_count_int(x: int, y: paddle.Tensor): return z.count(x) -def list_count_tensor(x: int, y: paddle.Tensor): - a = paddle.to_tensor(1) - b = paddle.to_tensor(2) - z = [a, b, y, y, y] - return z.count(y) +def list_count_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.count(x) def list_extend(x: int, y: paddle.Tensor): @@ -97,11 +94,8 @@ def list_index_int(x: int, y: paddle.Tensor): return z.index(x) -def list_index_tensor(x: int, y: paddle.Tensor): - a = paddle.to_tensor(1) - b = paddle.to_tensor(2) - z = [a, b, y, y] - return z.index(y) +def list_index_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.index(x) def list_insert(x: int, y: paddle.Tensor): @@ -163,6 +157,10 @@ def test_simple(self): self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_count_int, 1, paddle.to_tensor(2)) self.assert_results(list_index_int, 1, paddle.to_tensor(2)) + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(list_count_tensor, a, [a, b, a, b, a, b]) + self.assert_results(list_index_tensor, b, [a, b, a, b, a, b]) self.assert_results_with_side_effects( list_delitem_int, 1, paddle.to_tensor(2) ) @@ -199,8 +197,6 @@ def test_simple(self): list_default_sort, 1, paddle.to_tensor(2) ) # TODO: Not currently supported - # self.assert_results(list_count_tensor, 1, paddle.to_tensor(2)) - # self.assert_results(list_index_tensor, 1, paddle.to_tensor(2)) # self.assert_results_with_side_effects( # list_tensor_sort, 1, paddle.to_tensor(2) # ) From bb0ee66d0b92f5a618afb34c69d6fd37e6cd5abf Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 11 Jul 2023 16:56:47 +0800 Subject: [PATCH 10/10] fix --- tests/test_04_list.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_04_list.py b/tests/test_04_list.py index c92f7f1a7..c1157b0a7 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -3,6 +3,7 @@ # BINARY_SUBSCR # DELETE_SUBSCR +from __future__ import annotations import unittest