Skip to content

Commit

Permalink
Adepting pytorch dataset to new dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenretel committed Aug 13, 2024
1 parent f80fa28 commit 3f27803
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
4 changes: 3 additions & 1 deletion bigwig_loader/dataset_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class Dataset:
repeat_same_positions: if True the positions sampler does not draw a new random collection
of positions when the buffer runs out, but repeats the same samples. Can be used to
check whether network can overfit.
sub_sample_tracks: int, if set
sub_sample_tracks: int, if set a different random set of tracks is selected in each
superbatch from the total number of tracks. The indices corresponding to those tracks
are returned in the output.
"""

def __init__(
Expand Down
32 changes: 21 additions & 11 deletions bigwig_loader/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import IterableDataset

from bigwig_loader.collection import BigWigCollection
from bigwig_loader.dataset import BigWigDataset
from bigwig_loader.dataset_new import Dataset


class PytorchBigWigDataset(
Expand Down Expand Up @@ -51,6 +51,9 @@ class PytorchBigWigDataset(
sequence_encoder: encoder to apply to the sequence. Default: bigwig_loader.util.onehot_sequences
position_samples_buffer_size: number of intervals picked up front by the position sampler.
When all intervals are used, new intervals are picked.
sub_sample_tracks: int, if set a different random set of tracks is selected in each
superbatch from the total number of tracks. The indices corresponding to those tracks
are returned in the output.
"""

def __init__(
Expand All @@ -74,9 +77,10 @@ def __init__(
first_n_files: Optional[int] = None,
position_sampler_buffer_size: int = 100000,
repeat_same_positions: bool = False,
sub_sample_tracks: Optional[int] = None,
):
super().__init__()
self._dataset = BigWigDataset(
self._dataset = Dataset(
regions_of_interest=regions_of_interest,
collection=collection,
reference_genome_path=reference_genome_path,
Expand All @@ -94,17 +98,23 @@ def __init__(
first_n_files=first_n_files,
position_sampler_buffer_size=position_sampler_buffer_size,
repeat_same_positions=repeat_same_positions,
sub_sample_tracks=sub_sample_tracks,
)

def __iter__(self) -> Iterator[tuple[torch.FloatTensor, torch.FloatTensor]]:
iter(self._dataset)
return self

def __next__(self) -> tuple[torch.FloatTensor, torch.FloatTensor]:
sequences, target = next(self._dataset)
target = torch.as_tensor(target, device="cuda")
sequences = torch.FloatTensor(sequences)
return sequences, target
def __iter__(
self,
) -> Iterator[
tuple[torch.FloatTensor, torch.FloatTensor]
| tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]
]:
for batch in self._dataset:
sequences = torch.FloatTensor(batch[0])
target = torch.as_tensor(batch[1], device="cuda")
if len(batch) >= 3:
track_indices = torch.tensor(batch[2], device="cuda", dtype=torch.int) # type: ignore
yield sequences, target, track_indices
else:
yield sequences, target

def reset_gpu(self) -> None:
self._dataset.reset_gpu()
2 changes: 0 additions & 2 deletions bigwig_loader/streamed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ def _generate_batches(self) -> Generator[Batch, None, None]:
scaling_factors = scaling_factors[
:, batch.track_indices, :
]
print(scaling_factors)
print(scaling_factors.shape)

values *= scaling_factors
stream.synchronize()
Expand Down

0 comments on commit 3f27803

Please sign in to comment.