-
I am trying to understand the behaviors of FFI in Jax 0.4.33, and I saw something I cannot explain. I would really appreciate an explanation, or a pointer to where am I doing wrong. Thank you! I successfully compiled a function in C++ and apparently could get it into Python via The C++ function that I compiled is supposed to print a line of text every time it's called. However, when I tried to call I would really appreciate some help explaining what is going on. Thank you again! My system details are:
My code is as follows:
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 9 replies
-
Perhaps try to return the success code from your ffi::Error print_text_impl() {
printf("---> Hello. You are finally in this function\n");
fflush(stdout);
return ffi::Error::Success();
} |
Beta Was this translation helpful? Give feedback.
-
Good question! There are a few issues here that I'm happy to explain. Let me start by saying that if you do want some sort of printing op, you should use The main point to note is that So, you can update your function to take an output, e.g.: ffi::Error PrintImpl(ffi::Result<ffi::BufferR0<ffi::S32>>) {
printf("---> Hello. You are finally in this function\n");
fflush(stdout);
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(Print, PrintImpl,
ffi::Ffi::Bind().Ret<ffi::BufferR0<ffi::S32>>()); and then call it as follows: def print_text():
return jex.ffi.ffi_call("print_text", jax.ShapeDtypeStruct((), np.int32)) Then, if you run But, you have to remember that we're still going to assume that your function call is pure, and printing breaks that contract. That means that if you then do: @jax.jit
def print_multiple():
print_text()
print_text()
print_multiple() you won't actually see any logging, because the output of I hope this clarifies the behavior you're seeing! |
Beta Was this translation helpful? Give feedback.
-
You may have to use io_callback C++ Code:#include <cstdio>
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/api/c_api.h"
namespace ffi = xla::ffi;
ffi::Error SayHi(ffi::Result<ffi::BufferR0<ffi::S32>> y) {
printf("u are finally there!\n");
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
Print, SayHi,
ffi::Ffi::Bind()
.Ret<ffi::BufferR0<ffi::S32>>() // Ensure closing parenthesis
); Python Code:import ctypes
from types import ModuleType
from pathlib import Path
import jax.extend as jex
from jax import numpy as jnp, vmap, jit
from jax import experimental as exper
def create_shape_dtype(dtype, shape=()):
module = ModuleType(f"ShapeDtype_{dtype.__name__}_{shape}")
module.dtype = dtype
module.shape = shape
return module
cDtype = create_shape_dtype(ctypes.c_int32, ())
ffi_build_dir = Path(__file__).parent.parent / "ffi" / "_build"
lib_path = next(ffi_build_dir.glob("sayhi.so"), None)
sayhi_lib = ctypes.cdll.LoadLibrary(lib_path)
jex.ffi.register_ffi_target(
"sayhi",
jex.ffi.pycapsule(sayhi_lib.Print),
platform="cpu"
)
call = jex.ffi.ffi_call(
"sayhi",
cDtype,
vmap_method="sequential"
)
def io_call(x):
call()
call()
call()
return x
@jit # <----------- JIT the test function
def test(v):
x = exper.io_callback(
io_call,
cDtype,
v,
ordered=False
)
if __name__ == "__main__":
test(jnp.int32(1)) Output:
|
Beta Was this translation helpful? Give feedback.
Good question! There are a few issues here that I'm happy to explain. Let me start by saying that if you do want some sort of printing op, you should use
jax.debug.print
rather than trying to write your own because it handles all the issues that I'll talk about here.The main point to note is that
ffi_call
assumes that the custom calls are pure, which means that we're allowed to re-order or choose to not call the function at all. With that in mind, I'm actually surprised that you even see the printing once! When I tested the same function, I don't see any output. The reason for this is that, since yourffi_call
doesn't return any outputs, there's no need for it to ever be called.So, you …