diff --git a/src/qusi/internal/light_curve_collection.py b/src/qusi/internal/light_curve_collection.py index 6bbb479..c3bdd70 100644 --- a/src/qusi/internal/light_curve_collection.py +++ b/src/qusi/internal/light_curve_collection.py @@ -61,7 +61,7 @@ def __getitem__(self, indexes: int | tuple[int]) -> Path | tuple[Path]: class PathGetterBase(PathIterableBase, PathIndexableBase): - pass + random_number_generator: Random @dataclass @@ -265,17 +265,26 @@ def observation_iter(self) -> Iterator[LightCurveObservation]: :return: The iterable of the light curves. """ + light_curve_paths = self.path_iter() + for light_curve_path in light_curve_paths: + light_curve_observation = self.observation_from_path(light_curve_path) + yield light_curve_observation + + def observation_from_path(self, light_curve_path: Path) -> LightCurveObservation: + times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path( + light_curve_path + ) + label = self.load_label_from_path_function(light_curve_path) + light_curve = LightCurve.new(times, fluxes) + light_curve_observation = LightCurveObservation.new(light_curve, label) + light_curve_observation.path = light_curve_path # TODO: Quick debug hack. + return light_curve_observation + + def path_iter(self) -> Iterable[Path]: light_curve_paths = self.path_getter.get_shuffled_paths() if len(light_curve_paths) == 0: raise ValueError('LightCurveObservationCollection returned no paths.') - for light_curve_path in light_curve_paths: - times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path( - light_curve_path - ) - label = self.load_label_from_path_function(light_curve_path) - light_curve = LightCurve.new(times, fluxes) - light_curve_observation = LightCurveObservation.new(light_curve, label) - yield light_curve_observation + return light_curve_paths def __getitem__(self, index: int) -> LightCurveObservation: light_curve_path = self.path_getter[index] diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index 0f45fdd..85aea45 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import itertools import math import re import shutil @@ -8,6 +9,7 @@ from enum import Enum from functools import partial from pathlib import Path +from random import Random from typing import TYPE_CHECKING, Any, Callable, TypeVar import numpy as np @@ -75,17 +77,24 @@ def __init__( ) raise ValueError(error_message) self.post_injection_transform: Callable[[Any], Any] = post_injection_transform + self.worker_randomizing_set: bool = False def __iter__(self): + if not self.worker_randomizing_set: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + self.seed_random(worker_info.id) + self.worker_randomizing_set = True base_light_curve_collection_iter_and_type_pairs: list[ - tuple[Iterator[LightCurveObservation], LightCurveCollectionType] + tuple[Iterator[Path], Callable[[Path], LightCurveObservation], LightCurveCollectionType] ] = [] injectee_collections = copy.copy(self.injectee_light_curve_collections) for standard_collection in self.standard_light_curve_collections: if standard_collection in injectee_collections: base_light_curve_collection_iter_and_type_pairs.append( ( - loop_iter_function(standard_collection.observation_iter), + loop_iter_function(standard_collection.path_iter), + standard_collection.observation_from_path, LightCurveCollectionType.STANDARD_AND_INJECTEE, ) ) @@ -93,34 +102,39 @@ def __iter__(self): else: base_light_curve_collection_iter_and_type_pairs.append( ( - loop_iter_function(standard_collection.observation_iter), + loop_iter_function(standard_collection.path_iter), + standard_collection.observation_from_path, LightCurveCollectionType.STANDARD, ) ) for injectee_collection in injectee_collections: base_light_curve_collection_iter_and_type_pair = ( - loop_iter_function(injectee_collection.observation_iter), + loop_iter_function(injectee_collection.path_iter), + injectee_collection.observation_from_path, LightCurveCollectionType.INJECTEE, ) base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair) injectable_light_curve_collection_iters: list[ - Iterator[LightCurveObservation] + tuple[Iterator[Path], Callable[[Path], LightCurveObservation]] ] = [] for injectable_collection in self.injectable_light_curve_collections: - injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter) - injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter) + injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.path_iter) + injectable_light_curve_collection_iters.append( + (injectable_light_curve_collection_iter, injectable_collection.observation_from_path)) while True: for ( base_light_curve_collection_iter_and_type_pair ) in base_light_curve_collection_iter_and_type_pairs: - (base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair + (base_collection_iter, observation_from_path_function, + collection_type) = base_light_curve_collection_iter_and_type_pair if collection_type in [ LightCurveCollectionType.STANDARD, LightCurveCollectionType.STANDARD_AND_INJECTEE, ]: # TODO: Preprocessing step should be here. Or maybe that should all be on the light curve collection # as well? Or passed in somewhere else? - standard_light_curve = next(base_collection_iter) + standard_path = next(base_collection_iter) + standard_light_curve = observation_from_path_function(standard_path) transformed_standard_light_curve = self.post_injection_transform( standard_light_curve ) @@ -129,10 +143,12 @@ def __iter__(self): LightCurveCollectionType.INJECTEE, LightCurveCollectionType.STANDARD_AND_INJECTEE, ]: - for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters: - injectable_light_curve = next( + for (injectable_light_curve_collection_iter, + injectable_observation_from_path_function) in injectable_light_curve_collection_iters: + injectable_light_path = next( injectable_light_curve_collection_iter ) + injectable_light_curve = injectable_observation_from_path_function(injectable_light_path) injectee_light_curve = next(base_collection_iter) injected_light_curve = inject_light_curve( injectee_light_curve, injectable_light_curve @@ -188,6 +204,12 @@ def new( ) return instance + def seed_random(self, seed: int): + for collection_group in [self.standard_light_curve_collections, self.injectee_light_curve_collections, + self.injectable_light_curve_collections]: + for collection in collection_group: + collection.path_getter.random_number_generator = Random(seed) + def inject_light_curve( injectee_observation: LightCurveObservation, diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration_tests/test_light_curve_dataset.py b/tests/integration_tests/test_light_curve_dataset.py new file mode 100644 index 0000000..39dfb15 --- /dev/null +++ b/tests/integration_tests/test_light_curve_dataset.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import numpy as np +import numpy.typing as npt +import torch +from torch.utils.data import DataLoader + +from qusi.internal.light_curve_collection import LightCurveObservationCollection +from qusi.internal.light_curve_dataset import LightCurveDataset +from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ + pair_array_to_tensor + + +def get_paths() -> list[Path]: + return [Path('1'), Path('2'), Path('3'), Path('4'), Path('5'), Path('6'), Path('7'), Path('8')] + +def load_times_and_fluxes_from_path(path: Path) -> [npt.NDArray, npt.NDArray]: + value = float(str(path)) + return np.array([value]), np.array([value]) + +def load_label_from_path_function(path: Path) -> int: + value = int(str(path)) + return value * 10 + +def post_injection_transform(x): + x = from_light_curve_observation_to_fluxes_array_and_label_array(x) + x = pair_array_to_tensor(x) + return x + + +def test_light_curve_dataset_with_and_without_multiple_workers_gives_same_batch_order(): + light_curve_collection = LightCurveObservationCollection.new( + get_paths_function=get_paths, + load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, + load_label_from_path_function=load_label_from_path_function) + light_curve_dataset = LightCurveDataset.new(standard_light_curve_collections=[light_curve_collection], + post_injection_transform=post_injection_transform) + multi_worker_dataloader = DataLoader(light_curve_dataset, batch_size=4, num_workers=2, prefetch_factor=1) + multi_worker_dataloader_iter = iter(multi_worker_dataloader) + multi_worker_batch0 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0] + multi_worker_batch1 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0] + assert not np.array_equal(multi_worker_batch0, multi_worker_batch1)