diff --git a/src/pymovements/dataset/dataset.py b/src/pymovements/dataset/dataset.py index 9c851341..755892fd 100644 --- a/src/pymovements/dataset/dataset.py +++ b/src/pymovements/dataset/dataset.py @@ -231,6 +231,40 @@ def load_precomputed_reading_measures(self) -> None: self.paths, ) + def _split_gaze_data( + self, + by: list[str] | str, + ) -> None: + """Split gaze data into seperated GazeDataFrame's. + + Parameters + ---------- + by: list[str] | str + Column's to split dataframe by. + """ + if isinstance(by, str): + by = [by] + new_data = [ + ( + GazeDataFrame( + new_frame, + experiment=_frame.experiment, + trial_columns=self.definition.trial_columns, + time_column=self.definition.time_column, + time_unit=self.definition.time_unit, + position_columns=self.definition.position_columns, + velocity_columns=self.definition.velocity_columns, + acceleration_columns=self.definition.acceleration_columns, + distance_column=self.definition.distance_column, + ), + fileinfo_row, + ) + for (_frame, fileinfo_row) in zip(self.gaze, self.fileinfo['gaze'].to_dicts()) + for new_frame in _frame.frame.partition_by(by=by) + ] + self.gaze = [data[0] for data in new_data] + self.fileinfo['gaze'] = pl.concat([pl.from_dict(data[1]) for data in new_data]) + def split_precomputed_events( self, by: list[str] | str, diff --git a/tests/unit/dataset/dataset_test.py b/tests/unit/dataset/dataset_test.py index 7bb14470..17f3e97a 100644 --- a/tests/unit/dataset/dataset_test.py +++ b/tests/unit/dataset/dataset_test.py @@ -146,6 +146,8 @@ def mock_toy( 'y_left_pix': np.zeros(1000), 'x_right_pix': np.zeros(1000), 'y_right_pix': np.zeros(1000), + 'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]), + 'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600, }, schema={ 'subject_id': pl.Int64, @@ -154,6 +156,8 @@ def mock_toy( 'y_left_pix': pl.Float64, 'x_right_pix': pl.Float64, 'y_right_pix': pl.Float64, + 'trial_id_1': pl.Float64, + 'trial_id_2': pl.Utf8, }, ) pixel_columns = ['x_left_pix', 'y_left_pix', 'x_right_pix', 'y_right_pix'] @@ -169,6 +173,8 @@ def mock_toy( 'y_right_pix': np.zeros(1000), 'x_avg_pix': np.zeros(1000), 'y_avg_pix': np.zeros(1000), + 'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]), + 'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600, }, schema={ 'subject_id': pl.Int64, @@ -179,6 +185,8 @@ def mock_toy( 'y_right_pix': pl.Float64, 'x_avg_pix': pl.Float64, 'y_avg_pix': pl.Float64, + 'trial_id_1': pl.Float64, + 'trial_id_2': pl.Utf8, }, ) pixel_columns = [ @@ -192,12 +200,16 @@ def mock_toy( 'time': np.arange(1000), 'x_left_pix': np.zeros(1000), 'y_left_pix': np.zeros(1000), + 'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]), + 'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600, }, schema={ 'subject_id': pl.Int64, 'time': pl.Int64, 'x_left_pix': pl.Float64, 'y_left_pix': pl.Float64, + 'trial_id_1': pl.Float64, + 'trial_id_2': pl.Utf8, }, ) pixel_columns = ['x_left_pix', 'y_left_pix'] @@ -208,12 +220,16 @@ def mock_toy( 'time': np.arange(1000), 'x_right_pix': np.zeros(1000), 'y_right_pix': np.zeros(1000), + 'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]), + 'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600, }, schema={ 'subject_id': pl.Int64, 'time': pl.Int64, 'x_right_pix': pl.Float64, 'y_right_pix': pl.Float64, + 'trial_id_1': pl.Float64, + 'trial_id_2': pl.Utf8, }, ) pixel_columns = ['x_right_pix', 'y_right_pix'] @@ -224,12 +240,16 @@ def mock_toy( 'time': np.arange(1000), 'x_pix': np.zeros(1000), 'y_pix': np.zeros(1000), + 'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]), + 'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600, }, schema={ 'subject_id': pl.Int64, 'time': pl.Int64, 'x_pix': pl.Float64, 'y_pix': pl.Float64, + 'trial_id_1': pl.Float64, + 'trial_id_2': pl.Utf8, }, ) pixel_columns = ['x_pix', 'y_pix'] @@ -1000,7 +1020,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration): }, ( "Column 'position' not found. Available columns are: " - "['time', 'subject_id', 'pixel', 'custom_position', 'velocity']" + "['time', 'trial_id_1', 'trial_id_2', 'subject_id', " + "'pixel', 'custom_position', 'velocity']" ), id='no_position', ), @@ -1012,7 +1033,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration): }, ( "Column 'velocity' not found. Available columns are: " - "['time', 'subject_id', 'pixel', 'position', 'custom_velocity']" + "['time', 'trial_id_1', 'trial_id_2', 'subject_id', " + "'pixel', 'position', 'custom_velocity']" ), id='no_velocity', ), @@ -1930,3 +1952,30 @@ def test_load_split_precomputed_events(precomputed_dataset_configuration, by, ex dataset.load() dataset.split_precomputed_events(by) assert len(dataset.precomputed_events) == expected_len + + +@pytest.mark.parametrize( + ('by', 'expected_len'), + [ + pytest.param( + 'trial_id_1', + 40, + id='subset_int', + ), + pytest.param( + 'trial_id_2', + 60, + id='subset_int', + ), + pytest.param( + ['trial_id_1', 'trial_id_2'], + 80, + id='subset_int', + ), + ], +) +def test_load_split_gaze(gaze_dataset_configuration, by, expected_len): + dataset = pm.Dataset(**gaze_dataset_configuration['init_kwargs']) + dataset.load() + dataset._split_gaze_data(by) + assert len(dataset.gaze) == expected_len