Skip to content

Commit

Permalink
fix sample rate bug
Browse files Browse the repository at this point in the history
  • Loading branch information
faroit committed Nov 24, 2023
1 parent 2277849 commit 925fb9a
Showing 1 changed file with 72 additions and 81 deletions.
153 changes: 72 additions & 81 deletions musdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import tempfile


class DB(object):
"""
The musdb DB Object
Expand Down Expand Up @@ -71,15 +72,16 @@ class DB(object):
returns ``Track`` objects
"""

def __init__(
self,
root=None,
setup_file=None,
is_wav=False,
download=False,
subsets=['train', 'test'],
subsets=["train", "test"],
split=None,
sample_rate=None
sample_rate=None,
):
if root is None:
if download:
Expand All @@ -95,24 +97,22 @@ def __init__(
if setup_file is not None:
setup_path = op.join(self.root, setup_file)
else:
setup_path = os.path.join(
musdb.__path__[0], 'configs', 'mus.yaml'
)
setup_path = os.path.join(musdb.__path__[0], "configs", "mus.yaml")

with open(setup_path, 'r') as f:
with open(setup_path, "r") as f:
self.setup = yaml.safe_load(f)

if download:
self.url = self.setup['sample-url']
self.url = self.setup["sample-url"]
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
'You can use download=True to download a sample version of the dataset')
raise RuntimeError(
"Dataset not found."
+ "You can use download=True to download a sample version of the dataset"
)

if sample_rate != self.setup['sample_rate']:
self.sample_rate = sample_rate
self.sources_names = list(self.setup['sources'].keys())
self.targets_names = list(self.setup['targets'].keys())
self.sources_names = list(self.setup["sources"].keys())
self.targets_names = list(self.setup["targets"].keys())
self.is_wav = is_wav
self.tracks = self.load_mus_tracks(subsets=subsets, split=split)

Expand All @@ -139,14 +139,14 @@ def get_validation_track_indices(self, validation_track_names=None):
return a list of validation track indices
"""
if validation_track_names is None:
validation_track_names = self.setup['validation_tracks']
validation_track_names = self.setup["validation_tracks"]

return self.get_track_indices_by_names(validation_track_names)

def get_track_indices_by_names(self, names):
"""Returns musdb track indices by track name
Can be used to filter the musdb tracks for
Can be used to filter the musdb tracks for
a validation subset by trackname
Parameters
Expand All @@ -161,7 +161,7 @@ def get_track_indices_by_names(self, names):
"""
if isinstance(names, str):
names = [names]

return [[t.name for t in self.tracks].index(name) for name in names]

def load_mus_tracks(self, subsets=None, split=None):
Expand All @@ -187,56 +187,54 @@ def load_mus_tracks(self, subsets=None, split=None):
if isinstance(subsets, str):
subsets = [subsets]
else:
subsets = ['train', 'test']
subsets = ["train", "test"]

if subsets != ['train'] and split is not None:
if subsets != ["train"] and split is not None:
raise RuntimeError("Subset has to set to `train` when split is used")

tracks = []
for subset in subsets:
for subset in subsets:
subset_folder = op.join(self.root, subset)

for _, folders, files in os.walk(subset_folder):
if self.is_wav:
# parse pcm tracks and sort by name
for track_name in sorted(folders):
if subset == 'train':
if split == 'train' and track_name in self.setup['validation_tracks']:
if subset == "train":
if (
split == "train"
and track_name in self.setup["validation_tracks"]
):
continue
elif split == 'valid' and track_name not in self.setup['validation_tracks']:
elif (
split == "valid"
and track_name not in self.setup["validation_tracks"]
):
continue

track_folder = op.join(subset_folder, track_name)
# create new mus track
track = MultiTrack(
name=track_name,
path=op.join(
track_folder,
self.setup['mixture']
),
path=op.join(track_folder, self.setup["mixture"]),
subset=subset,
is_wav=self.is_wav,
stem_id=self.setup['stem_ids']['mixture'],
sample_rate=self.sample_rate
stem_id=self.setup["stem_ids"]["mixture"],
sample_rate=self.sample_rate,
)

# add sources to track
sources = {}
for src, source_file in list(
self.setup['sources'].items()
):
for src, source_file in list(self.setup["sources"].items()):
# create source object
abs_path = op.join(
track_folder,
source_file
)
abs_path = op.join(track_folder, source_file)
if os.path.exists(abs_path):
sources[src] = Source(
track,
name=src,
path=abs_path,
stem_id=self.setup['stem_ids'][src],
sample_rate=self.sample_rate
stem_id=self.setup["stem_ids"][src],
sample_rate=self.sample_rate,
)
track.sources = sources
track.targets = self.create_targets(track)
Expand All @@ -246,40 +244,43 @@ def load_mus_tracks(self, subsets=None, split=None):
else:
# parse stem files
for track_name in sorted(files):
if not track_name.endswith('.stem.mp4'):
if not track_name.endswith(".stem.mp4"):
continue
if subset == 'train':
if split == 'train' and track_name.split('.stem.mp4')[0] in self.setup['validation_tracks']:
if subset == "train":
if (
split == "train"
and track_name.split(".stem.mp4")[0]
in self.setup["validation_tracks"]
):
continue
elif split == 'valid' and track_name.split('.stem.mp4')[0] not in self.setup['validation_tracks']:
elif (
split == "valid"
and track_name.split(".stem.mp4")[0]
not in self.setup["validation_tracks"]
):
continue

# create new mus track
track = MultiTrack(
name=track_name.split('.stem.mp4')[0],
name=track_name.split(".stem.mp4")[0],
path=op.join(subset_folder, track_name),
subset=subset,
stem_id=self.setup['stem_ids']['mixture'],
stem_id=self.setup["stem_ids"]["mixture"],
is_wav=self.is_wav,
sample_rate=self.sample_rate
sample_rate=self.sample_rate,
)
# add sources to track
sources = {}
for src, source_file in list(
self.setup['sources'].items()
):
for src, source_file in list(self.setup["sources"].items()):
# create source object
abs_path = op.join(
subset_folder,
track_name
)
abs_path = op.join(subset_folder, track_name)
if os.path.exists(abs_path):
sources[src] = Source(
track,
name=src,
path=abs_path,
stem_id=self.setup['stem_ids'][src],
sample_rate=self.sample_rate
stem_id=self.setup["stem_ids"][src],
sample_rate=self.sample_rate,
)
track.sources = sources

Expand All @@ -292,9 +293,7 @@ def load_mus_tracks(self, subsets=None, split=None):
def create_targets(self, track):
# add targets to track
targets = collections.OrderedDict()
for name, target_srcs in list(
self.setup['targets'].items()
):
for name, target_srcs in list(self.setup["targets"].items()):
# add a list of target sources
target_sources = []
for source, gain in list(target_srcs.items()):
Expand All @@ -305,21 +304,11 @@ def create_targets(self, track):
target_sources.append(track.sources[source])
# add sources to target
if target_sources:
targets[name] = Target(
track,
sources=target_sources,
name=name
)
targets[name] = Target(track, sources=target_sources, name=name)

return targets

def save_estimates(
self,
user_estimates,
track,
estimates_dir,
write_stems=False
):
def save_estimates(self, user_estimates, track, estimates_dir, write_stems=False):
"""Writes `user_estimates` to disk while recreating the musdb file structure in that folder.
Parameters
Expand All @@ -331,9 +320,7 @@ def save_estimates(
estimates_dir : str,
output folder name where to save the estimates.
"""
track_estimate_dir = op.join(
estimates_dir, track.subset, track.name
)
track_estimate_dir = op.join(estimates_dir, track.subset, track.name)
if not os.path.exists(track_estimate_dir):
os.makedirs(track_estimate_dir)

Expand All @@ -343,11 +330,9 @@ def save_estimates(
# to be implemented
else:
for target, estimate in list(user_estimates.items()):
target_path = op.join(track_estimate_dir, target + '.wav')
target_path = op.join(track_estimate_dir, target + ".wav")
stempeg.write_audio(
path=target_path,
data=estimate,
sample_rate=track.rate
path=target_path, data=estimate, sample_rate=track.rate
)

def _check_exists(self):
Expand All @@ -366,13 +351,13 @@ def download(self, progress: bool = True, suffix: str = ".zip"):
pass
else:
raise
print('Downloading MUSDB 7s Sample Dataset to %s...' % self.root)
print("Downloading MUSDB 7s Sample Dataset to %s..." % self.root)

file_size = None
req = Request(self.url, headers={"User-Agent": "musdb_downloader"})
u = urlopen(req)
meta = u.info()
if hasattr(meta, 'getheaders'):
if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
Expand All @@ -385,7 +370,13 @@ def download(self, progress: bool = True, suffix: str = ".zip"):
f = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)

try:
with tqdm(total=file_size, disable=not progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
with tqdm(
total=file_size,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
Expand All @@ -394,10 +385,10 @@ def download(self, progress: bool = True, suffix: str = ".zip"):
pbar.update(len(buffer))

f.close()
zip_ref = zipfile.ZipFile(f.name, 'r')
zip_ref = zipfile.ZipFile(f.name, "r")
zip_ref.extractall(os.path.join(self.root))
zip_ref.close()
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
os.remove(f.name)

0 comments on commit 925fb9a

Please sign in to comment.