diff --git a/amt/audio.py b/amt/audio.py index 599885c..8ce6f4c 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -23,6 +23,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token + def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. @@ -50,6 +51,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): return array + # Refactor default params are stored in config.json class AudioTransform(torch.nn.Module): def __init__(