diff --git a/src/nlpsig/data_preparation.py b/src/nlpsig/data_preparation.py index 58adeb2..6404082 100644 --- a/src/nlpsig/data_preparation.py +++ b/src/nlpsig/data_preparation.py @@ -1211,7 +1211,7 @@ def get_torch_path_for_SeqSigNet( include_features_in_input: bool, include_embedding_in_input: bool, reduced_embeddings: bool = False, - ) -> dict[str, torch.tensor | int]: + ) -> dict[str, dict[str, torch.tensor] | int | None]: """ Returns a `torch.tensor` object that can be passed into `nlpsig_networks.SeqSigNet` model.