Skip to content

Commit

Permalink
add path_indices as argument to swnu and seqsignet path constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Aug 29, 2023
1 parent 01800ec commit ad8dacd
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/nlpsig/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,7 @@ def get_torch_path_for_SWNUNetwork(
include_features_in_input: bool,
include_embedding_in_input: bool,
reduced_embeddings: bool = False,
path_indices: list | np.array | None = None,
) -> dict[str, dict[str, torch.tensor] | int | None]:
"""
Returns a `torch.tensor` object that can be passed into `nlpsig_networks.SWNUNetwork` model.
Expand All @@ -1025,6 +1026,10 @@ def get_torch_path_for_SWNUNetwork(
Whether or not to concatenate the dimension reduced embeddings, by default False.
This is ignored if we created a path for each if in `.id_column`,
i.e. `.pad_method='id'`.
path_indices : list | np.array | None, optional
If not None, will return the path for the indices specified in `path_indices`.
If None, will return the path for all indices in `.df` (or all ids in `.id_column`
if `pad_by="id"`), by default None.
Returns
-------
Expand Down Expand Up @@ -1148,6 +1153,11 @@ def get_torch_path_for_SWNUNetwork(
self.get_path(include_features=include_features_in_path)
)

if path_indices is not None:
path = path[path_indices, :, :]
if features is not None:
features = features[path_indices, :]

return {
"x_data": {"path": path, "features": features},
"input_channels": path.shape[2],
Expand Down Expand Up @@ -1211,6 +1221,7 @@ def get_torch_path_for_SeqSigNet(
include_features_in_input: bool,
include_embedding_in_input: bool,
reduced_embeddings: bool = False,
path_indices: list | np.array | None = None,
) -> dict[str, dict[str, torch.tensor] | int | None]:
"""
Returns a `torch.tensor` object that can be passed into `nlpsig_networks.SeqSigNet` model.
Expand Down Expand Up @@ -1240,6 +1251,10 @@ def get_torch_path_for_SeqSigNet(
Whether or not to concatenate the dimension reduced embeddings, by default False.
This is ignored if we created a path for each if in `.id_column`,
i.e. `.pad_method='id'`.
path_indices : list | np.array | None, optional
If not None, will return the path for the indices specified in `path_indices`.
If None, will return the path for all indices in `.df` (or all ids in `.id_column`
if `pad_by="id"`), by default None.
Returns
-------
Expand Down Expand Up @@ -1271,6 +1286,7 @@ def get_torch_path_for_SeqSigNet(
include_features_in_input=include_features_in_input,
include_embedding_in_input=include_embedding_in_input,
reduced_embeddings=reduced_embeddings,
path_indices=path_indices,
)

# taking windows of the path created (determined by shift, window_size, n)
Expand Down

0 comments on commit ad8dacd

Please sign in to comment.