From 53cbe8697c773ceec91fbeeb74eaf58625a9652f Mon Sep 17 00:00:00 2001 From: Joren Retel Date: Wed, 31 Jul 2024 17:35:31 +0200 Subject: [PATCH] DEBUG: intervals_to_values with window_size > 1 was not produced correctly in every single case. Added more combinatoral tests. --- bigwig_loader/intervals_to_values.py | 186 ++++++++---- cuda_kernels/intervals_to_values.cu | 63 ++-- tests/test_intervals_to_values.py | 18 ++ ...est_intervals_to_values_window_function.py | 276 ++++++++++++++++-- 4 files changed, 422 insertions(+), 121 deletions(-) diff --git a/bigwig_loader/intervals_to_values.py b/bigwig_loader/intervals_to_values.py index a8d9c5c..bfd1d57 100644 --- a/bigwig_loader/intervals_to_values.py +++ b/bigwig_loader/intervals_to_values.py @@ -27,7 +27,7 @@ def intervals_to_values( track_values: cp.ndarray, query_starts: cp.ndarray, query_ends: cp.ndarray, - out: cp.ndarray, + out: cp.ndarray | None = None, found_starts: cp.ndarray | None = None, found_ends: cp.ndarray | None = None, sizes: cp.ndarray | None = None, @@ -42,6 +42,9 @@ def intervals_to_values( When none of found_starts, found_ends or sizes are given, it is assumed that there is only one track. + When the sequence length is not a multiple of window_size, the output length will + be sequence_length // window_size, ignoring the last "incomplete" window. + Args: track_starts: array of length sum(sizes) with the start positions of the intervals @@ -58,10 +61,16 @@ def intervals_to_values( out: array of size n_tracks x batch_size x sequence_length """ + if cp.unique(query_ends - query_starts).size != 1: + raise ValueError("All queried intervals should have the same length.") + sequence_length = (query_ends[0] - query_starts[0]).item() + if (found_starts is None or found_ends is None) and sizes is None: - sizes = cp.asarray([len(track_ends)], dtype=track_starts.dtype) + # just one size, which is the length of the entire track_starts/tracks_ends/tracks_values + sizes = cp.asarray([len(track_starts)], dtype=track_starts.dtype) if found_starts is None or found_ends is None: + # n_subarrays x n_queries found_starts = searchsorted( track_ends, queries=query_starts, @@ -76,10 +85,13 @@ def intervals_to_values( side="left", absolute_indices=True, ) - - out *= _zero - - sequence_length = (query_ends[0] - query_starts[0]).item() + if out is None: + out = cp.zeros( + (found_starts.shape[0], len(query_starts), sequence_length // window_size), + dtype=cp.float32, + ) + else: + out *= _zero max_number_intervals = min( sequence_length, (found_ends - found_starts).max().item() @@ -131,8 +143,8 @@ def get_grid_and_block_size(n_threads: int) -> tuple[int, int]: def kernel_in_python_with_window( - grid_size: int, - block_size: int, + grid_size: tuple[int], + block_size: tuple[int], args: tuple[ cp.ndarray, cp.ndarray, @@ -141,6 +153,7 @@ def kernel_in_python_with_window( cp.ndarray, cp.ndarray, cp.ndarray, + cp.ndarray, int, int, int, @@ -158,81 +171,128 @@ def kernel_in_python_with_window( track_starts, track_ends, track_values, + num_tracks, batch_size, sequence_length, max_number_intervals, - _, window_size, + out, ) = args + _grid_size = grid_size[0] + _block_size = block_size[0] + query_starts = query_starts.get().tolist() query_ends = query_ends.get().tolist() - found_starts = found_starts.get().tolist() - found_ends = found_ends.get().tolist() + # flattening this because that's how we get it in cuda + found_starts = found_starts.flatten().get().tolist() + found_ends = found_ends.flatten().get().tolist() + track_starts = track_starts.get().tolist() track_ends = track_ends.get().tolist() track_values = track_values.get().tolist() - n_threads = grid_size * block_size + n_threads = _grid_size * _block_size + + print(n_threads) # this should be integer reduced_dim = sequence_length // window_size + print("sequence_length") + print(sequence_length) + print("reduced_dim") + print(reduced_dim) - out = [0.0] * reduced_dim * batch_size + out_vector = [0.0] * reduced_dim * batch_size * num_tracks for thread in range(n_threads): - i = thread - - if i < batch_size: - found_start_index = found_starts[i] - found_end_index = found_ends[i] - query_start = query_starts[i] - query_end = query_ends[i] - - cursor = found_start_index - window_index = 0 - summation = 0 - - while cursor < found_end_index and window_index < reduced_dim: - window_start = window_index * window_size - window_end = window_start + window_size - - interval_start = track_starts[cursor] - interval_end = track_ends[cursor] + batch_index = thread % batch_size + track_index = (thread // batch_size) % num_tracks + i = thread % (batch_size * num_tracks) + + print("\n\n\n######") + print(f"NEW thread {thread}") + print("batch_index", batch_index) + print("track_index", track_index) + print("i", i) + + # if i < batch_size * num_tracks: + found_start_index = found_starts[i] + found_end_index = found_ends[i] + query_start = query_starts[batch_index] + query_end = query_ends[batch_index] + + cursor = found_start_index + window_index = 0 + summation = 0 + + # cursor moves through the rows of the bigwig file + # window_index moves through the sequence + + while cursor < found_end_index and window_index < reduced_dim: + print("-----") + print("cursor:", cursor) + window_start = window_index * window_size + window_end = window_start + window_size + print(f"working on values in output window {window_start} - {window_end}") + print( + f"Corresponding to the genomic loc {query_start + window_start} - {query_start + window_end}" + ) + + interval_start = track_starts[cursor] + interval_end = track_ends[cursor] + + print("bigwig interval_start", "bigwig interval_end", "bigwig value") + print(interval_start, interval_end, track_values[cursor]) + + start_index = max(interval_start - query_start, 0) + end_index = min(interval_end, query_end) - query_start + print("start index", start_index) + + if start_index >= window_end: + print("CONTINUE") + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0 + window_index += 1 + continue + + number = min(window_end, end_index) - max(window_start, start_index) + + print( + f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation" + ) + summation += number * track_values[cursor] + print(f"Summation = {summation}") + + print("end_index", "window_end") + print(end_index, window_end) + + # calculate average, reset summation and move to next window + if end_index >= window_end or cursor + 1 >= found_end_index: + if end_index >= window_end: + print( + "end_index >= window_end \t\t calculate average, reset summation and move to next window" + ) + else: + print( + "cursor + 1 >= found_end_index \t\t calculate average, reset summation and move to next window" + ) + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0 + window_index += 1 + # move cursor + if end_index < window_end: + print("move cursor") + cursor += 1 + print("current out state:", out_vector) + print( + cp.reshape( + cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim) + ) + ) - start_index = max(interval_start - query_start, 0) - end_index = min(interval_end, query_end) - query_start - - if start_index >= window_end: - window_index += 1 - continue - - number = min(window_end, end_index) - max(window_start, start_index) - - summation += number * track_values[cursor] - print("-----") - print("window_index", "number", "summation") - print(window_index, number, summation) - print("interval_start", "interval_end", "value") - print(interval_start, interval_end, track_values[cursor]) - - print("end_index", "window_end") - print(end_index, window_end) - - # calculate average, reset summation and move to next window - if end_index >= window_end or cursor + 1 >= found_end_index: - print("calculate average, reset summation and move to next window") - out[i * reduced_dim + window_index] = summation / window_size - summation = 0 - window_index += 1 - # move cursor - if end_index < window_end: - print("move cursor") - cursor += 1 - - out = cp.reshape(cp.asarray(out), (batch_size, reduced_dim)) - return out + return cp.reshape(cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim)) def kernel_in_python( diff --git a/cuda_kernels/intervals_to_values.cu b/cuda_kernels/intervals_to_values.cu index 17cd02c..c8173d2 100644 --- a/cuda_kernels/intervals_to_values.cu +++ b/cuda_kernels/intervals_to_values.cu @@ -48,46 +48,49 @@ void intervals_to_values( } } } else { - if (i < batch_size) { - int found_start_index = found_starts[i]; - int found_end_index = found_ends[i]; - int query_start = query_starts[batch_index]; - int query_end = query_ends[batch_index]; - int cursor = found_start_index; - int window_index = 0; - float summation = 0.0f; + int track_index = i / batch_size; - int reduced_dim = sequence_length / window_size; + int found_start_index = found_starts[i]; + int found_end_index = found_ends[i]; + int query_start = query_starts[batch_index]; + int query_end = query_ends[batch_index]; - while (cursor < found_end_index && window_index < reduced_dim) { - int window_start = window_index * window_size; - int window_end = window_start + window_size; + int cursor = found_start_index; + int window_index = 0; + float summation = 0.0f; - int interval_start = track_starts[cursor]; - int interval_end = track_ends[cursor]; + int reduced_dim = sequence_length / window_size; - int start_index = max(interval_start - query_start, 0); - int end_index = min(interval_end, query_end) - query_start; + while (cursor < found_end_index && window_index < reduced_dim) { + int window_start = window_index * window_size; + int window_end = window_start + window_size; - if (start_index >= window_end) { - window_index += 1; - continue; - } + int interval_start = track_starts[cursor]; + int interval_end = track_ends[cursor]; - int number = min(window_end, end_index) - max(window_start, start_index); + int start_index = max(interval_start - query_start, 0); + int end_index = min(interval_end, query_end) - query_start; + + if (start_index >= window_end) { + out[i * reduced_dim + window_index] = summation / window_size; + summation = 0.0f; + window_index += 1; + continue; + } - summation += number * track_values[cursor]; + int number = min(window_end, end_index) - max(window_start, start_index); - if (end_index >= window_end || cursor + 1 >= found_end_index) { - out[i * reduced_dim + window_index] = summation / window_size; - summation = 0.0f; - window_index += 1; - } + summation += number * track_values[cursor]; + + if (end_index >= window_end || cursor + 1 >= found_end_index) { + out[i * reduced_dim + window_index] = summation / window_size; + summation = 0.0f; + window_index += 1; + } - if (end_index < window_end) { - cursor += 1; - } + if (end_index < window_end) { + cursor += 1; } } } diff --git a/tests/test_intervals_to_values.py b/tests/test_intervals_to_values.py index 45b5f9e..f2e70f5 100644 --- a/tests/test_intervals_to_values.py +++ b/tests/test_intervals_to_values.py @@ -1,8 +1,26 @@ import cupy as cp +import pytest from bigwig_loader.intervals_to_values import intervals_to_values +def test_throw_exception_when_queried_intervals_are_of_different_lengths() -> None: + """All query_ends - query_starts should have the same + length. Otherwise, ValueError should be raised. + """ + track_starts = cp.asarray([1, 2, 3], dtype=cp.int32) + track_ends = cp.asarray([2, 3, 4], dtype=cp.int32) + track_values = cp.asarray([1.0, 1.0, 1.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([2, 2], dtype=cp.int32) + query_ends = cp.asarray([4, 5], dtype=cp.int32) + reserved = cp.zeros((2, 2), dtype=cp.float32) + + with pytest.raises(ValueError): + intervals_to_values( + track_starts, track_ends, track_values, query_starts, query_ends, reserved + ) + + def test_get_values_from_intervals() -> None: """Probably most frequent situation.""" track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32) diff --git a/tests/test_intervals_to_values_window_function.py b/tests/test_intervals_to_values_window_function.py index 288b3f8..33fd5e2 100644 --- a/tests/test_intervals_to_values_window_function.py +++ b/tests/test_intervals_to_values_window_function.py @@ -1,4 +1,7 @@ +from itertools import product + import cupy as cp +import pytest from bigwig_loader.intervals_to_values import intervals_to_values @@ -23,7 +26,9 @@ def test_get_values_from_intervals_window() -> None: expected = cp.asarray([[16.0, 21.0, 42.0]]) + print("expected:") print(expected) + print("actual:") print(values) assert (values == expected).all() @@ -191,8 +196,134 @@ def test_get_values_from_intervals_batch_of_2() -> None: assert cp.allclose(values, expected) -def test_get_values_from_intervals_batch_multiple_tracks() -> None: - """Query end is exactly at end index before "gap".""" +def test_one_track_one_sample() -> None: + """ + This tests a specific combination of track and batch index + of the larger test case below: + test_get_values_from_intervals_batch_multiple_tracks + track index = 0 + batch_index = 1 + Included to narrow down source of error in the larger test case. + """ + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([9], dtype=cp.int32) + query_ends = cp.asarray([20], dtype=cp.int32) + reserved = cp.zeros([1, 1, 3], dtype=cp.dtype(" None: + """ + This tests a specific combination of track and batch index + of the larger test case below: + test_get_values_from_intervals_batch_multiple_tracks + track index = 2 + batch_index = 1 + Included to narrow down source of error in the larger test case. + """ + track_starts = cp.asarray([10, 100, 1000], dtype=cp.int32) + track_ends = cp.asarray([20, 200, 2000], dtype=cp.int32) + track_values = cp.asarray( + [110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([9], dtype=cp.int32) + query_ends = cp.asarray([20], dtype=cp.int32) + reserved = cp.zeros([1, 1, 3], dtype=cp.dtype(" None: + """Test interval_to_values with 3 tracks and a batch size of 1.""" track_starts = cp.asarray( [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 ) @@ -203,9 +334,9 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None: [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], dtype=cp.dtype("f4"), ) - query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) - query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32) - reserved = cp.zeros([3, 4, 11], dtype=cp.dtype(" None: expected = cp.asarray( [ [ - [20.0, 20.0, 20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0], [20.0, 30.0, 30.0, 40.0, 40.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [0.0, 60.0, 70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0], [70.0, 80.0, 80.0, 80.0, 80.0, 0.0, 0.0, 0.0, 0.0, 90.0, 90.0], - [90.0, 90.0, 0.0, 0.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [0.0, 0.0, 0.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0, 110.0], [ 0.0, 110.0, @@ -246,20 +370,6 @@ def test_get_values_from_intervals_batch_multiple_tracks() -> None: 110.0, 110.0, ], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [ - 0.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - 120.0, - ], ], ] ) @@ -270,13 +380,123 @@ def apply_window(full_matrix): cp.mean(full_matrix[:, :, :3], axis=2), cp.mean(full_matrix[:, :, 3:6], axis=2), cp.mean(full_matrix[:, :, 6:9], axis=2), - cp.mean(full_matrix[:, :, 9:], axis=2), ], axis=-1, ) expected = apply_window(expected) + print("expected:") print(expected) + print("actual:") print(values) - assert (values == expected).all() + print("difference") + print(values - expected) + assert cp.allclose(values, expected, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "sequence_length, window_size", product([8, 9, 10, 11, 13, 23], [2, 3, 4, 10, 11]) +) +def test_combinations_sequence_length_window_size(sequence_length, window_size) -> None: + """Test intervals_to_values with 3 tracks and a batch size of 4.""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) + query_ends = query_starts + sequence_length + reduced_dim = sequence_length // window_size + + reserved = cp.zeros([3, 4, reduced_dim], dtype=cp.dtype(" None: + """.""" + track_starts = cp.asarray([1, 3, 10, 12, 16] * n_tracks, dtype=cp.int32) + track_ends = cp.asarray([3, 10, 12, 16, 20] * n_tracks, dtype=cp.int32) + track_values = cp.asarray( + [20.0, 15.0, 30.0, 40.0, 50.0] * n_tracks, dtype=cp.dtype("f4") + ) + sizes = cp.asarray([5] * n_tracks, dtype=cp.int32) + sequence_length = 15 + query_starts = cp.asarray([2] * batch_size, dtype=cp.int32) + query_ends = query_starts + sequence_length + values = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=window_size, + ) + + values_with_window_size_1 = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=1, + ) + + reduced_dim = sequence_length // window_size + full_matrix = values_with_window_size_1[:, :, : reduced_dim * window_size] + full_matrix = full_matrix.reshape( + full_matrix.shape[0], full_matrix.shape[1], reduced_dim, window_size + ) + expected = cp.mean(full_matrix, axis=-1) + + print("expected:") + print(expected) + print("actual:") + print(values) + + assert cp.allclose(values, expected)