Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example demonstrating input-output aliasing with the FFI #25042

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions examples/ffi/src/jax_ffi_example/cpu_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());

// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.

ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());

// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
Expand All @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));

registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));

registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
return registrations;
});
}
6 changes: 6 additions & 0 deletions examples/ffi/src/jax_ffi_example/cpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
def counter(index):
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))


def aliasing(x):
return jax.ffi.ffi_call(
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
input_output_aliases={0: 0})(x)
13 changes: 12 additions & 1 deletion examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import absltest, parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -91,5 +91,16 @@ def counter_fun(x):
self.assertEqual(counter_fun(0)[1], 3)


class AliasingTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")

@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
def test_basic(self, x):
self.assertAllClose(cpu_examples.aliasing(x), x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading