diff --git a/src/nlpsig/data_preparation.py b/src/nlpsig/data_preparation.py index 6404082..9bbfec8 100644 --- a/src/nlpsig/data_preparation.py +++ b/src/nlpsig/data_preparation.py @@ -1002,6 +1002,7 @@ def get_torch_path_for_SWNUNetwork( include_features_in_input: bool, include_embedding_in_input: bool, reduced_embeddings: bool = False, + path_indices: list | np.array | None = None, ) -> dict[str, dict[str, torch.tensor] | int | None]: """ Returns a `torch.tensor` object that can be passed into `nlpsig_networks.SWNUNetwork` model. @@ -1025,6 +1026,10 @@ def get_torch_path_for_SWNUNetwork( Whether or not to concatenate the dimension reduced embeddings, by default False. This is ignored if we created a path for each if in `.id_column`, i.e. `.pad_method='id'`. + path_indices : list | np.array | None, optional + If not None, will return the path for the indices specified in `path_indices`. + If None, will return the path for all indices in `.df` (or all ids in `.id_column` + if `pad_by="id"`), by default None. Returns ------- @@ -1148,6 +1153,11 @@ def get_torch_path_for_SWNUNetwork( self.get_path(include_features=include_features_in_path) ) + if path_indices is not None: + path = path[path_indices, :, :] + if features is not None: + features = features[path_indices, :] + return { "x_data": {"path": path, "features": features}, "input_channels": path.shape[2], @@ -1211,6 +1221,7 @@ def get_torch_path_for_SeqSigNet( include_features_in_input: bool, include_embedding_in_input: bool, reduced_embeddings: bool = False, + path_indices: list | np.array | None = None, ) -> dict[str, dict[str, torch.tensor] | int | None]: """ Returns a `torch.tensor` object that can be passed into `nlpsig_networks.SeqSigNet` model. @@ -1240,6 +1251,10 @@ def get_torch_path_for_SeqSigNet( Whether or not to concatenate the dimension reduced embeddings, by default False. This is ignored if we created a path for each if in `.id_column`, i.e. `.pad_method='id'`. + path_indices : list | np.array | None, optional + If not None, will return the path for the indices specified in `path_indices`. + If None, will return the path for all indices in `.df` (or all ids in `.id_column` + if `pad_by="id"`), by default None. Returns ------- @@ -1271,6 +1286,7 @@ def get_torch_path_for_SeqSigNet( include_features_in_input=include_features_in_input, include_embedding_in_input=include_embedding_in_input, reduced_embeddings=reduced_embeddings, + path_indices=path_indices, ) # taking windows of the path created (determined by shift, window_size, n)