Skip to content

Commit

Permalink
DEBUG: intervals_to_values with window_size > 1 was not produced corr…
Browse files Browse the repository at this point in the history
…ectly in every single case. Added more combinatoral tests.
  • Loading branch information
jorenretel committed Jul 31, 2024
1 parent 6e1c6b2 commit 53cbe86
Show file tree
Hide file tree
Showing 4 changed files with 422 additions and 121 deletions.
186 changes: 123 additions & 63 deletions bigwig_loader/intervals_to_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def intervals_to_values(
track_values: cp.ndarray,
query_starts: cp.ndarray,
query_ends: cp.ndarray,
out: cp.ndarray,
out: cp.ndarray | None = None,
found_starts: cp.ndarray | None = None,
found_ends: cp.ndarray | None = None,
sizes: cp.ndarray | None = None,
Expand All @@ -42,6 +42,9 @@ def intervals_to_values(
When none of found_starts, found_ends or sizes are given, it is assumed that there is only
one track.
When the sequence length is not a multiple of window_size, the output length will
be sequence_length // window_size, ignoring the last "incomplete" window.
Args:
track_starts: array of length sum(sizes) with the start positions of the intervals
Expand All @@ -58,10 +61,16 @@ def intervals_to_values(
out: array of size n_tracks x batch_size x sequence_length
"""
if cp.unique(query_ends - query_starts).size != 1:
raise ValueError("All queried intervals should have the same length.")
sequence_length = (query_ends[0] - query_starts[0]).item()

if (found_starts is None or found_ends is None) and sizes is None:
sizes = cp.asarray([len(track_ends)], dtype=track_starts.dtype)
# just one size, which is the length of the entire track_starts/tracks_ends/tracks_values
sizes = cp.asarray([len(track_starts)], dtype=track_starts.dtype)

if found_starts is None or found_ends is None:
# n_subarrays x n_queries
found_starts = searchsorted(
track_ends,
queries=query_starts,
Expand All @@ -76,10 +85,13 @@ def intervals_to_values(
side="left",
absolute_indices=True,
)

out *= _zero

sequence_length = (query_ends[0] - query_starts[0]).item()
if out is None:
out = cp.zeros(
(found_starts.shape[0], len(query_starts), sequence_length // window_size),
dtype=cp.float32,
)
else:
out *= _zero

max_number_intervals = min(
sequence_length, (found_ends - found_starts).max().item()
Expand Down Expand Up @@ -131,8 +143,8 @@ def get_grid_and_block_size(n_threads: int) -> tuple[int, int]:


def kernel_in_python_with_window(
grid_size: int,
block_size: int,
grid_size: tuple[int],
block_size: tuple[int],
args: tuple[
cp.ndarray,
cp.ndarray,
Expand All @@ -141,6 +153,7 @@ def kernel_in_python_with_window(
cp.ndarray,
cp.ndarray,
cp.ndarray,
cp.ndarray,
int,
int,
int,
Expand All @@ -158,81 +171,128 @@ def kernel_in_python_with_window(
track_starts,
track_ends,
track_values,
num_tracks,
batch_size,
sequence_length,
max_number_intervals,
_,
window_size,
out,
) = args

_grid_size = grid_size[0]
_block_size = block_size[0]

query_starts = query_starts.get().tolist()
query_ends = query_ends.get().tolist()

found_starts = found_starts.get().tolist()
found_ends = found_ends.get().tolist()
# flattening this because that's how we get it in cuda
found_starts = found_starts.flatten().get().tolist()
found_ends = found_ends.flatten().get().tolist()

track_starts = track_starts.get().tolist()
track_ends = track_ends.get().tolist()
track_values = track_values.get().tolist()

n_threads = grid_size * block_size
n_threads = _grid_size * _block_size

print(n_threads)

# this should be integer
reduced_dim = sequence_length // window_size
print("sequence_length")
print(sequence_length)
print("reduced_dim")
print(reduced_dim)

out = [0.0] * reduced_dim * batch_size
out_vector = [0.0] * reduced_dim * batch_size * num_tracks

for thread in range(n_threads):
i = thread

if i < batch_size:
found_start_index = found_starts[i]
found_end_index = found_ends[i]
query_start = query_starts[i]
query_end = query_ends[i]

cursor = found_start_index
window_index = 0
summation = 0

while cursor < found_end_index and window_index < reduced_dim:
window_start = window_index * window_size
window_end = window_start + window_size

interval_start = track_starts[cursor]
interval_end = track_ends[cursor]
batch_index = thread % batch_size
track_index = (thread // batch_size) % num_tracks
i = thread % (batch_size * num_tracks)

print("\n\n\n######")
print(f"NEW thread {thread}")
print("batch_index", batch_index)
print("track_index", track_index)
print("i", i)

# if i < batch_size * num_tracks:
found_start_index = found_starts[i]
found_end_index = found_ends[i]
query_start = query_starts[batch_index]
query_end = query_ends[batch_index]

cursor = found_start_index
window_index = 0
summation = 0

# cursor moves through the rows of the bigwig file
# window_index moves through the sequence

while cursor < found_end_index and window_index < reduced_dim:
print("-----")
print("cursor:", cursor)
window_start = window_index * window_size
window_end = window_start + window_size
print(f"working on values in output window {window_start} - {window_end}")
print(
f"Corresponding to the genomic loc {query_start + window_start} - {query_start + window_end}"
)

interval_start = track_starts[cursor]
interval_end = track_ends[cursor]

print("bigwig interval_start", "bigwig interval_end", "bigwig value")
print(interval_start, interval_end, track_values[cursor])

start_index = max(interval_start - query_start, 0)
end_index = min(interval_end, query_end) - query_start
print("start index", start_index)

if start_index >= window_end:
print("CONTINUE")
out_vector[i * reduced_dim + window_index] = summation / window_size
summation = 0
window_index += 1
continue

number = min(window_end, end_index) - max(window_start, start_index)

print(
f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation"
)
summation += number * track_values[cursor]
print(f"Summation = {summation}")

print("end_index", "window_end")
print(end_index, window_end)

# calculate average, reset summation and move to next window
if end_index >= window_end or cursor + 1 >= found_end_index:
if end_index >= window_end:
print(
"end_index >= window_end \t\t calculate average, reset summation and move to next window"
)
else:
print(
"cursor + 1 >= found_end_index \t\t calculate average, reset summation and move to next window"
)
out_vector[i * reduced_dim + window_index] = summation / window_size
summation = 0
window_index += 1
# move cursor
if end_index < window_end:
print("move cursor")
cursor += 1
print("current out state:", out_vector)
print(
cp.reshape(
cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim)
)
)

start_index = max(interval_start - query_start, 0)
end_index = min(interval_end, query_end) - query_start

if start_index >= window_end:
window_index += 1
continue

number = min(window_end, end_index) - max(window_start, start_index)

summation += number * track_values[cursor]
print("-----")
print("window_index", "number", "summation")
print(window_index, number, summation)
print("interval_start", "interval_end", "value")
print(interval_start, interval_end, track_values[cursor])

print("end_index", "window_end")
print(end_index, window_end)

# calculate average, reset summation and move to next window
if end_index >= window_end or cursor + 1 >= found_end_index:
print("calculate average, reset summation and move to next window")
out[i * reduced_dim + window_index] = summation / window_size
summation = 0
window_index += 1
# move cursor
if end_index < window_end:
print("move cursor")
cursor += 1

out = cp.reshape(cp.asarray(out), (batch_size, reduced_dim))
return out
return cp.reshape(cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim))


def kernel_in_python(
Expand Down
63 changes: 33 additions & 30 deletions cuda_kernels/intervals_to_values.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,46 +48,49 @@ void intervals_to_values(
}
}
} else {
if (i < batch_size) {
int found_start_index = found_starts[i];
int found_end_index = found_ends[i];
int query_start = query_starts[batch_index];
int query_end = query_ends[batch_index];

int cursor = found_start_index;
int window_index = 0;
float summation = 0.0f;
int track_index = i / batch_size;

int reduced_dim = sequence_length / window_size;
int found_start_index = found_starts[i];
int found_end_index = found_ends[i];
int query_start = query_starts[batch_index];
int query_end = query_ends[batch_index];

while (cursor < found_end_index && window_index < reduced_dim) {
int window_start = window_index * window_size;
int window_end = window_start + window_size;
int cursor = found_start_index;
int window_index = 0;
float summation = 0.0f;

int interval_start = track_starts[cursor];
int interval_end = track_ends[cursor];
int reduced_dim = sequence_length / window_size;

int start_index = max(interval_start - query_start, 0);
int end_index = min(interval_end, query_end) - query_start;
while (cursor < found_end_index && window_index < reduced_dim) {
int window_start = window_index * window_size;
int window_end = window_start + window_size;

if (start_index >= window_end) {
window_index += 1;
continue;
}
int interval_start = track_starts[cursor];
int interval_end = track_ends[cursor];

int number = min(window_end, end_index) - max(window_start, start_index);
int start_index = max(interval_start - query_start, 0);
int end_index = min(interval_end, query_end) - query_start;

if (start_index >= window_end) {
out[i * reduced_dim + window_index] = summation / window_size;
summation = 0.0f;
window_index += 1;
continue;
}

summation += number * track_values[cursor];
int number = min(window_end, end_index) - max(window_start, start_index);

if (end_index >= window_end || cursor + 1 >= found_end_index) {
out[i * reduced_dim + window_index] = summation / window_size;
summation = 0.0f;
window_index += 1;
}
summation += number * track_values[cursor];

if (end_index >= window_end || cursor + 1 >= found_end_index) {
out[i * reduced_dim + window_index] = summation / window_size;
summation = 0.0f;
window_index += 1;
}

if (end_index < window_end) {
cursor += 1;
}
if (end_index < window_end) {
cursor += 1;
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_intervals_to_values.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import cupy as cp
import pytest

from bigwig_loader.intervals_to_values import intervals_to_values


def test_throw_exception_when_queried_intervals_are_of_different_lengths() -> None:
"""All query_ends - query_starts should have the same
length. Otherwise, ValueError should be raised.
"""
track_starts = cp.asarray([1, 2, 3], dtype=cp.int32)
track_ends = cp.asarray([2, 3, 4], dtype=cp.int32)
track_values = cp.asarray([1.0, 1.0, 1.0], dtype=cp.dtype("f4"))
query_starts = cp.asarray([2, 2], dtype=cp.int32)
query_ends = cp.asarray([4, 5], dtype=cp.int32)
reserved = cp.zeros((2, 2), dtype=cp.float32)

with pytest.raises(ValueError):
intervals_to_values(
track_starts, track_ends, track_values, query_starts, query_ends, reserved
)


def test_get_values_from_intervals() -> None:
"""Probably most frequent situation."""
track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32)
Expand Down
Loading

0 comments on commit 53cbe86

Please sign in to comment.