diff --git a/tests/test_intervals_to_values.py b/tests/test_intervals_to_values.py index 04f4e3d..45b5f9e 100644 --- a/tests/test_intervals_to_values.py +++ b/tests/test_intervals_to_values.py @@ -137,3 +137,78 @@ def test_get_values_from_intervals_batch_of_2() -> None: print(expected) print(values) assert (values == expected).all() + + +def test_get_values_from_intervals_batch_multiple_tracks() -> None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) + query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32) + reserved = cp.zeros([3, 4, 11], dtype=cp.dtype(" None: print(expected) print(values) assert cp.allclose(values, expected) + + +def test_get_values_from_intervals_batch_multiple_tracks() -> None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) + query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32) + reserved = cp.zeros([3, 4, 11], dtype=cp.dtype("