Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log sample rate in .nam files #284

Merged
merged 4 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bin/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def main_inner(
dataset_validation = init_dataset(data_config, Split.VALIDATION)
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
raise RuntimeError(
"Train and validation data loaders have different data set sample rates: "
f"{train_dataloader.dataset.sample_rate}, "
f"{val_dataloader.dataset.sample_rate}"
)

# ckpt_path = Path(outdir, "checkpoints")
# ckpt_path.mkdir()
Expand All @@ -204,6 +210,7 @@ def main_inner(
)
model.cpu()
model.eval()
model.net.sample_rate = train_dataloader.dataset.sample_rate
if make_plots:
plot(
model,
Expand Down
41 changes: 34 additions & 7 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Steven Atkinson ([email protected])

import abc
import logging
from collections import namedtuple
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -19,6 +20,8 @@

from ._core import InitializableFromConfig

logger = logging.getLogger(__name__)

_REQUIRED_SAMPWIDTH = 3
REQUIRED_RATE = 48_000
_REQUIRED_CHANNELS = 1 # Mono
Expand Down Expand Up @@ -94,7 +97,7 @@ def wav_to_np(
required_shape, # Expected
arr_premono.shape, # Actual
f"Mismatched shapes. Expected {required_shape}, but this is "
f"{arr_premono.shape}!"
f"{arr_premono.shape}!",
)
# sampwidth fine--we're just casting to 32-bit float anyways
arr = arr_premono[:, 0]
Expand Down Expand Up @@ -122,8 +125,8 @@ def np_to_wav(
filename: Union[str, Path],
rate: int = 48_000,
sampwidth: int = 3,
scale = None,
**kwargs
scale=None,
**kwargs,
):
if wavio.__version__ <= "0.0.4" and scale is None:
scale = "none"
Expand All @@ -133,7 +136,7 @@ def np_to_wav(
rate,
scale=scale,
sampwidth=sampwidth,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -235,7 +238,8 @@ def __init__(
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
rate: int = REQUIRED_RATE,
sample_rate: Optional[int] = None,
rate: Optional[int] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
Expand Down Expand Up @@ -272,13 +276,14 @@ def __init__(
"""
self._validate_x_y(x, y)
self._validate_start_stop(x, y, start, stop)
self._sample_rate = self._validate_sample_rate(sample_rate, rate)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
)
if require_input_pre_silence is not None:
self._validate_preceding_silence(
x, start, int(require_input_pre_silence * rate)
x, start, int(require_input_pre_silence * self._sample_rate)
)
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
Expand All @@ -293,7 +298,6 @@ def __init__(
self._y = y
self._nx = nx
self._ny = ny if ny is not None else len(x) - nx + 1
self._rate = rate

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -317,6 +321,10 @@ def __len__(self) -> int:
def ny(self) -> int:
return self._ny

@property
def sample_rate(self) -> Optional[float]:
return self._sample_rate

@property
def x(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -444,6 +452,25 @@ def _apply_delay_float(
y = _interpolate_delay(y, delay, method)
return x, y

@classmethod
def _validate_sample_rate(
cls, sample_rate: Optional[float], rate: Optional[int]
) -> float:
if sample_rate is None and rate is None: # Default value
return REQUIRED_RATE
if rate is not None:
if sample_rate is not None:
raise ValueError(
"Provided both sample_rate and rate. Provide only sample_rate!"
)
else:
logger.warning(
"Use of 'rate' is deprecated and will be removed. Use sample_rate instead"
)
return float(rate)
else:
return sample_rate

@classmethod
def _validate_start_stop(
self,
Expand Down
15 changes: 15 additions & 0 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@


class _Base(nn.Module, InitializableFromConfig, Exportable):
def __init__(self, sample_rate: Optional[float] = None):
super().__init__()
self.sample_rate = sample_rate

@abc.abstractproperty
def pad_start_default(self) -> bool:
pass
Expand All @@ -45,6 +49,17 @@ def _metadata_loudness_x(cls) -> torch.Tensor:
)
)

def _get_export_dict(self):
d = super()._get_export_dict()
sample_rate_key = "sample_rate"
if sample_rate_key in d:
raise RuntimeError(
"Model wants to put 'sample_rate' into model export dict, but the key "
"is already taken!"
)
d[sample_rate_key] = self.sample_rate
return d

def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
"""
How loud is this model when given a standardized input?
Expand Down
2 changes: 1 addition & 1 deletion nam/models/_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Model version is independent from package version as of package version 0.5.2 so that
# the API of the package can iterate at a different pace from that of the model files.
_MODEL_VERSION = "0.5.1"
_MODEL_VERSION = "0.5.2"


def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
Expand Down
3 changes: 2 additions & 1 deletion nam/models/conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def __init__(
*args,
train_strategy: TrainStrategy = default_train_strategy,
ir: Optional[_IR] = None,
sample_rate: Optional[float] = None,
**kwargs,
):
super().__init__()
super().__init__(sample_rate=sample_rate)
self._net = _conv_net(*args, **kwargs)
assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported"
self._train_strategy = train_strategy
Expand Down
4 changes: 2 additions & 2 deletions nam/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


class Linear(BaseNet):
def __init__(self, receptive_field: int, bias: bool = False):
super().__init__()
def __init__(self, receptive_field: int, *args, bias: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._net = nn.Conv1d(1, 1, receptive_field, bias=bias)

@property
Expand Down
3 changes: 2 additions & 1 deletion nam/models/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
train_burn_in: Optional[int] = None,
train_truncate: Optional[int] = None,
input_size: int = 1,
sample_rate: Optional[float] = None,
**lstm_kwargs,
):
"""
Expand All @@ -144,7 +145,7 @@ def __init__(
:param input_size: Usually 1 (mono input). A catnet extending this might change
it and provide the parametric inputs as additional input dimensions.
"""
super().__init__()
super().__init__(sample_rate=sample_rate)
if "batch_first" in lstm_kwargs:
raise ValueError("batch_first cannot be set.")
self._input_size = input_size
Expand Down
4 changes: 2 additions & 2 deletions nam/models/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class WaveNet(BaseNet):
def __init__(self, *args, **kwargs):
super().__init__()
def __init__(self, *args, sample_rate: Optional[float] = None, **kwargs):
super().__init__(sample_rate=sample_rate)
self._net = _WaveNet(*args, **kwargs)

@property
Expand Down
12 changes: 12 additions & 0 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,10 +873,21 @@ def train(
)

print("Starting training. It's time to kick ass and chew bubblegum!")
# Issue:
# * Model needs sample rate from data, but data set needs nx from model.
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
model = Model.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
raise RuntimeError(
"Train and validation data loaders have different data set sample rates: "
f"{train_dataloader.dataset.sample_rate}, "
f"{val_dataloader.dataset.sample_rate}"
)

trainer = pl.Trainer(
callbacks=[
Expand Down Expand Up @@ -904,6 +915,7 @@ def train(
)
model.cpu()
model.eval()
model.net.sample_rate = train_dataloader.dataset.sample_rate

def window_kwargs(version: Version):
if version.major == 1:
Expand Down
15 changes: 12 additions & 3 deletions tests/test_nam/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def test_init(self):
x, y = self._create_xy()
data.Dataset(x, y, 3, None)

def test_init_sample_rate(self):
x, y = self._create_xy()
sample_rate = 48_000.0
d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
assert hasattr(d, "sample_rate")
assert isinstance(d.sample_rate, float)
assert d.sample_rate == sample_rate

def test_init_zero_delay(self):
"""
Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
Expand Down Expand Up @@ -285,6 +293,7 @@ def test_np_to_wav_to_np_scale_arg(self, tmpdir):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)


def test_audio_mismatch_shapes_in_order():
"""
https://github.com/sdatkinson/neural-amp-modeler/issues/257
Expand All @@ -293,12 +302,12 @@ def test_audio_mismatch_shapes_in_order():
num_channels = 1

x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)]

with TemporaryDirectory() as tmpdir:
y_path = Path(tmpdir, "y.wav")
data.np_to_wav(y, y_path)
f = lambda: data.wav_to_np(y_path, required_shape=x.shape)

with pytest.raises(data.AudioShapeMismatchError) as e:
f()

Expand All @@ -309,7 +318,7 @@ def test_audio_mismatch_shapes_in_order():
# x is loaded first; we expect that y matches.
assert e.shape_expected == (x_samples, num_channels)
assert e.shape_actual == (y_samples, num_channels)


if __name__ == "__main__":
pytest.main()