diff --git a/bigwig_loader/sampler/track_sampler.py b/bigwig_loader/sampler/track_sampler.py index ecac830..c21f379 100644 --- a/bigwig_loader/sampler/track_sampler.py +++ b/bigwig_loader/sampler/track_sampler.py @@ -8,4 +8,5 @@ def __init__(self, total_number_of_tracks: int, sample_size: int): self.sample_size = sample_size def __iter__(self) -> Iterator[list[int]]: - yield sorted(sample(range(self.total_number_of_tracks), self.sample_size)) + while True: + yield sorted(sample(range(self.total_number_of_tracks), self.sample_size)) diff --git a/tests/test_new_dataset.py b/tests/test_new_dataset.py index 03a4e98..96df4d7 100644 --- a/tests/test_new_dataset.py +++ b/tests/test_new_dataset.py @@ -114,5 +114,6 @@ def test_batch_return_type(bigwig_path, reference_genome_path, merged_intervals) sub_sample_tracks=1, ) + print("start") test_output_shape_sub_sampled_tracks(ds) print("done") diff --git a/tests/test_track_sampler.py b/tests/test_track_sampler.py new file mode 100644 index 0000000..d8fe583 --- /dev/null +++ b/tests/test_track_sampler.py @@ -0,0 +1,13 @@ +from bigwig_loader.sampler.track_sampler import TrackSampler + + +def test_track_samler(): + sampler = TrackSampler(total_number_of_tracks=40, sample_size=10) + samples = [] + for i, sample in enumerate(sampler): + samples.append(sample) + assert len(sample) == 10 + assert all([0 <= s < 40 for s in sample]) + if i == 100: + break + assert len(samples) == 101