Skip to content

Commit

Permalink
Merge pull request #8 from kyutai-labs/pyo3-0.23
Browse files Browse the repository at this point in the history
Update to pyo3 0.23.
  • Loading branch information
LaurentMazare authored Dec 20, 2024
2 parents 40a02cb + 9706666 commit e7d309d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sphn"
version = "0.1.4"
version = "0.1.5"
edition = "2021"
license = "MIT/Apache-2.0"
description = "pyo3 wrappers to read/write audio files"
Expand All @@ -15,10 +15,10 @@ crate-type = ["cdylib"]
[dependencies]
anyhow = "1.0.79"
byteorder = "1.5.0"
numpy = "0.21.0"
numpy = "0.23.0"
ogg = "0.9.1"
opus = "0.3.0"
pyo3 = "0.21.0"
pyo3 = "0.23.0"
rayon = "1.8.1"
rubato = "0.15.0"
serde = { version = "1.0", features = ["derive"] }
Expand Down
58 changes: 30 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod opus;
mod wav;

use pyo3::prelude::*;
use std::sync::Mutex;

trait PyRes<R> {
#[allow(unused)]
Expand Down Expand Up @@ -80,30 +81,30 @@ impl FileReader {
fn decode(&mut self, start_sec: f64, duration_sec: f64, py: Python) -> PyResult<PyObject> {
let (data, _unpadded_len) =
self.inner.decode(start_sec, duration_sec, false).w_f(&self.path)?;
Ok(numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py))
Ok(numpy::PyArray2::from_vec2(py, &data)?.into_any().unbind())
}

/// Decodes the audio data from `start_sec` to `start_sec + duration_sec` and return the PCM
/// data as a two dimensional numpy array. The first dimension is the channel, the second one
/// is time.
/// If the end of the file is reached, the array is padded with zeros so that its length is
/// still matching `duration_sec`.
fn decode_with_padding(
fn decode_with_padding<'a>(
&mut self,
start_sec: f64,
duration_sec: f64,
py: Python,
) -> PyResult<(PyObject, usize)> {
py: Python<'a>,
) -> PyResult<(Bound<'a, PyAny>, usize)> {
let (data, unpadded_len) =
self.inner.decode(start_sec, duration_sec, true).w_f(&self.path)?;
let data = numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py);
let data = numpy::PyArray2::from_vec2(py, &data)?.into_any();
Ok((data, unpadded_len))
}

/// Decodes the audio data for the whole file and return it as a two dimensional numpy array.
fn decode_all(&mut self, py: Python) -> PyResult<PyObject> {
fn decode_all<'a>(&mut self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
let data = self.inner.decode_all().w_f(&self.path)?;
Ok(numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py))
Ok(numpy::PyArray2::from_vec2(py, &data)?.into_any())
}
}

Expand Down Expand Up @@ -158,7 +159,7 @@ fn read(
}
};
let data = Python::with_gil(|py| {
Ok::<_, PyErr>(numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py))
Ok::<_, PyErr>(numpy::PyArray2::from_vec2(py, &data)?.into_any().unbind())
})
.w_f(&filename)?;
Ok((data, sample_rate))
Expand Down Expand Up @@ -188,12 +189,12 @@ fn write_wav(
let data = data.into_dimensionality::<numpy::Ix2>().w()?;
match data.shape() {
[1, l] => {
let data = data.into_shape((*l,)).w()?;
let data = data.into_shape_with_order((*l,)).w()?;
let data = to_cow(&data);
wav::write_mono(&mut w, &data, sample_rate).w_f(&filename)?;
}
[2, l] => {
let data = data.into_shape((2 * *l,)).w()?;
let data = data.into_shape_with_order((2 * *l,)).w()?;
let data = to_cow(&data);
let (pcm1, pcm2) = (&data[..*l], &data[*l..]);
let data = pcm1
Expand Down Expand Up @@ -241,11 +242,11 @@ fn write_opus(
let data = data.into_dimensionality::<numpy::Ix2>().w()?;
match data.shape() {
[1, l] => {
let data = data.into_shape((*l,)).w()?;
let data = data.into_shape_with_order((*l,)).w()?;
write_mono(w, data)?
}
[2, l] => {
let data = data.into_shape((*l * 2,)).w()?;
let data = data.into_shape_with_order((*l * 2,)).w()?;
let data = to_cow(&data);
let (pcm1, pcm2) = (&data[..*l], &data[*l..]);
opus::write_ogg_stereo(&mut w, pcm1, pcm2, sample_rate).w_f(&filename)?
Expand Down Expand Up @@ -285,20 +286,20 @@ fn resample(
let pcm = to_cow(&pcm);
let pcm = audio::resample(&pcm[..], src_sample_rate, dst_sample_rate).w()?;
Python::with_gil(|py| {
Ok::<_, PyErr>(numpy::PyArray1::from_vec_bound(py, pcm).into_py(py))
Ok::<_, PyErr>(numpy::PyArray1::from_vec(py, pcm).into_any().unbind())
})
}
2 => {
let pcm = pcm.into_dimensionality::<numpy::Ix2>().w()?;
let (channels, l) = pcm.dim();
let pcm = pcm.into_shape((channels * l,)).w()?;
let pcm = pcm.into_shape_with_order((channels * l,)).w()?;
let pcm = to_cow(&pcm)
.chunks(l)
.map(|pcm| audio::resample(pcm, src_sample_rate, dst_sample_rate))
.collect::<anyhow::Result<Vec<_>>>()
.w()?;
Python::with_gil(|py| {
Ok::<_, PyErr>(numpy::PyArray2::from_vec2_bound(py, &pcm)?.into_py(py))
Ok::<_, PyErr>(numpy::PyArray2::from_vec2(py, &pcm)?.into_any().unbind())
})
}
_ => py_bail!("expected one or two dimensions, got shape {:?}", pcm.shape()),
Expand All @@ -315,7 +316,7 @@ fn read_opus(filename: std::path::PathBuf, py: Python) -> PyResult<(PyObject, u3
let file = std::fs::File::open(&filename)?;
let file = std::io::BufReader::new(file);
let (data, sample_rate) = opus::read_ogg(file).w_f(&filename)?;
let data = numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py);
let data = numpy::PyArray2::from_vec2(py, &data)?.into_any().unbind();
Ok((data, sample_rate))
}

Expand All @@ -328,13 +329,13 @@ fn read_opus(filename: std::path::PathBuf, py: Python) -> PyResult<(PyObject, u3
fn read_opus_bytes(bytes: Vec<u8>, py: Python) -> PyResult<(PyObject, u32)> {
let bytes = std::io::Cursor::new(bytes);
let (data, sample_rate) = opus::read_ogg(bytes).w()?;
let data = numpy::PyArray2::from_vec2_bound(py, &data)?.into_py(py);
let data = numpy::PyArray2::from_vec2(py, &data)?.into_any().unbind();
Ok((data, sample_rate))
}

#[pyclass]
struct OpusStreamWriter {
inner: opus::StreamWriter,
inner: Mutex<opus::StreamWriter>,
sample_rate: u32,
}

Expand All @@ -343,7 +344,7 @@ impl OpusStreamWriter {
#[new]
fn new(sample_rate: u32) -> PyResult<Self> {
let inner = opus::StreamWriter::new(sample_rate).w()?;
Ok(Self { inner, sample_rate })
Ok(Self { inner: Mutex::new(inner), sample_rate })
}

fn __str__(&self) -> String {
Expand All @@ -355,22 +356,23 @@ impl OpusStreamWriter {
fn append_pcm(&mut self, pcm: numpy::PyReadonlyArray1<f32>) -> PyResult<()> {
let pcm = pcm.as_array();
let pcm = to_cow(&pcm);
self.inner.append_pcm(&pcm).w()?;
self.inner.lock().unwrap().append_pcm(&pcm).w()?;
Ok(())
}

/// Gets the pending opus bytes from the stream. An empty bytes object is returned if no data
/// is currently available.
fn read_bytes(&mut self) -> PyResult<PyObject> {
let bytes = self.inner.read_bytes().w()?;
let bytes = Python::with_gil(|py| pyo3::types::PyBytes::new_bound(py, &bytes).into_py(py));
let bytes = self.inner.lock().unwrap().read_bytes().w()?;
let bytes =
Python::with_gil(|py| pyo3::types::PyBytes::new(py, &bytes).into_any().unbind());
Ok(bytes)
}
}

#[pyclass]
struct OpusStreamReader {
inner: opus::StreamReader,
inner: Mutex<opus::StreamReader>,
sample_rate: u32,
}

Expand All @@ -379,7 +381,7 @@ impl OpusStreamReader {
#[new]
fn new(sample_rate: u32) -> PyResult<Self> {
let inner = opus::StreamReader::new(sample_rate).w()?;
Ok(Self { inner, sample_rate })
Ok(Self { inner: Mutex::new(inner), sample_rate })
}

fn __str__(&self) -> String {
Expand All @@ -388,18 +390,18 @@ impl OpusStreamReader {

/// Writes some ogg/opus bytes to the current stream.
fn append_bytes(&mut self, data: &[u8]) -> PyResult<()> {
self.inner.append(data.to_vec()).w()
self.inner.lock().unwrap().append(data.to_vec()).w()
}

// TODO(laurent): maybe we should also have a pyo3_async api here.
/// Gets the pcm data decoded by the stream, this returns a 1d numpy array or None if the
/// stream has been closed. The array is empty if no data is currently available.
fn read_pcm(&mut self) -> PyResult<PyObject> {
let pcm_data = self.inner.read_pcm().w()?;
let pcm_data = self.inner.lock().unwrap().read_pcm().w()?;
Python::with_gil(|py| match pcm_data {
None => Ok(py.None()),
Some(data) => {
let data = numpy::PyArray1::from_vec_bound(py, data.to_vec()).into_py(py);
let data = numpy::PyArray1::from_vec(py, data.to_vec()).into_any().unbind();
Ok(data)
}
})
Expand All @@ -408,7 +410,7 @@ impl OpusStreamReader {
/// Closes the stream, this results in the worker thread exiting and the follow up
/// calls to `read_pcm` will return None once all the pcm data has been returned.
fn close(&mut self) {
self.inner.close()
self.inner.lock().unwrap().close()
}
}

Expand Down

0 comments on commit e7d309d

Please sign in to comment.