From 2db959e13b37fedb7c2876c002aa299e1a7bee45 Mon Sep 17 00:00:00 2001 From: Stela IS Date: Mon, 26 Aug 2024 11:09:59 -0400 Subject: [PATCH 1/6] removing infinite values --- src/qusi/internal/light_curve.py | 15 ++++++++++++ src/qusi/internal/light_curve_dataset.py | 2 ++ src/qusi/internal/light_curve_observation.py | 25 +++++++++++++++++++- 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/qusi/internal/light_curve.py b/src/qusi/internal/light_curve.py index d866dda..5a9bf7b 100644 --- a/src/qusi/internal/light_curve.py +++ b/src/qusi/internal/light_curve.py @@ -48,6 +48,21 @@ def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> Lig return light_curve +def remove_infinite_flux_data_points_from_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Removes infinite values from a light curve. If there is an infinite value in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve: The light curve. + :return: The light curve with infinite values removed. + """ + light_curve = deepcopy(light_curve) + infinite_flux_indexes = np.isinf(light_curve.fluxes) + light_curve.fluxes = light_curve.fluxes[~infinite_flux_indexes] + light_curve.times = light_curve.times[~infinite_flux_indexes] + return light_curve + + def randomly_roll_light_curve(light_curve: LightCurve) -> LightCurve: """ Randomly rolls a light curve. That is, a random position in the light curve is chosen, the light curve diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index 85aea45..61a6edd 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -31,6 +31,7 @@ LightCurveObservation, randomly_roll_light_curve_observation, remove_nan_flux_data_points_from_light_curve_observation, + remove_infinite_flux_data_points_from_light_curve_observation, ) from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, @@ -342,6 +343,7 @@ def default_light_curve_observation_post_injection_transform( :param randomize: Whether to have randomization in the transforms. :return: The transformed light curve observation. """ + x = remove_infinite_flux_data_points_from_light_curve_observation(x) x = remove_nan_flux_data_points_from_light_curve_observation(x) if randomize: x = randomly_roll_light_curve_observation(x) diff --git a/src/qusi/internal/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py index aa1cc86..ff43c21 100644 --- a/src/qusi/internal/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -3,7 +3,9 @@ from typing_extensions import Self -from qusi.internal.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import (LightCurve, randomly_roll_light_curve, + remove_nan_flux_data_points_from_light_curve, + remove_infinite_flux_data_points_from_light_curve) @dataclass @@ -48,6 +50,27 @@ def remove_nan_flux_data_points_from_light_curve_observation( return light_curve_observation +def remove_inf_flux_data_points_from_light_curve(light_curve): + pass + + +def remove_infinite_flux_data_points_from_light_curve_observation( + light_curve_observation: LightCurveObservation, +) -> LightCurveObservation: + """ + Removes the inf values from a light curve in a light curve observation. If there is an inf in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with inf values removed. + """ + light_curve_observation = deepcopy(light_curve_observation) + light_curve_observation.light_curve = remove_infinite_flux_data_points_from_light_curve( + light_curve_observation.light_curve + ) + return light_curve_observation + + def randomly_roll_light_curve_observation(light_curve_observation: LightCurveObservation) -> LightCurveObservation: """ Randomly rolls a light curve observation. That is, a random position in the light curve is chosen, the light curve From 5c6b902227744c547ba94d915216f635a0e811a1 Mon Sep 17 00:00:00 2001 From: Stela IS Date: Mon, 26 Aug 2024 11:24:25 -0400 Subject: [PATCH 2/6] removing infinite values --- src/qusi/internal/light_curve_observation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/qusi/internal/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py index ff43c21..53f12ee 100644 --- a/src/qusi/internal/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -50,10 +50,6 @@ def remove_nan_flux_data_points_from_light_curve_observation( return light_curve_observation -def remove_inf_flux_data_points_from_light_curve(light_curve): - pass - - def remove_infinite_flux_data_points_from_light_curve_observation( light_curve_observation: LightCurveObservation, ) -> LightCurveObservation: From 7512fdb2a97546b16db6ebe7c3874e98c4216e65 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 26 Aug 2024 23:16:04 -0400 Subject: [PATCH 3/6] Add infinite removal to the light curve only post injection transform --- src/qusi/internal/light_curve_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index 61a6edd..34ee515 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -25,7 +25,7 @@ from qusi.internal.light_curve import ( LightCurve, randomly_roll_light_curve, - remove_nan_flux_data_points_from_light_curve, + remove_nan_flux_data_points_from_light_curve, remove_infinite_flux_data_points_from_light_curve, ) from qusi.internal.light_curve_observation import ( LightCurveObservation, @@ -369,6 +369,7 @@ def default_light_curve_post_injection_transform( :param randomize: Whether to have randomization in the transforms. :return: The transformed light curve. """ + x = remove_infinite_flux_data_points_from_light_curve(x) x = remove_nan_flux_data_points_from_light_curve(x) if randomize: x = randomly_roll_light_curve(x) From ea412d83c09d9c4e1045d8e11cefcda05aa2bb23 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 26 Aug 2024 23:16:32 -0400 Subject: [PATCH 4/6] Add test for light curve infinite flux removal --- tests/unit_tests/test_transform.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 tests/unit_tests/test_transform.py diff --git a/tests/unit_tests/test_transform.py b/tests/unit_tests/test_transform.py new file mode 100644 index 0000000..aff7130 --- /dev/null +++ b/tests/unit_tests/test_transform.py @@ -0,0 +1,17 @@ +import numpy as np + +from qusi.internal.light_curve import LightCurve, remove_infinite_flux_data_points_from_light_curve + + +def test_remove_infinite_flux_data_points_from_light_curve(): + times = np.array([0.0, 1.0, 2.0]) + fluxes = np.array([0.0, np.inf, 20.0]) + light_curve = LightCurve.new( + times=times, + fluxes=fluxes, + ) + updated_light_curve = remove_infinite_flux_data_points_from_light_curve(light_curve=light_curve) + expected_times = np.array([0.0, 2.0]) + expected_fluxes = np.array([0.0, 20.0]) + assert np.array_equal(updated_light_curve.times, expected_times) + assert np.array_equal(updated_light_curve.fluxes, expected_fluxes) From c11cf7bc9c76afeb11c9cde880b59412fd3cd502 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 26 Aug 2024 23:28:11 -0400 Subject: [PATCH 5/6] Make the infinite removal functions available in the public interface --- src/qusi/transform.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/qusi/transform.py b/src/qusi/transform.py index ba3b863..dc5a3da 100644 --- a/src/qusi/transform.py +++ b/src/qusi/transform.py @@ -1,11 +1,12 @@ """ Data transform related public interface. """ -from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve, \ + remove_infinite_flux_data_points_from_light_curve from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ - randomly_roll_light_curve_observation + randomly_roll_light_curve_observation, remove_infinite_flux_data_points_from_light_curve_observation from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score @@ -20,4 +21,6 @@ 'randomly_roll_light_curve_observation', 'remove_nan_flux_data_points_from_light_curve', 'remove_nan_flux_data_points_from_light_curve_observation', + 'remove_infinite_flux_data_points_from_light_curve', + 'remove_infinite_flux_data_points_from_light_curve_observation', ] From 4ce0da17d744aef0b88c2bc6461d8d75d94e0f1a Mon Sep 17 00:00:00 2001 From: golmschenk Date: Mon, 26 Aug 2024 23:30:24 -0400 Subject: [PATCH 6/6] Update the tutorial to reflect the changes to the default post injection transform --- docs/source/tutorials/crafting_standard_datasets.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index baabe69..ed0ffe7 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -35,6 +35,7 @@ In the previous section, we only changed the length of that the uniform lengthen ```python def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int, randomize: bool = True) -> (Tensor, Tensor): + x = remove_infinite_flux_data_points_from_light_curve_observation(x) x = remove_nan_flux_data_points_from_light_curve_observation(x) if randomize: x = randomly_roll_light_curve_observation(x)