Skip to content

Commit

Permalink
tests for get_time_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Aug 16, 2023
1 parent 2af8171 commit eb56b34
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/nlpsig/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def pad(
def get_time_feature(
self,
time_feature: str = "timeline_index",
standardise_method: str = "z_score",
standardise_method: str | None = None,
) -> dict[str, np.array | Callable | None]:
"""
Returns a `np.array` object of the time_feature that is requested
Expand Down Expand Up @@ -906,7 +906,9 @@ def get_time_feature(
(can be found in `._feature_list` attribute).
"""
if time_feature not in self._feature_list:
raise ValueError(f"`time_feature` should be in {self._feature_list}.")
raise ValueError(
f"`time_feature` '{time_feature}' should be in {self._feature_list}."
)

if not self.time_features_added:
self.set_time_features()
Expand Down Expand Up @@ -1034,7 +1036,7 @@ def get_torch_path_for_SWNUNetwork(
if self.array_padded is None:
raise ValueError("Need to first call to create the path `.pad()`.")

# obtains a torch tensor which can be inputted into deepsignet
# obtains a torch tensor which can be inputted into SWNUNetwork
# computes how many features there are currently
# (which occur in the first n_features columns)
n_features = len(
Expand Down
59 changes: 59 additions & 0 deletions tests/test_data_preparation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
import regex as re
Expand Down Expand Up @@ -735,3 +736,61 @@ def test_standardise_pd_wrong_method(vec_to_standardise, test_df_no_time, emb):
match=re.escape(f"`method`: {incorrect_method} must be in {implemented}."),
):
obj._standardise_pd(vec=vec_to_standardise, method=incorrect_method)


def test_get_time_feature(test_df_with_datetime, emb):
# test get_time_feature function with default arguments
# default initialisation
obj = PrepareData(original_df=test_df_with_datetime, embeddings=emb)
time_feature = obj.get_time_feature()

# check it returns a dict
assert type(time_feature) == dict
# time_feature should be an numpy array
assert type(time_feature["time_feature"]) == np.ndarray
# should just be the timeline_index column
np.testing.assert_array_equal(
time_feature["time_feature"], np.array(obj.df["timeline_index"])
)
# no standardisation applied by default
assert time_feature["transform"] is None


def test_get_time_feature_incorrect_time_feature(test_df_with_datetime, emb):
# test get_time_feature function with time_feature that isn't in the feature list
# default initialisation
obj = PrepareData(original_df=test_df_with_datetime, embeddings=emb)

incorrect_time_feature = "fake_time_feature"
with pytest.raises(
ValueError,
match=re.escape(
f"`time_feature` '{incorrect_time_feature}' should be in {obj._feature_list}."
),
):
obj.get_time_feature(time_feature=incorrect_time_feature)


def test_get_time_feature_with_standardisation(test_df_with_datetime, emb):
# test get_time_feature function with requested standardisation (using z_score)
# default initialisation
obj = PrepareData(original_df=test_df_with_datetime, embeddings=emb)
time_feature = obj.get_time_feature(
time_feature="timeline_index", standardise_method="z_score"
)
standardised = obj._standardise_pd(vec=obj.df["timeline_index"], method="z_score")

# check it returns a dict
assert type(time_feature) == dict
# time_feature should be an numpy array
assert type(time_feature["time_feature"]) == np.ndarray
# should equal the standardised array using z_score
np.testing.assert_equal(
time_feature["time_feature"], np.array(standardised["standardised_pd"])
)
# the transform applied to the time feature should be
# the same as the standardised array using z_score
np.testing.assert_equal(
np.array(time_feature["transform"](obj.df["timeline_index"].values)),
np.array(standardised["standardised_pd"]),
)

0 comments on commit eb56b34

Please sign in to comment.