From 62656b32db928a12052be1fe57f451e791d75a55 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 13:08:50 -0500 Subject: [PATCH] Add an example demonstrating input-output aliasing with the FFI. --- .../ffi/src/jax_ffi_example/cpu_examples.cc | 30 +++++++++++++++++-- .../ffi/src/jax_ffi_example/cpu_examples.py | 6 ++++ examples/ffi/tests/cpu_examples_test.py | 13 +++++++- 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc index 3832c86b29b2..8d808ecd8e30 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Counter, CounterImpl, ffi::Ffi::Bind().Attr("index").Ret>()); +// -------- +// Aliasing +// -------- +// +// This example demonstrates how input-output aliasing works. The handler +// doesn't do anything except to check that the input and output pointers +// address the same data. + +ffi::Error AliasingImpl(ffi::AnyBuffer input, + ffi::Result output) { + if (input.element_type() != output->element_type() || + input.element_count() != output->element_count()) { + return ffi::Error::InvalidArgument( + "The input and output data types and sizes must match."); + } + if (input.untyped_data() != output->untyped_data()) { + return ffi::Error::InvalidArgument( + "When aliased, the input and output buffers should point to the same " + "data."); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Aliasing, AliasingImpl, + ffi::Ffi::Bind().Arg().Ret()); + // Boilerplate for exposing handlers to Python NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) { nb::capsule(reinterpret_cast(ArrayAttr)); registrations["dictionary_attr"] = nb::capsule(reinterpret_cast(DictionaryAttr)); - registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); - + registrations["aliasing"] = nb::capsule(reinterpret_cast(Aliasing)); return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py index 563e5a911b99..155e100dcd77 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs): def counter(index): return jax.ffi.ffi_call( "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) + + +def aliasing(x): + return jax.ffi.ffi_call( + "aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype), + input_output_aliases={0: 0})(x) diff --git a/examples/ffi/tests/cpu_examples_test.py b/examples/ffi/tests/cpu_examples_test.py index 0e2cfde02db6..8db524f6264b 100644 --- a/examples/ffi/tests/cpu_examples_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax import jax.numpy as jnp @@ -91,5 +91,16 @@ def counter_fun(x): self.assertEqual(counter_fun(0)[1], 3) +class AliasingTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + @parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),)) + def test_basic(self, x): + self.assertAllClose(cpu_examples.aliasing(x), x) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())