Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a reset_gpu method to the BigWigCollection and Dataset objects… #3

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bigwig_loader/bigwig.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,10 @@ def _guess_max_rows_per_chunk(
"""

rows_for_chunks = []
sample_leaf_nodes = sample(self.rtree_leaf_nodes, sample_size)
if len(self.rtree_leaf_nodes) < sample_size:
sample_leaf_nodes = self.rtree_leaf_nodes
else:
sample_leaf_nodes = sample(self.rtree_leaf_nodes, sample_size)
for leaf_node in sample_leaf_nodes:
file_object.seek(leaf_node.data_offset, 0) # type: ignore
decoded = zlib.decompress(file_object.read(leaf_node.data_size)) # type: ignore
Expand Down
60 changes: 41 additions & 19 deletions bigwig_loader/collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from functools import cached_property
from pathlib import Path
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -54,14 +55,10 @@ def __init__(
self.local_chrom_ids_to_offset_matrix,
) = self.create_global_position_system()

max_rows_per_chunk = max([bigwig.max_rows_per_chunk for bigwig in self.bigwigs])
self.pinned_memory_size = pinned_memory_size
self._memory_bank: Optional[MemoryBank] = None
self._decoder = Decoder(
max_rows_per_chunk=max_rows_per_chunk,
max_uncompressed_chunk_size=max_rows_per_chunk * 12 + 24,
chromosome_offsets=self.local_chrom_ids_to_offset_matrix,
self.max_rows_per_chunk = max(
[bigwig.max_rows_per_chunk for bigwig in self.bigwigs]
)
self.pinned_memory_size = pinned_memory_size
self._out: cp.ndarray = cp.zeros((len(self), 1, 1), dtype=cp.float32)

self.run_indexing()
Expand All @@ -75,13 +72,39 @@ def run_indexing(self) -> None:
def __len__(self) -> int:
return len(self.bigwigs)

def _get_memory_bank(self) -> MemoryBank:
if self._memory_bank:
return self._memory_bank
self._memory_bank = MemoryBank(nbytes=self.pinned_memory_size, elastic=True)
return self._memory_bank
def reset_gpu(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ever called on the collection level?

It seems you have to repeat the method to have to delegate from the top levels down.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Because a BigWigDataset object can take a BigWigCollection object as one of the init args. This can be handy when you want to create train, val and test datasets, all based on the same collection of bigwig files, but with different "regions of interest" from which the samples are taken:

https://github.com/pfizer-opensource/bigwig-loader/blob/d84a233a6214162c3e7178684f43d8a198fb88d0/bigwig_loader/dataset.py#L61C1-L62C1

"""
Remove all gpu arrays from the previously used device and recreate when necessary on
current device. This is useful when training is done on multiple gpus and the arrays
need to be recreated on the new gpu.
"""

self._out = cp.zeros((len(self), 1, 1), dtype=cp.float32)
del self.__dict__["decoder"]
del self.__dict__["memory_bank"]

@cached_property
def decoder(self) -> Decoder:
return Decoder(
max_rows_per_chunk=self.max_rows_per_chunk,
max_uncompressed_chunk_size=self.max_rows_per_chunk * 12 + 24,
chromosome_offsets=self.local_chrom_ids_to_offset_matrix,
)

@cached_property
def memory_bank(self) -> MemoryBank:
return MemoryBank(nbytes=self.pinned_memory_size, elastic=True)

def _get_out_tensor(self, batch_size: int, sequence_length: int) -> cp.ndarray:
"""Resuses a reserved tensor if possible (when out shape is constant),
otherwise creates a new one.
args:
batch_size: batch size
sequence_length: length of genomic sequence
returns:
tensor of shape (number of bigwig files, batch_size, sequence_length)
"""

shape = (len(self), batch_size, sequence_length)
if self._out.shape != shape:
self._out = cp.zeros(shape, dtype=cp.float32)
Expand All @@ -95,8 +118,7 @@ def get_batch(
window_size: int = 1,
out: Optional[cp.ndarray] = None,
) -> cp.ndarray:
memory_bank = self._get_memory_bank()
memory_bank.reset()
self.memory_bank.reset()

if (end[0] - start[0]) % window_size:
raise ValueError(
Expand All @@ -119,7 +141,7 @@ def get_batch(
n_chunks_per_bigwig.append(len(offsets))
bigwig_ids.extend([bigwig.id] * len(offsets))
# read chunks into preallocated memory
memory_bank.add_many(
self.memory_bank.add_many(
bigwig.store._fh,
offsets,
sizes,
Expand All @@ -128,8 +150,8 @@ def get_batch(

# bring the gpu
bigwig_ids = cp.asarray(bigwig_ids, dtype=cp.uint32)
gpu_byte_array, comp_chunks, compressed_chunk_sizes = memory_bank.to_gpu()
_, start, end, value, n_rows_for_chunks = self._decoder.decode(
gpu_byte_array, comp_chunks, compressed_chunk_sizes = self.memory_bank.to_gpu()
_, start, end, value, n_rows_for_chunks = self.decoder.decode(
gpu_byte_array, comp_chunks, compressed_chunk_sizes, bigwig_ids=bigwig_ids
)

Expand Down Expand Up @@ -201,8 +223,8 @@ def intervals(
threshold=threshold,
merge=merge,
merge_allow_gap=merge_allow_gap,
memory_bank=self._get_memory_bank(),
decoder=self._decoder,
memory_bank=self.memory_bank,
decoder=self.decoder,
batch_size=batch_size,
)
for bw in self.bigwigs
Expand Down
6 changes: 6 additions & 0 deletions bigwig_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(
self._n = 0
self._offset = 0

def reset_gpu(self) -> None:
self._super_dataset.reset_gpu()

def __iter__(self) -> Iterator[tuple[Any, cp.ndarray]]:
self._n = 0
self._offset = 0
Expand Down Expand Up @@ -237,6 +240,9 @@ def __init__(
self._position_sampler_buffer_size = position_sampler_buffer_size
self._repeat_same_positions = repeat_same_positions

def reset_gpu(self) -> None:
self.bigwig_collection.reset_gpu()

@property
def genome(self) -> Genome:
"""
Expand Down
3 changes: 3 additions & 0 deletions bigwig_loader/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ def __next__(self) -> tuple[torch.FloatTensor, torch.FloatTensor]:
target = torch.as_tensor(target, device="cuda")
sequences = torch.FloatTensor(sequences)
return sequences, target

def reset_gpu(self) -> None:
self._dataset.reset_gpu()
Loading