diff --git a/common/setups/rasr/util/nn.py b/common/setups/rasr/util/nn.py index c4a56085e..f2bfe364a 100644 --- a/common/setups/rasr/util/nn.py +++ b/common/setups/rasr/util/nn.py @@ -148,13 +148,15 @@ def get_crp(self, **kwargs): class OggZipHdfDataInput: def __init__( self, - oggzip_files: tk.Path, + oggzip_files: List[tk.Path], alignments: tk.Path, context_window: Dict, audio: Dict, - targets: str, + targets: Optional[str] = None, partition_epoch: int = 1, seq_ordering: str = "laplace:.1000", + ogg_args: Optional[Dict[str, Any]] = None, + acoustic_mixtures: Optional[tk.Path] = None, ): """ :param oggzip_files: @@ -172,6 +174,8 @@ def __init__( self.partition_epoch = partition_epoch self.seq_ordering = seq_ordering self.targets = targets + self.ogg_args = ogg_args + self.acoustic_mixtures = acoustic_mixtures def get_data_dict(self): return { @@ -188,10 +192,11 @@ def get_data_dict(self): "class": "OggZipDataset", "audio": self.audio, "partition_epoch": self.partition_epoch, - "path": tuple(self.oggzip_files.get_path()), + "path": self.oggzip_files, "seq_ordering": self.seq_ordering, "targets": self.targets, "use_cache_manager": True, + **(self.ogg_args or {}), }, }, "seq_order_control_dataset": "ogg",