diff --git a/examples/encoding/audio_encoding.py b/examples/encoding/audio_encoding.py index 10a08ffd..8bcc1e30 100644 --- a/examples/encoding/audio_encoding.py +++ b/examples/encoding/audio_encoding.py @@ -78,13 +78,15 @@ def make_sinewave() -> tuple[torch.Tensor, int]: # %% # The encoder supports some encoding options that allow you to change how to # data is encoded. For example, we can decide to encode our mono data (1 -# channel) into stereo data (2 channels): -encoded_samples = encoder.to_tensor(format="wav", num_channels=2) +# channel) into stereo data (2 channels), and to specify an output sample rate: + +desired_sample_rate = 32000 +encoded_samples = encoder.to_tensor(format="wav", num_channels=2, sample_rate=desired_sample_rate) stereo_samples_back = AudioDecoder(encoded_samples).get_all_samples() print(stereo_samples_back) -play_audio(stereo_samples_back.data, rate=stereo_samples_back.sample_rate) +play_audio(stereo_samples_back.data, rate=desired_sample_rate) # %% # Check the docstring of the encoding methods to learn about the different diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 90579432..9f370c6f 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -105,7 +105,7 @@ AudioEncoder::AudioEncoder( int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions) - : samples_(validateSamples(samples)) { + : samples_(validateSamples(samples)), inSampleRate_(sampleRate) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -128,7 +128,7 @@ AudioEncoder::AudioEncoder( ", make sure it's a valid path? ", getFFMPEGErrorStringFromErrorCode(status)); - initializeEncoder(sampleRate, audioStreamOptions); + initializeEncoder(audioStreamOptions); } AudioEncoder::AudioEncoder( @@ -138,6 +138,7 @@ AudioEncoder::AudioEncoder( std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), + inSampleRate_(sampleRate), avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -155,11 +156,10 @@ AudioEncoder::AudioEncoder( avFormatContext_->pb = avioContextHolder_->getAVIOContext(); - initializeEncoder(sampleRate, audioStreamOptions); + initializeEncoder(audioStreamOptions); } void AudioEncoder::initializeEncoder( - int sampleRate, const AudioStreamOptions& audioStreamOptions) { // We use the AVFormatContext's default codec for that // specific format/container. @@ -187,8 +187,9 @@ void AudioEncoder::initializeEncoder( // not related to the input sampes. setDefaultChannelLayout(avCodecContext_, outNumChannels_); - validateSampleRate(*avCodec, sampleRate); - avCodecContext_->sample_rate = sampleRate; + outSampleRate_ = audioStreamOptions.sampleRate.value_or(inSampleRate_); + validateSampleRate(*avCodec, outSampleRate_); + avCodecContext_->sample_rate = outSampleRate_; // Input samples are expected to be FLTP. Not all encoders support FLTP, so we // may need to convert the samples into a supported output sample format, @@ -213,6 +214,21 @@ void AudioEncoder::initializeEncoder( "avcodec_parameters_from_context failed: ", getFFMPEGErrorStringFromErrorCode(status)); streamIndex_ = avStream->index; + + // If sample rate conversion is needed and the encoder doesn't support + // variable frame size, we need to create an intermediate FIFO. See + // [Encoding loop, sample rate conversion and FIFO]. + if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0) && + (inSampleRate_ != outSampleRate_)) { + // frame_size * 2 is a decent default size. FFmpeg automatically + // re-allocates the fifo if more space is needed. + auto avAudioFifo = av_audio_fifo_alloc( + avCodecContext_->sample_fmt, + outNumChannels_, + avCodecContext_->frame_size * 2); + TORCH_CHECK(avAudioFifo != nullptr, "Couldn't create AVAudioFifo."); + avAudioFifo_.reset(avAudioFifo); + } } torch::Tensor AudioEncoder::encodeToTensor() { @@ -230,24 +246,15 @@ void AudioEncoder::encode() { TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); encodeWasCalled_ = true; - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio int numSamplesAllocatedPerFrame = avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256; - avFrame->nb_samples = numSamplesAllocatedPerFrame; - avFrame->format = AV_SAMPLE_FMT_FLTP; - avFrame->sample_rate = avCodecContext_->sample_rate; + UniqueAVFrame avFrame = allocateAVFrame( + numSamplesAllocatedPerFrame, + inSampleRate_, + static_cast(samples_.sizes()[0]), + AV_SAMPLE_FMT_FLTP); avFrame->pts = 0; - // We set the channel layout of the frame to the default layout corresponding - // to the input samples' number of channels - setDefaultChannelLayout(avFrame, static_cast(samples_.sizes()[0])); - - auto status = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't allocate avFrame's buffers: ", - getFFMPEGErrorStringFromErrorCode(status)); AutoAVPacket autoAVPacket; @@ -257,19 +264,13 @@ void AudioEncoder::encode() { int numBytesPerSample = static_cast(samples_.element_size()); int numBytesPerChannel = numSamples * numBytesPerSample; - status = avformat_write_header(avFormatContext_.get(), nullptr); + auto status = avformat_write_header(avFormatContext_.get(), nullptr); TORCH_CHECK( status == AVSUCCESS, "Error in avformat_write_header: ", getFFMPEGErrorStringFromErrorCode(status)); while (numEncodedSamples < numSamples) { - status = av_frame_make_writable(avFrame.get()); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't make AVFrame writable: ", - getFFMPEGErrorStringFromErrorCode(status)); - int numSamplesToEncode = std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples); int numBytesToEncode = numSamplesToEncode * numBytesPerSample; @@ -290,10 +291,9 @@ void AudioEncoder::encode() { avFrame->nb_samples = numSamplesToEncode; UniqueAVFrame convertedAVFrame = maybeConvertAVFrame(avFrame); - encodeInnerLoop(autoAVPacket, convertedAVFrame); + encodeFrameThroughFifo(autoAVPacket, convertedAVFrame); numEncodedSamples += numSamplesToEncode; - avFrame->pts += static_cast(numSamplesToEncode); } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); @@ -309,7 +309,8 @@ void AudioEncoder::encode() { UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { if (static_cast(avFrame->format) == avCodecContext_->sample_fmt && - getNumChannels(avFrame) == outNumChannels_) { + getNumChannels(avFrame) == outNumChannels_ && + avFrame->sample_rate == outSampleRate_) { // Note: the clone references the same underlying data, it's a cheap copy. return UniqueAVFrame(av_frame_clone(avFrame.get())); } @@ -318,8 +319,8 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { swrContext_.reset(createSwrContext( static_cast(avFrame->format), avCodecContext_->sample_fmt, - avFrame->sample_rate, // No sample rate conversion avFrame->sample_rate, + outSampleRate_, avFrame, outNumChannels_)); } @@ -327,22 +328,75 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { swrContext_, avFrame, avCodecContext_->sample_fmt, - avFrame->sample_rate, // No sample rate conversion + outSampleRate_, outNumChannels_); + + if (avFrame->sample_rate == outSampleRate_) { + TORCH_CHECK( + convertedAVFrame->nb_samples == avFrame->nb_samples, + "convertedAVFrame->nb_samples=", + convertedAVFrame->nb_samples, + " differs from ", + "avFrame->nb_samples=", + avFrame->nb_samples, + "This is unexpected, please report on the TorchCodec bug tracker."); + } + return convertedAVFrame; +} + +void AudioEncoder::encodeFrameThroughFifo( + AutoAVPacket& autoAVPacket, + const UniqueAVFrame& avFrame, + bool andFlushFifo) { + if (avAudioFifo_ == nullptr) { + encodeFrame(autoAVPacket, avFrame); + return; + } + int numSamplesWritten = av_audio_fifo_write( + avAudioFifo_.get(), + reinterpret_cast(avFrame->data), + avFrame->nb_samples); TORCH_CHECK( - convertedAVFrame->nb_samples == avFrame->nb_samples, - "convertedAVFrame->nb_samples=", - convertedAVFrame->nb_samples, - " differs from ", - "avFrame->nb_samples=", + numSamplesWritten == avFrame->nb_samples, + "Tried to write ", avFrame->nb_samples, - "This is unexpected, please report on the TorchCodec bug tracker."); - return convertedAVFrame; + " samples, but only wrote ", + numSamplesWritten); + + UniqueAVFrame newavFrame = allocateAVFrame( + avCodecContext_->frame_size, + outSampleRate_, + outNumChannels_, + avCodecContext_->sample_fmt); + + while (av_audio_fifo_size(avAudioFifo_.get()) >= + (andFlushFifo ? 1 : avCodecContext_->frame_size)) { + int samplesToRead = std::min( + av_audio_fifo_size(avAudioFifo_.get()), newavFrame->nb_samples); + int numSamplesRead = av_audio_fifo_read( + avAudioFifo_.get(), + reinterpret_cast(newavFrame->data), + samplesToRead); + TORCH_CHECK( + numSamplesRead == samplesToRead, + "Tried to read ", + samplesToRead, + " samples, but only read ", + numSamplesRead); + + newavFrame->nb_samples = numSamplesRead; + encodeFrame(autoAVPacket, newavFrame); + } } -void AudioEncoder::encodeInnerLoop( +void AudioEncoder::encodeFrame( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { + if (avFrame != nullptr) { + avFrame->pts = lastEncodedAVFramePts_; + lastEncodedAVFramePts_ += avFrame->nb_samples; + } + auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get()); TORCH_CHECK( status == AVSUCCESS, @@ -381,11 +435,39 @@ void AudioEncoder::encodeInnerLoop( } } +void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) { + // Similar to the decoder's method with the same name, but for encoding this + // time. That is, when sample conversion is involved, libswresample may have + // buffered some samples that we now need to flush and send to the encoder. + if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) { + return; + } + TORCH_CHECK( + swrContext_ != nullptr, + "swrContext is null, but sample rate conversion is needed. ", + "This is unexpected, please report on the TorchCodec bug tracker."); + + int numRemainingSamples = // this is an upper bound + swr_get_out_samples(swrContext_.get(), 0); + if (numRemainingSamples == 0) { + return; + } + + UniqueAVFrame avFrame = allocateAVFrame( + numRemainingSamples, + outSampleRate_, + outNumChannels_, + avCodecContext_->sample_fmt); + int actualNumRemainingSamples = swr_convert( + swrContext_.get(), avFrame->data, avFrame->nb_samples, NULL, 0); + avFrame->nb_samples = actualNumRemainingSamples; + + encodeFrameThroughFifo(autoAVPacket, avFrame, /*andFlushFifo=*/true); +} + void AudioEncoder::flushBuffers() { - // We flush the main FFmpeg buffers, but not swresample buffers. Flushing - // swresample is only necessary when converting sample rates, which we don't - // do for encoding. AutoAVPacket autoAVPacket; - encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr)); + maybeFlushSwrBuffers(autoAVPacket); + encodeFrame(autoAVPacket, UniqueAVFrame(nullptr)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index ea500901..e15d5ac3 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -11,12 +11,6 @@ class AudioEncoder { AudioEncoder( const torch::Tensor& samples, - // TODO-ENCODING: update this comment when we support an output sample - // rate. This will become the input sample rate. - // The *output* sample rate. We can't really decide for the user what it - // should be. Particularly, the sample rate of the input samples should - // match this, and that's up to the user. If sample rates don't match, - // encoding will still work but audio will be distorted. int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions); @@ -30,13 +24,14 @@ class AudioEncoder { torch::Tensor encodeToTensor(); private: - void initializeEncoder( - int sampleRate, - const AudioStreamOptions& audioStreamOptions); + void initializeEncoder(const AudioStreamOptions& audioStreamOptions); UniqueAVFrame maybeConvertAVFrame(const UniqueAVFrame& avFrame); - void encodeInnerLoop( + void encodeFrameThroughFifo( AutoAVPacket& autoAVPacket, - const UniqueAVFrame& srcAVFrame); + const UniqueAVFrame& avFrame, + bool andFlushFifo = false); + void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); + void maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket); void flushBuffers(); UniqueEncodingAVFormatContext avFormatContext_; @@ -45,13 +40,72 @@ class AudioEncoder { UniqueSwrContext swrContext_; AudioStreamOptions audioStreamOptions; + const torch::Tensor samples_; + int outNumChannels_ = -1; + int outSampleRate_ = -1; + int inSampleRate_ = -1; - const torch::Tensor samples_; + UniqueAVAudioFifo avAudioFifo_; // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; bool encodeWasCalled_ = false; + int64_t lastEncodedAVFramePts_ = 0; }; } // namespace facebook::torchcodec + +/* clang-format off */ +// +// Note: [Encoding loop, sample rate conversion and FIFO] +// +// The input samples are in a given format, sample rate, and number of channels. +// We may want to change these properties before encoding. The conversion is +// done in maybeConvertAVFrame() and we rely on libswresample. When sample rate +// conversion is needed, this means two things: +// - swr will be storing samples in its internal buffers, which we'll need to +// flush at the very end of the encoding process. +// - the converted AVFrame we get back from maybeConvertAVFrame() typically +// won't have the same number of samples as the original AVFrame. And that's +// a problem, because some encoders expect AVFrames with a specific and +// constant number of samples. If we were to send it as-is, we'd get an error +// in avcodec_send_frame(). In order to feed the encoder with AVFrames +// with the expected number of samples, we go through an intermediate FIFO +// from which we can pull the exact number of samples that we need. Note that +// this involves at least 2 additional copies. +// +// To be clear, the FIFO is only used if BOTH the following conditions are met: +// - sample rate conversion is needed (inSampleRate_ != outSampleRate_) +// - the encoder expects a specific number of samples per AVFrame (fixed frame size) +// This is not the case for all encoders, e.g. WAV doesn't care about frame size. +// +// ┌─One─iteration─of─main─encoding─loop─(encode())───────────────────────────────────────────┐ +// │ │ +// │ Converts: │ +// │ - num channels │ +// │ - format │ +// │ - sample rate │ +// │ If sample rate, │ +// │ stores data in │ +// │ swr buffers │ +// │ which will need │ +// │ to be flushed │ +// │ │ +// │ ▲ │ +// │ │ ┌─EncodeFrameThroughFifo()──────────────┐│ +// │ │ │ ││ +// │ AVFrame ──────► MaybeConvertAVFrame()───▲──│─┬──────────────┬──▲────►encodeFrame() ││ +// │ with │ │ │ │ │ ││ +// │ input │ │ │ │ │ ││ +// │ samples │ │ │ │ │ ││ +// │ │ │ │ │ │ ││ +// │ │ │ └────► FIFO ───┘ │ ││ +// │ │ └───────────────────┼───────────────────┘│ +// └──────────────────────────────────────────────┼──────────────────────┼────────────────────┘ +// │ │ +// AVFrame from maybeFlushSwrBuffers() ───┘ │ +// Only if sample rate conversion was needed +// nullptr, to flush +// FFmpeg buffers +/* clang-format on */ diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index f47b4275..4f663f10 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -210,6 +210,33 @@ void setChannelLayout( #endif } +UniqueAVFrame allocateAVFrame( + int numSamples, + int sampleRate, + int numChannels, + AVSampleFormat sampleFormat) { + auto avFrame = UniqueAVFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); + + avFrame->nb_samples = numSamples; + avFrame->sample_rate = sampleRate; + setDefaultChannelLayout(avFrame, numChannels); + avFrame->format = sampleFormat; + auto status = av_frame_get_buffer(avFrame.get(), 0); + + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't allocate avFrame's buffers: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = av_frame_make_writable(avFrame.get()); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't make AVFrame writable: ", + getFFMPEGErrorStringFromErrorCode(status)); + return avFrame; +} + SwrContext* createSwrContext( AVSampleFormat srcSampleFormat, AVSampleFormat outSampleFormat, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 07b7443e..38a0c099 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -15,6 +15,7 @@ extern "C" { #include #include #include +#include #include #include #include @@ -73,6 +74,8 @@ using UniqueSwsContext = std::unique_ptr>; using UniqueSwrContext = std::unique_ptr>; +using UniqueAVAudioFifo = std:: + unique_ptr>; // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: @@ -160,6 +163,12 @@ void setChannelLayout( const UniqueAVFrame& srcAVFrame, int desiredNumChannels); +UniqueAVFrame allocateAVFrame( + int numSamples, + int sampleRate, + int numChannels, + AVSampleFormat sampleFormat); + SwrContext* createSwrContext( AVSampleFormat srcSampleFormat, AVSampleFormat desiredSampleFormat, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 4aa68a3b..b8e6a259 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -29,9 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( - "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None) -> ()"); + "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None) -> Tensor"); + "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -392,12 +392,14 @@ void encode_audio_to_file( int64_t sample_rate, std::string_view file_name, std::optional bit_rate = std::nullopt, - std::optional num_channels = std::nullopt) { + std::optional num_channels = std::nullopt, + std::optional desired_sample_rate = std::nullopt) { // TODO Fix implicit int conversion: // https://github.com/pytorch/torchcodec/issues/679 AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; + audioStreamOptions.sampleRate = desired_sample_rate; AudioEncoder( samples, validateSampleRate(sample_rate), file_name, audioStreamOptions) .encode(); @@ -408,13 +410,15 @@ at::Tensor encode_audio_to_tensor( int64_t sample_rate, std::string_view format, std::optional bit_rate = std::nullopt, - std::optional num_channels = std::nullopt) { + std::optional num_channels = std::nullopt, + std::optional desired_sample_rate = std::nullopt) { auto avioContextHolder = std::make_unique(); // TODO Fix implicit int conversion: // https://github.com/pytorch/torchcodec/issues/679 AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; + audioStreamOptions.sampleRate = desired_sample_rate; return AudioEncoder( samples, validateSampleRate(sample_rate), diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a68b51e2..3c9fad43 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -168,6 +168,7 @@ def encode_audio_to_file_abstract( filename: str, bit_rate: Optional[int] = None, num_channels: Optional[int] = None, + desired_sample_rate: Optional[int] = None, ) -> None: return @@ -179,6 +180,7 @@ def encode_audio_to_tensor_abstract( format: str, bit_rate: Optional[int] = None, num_channels: Optional[int] = None, + desired_sample_rate: Optional[int] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index 742ea908..fb622cbd 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -15,7 +15,9 @@ class AudioEncoder: tensor of shape ``(num_channels, num_samples)``, or a 1D tensor in which case ``num_channels = 1`` is assumed. Values must be float values in ``[-1, 1]``. - sample_rate (int): The sample rate of the **input** ``samples``. + sample_rate (int): The sample rate of the **input** ``samples``. The + sample rate of the necoded output can be specified using the + encoding methods (``to_file``, etc.). """ def __init__(self, samples: Tensor, *, sample_rate: int): @@ -45,6 +47,7 @@ def to_file( *, bit_rate: Optional[int] = None, num_channels: Optional[int] = None, + sample_rate: Optional[int] = None, ) -> None: """Encode samples into a file. @@ -59,6 +62,8 @@ def to_file( num_channels (int, optional): The number of channels of the encoded output samples. By default, the number of channels of the input ``samples`` is used. + sample_rate (int, optional): The sample rate of the encoded output. + By default, the sample rate of the input ``samples`` is used. """ _core.encode_audio_to_file( samples=self._samples, @@ -66,6 +71,7 @@ def to_file( filename=dest, bit_rate=bit_rate, num_channels=num_channels, + desired_sample_rate=sample_rate, ) def to_tensor( @@ -74,6 +80,7 @@ def to_tensor( *, bit_rate: Optional[int] = None, num_channels: Optional[int] = None, + sample_rate: Optional[int] = None, ) -> Tensor: """Encode samples into raw bytes, as a 1D uint8 Tensor. @@ -87,6 +94,8 @@ def to_tensor( num_channels (int, optional): The number of channels of the encoded output samples. By default, the number of channels of the input ``samples`` is used. + sample_rate (int, optional): The sample rate of the encoded output. + By default, the sample rate of the input ``samples`` is used. Returns: Tensor: The raw encoded bytes as 1D uint8 Tensor. @@ -97,4 +106,5 @@ def to_tensor( format=format, bit_rate=bit_rate, num_channels=num_channels, + desired_sample_rate=sample_rate, ) diff --git a/test/test_encoders.py b/test/test_encoders.py index 284d053e..d432263a 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -2,6 +2,8 @@ import os import re import subprocess +import sys +from functools import partial from pathlib import Path import pytest @@ -11,6 +13,7 @@ from torchcodec.encoders import AudioEncoder from .utils import ( + assert_tensor_close_on_at_least, get_ffmpeg_major_version, in_fbcode, NASA_AUDIO_MP3, @@ -150,6 +153,10 @@ def test_bad_input_parametrized(self, method, tmp_path): decoder = AudioEncoder( self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate ) + with pytest.raises(RuntimeError, match="invalid sample rate=10"): + getattr(decoder, method)(sample_rate=10, **valid_params) + with pytest.raises(RuntimeError, match="invalid sample rate=99999999"): + getattr(decoder, method)(sample_rate=99999999, **valid_params) with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): getattr(decoder, method)(**valid_params, bit_rate=-1) @@ -203,6 +210,7 @@ def test_round_trip(self, method, format, tmp_path): @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) + @pytest.mark.parametrize("sample_rate", (8_000, 32_000)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) @pytest.mark.parametrize("method", ("to_file", "to_tensor")) def test_against_cli( @@ -210,6 +218,7 @@ def test_against_cli( asset, bit_rate, num_channels, + sample_rate, format, method, tmp_path, @@ -227,6 +236,7 @@ def test_against_cli( ["ffmpeg", "-i", str(asset.path)] + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) + (["-ac", f"{num_channels}"] if num_channels is not None else []) + + ["-ar", f"{sample_rate}"] + [ str(encoded_by_ffmpeg), ], @@ -235,8 +245,9 @@ def test_against_cli( ) encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) - - params = dict(bit_rate=bit_rate, num_channels=num_channels) + params = dict( + bit_rate=bit_rate, num_channels=num_channels, sample_rate=sample_rate + ) if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" encoder.to_file(dest=str(encoded_by_us), **params) @@ -253,7 +264,12 @@ def test_against_cli( if format in ("flac", "mp3"): assert "Application provided invalid" not in captured.err - if format == "wav": + assert_close = torch.testing.assert_close + if sample_rate != asset.sample_rate: + rtol, atol = 0, 1e-3 + if sys.platform == "darwin": + assert_close = partial(assert_tensor_close_on_at_least, percentage=99) + elif format == "wav": rtol, atol = 0, 1e-4 elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2: # Not sure why, this one needs slightly higher tol. With default @@ -265,7 +281,7 @@ def test_against_cli( rtol, atol = None, None samples_by_us = self.decode(encoded_by_us) samples_by_ffmpeg = self.decode(encoded_by_ffmpeg) - torch.testing.assert_close( + assert_close( samples_by_us.data, samples_by_ffmpeg.data, rtol=rtol, diff --git a/test/utils.py b/test/utils.py index e7ce12e5..e3368e3f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -53,7 +53,9 @@ def assert_frames_equal(*args, **kwargs): # Asserts that at least `percentage`% of the values are within the absolute tolerance. # Percentage is expected in [0, 100] (actually, [60, 100]) -def assert_tensor_close_on_at_least(actual_tensor, ref_tensor, *, percentage, atol): +def assert_tensor_close_on_at_least( + actual_tensor, ref_tensor, *, percentage, atol, **kwargs +): # In theory lower bound should be 0, but we want to make sure we don't # mistakenly pass percentage in [0, 1] assert 60 < percentage <= 100, (