Skip to content

Commit

Permalink
Allowing bigwig_loader.batch.Batch as return type of the dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenretel committed Aug 13, 2024
1 parent eb9f962 commit dcf7e7c
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 14 deletions.
14 changes: 14 additions & 0 deletions bigwig_loader/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,17 @@ def __getitem__(self, item: slice) -> "Batch":

def __len__(self) -> int:
return len(self.starts)

def __repr__(self) -> str:
n_chromosomes = len(self.chromosomes) if self.chromosomes is not None else 0
n_starts = len(self.starts) if self.starts is not None else 0
n_ends = len(self.ends) if self.ends is not None else 0
n_sequences = len(self.sequences) if self.sequences is not None else 0
value_shape = self.values.shape if self.values is not None else 0
n_track_indices = (
len(self.track_indices) if self.track_indices is not None else 0
)
return (
f"Batch(chromosomes={n_chromosomes}, starts={n_starts}, ends={n_ends}, "
f"sequences={n_sequences}, values={value_shape}, track_indices={n_track_indices})"
)
12 changes: 11 additions & 1 deletion bigwig_loader/dataset_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cupy as cp
import pandas as pd

from bigwig_loader.batch import Batch
from bigwig_loader.batch import IntSequenceType
from bigwig_loader.collection import BigWigCollection
from bigwig_loader.cupy_functions import moving_average
Expand Down Expand Up @@ -69,6 +70,8 @@ class Dataset:
sub_sample_tracks: int, if set a different random set of tracks is selected in each
superbatch from the total number of tracks. The indices corresponding to those tracks
are returned in the output.
return_batch_objects: if True, the batches will be returned as instances of
bigwig_loader.batch.Batch
"""

def __init__(
Expand All @@ -94,6 +97,7 @@ def __init__(
position_sampler_buffer_size: int = 100000,
repeat_same_positions: bool = False,
sub_sample_tracks: Optional[int] = None,
return_batch_objects: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -136,6 +140,7 @@ def __init__(
self._repeat_same_positions = repeat_same_positions
self._moving_average_window_size = moving_average_window_size
self._sub_sample_tracks = sub_sample_tracks
self._return_batch_objects = return_batch_objects

def _create_dataloader(self) -> StreamedDataloader:
position_sampler = RandomPositionSampler(
Expand Down Expand Up @@ -178,6 +183,7 @@ def __iter__(
) -> Iterator[
tuple[cp.ndarray | list[str] | None, cp.ndarray]
| tuple[cp.ndarray | list[str] | None, cp.ndarray, IntSequenceType]
| Batch
]:
dataloader = self._create_dataloader()
for i, batch in enumerate(dataloader):
Expand All @@ -186,7 +192,11 @@ def __iter__(
sequences = self._sequence_encoder(batch.sequences)
else:
sequences = batch.sequences
if batch.track_indices is not None:
if self._return_batch_objects:
batch.sequences = sequences
batch.values = values
yield batch
elif batch.track_indices is not None:
yield sequences, values, batch.track_indices
else:
yield sequences, values
Expand Down
115 changes: 102 additions & 13 deletions bigwig_loader/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,104 @@
from typing import Sequence
from typing import Union

import cupy as cp
import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset

from bigwig_loader.batch import Batch
from bigwig_loader.collection import BigWigCollection
from bigwig_loader.dataset_new import Dataset


class PytorchBigWigDataset(
IterableDataset[tuple[torch.FloatTensor, torch.FloatTensor]]
):
class PytorchBatch:
"""Batch
This is a simple container object to hold on to a set of arrays
that represent a batch of data. It serves as the query to bigwig_loader,
therefore the minimum init args are chromosomes, starts and ends, which
is really everything bigwig_loader really needs.
It is used to pass data through the input and output queues of the
dataloader that works with threads and cuda streams. At some point it
gets cumbersome to keep track of the order of the arrays, so this
object is used to make that a bit simpler.
Args:
chromosomes: 1D array of size batch_size ["chr1", "chr2", ...],
starts: 1D array of size batch_size [0, 42, ...],
ends: 1D array of size batch_size [100, 142, ...],
track_indices (Optional): which tracks to include in this batch
(index should correspond to the order in
bigwig_loader.collection.BigWigCollection). When None, all
tracks are included.
sequences (Optional): ["ACTAGANTG", "CCTTGAGT", ...].
values (cp.ndarray | None): The values of the batch: the output matrix
bigwig_loader produces. size: (batch_size, n_tracks, n_values)
other_batched (list of Arrays): other arrays that share
the batch_dimension with chromosomes, starts, ends, sequences and
values. Here for convenience. When creating a slice of Batch,
these arrays are sliced in the same way the previously mentioned
arrays are sliced.
other (Any): Any other data to hold on to for the batch. Can be anything
No slicing is performed on this object when the Batch is sliced.
"""

def __init__(
self,
chromosomes: Any,
starts: Any,
ends: Any,
values: torch.Tensor,
track_indices: torch.Tensor | None,
sequences: torch.Tensor | list[str] | None,
other_batched: Any | None,
other: Any,
):
self.chromosomes = chromosomes
self.starts = starts
self.ends = ends
self.sequences = sequences
self.track_indices = track_indices
self.values = values
self.other_batched = other_batched
self.other = other

@classmethod
def from_batch(cls, batch: Batch) -> "PytorchBatch":
if batch.other_batched is not None:
other_batched = (
[cls._convert_if_possible(tensor) for tensor in batch.other_batched],
)
else:
other_batched = None
return PytorchBatch(
chromosomes=cls._convert_if_possible(batch.chromosomes),
starts=cls._convert_if_possible(batch.starts),
ends=cls._convert_if_possible(batch.ends),
values=cls._convert_if_possible(batch.values),
track_indices=cls._convert_if_possible(batch.track_indices),
sequences=cls._convert_if_possible(batch.sequences),
other_batched=other_batched,
other=cls._convert_if_possible(batch.other),
)

@staticmethod
def _convert_if_possible(tensor: Any) -> Any:
if isinstance(tensor, cp.ndarray) or isinstance(tensor, np.ndarray):
return torch.as_tensor(tensor)
return tensor


GENOMIC_SEQUENCE_TYPE = Union[torch.Tensor, list[str], None]
BATCH_TYPE = Union[
tuple[GENOMIC_SEQUENCE_TYPE, torch.Tensor],
tuple[GENOMIC_SEQUENCE_TYPE, torch.Tensor, torch.Tensor],
PytorchBatch,
]


class PytorchBigWigDataset(IterableDataset[BATCH_TYPE]):

"""
Pytorch IterableDataset over FASTA files and BigWig profiles.
Expand Down Expand Up @@ -54,6 +141,8 @@ class PytorchBigWigDataset(
sub_sample_tracks: int, if set a different random set of tracks is selected in each
superbatch from the total number of tracks. The indices corresponding to those tracks
are returned in the output.
return_batch_objects: if True, the batches will be returned as instances of
bigwig_loader.pytorch.PytorchBatch
"""

def __init__(
Expand All @@ -78,6 +167,7 @@ def __init__(
position_sampler_buffer_size: int = 100000,
repeat_same_positions: bool = False,
sub_sample_tracks: Optional[int] = None,
return_batch_objects: bool = False,
):
super().__init__()
self._dataset = Dataset(
Expand All @@ -99,22 +189,21 @@ def __init__(
position_sampler_buffer_size=position_sampler_buffer_size,
repeat_same_positions=repeat_same_positions,
sub_sample_tracks=sub_sample_tracks,
return_batch_objects=True,
)
self._return_batch_objects = return_batch_objects

def __iter__(
self,
) -> Iterator[
tuple[torch.FloatTensor, torch.FloatTensor]
| tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]
]:
) -> Iterator[BATCH_TYPE]:
for batch in self._dataset:
sequences = torch.FloatTensor(batch[0])
target = torch.as_tensor(batch[1], device="cuda")
if len(batch) >= 3:
track_indices = torch.tensor(batch[2], device="cuda", dtype=torch.int) # type: ignore
yield sequences, target, track_indices
pytorch_batch = PytorchBatch.from_batch(batch) # type: ignore
if self._return_batch_objects:
yield pytorch_batch
elif pytorch_batch.track_indices is None:
yield pytorch_batch.sequences, pytorch_batch.values
else:
yield sequences, target
yield pytorch_batch.sequences, pytorch_batch.values, pytorch_batch.track_indices

def reset_gpu(self) -> None:
self._dataset.reset_gpu()
63 changes: 63 additions & 0 deletions tests/test_new_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,60 @@ def dataset(bigwig_path, reference_genome_path, merged_intervals):
return dataset


@pytest.fixture
def dataset_with_track_sampling(bigwig_path, reference_genome_path, merged_intervals):
dataset = Dataset(
regions_of_interest=merged_intervals,
collection=bigwig_path,
reference_genome_path=reference_genome_path,
sequence_length=2000,
center_bin_to_predict=1000,
window_size=4,
batch_size=265,
batches_per_epoch=10,
maximum_unknown_bases_fraction=0.1,
first_n_files=2,
sub_sample_tracks=1,
)
return dataset


# @pytest.mark.timeout(10)
def test_output_shape(dataset):
for i, (sequence, values) in enumerate(dataset):
print(i, "---", flush=True)
assert values.shape == (265, 2, 250)


def test_output_shape_sub_sampled_tracks(dataset_with_track_sampling):
for i, (sequence, values, track_indices) in enumerate(dataset_with_track_sampling):
print(i, "---", flush=True)
assert len(track_indices) == 1
assert values.shape == (265, 1, 250)


def test_batch_return_type(bigwig_path, reference_genome_path, merged_intervals):
from bigwig_loader.batch import Batch

dataset = Dataset(
regions_of_interest=merged_intervals,
collection=bigwig_path,
reference_genome_path=reference_genome_path,
sequence_length=2000,
center_bin_to_predict=1000,
window_size=4,
batch_size=265,
batches_per_epoch=10,
maximum_unknown_bases_fraction=0.1,
first_n_files=2,
sub_sample_tracks=1,
return_batch_objects=True,
)
for i, batch in enumerate(dataset):
assert isinstance(batch, Batch)
assert batch.track_indices is not None


if __name__ == "__main__":
from bigwig_loader.collection import BigWigCollection
from bigwig_loader.download_example_data import get_example_bigwigs_files
Expand All @@ -52,4 +99,20 @@ def test_output_shape(dataset):
)

test_output_shape(ds)

ds = Dataset(
regions_of_interest=merged_intervals,
collection=bigwig_path,
reference_genome_path=get_reference_genome(),
sequence_length=2000,
center_bin_to_predict=1000,
window_size=4,
batch_size=265,
batches_per_epoch=10,
maximum_unknown_bases_fraction=0.1,
first_n_files=2,
sub_sample_tracks=1,
)

test_output_shape_sub_sampled_tracks(ds)
print("done")

0 comments on commit dcf7e7c

Please sign in to comment.