Skip to content

Commit

Permalink
added tests intervals_to_values with multiple tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenretel committed Jul 30, 2024
1 parent 338dfbe commit 6e1c6b2
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/test_intervals_to_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,78 @@ def test_get_values_from_intervals_batch_of_2() -> None:
print(expected)
print(values)
assert (values == expected).all()


def test_get_values_from_intervals_batch_multiple_tracks() -> None:
"""Query end is exactly at end index before "gap"."""
track_starts = cp.asarray(
[5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32
)
track_ends = cp.asarray(
[10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32
)
track_values = cp.asarray(
[20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0],
dtype=cp.dtype("f4"),
)
query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32)
query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32)
reserved = cp.zeros([3, 4, 11], dtype=cp.dtype("<f4"))
values = intervals_to_values(
track_starts,
track_ends,
track_values,
query_starts,
query_ends,
reserved,
sizes=cp.asarray([4, 5, 3], dtype=cp.int32),
)
expected = cp.asarray(
[
[
[20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0],
[20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 60.0, 70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0],
[70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0, 90.0, 90.0],
[90.0, 90.0, 0.0, 0.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0],
[
0.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[
0.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
],
],
]
)
print(expected)
print(values)
assert (values == expected).all()
91 changes: 91 additions & 0 deletions tests/test_intervals_to_values_window_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,94 @@ def test_get_values_from_intervals_batch_of_2() -> None:
print(expected)
print(values)
assert cp.allclose(values, expected)


def test_get_values_from_intervals_batch_multiple_tracks() -> None:
"""Query end is exactly at end index before "gap"."""
track_starts = cp.asarray(
[5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32
)
track_ends = cp.asarray(
[10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32
)
track_values = cp.asarray(
[20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0],
dtype=cp.dtype("f4"),
)
query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32)
query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32)
reserved = cp.zeros([3, 4, 11], dtype=cp.dtype("<f4"))
values = intervals_to_values(
track_starts,
track_ends,
track_values,
query_starts,
query_ends,
reserved,
sizes=cp.asarray([4, 5, 3], dtype=cp.int32),
window_size=3,
)

expected = cp.asarray(
[
[
[20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0],
[20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 60.0, 70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0],
[70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0, 90.0, 90.0],
[90.0, 90.0, 0.0, 0.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0],
[
0.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
110.0,
],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[
0.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
120.0,
],
],
]
)

def apply_window(full_matrix):
return cp.stack(
[
cp.mean(full_matrix[:, :, :3], axis=2),
cp.mean(full_matrix[:, :, 3:6], axis=2),
cp.mean(full_matrix[:, :, 6:9], axis=2),
cp.mean(full_matrix[:, :, 9:], axis=2),
],
axis=-1,
)

expected = apply_window(expected)

print(expected)
print(values)
assert (values == expected).all()

0 comments on commit 6e1c6b2

Please sign in to comment.