Skip to content

Commit

Permalink
Fix list optimization for new slices; update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii committed Feb 24, 2024
1 parent 8619eb6 commit 4c490ca
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 30 deletions.
2 changes: 1 addition & 1 deletion codon/cir/transform/pythonic/list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace pythonic {
namespace {

static const std::string LIST = "std.internal.types.ptr.List";
static const std::string SLICE = "std.internal.types.slice.Slice";
static const std::string SLICE = "std.internal.types.slice.Slice[int,int,int]";

bool isList(Value *v) { return v->getType()->getName().rfind(LIST + "[", 0) == 0; }
bool isSlice(Value *v) { return v->getType()->getName() == SLICE; }
Expand Down
29 changes: 10 additions & 19 deletions stdlib/internal/types/slice.codon
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,21 @@ class Slice:
V: type = int) -> Slice[T, U, V]:
return (start, stop, step)

def _ensure_int_or_none(self, method_name: str):
if not (
(T is int or T is Optional[int] or T is Optional[NoneType]) and
(U is int or U is Optional[int] or U is Optional[NoneType]) and
(V is int or V is Optional[int] or V is Optional[NoneType])
):
raise TypeError(f"{method_name} requires all fields to be ints. Fields types are: "
f"start: {T.__name__}, stop: {U.__name__}, step: {V.__name__}")

def adjust_indices(self, length: int) -> Tuple[int, int, int, int]:
self._ensure_int_or_none("slice.adjust_indices")
def has_int_value(v):
return isinstance(v, int) or (isinstance(v, Optional[int]) and v is not None)
step: int = self.step if has_int_value(self.step) else 1
if not (T is int and U is int and V is int):
compile_error("slice indices must be integers or None")

step: int = self.step if self.step is not None else 1
start: int = 0
stop: int = 0
if step == 0:
raise ValueError("slice step cannot be zero")
if step > 0:
start = self.start if has_int_value(self.start) else 0
stop = self.stop if has_int_value(self.stop) else length
start = self.start if self.start is not None else 0
stop = self.stop if self.stop is not None else length
else:
start = self.start if has_int_value(self.start) else length - 1
stop = self.stop if has_int_value(self.stop) else -(length + 1)
start = self.start if self.start is not None else length - 1
stop = self.stop if self.stop is not None else -(length + 1)

return Slice.adjust_indices_helper(length, start, stop, step)

Expand Down Expand Up @@ -89,10 +80,10 @@ class Slice:
def __repr__(self):
return f"slice({self.start}, {self.stop}, {self.step})"

def __eq__(self, other):
def __eq__(self, other: Slice):
return self.start == other.start and self.step == other.step and self.stop == other.stop

def __ne__(self, other):
def __ne__(self, other: Slice):
return not self.__eq__(other)

slice = Slice
53 changes: 43 additions & 10 deletions test/core/containers.codon
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy, deepcopy

@tuple
class A:
a: int
Expand Down Expand Up @@ -600,16 +601,14 @@ def test_dict():
assert repr(Dict[int,int]()) == '{}'
test_dict()

def slice_indices(slc, length):
def slice_indices(slice, length):
"""
Reference implementation for the slice.indices method.
"""
def has_int_value(v):
return isinstance(v, int) or (isinstance(v, Optional[int]) and v is not None)
# Compute step and length as integers.
#length = operator.index(length)
step: int = slc.step if has_int_value(slc.step) else 1
step: int = 1 if slice.step is None else slice.step

# Raise ValueError for negative length or zero step.
if length < 0:
Expand All @@ -622,17 +621,17 @@ def slice_indices(slc, length):
upper = length - 1 if step < 0 else length

# Compute start.
if not has_int_value(slc.start):
if slice.start is None:
start = upper if step < 0 else lower
else:
start = slc.start
start = slice.start
start = max(start + length, lower) if start < 0 else min(start, upper)

# Compute stop.
if not has_int_value(slc.stop):
if slice.stop is None:
stop = lower if step < 0 else upper
else:
stop = slc.stop
stop = slice.stop
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)

return start, stop, step
Expand Down Expand Up @@ -733,7 +732,7 @@ def test_slice():
s = slice(*slice_args)
for length in lengths:
assert check_indices(s, length)
# assert check_indices(slice(0, 10, 1), -3)
assert check_indices(slice(0, 10, 1), -3)

# Negative length should raise ValueError
try:
Expand Down Expand Up @@ -764,7 +763,41 @@ def test_slice():

x = X(tmp)
x[1:2] = 42
assert tmp == [(slice(1, 2, None), 42)]
assert tmp == [(slice(1, 2), 42)]

# Non-int elements
def check_types(s, T: type, U: type, V: type):
return (type(s.start) is Optional[T] and
type(s.stop) is Optional[U] and
type(s.step) is Optional[V])
assert check_types(slice(1j, 'x', 3.14), complex, str, float)
assert check_types(slice(None, 'x', 3.14), int, str, float)
assert check_types(slice(1j, None, 3.14), complex, int, float)
assert check_types(slice(1j, 'x', None), complex, str, int)
assert check_types(slice(1j, None, None), complex, int, int)
assert check_types(slice(None, 'x', None), int, str, int)
assert check_types(slice(None, None, 3.14), int, int, float)
assert check_types(slice(None, None, None), int, int, int)
assert check_types(slice(1j, 'x'), complex, str, int)
assert check_types(slice(None, 'x'), int, str, int)
assert check_types(slice(1j, None), complex, int, int)
assert check_types(slice(None, None), int, int, int)
assert check_types(slice(1j), int, complex, int)
assert check_types(slice(None), int, int, int)

# eq / ne
assert slice(1, 2, 3) == slice(1, 2, 3)
assert slice(0, 2, 3) != slice(1, 2, 3)
assert slice(1, 0, 3) != slice(1, 2, 3)
assert slice(1, 2, 0) != slice(1, 2, 3)
assert slice(None, None, None) == slice(None, None, None)
assert slice(None, 42, None) == slice(None, 42, None)
assert slice(None, 42, None) != slice(None, 43, None)
assert slice(1, None, 3) == slice(1, None, 3)
assert slice(1, None, 3) != slice(1, None, 0)
assert slice(1, None, 3) != slice(0, None, 3)
assert slice(1) == slice(1)
assert slice(1) != slice(2)
test_slice()

@test
Expand Down

0 comments on commit 4c490ca

Please sign in to comment.