diff --git a/ddsp/training/data.py b/ddsp/training/data.py index fb0c4d37..52a56bac 100644 --- a/ddsp/training/data.py +++ b/ddsp/training/data.py @@ -411,7 +411,11 @@ def features_dict(self): class Urmp(TFRecordProvider): """Urmp training set.""" - def __init__(self, base_dir, instrument_key='tpt', split='train'): + def __init__(self, + base_dir, + instrument_key='tpt', + split='train', + suffix=None): """URMP dataset for either a specific instrument or all instruments. Args: @@ -420,19 +424,30 @@ def __init__(self, base_dir, instrument_key='tpt', split='train'): ['all', 'bn', 'cl', 'db', 'fl', 'hn', 'ob', 'sax', 'tba', 'tbn', 'tpt', 'va', 'vc', 'vn']. split: Choices include ['train', 'test']. + suffix: Choices include [None, 'batched', 'unbatched'], but broadly + applies to any suffix adding to the file pattern. + When suffix is not None, will add "_{suffix}" to the file pattern. + This option is used in gs://magentadata/datasets/urmp/urmp_20210324. + With the "batched" suffix, the dataloader will load tfrecords + containing segmented audio samples in 4 seconds. With the "unbatched" + suffix, the dataloader will load tfrecords containing unsegmented + samples which could be used for learning note sequence in URMP dataset. + """ self.instrument_key = instrument_key self.split = split self.base_dir = base_dir + self.suffix = '' if suffix is None else '_' + suffix super().__init__() @property def default_file_pattern(self): if self.instrument_key == 'all': - file_pattern = 'all_instruments_{}.tfrecord*'.format(self.split) + file_pattern = 'all_instruments_{}{}.tfrecord*'.format( + self.split, self.suffix) else: - file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}.tfrecord*'.format( - self.instrument_key, self.split) + file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}{}.tfrecord*'.format( + self.instrument_key, self.split, self.suffix) return os.path.join(self.base_dir, file_pattern)