From 4c490ca0ecb0e7b86ac9fe0d019240f3db2cec54 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Sat, 24 Feb 2024 15:46:15 -0500 Subject: [PATCH] Fix list optimization for new slices; update tests --- codon/cir/transform/pythonic/list.cpp | 2 +- stdlib/internal/types/slice.codon | 29 +++++---------- test/core/containers.codon | 53 ++++++++++++++++++++++----- 3 files changed, 54 insertions(+), 30 deletions(-) diff --git a/codon/cir/transform/pythonic/list.cpp b/codon/cir/transform/pythonic/list.cpp index 38c785dc..aee57a2e 100644 --- a/codon/cir/transform/pythonic/list.cpp +++ b/codon/cir/transform/pythonic/list.cpp @@ -14,7 +14,7 @@ namespace pythonic { namespace { static const std::string LIST = "std.internal.types.ptr.List"; -static const std::string SLICE = "std.internal.types.slice.Slice"; +static const std::string SLICE = "std.internal.types.slice.Slice[int,int,int]"; bool isList(Value *v) { return v->getType()->getName().rfind(LIST + "[", 0) == 0; } bool isSlice(Value *v) { return v->getType()->getName() == SLICE; } diff --git a/stdlib/internal/types/slice.codon b/stdlib/internal/types/slice.codon index 9b20b16a..8c0affbd 100644 --- a/stdlib/internal/types/slice.codon +++ b/stdlib/internal/types/slice.codon @@ -28,30 +28,21 @@ class Slice: V: type = int) -> Slice[T, U, V]: return (start, stop, step) - def _ensure_int_or_none(self, method_name: str): - if not ( - (T is int or T is Optional[int] or T is Optional[NoneType]) and - (U is int or U is Optional[int] or U is Optional[NoneType]) and - (V is int or V is Optional[int] or V is Optional[NoneType]) - ): - raise TypeError(f"{method_name} requires all fields to be ints. Fields types are: " - f"start: {T.__name__}, stop: {U.__name__}, step: {V.__name__}") - def adjust_indices(self, length: int) -> Tuple[int, int, int, int]: - self._ensure_int_or_none("slice.adjust_indices") - def has_int_value(v): - return isinstance(v, int) or (isinstance(v, Optional[int]) and v is not None) - step: int = self.step if has_int_value(self.step) else 1 + if not (T is int and U is int and V is int): + compile_error("slice indices must be integers or None") + + step: int = self.step if self.step is not None else 1 start: int = 0 stop: int = 0 if step == 0: raise ValueError("slice step cannot be zero") if step > 0: - start = self.start if has_int_value(self.start) else 0 - stop = self.stop if has_int_value(self.stop) else length + start = self.start if self.start is not None else 0 + stop = self.stop if self.stop is not None else length else: - start = self.start if has_int_value(self.start) else length - 1 - stop = self.stop if has_int_value(self.stop) else -(length + 1) + start = self.start if self.start is not None else length - 1 + stop = self.stop if self.stop is not None else -(length + 1) return Slice.adjust_indices_helper(length, start, stop, step) @@ -89,10 +80,10 @@ class Slice: def __repr__(self): return f"slice({self.start}, {self.stop}, {self.step})" - def __eq__(self, other): + def __eq__(self, other: Slice): return self.start == other.start and self.step == other.step and self.stop == other.stop - def __ne__(self, other): + def __ne__(self, other: Slice): return not self.__eq__(other) slice = Slice diff --git a/test/core/containers.codon b/test/core/containers.codon index 67f808cc..f85d1177 100644 --- a/test/core/containers.codon +++ b/test/core/containers.codon @@ -1,4 +1,5 @@ from copy import copy, deepcopy + @tuple class A: a: int @@ -600,16 +601,14 @@ def test_dict(): assert repr(Dict[int,int]()) == '{}' test_dict() -def slice_indices(slc, length): +def slice_indices(slice, length): """ Reference implementation for the slice.indices method. """ - def has_int_value(v): - return isinstance(v, int) or (isinstance(v, Optional[int]) and v is not None) # Compute step and length as integers. #length = operator.index(length) - step: int = slc.step if has_int_value(slc.step) else 1 + step: int = 1 if slice.step is None else slice.step # Raise ValueError for negative length or zero step. if length < 0: @@ -622,17 +621,17 @@ def slice_indices(slc, length): upper = length - 1 if step < 0 else length # Compute start. - if not has_int_value(slc.start): + if slice.start is None: start = upper if step < 0 else lower else: - start = slc.start + start = slice.start start = max(start + length, lower) if start < 0 else min(start, upper) # Compute stop. - if not has_int_value(slc.stop): + if slice.stop is None: stop = lower if step < 0 else upper else: - stop = slc.stop + stop = slice.stop stop = max(stop + length, lower) if stop < 0 else min(stop, upper) return start, stop, step @@ -733,7 +732,7 @@ def test_slice(): s = slice(*slice_args) for length in lengths: assert check_indices(s, length) - # assert check_indices(slice(0, 10, 1), -3) + assert check_indices(slice(0, 10, 1), -3) # Negative length should raise ValueError try: @@ -764,7 +763,41 @@ def test_slice(): x = X(tmp) x[1:2] = 42 - assert tmp == [(slice(1, 2, None), 42)] + assert tmp == [(slice(1, 2), 42)] + + # Non-int elements + def check_types(s, T: type, U: type, V: type): + return (type(s.start) is Optional[T] and + type(s.stop) is Optional[U] and + type(s.step) is Optional[V]) + assert check_types(slice(1j, 'x', 3.14), complex, str, float) + assert check_types(slice(None, 'x', 3.14), int, str, float) + assert check_types(slice(1j, None, 3.14), complex, int, float) + assert check_types(slice(1j, 'x', None), complex, str, int) + assert check_types(slice(1j, None, None), complex, int, int) + assert check_types(slice(None, 'x', None), int, str, int) + assert check_types(slice(None, None, 3.14), int, int, float) + assert check_types(slice(None, None, None), int, int, int) + assert check_types(slice(1j, 'x'), complex, str, int) + assert check_types(slice(None, 'x'), int, str, int) + assert check_types(slice(1j, None), complex, int, int) + assert check_types(slice(None, None), int, int, int) + assert check_types(slice(1j), int, complex, int) + assert check_types(slice(None), int, int, int) + + # eq / ne + assert slice(1, 2, 3) == slice(1, 2, 3) + assert slice(0, 2, 3) != slice(1, 2, 3) + assert slice(1, 0, 3) != slice(1, 2, 3) + assert slice(1, 2, 0) != slice(1, 2, 3) + assert slice(None, None, None) == slice(None, None, None) + assert slice(None, 42, None) == slice(None, 42, None) + assert slice(None, 42, None) != slice(None, 43, None) + assert slice(1, None, 3) == slice(1, None, 3) + assert slice(1, None, 3) != slice(1, None, 0) + assert slice(1, None, 3) != slice(0, None, 3) + assert slice(1) == slice(1) + assert slice(1) != slice(2) test_slice() @test