diff --git a/src/nlpsig/data_preparation.py b/src/nlpsig/data_preparation.py index 3aeaa2a..9024fd9 100644 --- a/src/nlpsig/data_preparation.py +++ b/src/nlpsig/data_preparation.py @@ -328,9 +328,9 @@ def _check_feature_exists(self, feature: str) -> bool: # not in ._feature_list, but is a valid column name in self.df, # so add to feature list self._feature_list += [feature] - + return feature in self._feature_list - + def _obtain_feature_columns( self, features: list[str] | str | None, @@ -369,8 +369,8 @@ def _obtain_feature_columns( # convert to list of strings if isinstance(features, str): features = [features] - - if isinstance(features, list): + + if isinstance(features, list): # check each item in features is in self._feature_list # if it isn't, but is a column in self.df, it will add # it to self._feature_list @@ -777,9 +777,7 @@ def pad( raise ValueError("`method` must be either 'k_last' or 'max'.") # obtain feature colnames - feature_colnames = self._obtain_feature_columns( - features=features - ) + feature_colnames = self._obtain_feature_columns(features=features) if len(feature_colnames) > 0: if isinstance(standardise_method, str): standardise_method = [standardise_method] * len(feature_colnames) @@ -881,9 +879,7 @@ def get_time_feature( (can be found in `._feature_list` attribute). """ if time_feature not in self._feature_list: - raise ValueError( - f"`time_feature` should be in {self._feature_list}." - ) + raise ValueError(f"`time_feature` should be in {self._feature_list}.") if not self.time_features_added: self.set_time_features() diff --git a/tests/conftest.py b/tests/conftest.py index 5d3b0cc..cda224c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def test_df_with_datetime(): return pd.DataFrame( { "text": [f"text_{i}" for i in range(n_entries)], - "binary_var": [rng.choice([0,1]) for i in range(n_entries)], + "binary_var": [rng.choice([0, 1]) for i in range(n_entries)], "continuous_var": rng.random(n_entries), "id_col": [0 for i in range(100)] + [rng.integers(1, 5) for i in range(n_entries - 100)], @@ -45,7 +45,7 @@ def test_df_no_time(): return pd.DataFrame( { "text": [f"text_{i}" for i in range(n_entries)], - "binary_var": [rng.choice([0,1]) for i in range(n_entries)], + "binary_var": [rng.choice([0, 1]) for i in range(n_entries)], "continuous_var": rng.random(n_entries), "id_col": [0 for i in range(100)] + [rng.integers(1, 5) for i in range(n_entries - 100)], @@ -60,7 +60,7 @@ def test_df_to_pad(): return pd.DataFrame( { "text": [f"text_{i}" for i in range(n_entries)], - "binary_var": [rng.choice([0,1]) for i in range(n_entries)], + "binary_var": [rng.choice([0, 1]) for i in range(n_entries)], "continuous_var": rng.random(n_entries), "id_col": 0, "label_col": [rng.integers(0, 4) for i in range(n_entries)], diff --git a/tests/test_data_preparation.py b/tests/test_data_preparation.py index 3f43c68..35d7746 100644 --- a/tests/test_data_preparation.py +++ b/tests/test_data_preparation.py @@ -28,11 +28,7 @@ def test_default_initialisation_datetime( # 1 dummy id column assert obj.df.shape == ( len(obj.original_df.index), - 1 - + len(obj.original_df.columns) - + emb.shape[1] - + len(obj._feature_list) - + 1, + 1 + len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list) + 1, ) assert obj.pooled_embeddings is None assert set(obj._feature_list) == { @@ -67,10 +63,7 @@ def test_default_initialisation_no_time( # 1 dummy id column assert obj.df.shape == ( len(obj.original_df.index), - len(obj.original_df.columns) - + emb.shape[1] - + len(obj._feature_list) - + 1, + len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list) + 1, ) assert obj.pooled_embeddings is None assert obj._feature_list == ["timeline_index"] @@ -105,10 +98,7 @@ def test_initialisation_with_id_and_label_datetime( # 3 time features assert obj.df.shape == ( len(obj.original_df.index), - 1 - + len(obj.original_df.columns) - + emb.shape[1] - + len(obj._feature_list), + 1 + len(obj.original_df.columns) + emb.shape[1] + len(obj._feature_list), ) assert obj.pooled_embeddings is None assert set(obj._feature_list) == { @@ -516,7 +506,10 @@ def test_obtain_colnames_both(test_df_with_datetime, emb, emb_reduced): ) assert obj._obtain_embedding_colnames(embeddings="full") == emb_names assert obj._obtain_embedding_colnames(embeddings="dim_reduced") == emb_reduced_names - assert obj._obtain_embedding_colnames(embeddings="both") == emb_reduced_names + emb_names + assert ( + obj._obtain_embedding_colnames(embeddings="both") + == emb_reduced_names + emb_names + ) def test_obtain_feature_columns_string(test_df_with_datetime, emb): @@ -548,9 +541,11 @@ def test_obtain_feature_columns_string_additional_binary(test_df_with_datetime, "timeline_index", "binary_var", } - -def test_obtain_feature_columns_string_additional_continuous(test_df_with_datetime, emb): + +def test_obtain_feature_columns_string_additional_continuous( + test_df_with_datetime, emb +): # default initialisation obj = PrepareData(original_df=test_df_with_datetime, embeddings=emb) # originally only have the time features @@ -592,7 +587,9 @@ def test_obtain_feature_columns_list_additional(test_df_with_datetime, emb): "time_diff", "timeline_index", } - assert obj._obtain_feature_columns(["time_encoding", "timeline_index", "binary_var", "continuous_var"]) == [ + assert obj._obtain_feature_columns( + ["time_encoding", "timeline_index", "binary_var", "continuous_var"] + ) == [ "time_encoding", "timeline_index", "binary_var", @@ -734,6 +731,7 @@ def test_standardise_pd_wrong_method(vec_to_standardise, test_df_no_time, emb): obj = PrepareData(original_df=test_df_no_time, embeddings=emb) incorrect_method = "fake_method" with pytest.raises( - ValueError, match=re.escape(f"`method`: {incorrect_method} must be in {implemented}.") + ValueError, + match=re.escape(f"`method`: {incorrect_method} must be in {implemented}."), ): obj._standardise_pd(vec=vec_to_standardise, method=incorrect_method) diff --git a/tests/test_padding.py b/tests/test_padding.py index deb7b0d..e2612a3 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -1250,7 +1250,7 @@ def test_pad_by_id_k_last_additional(test_df_with_datetime, emb): assert type(obj.array_padded) == np.ndarray assert np.array_equal(padded_array, obj.array_padded) assert obj.array_padded.shape == (len(obj.original_df["id_col"].unique()), k, ncol) - + def test_pad_by_id_max(test_df_with_datetime, emb): obj = PrepareData( @@ -1278,7 +1278,7 @@ def test_pad_by_id_max(test_df_with_datetime, emb): assert type(obj.array_padded) == np.ndarray assert np.array_equal(padded_array, obj.array_padded) assert obj.array_padded.shape == (len(obj.original_df["id_col"].unique()), k, ncol) - + def test_pad_by_id_max_additional(test_df_with_datetime, emb): obj = PrepareData(