Skip to content

Commit

Permalink
Merge pull request #25042 from dfm:ffi-example-input-output-alias
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712979906
  • Loading branch information
Google-ML-Automation committed Jan 7, 2025
2 parents 64c0f62 + 62656b3 commit f1777d5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
30 changes: 28 additions & 2 deletions examples/ffi/src/jax_ffi_example/cpu_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());

// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.

ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());

// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
Expand All @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));

registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));

registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
return registrations;
});
}
6 changes: 6 additions & 0 deletions examples/ffi/src/jax_ffi_example/cpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
def counter(index):
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))


def aliasing(x):
return jax.ffi.ffi_call(
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
input_output_aliases={0: 0})(x)
13 changes: 12 additions & 1 deletion examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import absltest, parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -91,5 +91,16 @@ def counter_fun(x):
self.assertEqual(counter_fun(0)[1], 3)


class AliasingTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")

@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
def test_basic(self, x):
self.assertAllClose(cpu_examples.aliasing(x), x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit f1777d5

Please sign in to comment.