From 88bca7332956d761bfe33e3e0e155f63da743647 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 5 Sep 2024 10:54:23 +0200 Subject: [PATCH] Support writing stereo wav files. --- .gitignore | 2 +- Cargo.toml | 2 +- src/lib.rs | 35 ++++++++++++++++++++++++++++++++--- src/wav.rs | 27 +++++++++++++++++++++++---- test/basic.py | 4 +++- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index e0325a4..a7a0633 100644 --- a/.gitignore +++ b/.gitignore @@ -16,5 +16,5 @@ Cargo.lock __pycache__ bria.mp3 -bria.wav +bria*.wav bria.opus diff --git a/Cargo.toml b/Cargo.toml index 1ad99e1..f4f0993 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sphn" -version = "0.1.2" +version = "0.1.3" edition = "2021" license = "MIT/Apache-2.0" description = "pyo3 wrappers to read/write audio files" diff --git a/src/lib.rs b/src/lib.rs index fc70250..2e6aa01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,14 +172,43 @@ fn read( #[pyo3(signature = (filename, data, sample_rate))] fn write_wav( filename: std::path::PathBuf, - data: numpy::PyReadonlyArray1, + data: numpy::PyReadonlyArrayDyn, sample_rate: u32, ) -> PyResult<()> { let w = std::fs::File::create(&filename).w_f(&filename)?; let mut w = std::io::BufWriter::new(w); let data = data.as_array(); - let data = to_cow(&data); - wav::write(&mut w, &data, sample_rate).w_f(&filename)?; + match data.ndim() { + 1 => { + let data = data.into_dimensionality::().w()?; + let data = to_cow(&data); + wav::write_mono(&mut w, &data, sample_rate).w_f(&filename)?; + } + 2 => { + let data = data.into_dimensionality::().w()?; + match data.shape() { + [1, l] => { + let data = data.into_shape((*l,)).w()?; + let data = to_cow(&data); + wav::write_mono(&mut w, &data, sample_rate).w_f(&filename)?; + } + [2, l] => { + let data = data.into_shape((2 * *l,)).w()?; + let data = to_cow(&data); + let (pcm1, pcm2) = (&data[..*l], &data[*l..]); + let data = pcm1 + .iter() + .zip(pcm2.iter()) + .flat_map(|(s1, s2)| [*s1, *s2]) + .collect::>(); + println!("{:?}", &data[..20]); + wav::write_stereo(&mut w, &data, sample_rate).w_f(&filename)? + } + _ => py_bail!("expected one or two channels, got shape {:?}", data.shape()), + } + } + _ => py_bail!("expected one or two dimensions, got shape {:?}", data.shape()), + } Ok(()) } diff --git a/src/wav.rs b/src/wav.rs index c1cb272..8a5cf79 100644 --- a/src/wav.rs +++ b/src/wav.rs @@ -22,15 +22,18 @@ impl Sample for i16 { } } -pub fn write( +/// The samples are copied as is in the resulting wav files so are assumed to be interleaved by +/// channel. +pub fn write_multi( w: &mut W, samples: &[S], + n_channels: u16, sample_rate: u32, ) -> std::io::Result<()> { + // https://en.wikipedia.org/wiki/WAV#WAV_file_header let len = 12u32; // header let len = len + 24u32; // fmt let len = len + samples.len() as u32 * 2 + 8; // data - let n_channels = 1u16; let bytes_per_second = sample_rate * 2 * n_channels as u32; w.write_all(b"RIFF")?; w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes @@ -40,10 +43,10 @@ pub fn write( w.write_all(b"fmt ")?; w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes w.write_all(&1u16.to_le_bytes())?; // PCM - w.write_all(&n_channels.to_le_bytes())?; // one channel + w.write_all(&n_channels.to_le_bytes())?; w.write_all(&sample_rate.to_le_bytes())?; w.write_all(&bytes_per_second.to_le_bytes())?; - w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample + w.write_all(&(n_channels * 2).to_le_bytes())?; // 2 bytes of data per sample and channel w.write_all(&16u16.to_le_bytes())?; // bits per sample // Data block @@ -54,3 +57,19 @@ pub fn write( } Ok(()) } + +pub fn write_mono( + w: &mut W, + samples: &[S], + sample_rate: u32, +) -> std::io::Result<()> { + write_multi(w, samples, 1, sample_rate) +} + +pub fn write_stereo( + w: &mut W, + samples: &[S], + sample_rate: u32, +) -> std::io::Result<()> { + write_multi(w, samples, 2, sample_rate) +} diff --git a/test/basic.py b/test/basic.py index 4d1ebc6..58d5a6c 100644 --- a/test/basic.py +++ b/test/basic.py @@ -1,3 +1,4 @@ +import numpy as np import sphn filename = "bria.mp3" @@ -13,7 +14,8 @@ data, sr = sphn.read(filename) print(data.shape, sr) -sphn.write_wav("bria.wav", data[0], sr) +sphn.write_wav("bria_mono.wav", data[0], sr) +sphn.write_wav("bria_stereo.wav", np.concatenate([data, data]), sr) sphn.write_opus("bria.opus", data, sr) data_roundtrip, sr_roundtrip = sphn.read_opus("bria.opus")