diff --git a/bigwig_loader/bigwig.py b/bigwig_loader/bigwig.py index 8bea5f0..cd9b321 100644 --- a/bigwig_loader/bigwig.py +++ b/bigwig_loader/bigwig.py @@ -14,7 +14,7 @@ import pandas as pd from ncls import NCLS -from bigwig_loader.gpu_decompressor import Decoder +from bigwig_loader.decompressor import Decoder from bigwig_loader.memory_bank import MemoryBank from bigwig_loader.memory_bank import create_memory_bank from bigwig_loader.merge_intervals import merge_intervals diff --git a/bigwig_loader/collection.py b/bigwig_loader/collection.py index f11f02d..bd01c81 100644 --- a/bigwig_loader/collection.py +++ b/bigwig_loader/collection.py @@ -13,13 +13,14 @@ import pandas as pd from bigwig_loader.bigwig import BigWig -from bigwig_loader.gpu_decompressor import Decoder -from bigwig_loader.intervals_to_values_gpu import intervals_to_values +from bigwig_loader.decompressor import Decoder +from bigwig_loader.intervals_to_values import intervals_to_values from bigwig_loader.memory_bank import MemoryBank from bigwig_loader.memory_bank import create_memory_bank from bigwig_loader.merge_intervals import merge_interval_dataframe from bigwig_loader.path import interpret_path from bigwig_loader.path import map_path_to_value +from bigwig_loader.searchsorted import searchsorted from bigwig_loader.subtract_intervals import subtract_interval_dataframe from bigwig_loader.util import chromosome_sort @@ -198,24 +199,36 @@ def get_batch( f"Cupy default memory pool:{ cp.get_default_memory_pool().used_bytes() / 1024} kB" ) - chunk_row_numbers = cp.pad(cp.cumsum(n_rows_for_chunks), (1, 0)) - i = 0 - for n_chunks, partial_out in zip(n_chunks_per_bigwig, out): - bigwig_end = i + n_chunks - - row_number_start = chunk_row_numbers[i] - row_number_end = chunk_row_numbers[bigwig_end] - - intervals_to_values( - track_starts=start[row_number_start:row_number_end], - track_ends=end[row_number_start:row_number_end], - track_values=value[row_number_start:row_number_end], - query_starts=cp.asarray(abs_start, dtype=cp.uint32), - query_ends=cp.asarray(abs_end, dtype=cp.uint32), - window_size=window_size, - out=partial_out, - ) - i = bigwig_end + bigwig_starts = cp.pad( + cp.cumsum(cp.asarray(n_chunks_per_bigwig, dtype=cp.uint32)), (1, 0) + ) + chunk_starts = cp.pad(cp.cumsum(n_rows_for_chunks), (1, 0)) + bigwig_starts = chunk_starts[bigwig_starts] + sizes = bigwig_starts[1:] - bigwig_starts[:-1] + + sizes = sizes.astype(cp.uint32) + abs_end = cp.asarray(abs_end, dtype=cp.uint32) + abs_start = cp.asarray(abs_start, dtype=cp.uint32) + + # n_tracks x n_queries + found_starts = searchsorted( + end, queries=abs_start, sizes=sizes, side="right", absolute_indices=True + ) + found_ends = searchsorted( + start, queries=abs_end, sizes=sizes, side="left", absolute_indices=True + ) + + intervals_to_values( + track_starts=start, + track_ends=end, + track_values=value, + found_starts=found_starts, + found_ends=found_ends, + query_starts=abs_start, + query_ends=abs_end, + window_size=window_size, + out=out, + ) batch = cp.transpose(out, (1, 0, 2)) batch *= self.scaling_factors_cupy return batch diff --git a/bigwig_loader/gpu_decompressor.py b/bigwig_loader/decompressor.py similarity index 100% rename from bigwig_loader/gpu_decompressor.py rename to bigwig_loader/decompressor.py diff --git a/bigwig_loader/intervals_to_values.py b/bigwig_loader/intervals_to_values.py new file mode 100644 index 0000000..bfd1d57 --- /dev/null +++ b/bigwig_loader/intervals_to_values.py @@ -0,0 +1,377 @@ +import logging +import math +from pathlib import Path + +import cupy as cp + +from bigwig_loader.searchsorted import searchsorted + +CUDA_KERNEL_DIR = Path(__file__).parent.parent / "cuda_kernels" + +_zero = cp.asarray(0.0, dtype=cp.float32).item() + + +def get_cuda_kernel() -> str: + with open(CUDA_KERNEL_DIR / "intervals_to_values.cu") as f: + kernel_code = f.read() + return kernel_code + + +cuda_kernel = cp.RawKernel(get_cuda_kernel(), "intervals_to_values") +cuda_kernel.compile() + + +def intervals_to_values( + track_starts: cp.ndarray, + track_ends: cp.ndarray, + track_values: cp.ndarray, + query_starts: cp.ndarray, + query_ends: cp.ndarray, + out: cp.ndarray | None = None, + found_starts: cp.ndarray | None = None, + found_ends: cp.ndarray | None = None, + sizes: cp.ndarray | None = None, + window_size: int = 1, +) -> cp.ndarray: + """ + This function converts intervals to values. It can do this for multiple tracks at once. + When multiple tracks are given, track_starts, track_ends and track_values are expected + to be concatenated arrays of the individual tracks. The sizes array is used to indicate + where the individual tracks start and end. + + When none of found_starts, found_ends or sizes are given, it is assumed that there is only + one track. + + When the sequence length is not a multiple of window_size, the output length will + be sequence_length // window_size, ignoring the last "incomplete" window. + + + Args: + track_starts: array of length sum(sizes) with the start positions of the intervals + track_ends: array of length sum(sizes) with the end positions of the intervals + track_values: array of length sum(sizes) with the value for those intervals + query_starts: array of length batch_size with the (genomic) start positions of each batch element + query_ends: array of length batch_size with the (genomic) end positions of each batch element + out: array of size n_tracks x batch_size x sequence_length to store the output + found_starts: result of searchsorted (if precalculated). Indices into track_starts. + found_ends: result of searchsorted (if precalculated). Indices into track_ends. + sizes: number of elements in track_starts/track_ends/track_values for each track + window_size: size in basepairs to average over (default: 1) + Returns: + out: array of size n_tracks x batch_size x sequence_length + + """ + if cp.unique(query_ends - query_starts).size != 1: + raise ValueError("All queried intervals should have the same length.") + sequence_length = (query_ends[0] - query_starts[0]).item() + + if (found_starts is None or found_ends is None) and sizes is None: + # just one size, which is the length of the entire track_starts/tracks_ends/tracks_values + sizes = cp.asarray([len(track_starts)], dtype=track_starts.dtype) + + if found_starts is None or found_ends is None: + # n_subarrays x n_queries + found_starts = searchsorted( + track_ends, + queries=query_starts, + sizes=sizes, + side="right", + absolute_indices=True, + ) + found_ends = searchsorted( + track_starts, + queries=query_ends, + sizes=sizes, + side="left", + absolute_indices=True, + ) + if out is None: + out = cp.zeros( + (found_starts.shape[0], len(query_starts), sequence_length // window_size), + dtype=cp.float32, + ) + else: + out *= _zero + + max_number_intervals = min( + sequence_length, (found_ends - found_starts).max().item() + ) + batch_size = query_starts.shape[0] + num_tracks = found_starts.shape[0] + + if window_size == 1: + n_threads_needed = batch_size * max_number_intervals * num_tracks + grid_size, block_size = get_grid_and_block_size(n_threads_needed) + else: + n_threads_needed = batch_size * num_tracks + grid_size, block_size = get_grid_and_block_size(n_threads_needed) + + logging.debug( + f"batch_size: {batch_size}\nmax_number_intervals: {max_number_intervals}\ngrid_size: {grid_size}\nblock_size: {block_size}" + ) + + cuda_kernel( + (grid_size,), + (block_size,), + ( + query_starts, + query_ends, + found_starts, + found_ends, + track_starts, + track_ends, + track_values, + num_tracks, + batch_size, + sequence_length, + max_number_intervals, + window_size, + out, + ), + ) + + return out + + +def get_grid_and_block_size(n_threads: int) -> tuple[int, int]: + n_blocks_needed = math.ceil(n_threads / 512) + if n_blocks_needed == 1: + threads_per_block = n_threads + else: + threads_per_block = 512 + return n_blocks_needed, threads_per_block + + +def kernel_in_python_with_window( + grid_size: tuple[int], + block_size: tuple[int], + args: tuple[ + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + int, + int, + int, + cp.ndarray, + int, + ], +) -> cp.ndarray: + """Equivalent in python to cuda_kernel_with_window. Just for debugging.""" + + ( + query_starts, + query_ends, + found_starts, + found_ends, + track_starts, + track_ends, + track_values, + num_tracks, + batch_size, + sequence_length, + max_number_intervals, + window_size, + out, + ) = args + + _grid_size = grid_size[0] + _block_size = block_size[0] + + query_starts = query_starts.get().tolist() + query_ends = query_ends.get().tolist() + + # flattening this because that's how we get it in cuda + found_starts = found_starts.flatten().get().tolist() + found_ends = found_ends.flatten().get().tolist() + + track_starts = track_starts.get().tolist() + track_ends = track_ends.get().tolist() + track_values = track_values.get().tolist() + + n_threads = _grid_size * _block_size + + print(n_threads) + + # this should be integer + reduced_dim = sequence_length // window_size + print("sequence_length") + print(sequence_length) + print("reduced_dim") + print(reduced_dim) + + out_vector = [0.0] * reduced_dim * batch_size * num_tracks + + for thread in range(n_threads): + batch_index = thread % batch_size + track_index = (thread // batch_size) % num_tracks + i = thread % (batch_size * num_tracks) + + print("\n\n\n######") + print(f"NEW thread {thread}") + print("batch_index", batch_index) + print("track_index", track_index) + print("i", i) + + # if i < batch_size * num_tracks: + found_start_index = found_starts[i] + found_end_index = found_ends[i] + query_start = query_starts[batch_index] + query_end = query_ends[batch_index] + + cursor = found_start_index + window_index = 0 + summation = 0 + + # cursor moves through the rows of the bigwig file + # window_index moves through the sequence + + while cursor < found_end_index and window_index < reduced_dim: + print("-----") + print("cursor:", cursor) + window_start = window_index * window_size + window_end = window_start + window_size + print(f"working on values in output window {window_start} - {window_end}") + print( + f"Corresponding to the genomic loc {query_start + window_start} - {query_start + window_end}" + ) + + interval_start = track_starts[cursor] + interval_end = track_ends[cursor] + + print("bigwig interval_start", "bigwig interval_end", "bigwig value") + print(interval_start, interval_end, track_values[cursor]) + + start_index = max(interval_start - query_start, 0) + end_index = min(interval_end, query_end) - query_start + print("start index", start_index) + + if start_index >= window_end: + print("CONTINUE") + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0 + window_index += 1 + continue + + number = min(window_end, end_index) - max(window_start, start_index) + + print( + f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation" + ) + summation += number * track_values[cursor] + print(f"Summation = {summation}") + + print("end_index", "window_end") + print(end_index, window_end) + + # calculate average, reset summation and move to next window + if end_index >= window_end or cursor + 1 >= found_end_index: + if end_index >= window_end: + print( + "end_index >= window_end \t\t calculate average, reset summation and move to next window" + ) + else: + print( + "cursor + 1 >= found_end_index \t\t calculate average, reset summation and move to next window" + ) + out_vector[i * reduced_dim + window_index] = summation / window_size + summation = 0 + window_index += 1 + # move cursor + if end_index < window_end: + print("move cursor") + cursor += 1 + print("current out state:", out_vector) + print( + cp.reshape( + cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim) + ) + ) + + return cp.reshape(cp.asarray(out_vector), (num_tracks, batch_size, reduced_dim)) + + +def kernel_in_python( + grid_size: int, + block_size: int, + args: tuple[ + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + int, + int, + int, + cp.ndarray, + int, + ], +) -> cp.ndarray: + """Equivalent in python to cuda_kernel. Just for debugging.""" + + ( + query_starts, + query_ends, + found_starts, + found_ends, + track_starts, + track_ends, + track_values, + batch_size, + sequence_length, + max_number_intervals, + _, + window_size, + ) = args + + query_starts = query_starts.get().tolist() + query_ends = query_ends.get().tolist() + + found_starts = found_starts.get().tolist() + found_ends = found_ends.get().tolist() + track_starts = track_starts.get().tolist() + track_ends = track_ends.get().tolist() + track_values = track_values.get().tolist() + + n_threads = grid_size * block_size + + out = [0.0] * sequence_length * batch_size + + for thread in range(n_threads): + i = thread % batch_size + j = (thread // batch_size) % max_number_intervals + # k = thread // (batch_size * max_number_intervals) + # print("---") + # print(i, j) + + if i < batch_size: + found_start_index = found_starts[i] + found_end_index = found_ends[i] + query_start = query_starts[i] + query_end = query_ends[i] + + cursor = found_start_index + j + # print("cursor", cursor) + + if cursor < found_end_index: + interval_start = track_starts[cursor] + interval_end = track_ends[cursor] + start_index = max(interval_start - query_start, 0) + end_index = ( + (i * sequence_length) + min(interval_end, query_end) - query_start + ) + start_position = (i * sequence_length) + start_index + for position in range(start_position, end_index): + print("position", position, track_values[cursor]) + out[position] = track_values[cursor] + # print(out) + + # print(out) + out = cp.reshape(cp.asarray(out), (batch_size, sequence_length)) + return out diff --git a/bigwig_loader/intervals_to_values_gpu.py b/bigwig_loader/intervals_to_values_gpu.py deleted file mode 100644 index e62e832..0000000 --- a/bigwig_loader/intervals_to_values_gpu.py +++ /dev/null @@ -1,399 +0,0 @@ -import logging -import math - -import cupy as cp - -ROUTE_KERNELS = True - -_zero = cp.asarray(0.0, dtype=cp.float32).item() - -_cuda_kernel = """ -extern "C" __global__ -void intervals_to_values( - const int* query_starts, - const int* query_ends, - const int* found_starts, - const int* found_ends, - const int* track_starts, - const int* track_ends, - const float* track_values, - const int batch_size, - const int sequence_length, - const int max_number_intervals, - float* out -) { - - int thread = blockIdx.x * blockDim.x + threadIdx.x; - - int i = thread % batch_size; - int j = (thread / batch_size)%max_number_intervals; - - int found_start_index = found_starts[i]; - int found_end_index = found_ends[i]; - int query_start = query_starts[i]; - int query_end = query_ends[i]; - - int cursor = found_start_index + j; - - if (cursor < found_end_index){ - int interval_start = track_starts[cursor]; - int interval_end = track_ends[cursor]; - int start_index = max(interval_start - query_start, 0); - int end_index = (i * sequence_length) + min(interval_end, query_end) - query_start; - int start_position = (i * sequence_length) + start_index; - - float value = track_values[cursor]; - - for (int position = start_position; position < end_index; position++){ - out[position] = value; - } - } -} -""" - -_cuda_kernel_with_window = """ -extern "C" __global__ -void intervals_to_values( - const int* query_starts, - const int* query_ends, - const int* found_starts, - const int* found_ends, - const int* track_starts, - const int* track_ends, - const float* track_values, - const int batch_size, - const int sequence_length, - const int max_number_intervals, - const int window_size, - float* out -) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - - if (i < batch_size) { - int found_start_index = found_starts[i]; - int found_end_index = found_ends[i]; - int query_start = query_starts[i]; - int query_end = query_ends[i]; - - int cursor = found_start_index; - int window_index = 0; - float summation = 0.0f; - - int reduced_dim = sequence_length / window_size; - - while (cursor < found_end_index && window_index < reduced_dim) { - int window_start = window_index * window_size; - int window_end = window_start + window_size; - - int interval_start = track_starts[cursor]; - int interval_end = track_ends[cursor]; - - int start_index = max(interval_start - query_start, 0); - int end_index = min(interval_end, query_end) - query_start; - - if (start_index >= window_end) { - window_index += 1; - continue; - } - - int number = min(window_end, end_index) - max(window_start, start_index); - - summation += number * track_values[cursor]; - - if (end_index >= window_end || cursor + 1 >= found_end_index) { - out[i * reduced_dim + window_index] = summation / window_size; - summation = 0.0f; - window_index += 1; - } - - if (end_index < window_end) { - cursor += 1; - } - } - } -} -""" - -cuda_kernel = cp.RawKernel(_cuda_kernel, "intervals_to_values") -cuda_kernel.compile() - -cuda_kernel_with_window = cp.RawKernel(_cuda_kernel_with_window, "intervals_to_values") -cuda_kernel_with_window.compile() - - -def intervals_to_values( - track_starts: cp.ndarray, - track_ends: cp.ndarray, - track_values: cp.ndarray, - query_starts: cp.ndarray, - query_ends: cp.ndarray, - out: cp.ndarray, - window_size: int = 1, -) -> cp.ndarray: - out *= _zero - found_starts = cp.searchsorted(track_ends, query_starts, side="right").astype( - dtype=cp.int32 - ) - found_ends = cp.searchsorted(track_starts, query_ends, side="left").astype( - dtype=cp.int32 - ) - - sequence_length = (query_ends[0] - query_starts[0]).item() - - max_number_intervals = min( - sequence_length, (found_ends - found_starts).max().item() - ) - batch_size = query_starts.shape[0] - - if ROUTE_KERNELS and window_size == 1: - n_threads_needed = batch_size * max_number_intervals - grid_size, block_size = get_grid_and_block_size(n_threads_needed) - - logging.debug( - f"batch_size: {batch_size}\nmax_number_intervals: {max_number_intervals}\ngrid_size: {grid_size}\nblock_size: {block_size}" - ) - - cuda_kernel( - (grid_size,), - (block_size,), - ( - query_starts, - query_ends, - found_starts, - found_ends, - track_starts, - track_ends, - track_values, - batch_size, - sequence_length, - max_number_intervals, - out, - ), - ) - - return out - - else: - n_threads_needed = batch_size - grid_size, block_size = get_grid_and_block_size(n_threads_needed) - - logging.debug( - f"batch_size: {batch_size}\nmax_number_intervals: {max_number_intervals}\ngrid_size: {grid_size}\nblock_size: {block_size}" - ) - - cuda_kernel_with_window( - (grid_size,), - (block_size,), - ( - query_starts, - query_ends, - found_starts, - found_ends, - track_starts, - track_ends, - track_values, - batch_size, - sequence_length, - max_number_intervals, - window_size, - out, - ), - ) - - return out - - -def get_grid_and_block_size(n_threads: int) -> tuple[int, int]: - n_blocks_needed = math.ceil(n_threads / 512) - if n_blocks_needed == 1: - threads_per_block = n_threads - else: - threads_per_block = 512 - return n_blocks_needed, threads_per_block - - -def kernel_in_python_with_window( - grid_size: int, - block_size: int, - args: tuple[ - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - int, - int, - int, - cp.ndarray, - int, - ], -) -> cp.ndarray: - """Equivalent in python to cuda_kernel_with_window. Just for debugging.""" - - ( - query_starts, - query_ends, - found_starts, - found_ends, - track_starts, - track_ends, - track_values, - batch_size, - sequence_length, - max_number_intervals, - _, - window_size, - ) = args - - query_starts = query_starts.get().tolist() - query_ends = query_ends.get().tolist() - - found_starts = found_starts.get().tolist() - found_ends = found_ends.get().tolist() - track_starts = track_starts.get().tolist() - track_ends = track_ends.get().tolist() - track_values = track_values.get().tolist() - - n_threads = grid_size * block_size - - # this should be integer - reduced_dim = sequence_length // window_size - - out = [0.0] * reduced_dim * batch_size - - for thread in range(n_threads): - i = thread - - if i < batch_size: - found_start_index = found_starts[i] - found_end_index = found_ends[i] - query_start = query_starts[i] - query_end = query_ends[i] - - cursor = found_start_index - window_index = 0 - summation = 0 - - while cursor < found_end_index and window_index < reduced_dim: - window_start = window_index * window_size - window_end = window_start + window_size - - interval_start = track_starts[cursor] - interval_end = track_ends[cursor] - - start_index = max(interval_start - query_start, 0) - end_index = min(interval_end, query_end) - query_start - - if start_index >= window_end: - window_index += 1 - continue - - number = min(window_end, end_index) - max(window_start, start_index) - - summation += number * track_values[cursor] - print("-----") - print("window_index", "number", "summation") - print(window_index, number, summation) - print("interval_start", "interval_end", "value") - print(interval_start, interval_end, track_values[cursor]) - - print("end_index", "window_end") - print(end_index, window_end) - - # calculate average, reset summation and move to next window - if end_index >= window_end or cursor + 1 >= found_end_index: - print("calculate average, reset summation and move to next window") - out[i * reduced_dim + window_index] = summation / window_size - summation = 0 - window_index += 1 - # move cursor - if end_index < window_end: - print("move cursor") - cursor += 1 - - out = cp.reshape(cp.asarray(out), (batch_size, reduced_dim)) - return out - - -def kernel_in_python( - grid_size: int, - block_size: int, - args: tuple[ - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - cp.ndarray, - int, - int, - int, - cp.ndarray, - int, - ], -) -> cp.ndarray: - """Equivalent in python to cuda_kernel. Just for debugging.""" - - ( - query_starts, - query_ends, - found_starts, - found_ends, - track_starts, - track_ends, - track_values, - batch_size, - sequence_length, - max_number_intervals, - _, - window_size, - ) = args - - query_starts = query_starts.get().tolist() - query_ends = query_ends.get().tolist() - - found_starts = found_starts.get().tolist() - found_ends = found_ends.get().tolist() - track_starts = track_starts.get().tolist() - track_ends = track_ends.get().tolist() - track_values = track_values.get().tolist() - - n_threads = grid_size * block_size - - out = [0.0] * sequence_length * batch_size - - for thread in range(n_threads): - i = thread % batch_size - j = (thread // batch_size) % max_number_intervals - # k = thread // (batch_size * max_number_intervals) - # print("---") - # print(i, j) - - if i < batch_size: - found_start_index = found_starts[i] - found_end_index = found_ends[i] - query_start = query_starts[i] - query_end = query_ends[i] - - cursor = found_start_index + j - # print("cursor", cursor) - - if cursor < found_end_index: - interval_start = track_starts[cursor] - interval_end = track_ends[cursor] - start_index = max(interval_start - query_start, 0) - end_index = ( - (i * sequence_length) + min(interval_end, query_end) - query_start - ) - start_position = (i * sequence_length) + start_index - for position in range(start_position, end_index): - print("position", position, track_values[cursor]) - out[position] = track_values[cursor] - # print(out) - - # print(out) - out = cp.reshape(cp.asarray(out), (batch_size, sequence_length)) - return out diff --git a/bigwig_loader/searchsorted.py b/bigwig_loader/searchsorted.py new file mode 100644 index 0000000..03518ac --- /dev/null +++ b/bigwig_loader/searchsorted.py @@ -0,0 +1,84 @@ +from typing import Literal + +import cupy as cp +from cupy import _core +from cupy._sorting.search import _searchsorted_code + +_preamble = """ +template +__device__ bool _isnan(T val) { + return val != val; +} +""" + +_hip_preamble = r""" +#ifdef __HIP_DEVICE_COMPILE__ + #define no_thread_divergence(do_work, to_return) \ + if (!is_done) { \ + do_work; \ + is_done = true; \ + } +#else + #define no_thread_divergence(do_work, to_return) \ + do_work; \ + if (to_return) { return; } +#endif +""" + + +_searchsorted_kernel = _core.ElementwiseKernel( + "S x, S index, raw uint32 starts, raw T sizes, raw T all_bins, bool side_is_right, " + "bool assume_increasing", + "uint32 y", + """ + int start = starts[index]; + int n_bins = sizes[index]; + const T* bins = &all_bins[start]; + + """ + + _searchsorted_code, + name="cupy_searchsorted_kernel", + preamble=_preamble + _hip_preamble, +) + + +def searchsorted( + array: cp.ndarray, + queries: cp.ndarray, + sizes: cp.ndarray, + side: Literal["left", "right"] = "left", + absolute_indices: bool = True, +) -> cp.ndarray: + """ + This is a version of search sorted does the searchsorted operation on + multiple subarrays at once (for the same queries to find the insertion + points for). Where each subarray starts is indicated by start_indices. + + Args: + array: 1D Input array. Is expected to consist of subarrays that are + sorted in ascending order. + queries: Values to find the insertion indices for in subarrays of array. + start_indices: Indices of the starts of the subarrays in array. + sizes: Sizes of the subarrays. + side: If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. + absolute_indices: whether to give the indices with respect to the entire + array (True) or for the subarrays (False). + Returns: + And array of size n_subarrays x n_queries with insertion indices. + + """ + + start_indices = cp.pad(cp.cumsum(sizes, dtype=cp.uint32), (1, 0))[:-1] + n_subarrays = len(sizes) + n_queries = len(queries) + idx = cp.arange(n_subarrays, dtype=queries.dtype)[:, cp.newaxis] + queries = queries[cp.newaxis, :] + output = cp.zeros((n_subarrays, n_queries), dtype=cp.uint32) + + result = _searchsorted_kernel( + queries, idx, start_indices, sizes, array, side == "right", True, output + ) + if absolute_indices: + return result + start_indices[:, cp.newaxis] + return result diff --git a/cuda_kernels/intervals_to_values.cu b/cuda_kernels/intervals_to_values.cu new file mode 100644 index 0000000..c8173d2 --- /dev/null +++ b/cuda_kernels/intervals_to_values.cu @@ -0,0 +1,97 @@ +extern "C" __global__ +void intervals_to_values( + const unsigned int* query_starts, + const unsigned int* query_ends, + const unsigned int* found_starts, + const unsigned int* found_ends, + const unsigned int* track_starts, + const unsigned int* track_ends, + const float* track_values, + const int n_tracks, + const int batch_size, + const int sequence_length, + const int max_number_intervals, + const int window_size, + float* out +) { + + int thread = blockIdx.x * blockDim.x + threadIdx.x; + +// # out is a 1D array of size batch_size x n_tracks x sequence_length +// +// # n_tracks x n_batch + + int batch_index = thread % batch_size; + int i = thread % (batch_size * n_tracks); + + if (window_size == 1){ + int j = (thread / (batch_size * n_tracks)) % max_number_intervals; + + int found_start_index = found_starts[i]; + int found_end_index = found_ends[i]; + int query_start = query_starts[batch_index]; + int query_end = query_ends[batch_index]; + + int cursor = found_start_index + j; + + if (cursor < found_end_index){ + int interval_start = track_starts[cursor]; + int interval_end = track_ends[cursor]; + int start_index = max(interval_start - query_start, 0); + int end_index = (i * sequence_length) + min(interval_end, query_end) - query_start; + int start_position = (i * sequence_length) + start_index; + + float value = track_values[cursor]; + + for (int position = start_position; position < end_index; position++){ + out[position] = value; + } + } + } else { + + int track_index = i / batch_size; + + int found_start_index = found_starts[i]; + int found_end_index = found_ends[i]; + int query_start = query_starts[batch_index]; + int query_end = query_ends[batch_index]; + + int cursor = found_start_index; + int window_index = 0; + float summation = 0.0f; + + int reduced_dim = sequence_length / window_size; + + while (cursor < found_end_index && window_index < reduced_dim) { + int window_start = window_index * window_size; + int window_end = window_start + window_size; + + int interval_start = track_starts[cursor]; + int interval_end = track_ends[cursor]; + + int start_index = max(interval_start - query_start, 0); + int end_index = min(interval_end, query_end) - query_start; + + if (start_index >= window_end) { + out[i * reduced_dim + window_index] = summation / window_size; + summation = 0.0f; + window_index += 1; + continue; + } + + int number = min(window_end, end_index) - max(window_start, start_index); + + summation += number * track_values[cursor]; + + if (end_index >= window_end || cursor + 1 >= found_end_index) { + out[i * reduced_dim + window_index] = summation / window_size; + summation = 0.0f; + window_index += 1; + } + + if (end_index < window_end) { + cursor += 1; + } + } + } +} diff --git a/tests/test_intervals_to_values_gpu.py b/tests/test_intervals_to_values.py similarity index 62% rename from tests/test_intervals_to_values_gpu.py rename to tests/test_intervals_to_values.py index aba175e..f2e70f5 100644 --- a/tests/test_intervals_to_values_gpu.py +++ b/tests/test_intervals_to_values.py @@ -1,6 +1,24 @@ import cupy as cp +import pytest -from bigwig_loader.intervals_to_values_gpu import intervals_to_values +from bigwig_loader.intervals_to_values import intervals_to_values + + +def test_throw_exception_when_queried_intervals_are_of_different_lengths() -> None: + """All query_ends - query_starts should have the same + length. Otherwise, ValueError should be raised. + """ + track_starts = cp.asarray([1, 2, 3], dtype=cp.int32) + track_ends = cp.asarray([2, 3, 4], dtype=cp.int32) + track_values = cp.asarray([1.0, 1.0, 1.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([2, 2], dtype=cp.int32) + query_ends = cp.asarray([4, 5], dtype=cp.int32) + reserved = cp.zeros((2, 2), dtype=cp.float32) + + with pytest.raises(ValueError): + intervals_to_values( + track_starts, track_ends, track_values, query_starts, query_ends, reserved + ) def test_get_values_from_intervals() -> None: @@ -137,3 +155,78 @@ def test_get_values_from_intervals_batch_of_2() -> None: print(expected) print(values) assert (values == expected).all() + + +def test_get_values_from_intervals_batch_multiple_tracks() -> None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) + query_ends = cp.asarray([18, 20, 31, 110], dtype=cp.int32) + reserved = cp.zeros([3, 4, 11], dtype=cp.dtype(" None: - """.""" - track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32) - track_ends = cp.asarray([3, 10, 12, 16, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 15.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([2], dtype=cp.int32) - query_ends = cp.asarray([17], dtype=cp.int32) - reserved = cp.zeros((1, 3), dtype=cp.float32) - values = intervals_to_values( - track_starts, - track_ends, - track_values, - query_starts, - query_ends, - reserved, - window_size=5, - ) - - expected = cp.asarray([[16.0, 21.0, 42.0]]) - - print(expected) - print(values) - assert (values == expected).all() - - -def test_get_values_from_intervals_edge_case_1() -> None: - """Query start is somewhere in a "gap".""" - track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) - track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([6], dtype=cp.int32) - query_ends = cp.asarray([18], dtype=cp.int32) - reserved = cp.zeros((1, 4), dtype=cp.dtype(" None: - """Query start is exactly at start index after "gap".""" - track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) - track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([10], dtype=cp.int32) - query_ends = cp.asarray([18], dtype=cp.int32) - reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: - """Query end is somewhere in a "gap".""" - track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) - track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([8], dtype=cp.int32) - query_ends = cp.asarray([16], dtype=cp.int32) - reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: - """Query end is exactly at end index before "gap".""" - track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) - track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([8], dtype=cp.int32) - query_ends = cp.asarray([14], dtype=cp.int32) - reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: - """Query end is exactly at end index before "gap".""" - track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.uint32) - track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.uint32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([8], dtype=cp.uint32) - query_ends = cp.asarray([20], dtype=cp.uint32) - reserved = cp.zeros((1, 4), dtype=cp.dtype(" None: - """Query end is exactly at end index before "gap".""" - track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) - track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) - track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) - query_starts = cp.asarray([6, 8], dtype=cp.int32) - query_ends = cp.asarray([18, 20], dtype=cp.int32) - reserved = cp.zeros([2, 4], dtype=cp.dtype(" None: + """.""" + track_starts = cp.asarray([1, 3, 10, 12, 16], dtype=cp.int32) + track_ends = cp.asarray([3, 10, 12, 16, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 15.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([2], dtype=cp.int32) + query_ends = cp.asarray([17], dtype=cp.int32) + reserved = cp.zeros((1, 3), dtype=cp.float32) + values = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + reserved, + window_size=5, + ) + + expected = cp.asarray([[16.0, 21.0, 42.0]]) + + print("expected:") + print(expected) + print("actual:") + print(values) + assert (values == expected).all() + + +def test_get_values_from_intervals_edge_case_1() -> None: + """Query start is somewhere in a "gap".""" + track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) + track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([6], dtype=cp.int32) + query_ends = cp.asarray([18], dtype=cp.int32) + reserved = cp.zeros((1, 4), dtype=cp.dtype(" None: + """Query start is exactly at start index after "gap".""" + track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32) + track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([10], dtype=cp.int32) + query_ends = cp.asarray([18], dtype=cp.int32) + reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: + """Query end is somewhere in a "gap".""" + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([8], dtype=cp.int32) + query_ends = cp.asarray([16], dtype=cp.int32) + reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([8], dtype=cp.int32) + query_ends = cp.asarray([14], dtype=cp.int32) + reserved = cp.zeros((1, 2), dtype=cp.dtype(" None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.uint32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.uint32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([8], dtype=cp.uint32) + query_ends = cp.asarray([20], dtype=cp.uint32) + reserved = cp.zeros((1, 4), dtype=cp.dtype(" None: + """Query end is exactly at end index before "gap".""" + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) + track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4")) + query_starts = cp.asarray([6, 8], dtype=cp.int32) + query_ends = cp.asarray([18, 20], dtype=cp.int32) + reserved = cp.zeros([2, 4], dtype=cp.dtype(" None: + """ + This tests a specific combination of track and batch index + of the larger test case below: + test_get_values_from_intervals_batch_multiple_tracks + track index = 0 + batch_index = 1 + Included to narrow down source of error in the larger test case. + """ + track_starts = cp.asarray([5, 10, 12, 18], dtype=cp.int32) + track_ends = cp.asarray([10, 12, 14, 20], dtype=cp.int32) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([9], dtype=cp.int32) + query_ends = cp.asarray([20], dtype=cp.int32) + reserved = cp.zeros([1, 1, 3], dtype=cp.dtype(" None: + """ + This tests a specific combination of track and batch index + of the larger test case below: + test_get_values_from_intervals_batch_multiple_tracks + track index = 2 + batch_index = 1 + Included to narrow down source of error in the larger test case. + """ + track_starts = cp.asarray([10, 100, 1000], dtype=cp.int32) + track_ends = cp.asarray([20, 200, 2000], dtype=cp.int32) + track_values = cp.asarray( + [110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([9], dtype=cp.int32) + query_ends = cp.asarray([20], dtype=cp.int32) + reserved = cp.zeros([1, 1, 3], dtype=cp.dtype(" None: + """Test interval_to_values with 3 tracks and a batch size of 1.""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([9], dtype=cp.int32) + query_ends = cp.asarray([20], dtype=cp.int32) + reserved = cp.zeros([3, 1, 3], dtype=cp.dtype(" None: + """Test intervals_to_values with 3 tracks and a batch size of 4.""" + track_starts = cp.asarray( + [5, 10, 12, 18, 8, 9, 10, 18, 25, 10, 100, 1000], dtype=cp.int32 + ) + track_ends = cp.asarray( + [10, 12, 14, 20, 9, 10, 14, 22, 55, 20, 200, 2000], dtype=cp.int32 + ) + track_values = cp.asarray( + [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0], + dtype=cp.dtype("f4"), + ) + query_starts = cp.asarray([7, 9, 20, 99], dtype=cp.int32) + query_ends = query_starts + sequence_length + reduced_dim = sequence_length // window_size + + reserved = cp.zeros([3, 4, reduced_dim], dtype=cp.dtype(" None: + """.""" + track_starts = cp.asarray([1, 3, 10, 12, 16] * n_tracks, dtype=cp.int32) + track_ends = cp.asarray([3, 10, 12, 16, 20] * n_tracks, dtype=cp.int32) + track_values = cp.asarray( + [20.0, 15.0, 30.0, 40.0, 50.0] * n_tracks, dtype=cp.dtype("f4") + ) + sizes = cp.asarray([5] * n_tracks, dtype=cp.int32) + sequence_length = 15 + query_starts = cp.asarray([2] * batch_size, dtype=cp.int32) + query_ends = query_starts + sequence_length + values = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=window_size, + ) + + values_with_window_size_1 = intervals_to_values( + track_starts, + track_ends, + track_values, + query_starts, + query_ends, + sizes=sizes, + window_size=1, + ) + + reduced_dim = sequence_length // window_size + full_matrix = values_with_window_size_1[:, :, : reduced_dim * window_size] + full_matrix = full_matrix.reshape( + full_matrix.shape[0], full_matrix.shape[1], reduced_dim, window_size + ) + expected = cp.mean(full_matrix, axis=-1) + + print("expected:") + print(expected) + print("actual:") + print(values) + + assert cp.allclose(values, expected) diff --git a/tests/test_searchsorted.py b/tests/test_searchsorted.py new file mode 100644 index 0000000..619ca21 --- /dev/null +++ b/tests/test_searchsorted.py @@ -0,0 +1,88 @@ +import cupy as cp +import pytest + +from bigwig_loader.searchsorted import searchsorted + + +@pytest.fixture +def test_data(): + intervals_track1 = [5, 10, 12, 18] + intervals_track2 = [ + 1, + 3, + 5, + 7, + 9, + 10, + ] + + intervals_track3 = [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + ] + + intervals_track4 = [4, 100] + + array = cp.asarray( + intervals_track1 + intervals_track2 + intervals_track3 + intervals_track4, + dtype=cp.int32, + ) + queries = cp.asarray([7, 9, 11], dtype=cp.int32) + sizes = cp.asarray( + [ + len(intervals_track1), + len(intervals_track2), + len(intervals_track3), + len(intervals_track4), + ], + cp.int32, + ) + return array, queries, sizes + + +def test_searchsorted_left_relative(test_data) -> None: + array, queries, sizes = test_data + output = searchsorted( + array=array, queries=queries, sizes=sizes, side="left", absolute_indices=False + ) + expected = cp.asarray([[1, 1, 2], [3, 4, 6], [6, 8, 10], [1, 1, 1]]) + assert (output == expected).all() + + +def test_searchsorted_right_relative(test_data) -> None: + array, queries, sizes = test_data + output = searchsorted( + array=array, queries=queries, sizes=sizes, side="right", absolute_indices=False + ) + expected = cp.asarray([[1, 1, 2], [4, 5, 6], [7, 9, 11], [1, 1, 1]]) + assert (output == expected).all() + + +def test_searchsorted_left_absolute(test_data) -> None: + array, queries, sizes = test_data + output = searchsorted( + array=array, queries=queries, sizes=sizes, side="left", absolute_indices=True + ) + expected = cp.asarray([[1, 1, 2], [7, 8, 10], [16, 18, 20], [25, 25, 25]]) + assert (output == expected).all() + + +def test_searchsorted_right_absolute(test_data) -> None: + array, queries, sizes = test_data + output = searchsorted( + array=array, queries=queries, sizes=sizes, side="right", absolute_indices=True + ) + expected = cp.asarray([[1, 1, 2], [8, 9, 10], [17, 19, 21], [25, 25, 25]]) + assert (output == expected).all()