- 
                Notifications
    You must be signed in to change notification settings 
- Fork 128
BitGenerator support #499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
BitGenerator support #499
Changes from 25 commits
06d6ce1
              07e2416
              b611943
              d93a264
              f52b2fa
              05814d6
              37d360e
              eed5b19
              6c1a89b
              d1909d3
              bde2553
              a0b9ec5
              ee32246
              1be6838
              2aa3d90
              0258e6d
              876001b
              71ce8be
              2de7072
              016eb7a
              1f7f37f
              1d01c7a
              c90176a
              f49d3fa
              a16846d
              573d890
              06bb693
              663fa29
              c6105c9
              3a0aa92
              a92861a
              6dbb6dc
              b102d20
              e5e440e
              e73e3a2
              c6493df
              2327f36
              e5c6458
              e8cd5e8
              0868405
              8667203
              1fd7bb5
              3913171
              7bc0be8
              8caf054
              d8b62ac
              43e2d97
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| { | ||
| "[rust]": { | ||
| "editor.defaultFormatter": "rust-lang.rust-analyzer", | ||
| "editor.formatOnSave": true, | ||
| }, | ||
| "rust-analyzer.cargo.features": "all", | ||
| } | ||
|         
                  flying-sheep marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| use std::ffi::c_void; | ||
|  | ||
| #[repr(C)] | ||
| #[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? | ||
| pub struct npy_bitgen { | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| pub state: *mut c_void, | ||
| pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
| pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil | ||
| pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil | ||
| pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
|         
                  flying-sheep marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| //! Safe interface for NumPy's random [`BitGenerator`][bg]. | ||
| //! | ||
| //! Using the patterns described in [“Extending `numpy.random`”][ext], | ||
| //! you can generate random numbers without holding the GIL, | ||
| //! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]: | ||
| //! | ||
| //! ``` | ||
| //! use pyo3::prelude::*; | ||
| //! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; | ||
| //! | ||
| //! fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
| //! let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
| //! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; | ||
| //! Ok(bit_generator) | ||
| //! } | ||
| //! | ||
| //! let random_number = Python::with_gil(|py| -> PyResult<_> { | ||
| //! let mut bitgen = default_bit_gen(py)?.lock()?; | ||
| //! // use bitgen without holding the GIL | ||
| //! Ok(py.allow_threads(|| bitgen.next_uint64())) | ||
| //! })?; | ||
| //! # Ok::<(), PyErr>(()) | ||
| //! ``` | ||
| //! | ||
| //! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: | ||
| //! | ||
| //! ``` | ||
| //! # use pyo3::prelude::*; | ||
| //! use rand::Rng as _; | ||
| //! # use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; | ||
| //! # // TODO: reuse function definition from above? | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like there should be a convenient way to get this. I'm thinking about something like impl PyBitGenerator {
     fn new(py: Python<'_>) -> PyResult<Bound<..>>;
}There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are many implementations, we’d have to cover all of them. I’d rather leave this minimal until this PR is mostly done. | ||
| //! # fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
| //! # let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
| //! # let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; | ||
| //! # Ok(bit_generator) | ||
| //! # } | ||
| //! | ||
| //! Python::with_gil(|py| -> PyResult<_> { | ||
| //! let mut bitgen = default_bit_gen(py)?.lock()?; | ||
| //! if bitgen.random_ratio(1, 1_000_000) { | ||
| //! println!("a sure thing"); | ||
| //! } | ||
| //! Ok(()) | ||
| //! })?; | ||
| //! # Ok::<(), PyErr>(()) | ||
| //! ``` | ||
| //! | ||
| //! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
| //! [ext]: https://numpy.org/doc/stable/reference/random/extending.html | ||
| use std::ptr::NonNull; | ||
|  | ||
| use pyo3::{ | ||
| exceptions::PyRuntimeError, | ||
| ffi, | ||
| prelude::*, | ||
| sync::GILOnceCell, | ||
| types::{DerefToPyAny, PyCapsule, PyType}, | ||
| PyTypeInfo, | ||
| }; | ||
|  | ||
| use crate::npyffi::npy_bitgen; | ||
|  | ||
| /// Wrapper for [`np.random.BitGenerator`][bg]. | ||
| /// | ||
| /// See also [`PyBitGeneratorMethods`]. | ||
| /// | ||
| /// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
| #[repr(transparent)] | ||
| pub struct PyBitGenerator(PyAny); | ||
|  | ||
| impl DerefToPyAny for PyBitGenerator {} | ||
|  | ||
| unsafe impl PyTypeInfo for PyBitGenerator { | ||
| const NAME: &'static str = "PyBitGenerator"; | ||
| const MODULE: Option<&'static str> = Some("numpy.random"); | ||
|  | ||
| fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { | ||
| const CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new(); | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| let cls = CLS | ||
| .get_or_try_init::<_, PyErr>(py, || { | ||
| Ok(py | ||
| .import("numpy.random")? | ||
| .getattr("BitGenerator")? | ||
| .downcast_into::<PyType>()? | ||
| .unbind()) | ||
| }) | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| .expect("Failed to get BitGenerator type object") | ||
| .clone_ref(py) | ||
| .into_bound(py); | ||
| cls.as_type_ptr() | ||
| } | ||
| } | ||
|  | ||
| /// Methods for [`PyBitGenerator`]. | ||
| pub trait PyBitGeneratorMethods<'py> { | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| /// Acquire a lock on the BitGenerator to allow calling its methods in. | ||
| fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>>; | ||
| } | ||
|  | ||
| impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { | ||
| fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>> { | ||
| let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?; | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| let lock = self.getattr("lock")?; | ||
| if lock.call_method0("locked")?.extract()? { | ||
| return Err(PyRuntimeError::new_err("BitGenerator is already locked")); | ||
| } | ||
| lock.call_method0("acquire")?; | ||
|  | ||
| assert_eq!(capsule.name()?, Some(c"BitGenerator")); | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| let ptr = capsule.pointer() as *mut npy_bitgen; | ||
| let non_null = match NonNull::new(ptr) { | ||
| Some(non_null) => non_null, | ||
| None => { | ||
| lock.call_method0("release")?; | ||
| return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); | ||
| } | ||
| }; | ||
| Ok(PyBitGeneratorGuard { | ||
| raw_bitgen: non_null, | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| _capsule: capsule.unbind(), | ||
| lock: lock.unbind(), | ||
| py: self.py(), | ||
| }) | ||
| } | ||
| } | ||
|  | ||
| impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { | ||
| type Error = PyErr; | ||
| fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> { | ||
| value.lock() | ||
| } | ||
| } | ||
|  | ||
| /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. | ||
|         
                  flying-sheep marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| pub struct PyBitGeneratorGuard<'py> { | ||
| raw_bitgen: NonNull<npy_bitgen>, | ||
| /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. | ||
| _capsule: Py<PyCapsule>, | ||
| /// This lock makes sure no other threads try to use the BitGenerator while we do. | ||
| lock: Py<PyAny>, | ||
| /// This should be an unsafe field (https://github.com/rust-lang/rust/issues/132922) | ||
| /// | ||
| /// SAFETY: only use this in `Drop::drop` (when we are sure the GIL is held). | ||
| py: Python<'py>, | ||
| } | ||
|  | ||
| // SAFETY: we can’t have public APIs that access the Python objects, | ||
| // only the `raw_bitgen` pointer. | ||
| unsafe impl Send for PyBitGeneratorGuard<'_> {} | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| impl Drop for PyBitGeneratorGuard<'_> { | ||
| fn drop(&mut self) { | ||
| // ignore errors. This includes when `try_drop` was called manually | ||
| let _ = self.lock.bind(self.py).call_method0("release"); | ||
| } | ||
| } | ||
|  | ||
| // SAFETY: We hold the `BitGenerator.lock`, | ||
| // so nothing apart from us is allowed to change its state. | ||
| impl<'py> PyBitGeneratorGuard<'py> { | ||
| /// Drop the lock manually before `Drop::drop` tries to do it (used for testing). | ||
| #[allow(dead_code)] | ||
| fn try_drop(self, py: Python<'py>) -> PyResult<()> { | ||
| self.lock.bind(py).call_method0("release")?; | ||
| Ok(()) | ||
| } | ||
|  | ||
| /// Returns the next random unsigned 64 bit integer. | ||
| pub fn next_uint64(&mut self) -> u64 { | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| unsafe { | ||
| let bitgen = *self.raw_bitgen.as_ptr(); | ||
| (bitgen.next_uint64)(bitgen.state) | ||
| } | ||
| } | ||
| /// Returns the next random unsigned 32 bit integer. | ||
| pub fn next_uint32(&mut self) -> u32 { | ||
| unsafe { | ||
| let bitgen = *self.raw_bitgen.as_ptr(); | ||
| (bitgen.next_uint32)(bitgen.state) | ||
| } | ||
| } | ||
| /// Returns the next random double. | ||
| pub fn next_double(&mut self) -> libc::c_double { | ||
|         
                  flying-sheep marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| unsafe { | ||
| let bitgen = *self.raw_bitgen.as_ptr(); | ||
| (bitgen.next_double)(bitgen.state) | ||
| } | ||
| } | ||
| /// Returns the next raw value (can be used for testing). | ||
| pub fn next_raw(&mut self) -> u64 { | ||
| unsafe { | ||
| let bitgen = *self.raw_bitgen.as_ptr(); | ||
| (bitgen.next_raw)(bitgen.state) | ||
| } | ||
| } | ||
| } | ||
|  | ||
| #[cfg(feature = "rand")] | ||
| impl rand::RngCore for PyBitGeneratorGuard<'_> { | ||
| fn next_u32(&mut self) -> u32 { | ||
| self.next_uint32() | ||
| } | ||
| fn next_u64(&mut self) -> u64 { | ||
| self.next_uint64() | ||
| } | ||
| fn fill_bytes(&mut self, dst: &mut [u8]) { | ||
| rand::rand_core::impls::fill_bytes_via_next(self, dst) | ||
| } | ||
| } | ||
|  | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|  | ||
| fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
| let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
| let bit_generator = default_rng | ||
| .getattr("bit_generator")? | ||
| .downcast_into::<PyBitGenerator>()?; | ||
| Ok(bit_generator) | ||
| } | ||
|  | ||
| /// Test the primary use case: acquire the lock, release the GIL, then use the lock | ||
| #[test] | ||
| fn use_outside_gil() -> PyResult<()> { | ||
| Python::with_gil(|py| { | ||
| let mut bitgen = get_bit_generator(py)?.lock()?; | ||
| py.allow_threads(|| { | ||
| let _ = bitgen.next_raw(); | ||
| }); | ||
| assert!(bitgen.try_drop(py).is_ok()); | ||
| Ok(()) | ||
| }) | ||
| } | ||
|  | ||
| /// Test that the `rand::Rng` APIs work | ||
| #[cfg(feature = "rand")] | ||
| #[test] | ||
| fn rand() -> PyResult<()> { | ||
| use rand::Rng as _; | ||
|  | ||
| Python::with_gil(|py| { | ||
| let mut bitgen = get_bit_generator(py)?.lock()?; | ||
| py.allow_threads(|| { | ||
| assert!(bitgen.random_ratio(1, 1)); | ||
| assert!(!bitgen.random_ratio(0, 1)); | ||
| }); | ||
| assert!(bitgen.try_drop(py).is_ok()); | ||
| Ok(()) | ||
| }) | ||
| } | ||
|  | ||
| #[test] | ||
| fn double_lock_fails() -> PyResult<()> { | ||
| Python::with_gil(|py| { | ||
| let generator = get_bit_generator(py)?; | ||
| let bitgen = generator.lock()?; | ||
| assert!(generator.lock().is_err()); | ||
| assert!(bitgen.try_drop(py).is_ok()); | ||
| Ok(()) | ||
| }) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do when I’m done. I like working on multiple machines, and I don’t like re-doing settings for individual projects