From a85e6ae3b7da168d66c7cfa93889ac623c4b78c4 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 2 Aug 2024 12:34:32 -0400 Subject: [PATCH 1/3] Add NIF for loading custom plugins --- exla/c_src/exla/exla.cc | 82 +++++++++++++++++++++++- exla/lib/exla/nif.ex | 2 + exla/lib/exla/plugin.ex | 23 +++++++ exla/test/exla/plugin_test.exs | 21 ++++++ exla/test/support/c/custom_plugin.c | 27 ++++++++ exla/test/support/c/libcustom_plugin.so | Bin 0 -> 33440 bytes 6 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 exla/lib/exla/plugin.ex create mode 100644 exla/test/exla/plugin_test.exs create mode 100644 exla/test/support/c/custom_plugin.c create mode 100755 exla/test/support/c/libcustom_plugin.so diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 04202405d6c..924f9482101 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,4 +1,5 @@ #include +#include #include "exla_client.h" #include "exla_cuda.h" @@ -11,11 +12,36 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" #include "xla/service/platform_util.h" +#include "xla/service/custom_call_target_registry.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out // with calls to delete rather than just using the default destructor. +// We need to hold a reference to the `dlopen` handle for as long +// as EXLA is running, so we have this resource which holds the handle, +// then we define a custom free which calls `dlclose`. Then it's up to +// the caller to keep this resource in scope so it's not garbage collected +typedef struct { + void * handle; +} ExlaPlugin; + +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); + +typedef struct { + const char* name; + ExlaCustomCallFunction func; +} ExlaPluginCustomCall; + +static ErlNifResourceType* exla_plugin_resource_type; + +void free_exla_plugin(ErlNifEnv* env, void* obj) { + ExlaPlugin* plugin = reinterpret_cast(obj); + if (plugin != nullptr) { + dlclose(plugin->handle); + } +} + void free_exla_executable(ErlNifEnv* env, void* obj) { exla::ExlaExecutable** executable = reinterpret_cast(obj); if (*executable != nullptr) { @@ -65,10 +91,17 @@ static int open_resources(ErlNifEnv* env) { if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { return -1; } - if (!exla::nif::open_resource(env, mod, "MLIRContext")) { return -1; } + + // Just a C Resource + ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); + exla_plugin_resource_type = enif_open_resource_type(env, mod, "ExlaPlugin", free_exla_plugin, flags, NULL); + if (!exla_plugin_resource_type) { + return -1; + } + return 1; } @@ -911,6 +944,48 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) return exla::nif::ok(env); } +// Plugins + +ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + std::string library_path; + + if (!exla::nif::get(env, argv[0], library_path)) { + return exla::nif::error(env, "Unable to get library path."); + } + + void* handle = dlopen(library_path.c_str(), RTLD_NOW); + if (!handle) { + return exla::nif::error(env, "Unable to open library."); + } + + const ExlaPluginCustomCall* custom_calls = (ExlaPluginCustomCall*) dlsym(handle, "exla_custom_calls"); + + if(!custom_calls) { + dlclose(handle); + return exla::nif::error(env, "Unable to find exla_custom_calls"); + } + + int i = 0; + ExlaPluginCustomCall func = custom_calls[i]; + while (func.name != NULL) { + // TODO: GPU flags + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(func.name, func.func); + func = custom_calls[++i]; + } + + ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin)); + plugin->handle = handle; + + ERL_NIF_TERM result = enif_make_resource(env, plugin); + enif_release_resource(plugin); + + return exla::nif::ok(env, result); +} + static ErlNifFunc exla_funcs[] = { // MLIR Builder {"mlir_new_context", 0, mlir_new_context}, @@ -947,6 +1022,9 @@ static ErlNifFunc exla_funcs[] = { {"start_log_sink", 1, start_log_sink}, // Serialization {"serialize_executable", 1, serialize_executable}, - {"deserialize_executable", 2, deserialize_executable}}; + {"deserialize_executable", 2, deserialize_executable}, + // Plugins + {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library} + }; ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL); diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 6830df726c9..907398d699a 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -112,4 +112,6 @@ defmodule EXLA.NIF do def get_c_api_client(_device_type), do: :erlang.nif_error(:undef) def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef) + + def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/plugin.ex b/exla/lib/exla/plugin.ex new file mode 100644 index 00000000000..5df3a42d4c1 --- /dev/null +++ b/exla/lib/exla/plugin.ex @@ -0,0 +1,23 @@ +defmodule EXLA.Plugin do + @moduledoc """ + Plugin system for registering custom calls. + """ + + def register(library_path) do + unless File.exists?(library_path) do + raise ArgumentError, "#{library_path} does not exist" + end + + ref = + library_path + |> EXLA.NIF.load_custom_call_plugin_library() + |> unwrap!() + + # we need to keep the ref from getting garbage collected so + # we can use the symbols within it at anytime + :persistent_term.put({__MODULE__, library_path}, ref) + end + + defp unwrap!({:ok, ref}), do: ref + defp unwrap!({:error, reason}), do: raise "#{reason}" +end \ No newline at end of file diff --git a/exla/test/exla/plugin_test.exs b/exla/test/exla/plugin_test.exs new file mode 100644 index 00000000000..e7f7d0d7f35 --- /dev/null +++ b/exla/test/exla/plugin_test.exs @@ -0,0 +1,21 @@ +defmodule EXLA.PluginTest do + use ExUnit.Case + + describe "register/1" do + test "raises if file does not exist" do + assert_raise ArgumentError, ~r/does not exist/, fn -> + EXLA.Plugin.register("test/support/c/doesnotexist.so") + end + end + + test "does not crash on invalid files" do + assert_raise RuntimeError, ~r/Unable to open/, fn -> + EXLA.Plugin.register(__ENV__.file) + end + end + + test "registers a plugin" do + assert :ok = EXLA.Plugin.register("test/support/c/libcustom_plugin.so") + end + end +end \ No newline at end of file diff --git a/exla/test/support/c/custom_plugin.c b/exla/test/support/c/custom_plugin.c new file mode 100644 index 00000000000..085b64f6be3 --- /dev/null +++ b/exla/test/support/c/custom_plugin.c @@ -0,0 +1,27 @@ +#include +#include + +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); + +typedef struct { + const char* name; + ExlaCustomCallFunction func; +} ExlaPluginCustomCall; + +void custom_increment(void *out[], const void *in[]) { + int64_t *operand = (int64_t *)in[0]; + int64_t *dim_sizes = (int64_t *)in[1]; + + int64_t *out_buffer = (int64_t *)out[0]; + + int64_t n = dim_sizes[0]; + + for (int64_t i = 0; i < n; i++) { + out_buffer[i] = operand[i] + 1; + } +} + +extern "C" ExlaPluginCustomCall exla_custom_calls[] = { + {"custom_increment", custom_increment}, + {NULL, NULL} +}; \ No newline at end of file diff --git a/exla/test/support/c/libcustom_plugin.so b/exla/test/support/c/libcustom_plugin.so new file mode 100755 index 0000000000000000000000000000000000000000..cfc8eb7b93b039e36f53325916d0da668e128752 GIT binary patch literal 33440 zcmeI*eQ29S90%}U-dsy-TUjw`R~n|s*3Iliv1Nijl?_`qX0WA386GjoN@$vlCaLY5 znxYfMp;n6-iGUZr|Ag?l)tOjbpMiGAEG znWxf8scd{WYXUAVnD3Ty-?%-5W4B$G(o1<~s8lqQ9ZaP9Ok?GI<7U3gOurr5b1ayf z8)ngiEDj zk!-}=qvOAyFJKR`{xY|loTaR*le1D=y0&iKw6V+nO!H=xWl9rIn3t6G*lUszmvUaq z`3la?_GIGTXfBgY52ziehGz>u-#j(h?_K%%S7T%Axh=rC{4RC;<*n0o{r#-7(J@Xur6 z`0kNRHa-x%)!f1(iLT`=&)t_Ns^W5`iN#4$GiSZ7QMPk_5V{u^2tWV=5P$##AOHaf zKmY;|fB*y_009U<00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_ z009U<00Izz00bZa0SNrR1WF#)%-I$ye%(aHLX%aja#_WK-_(6vU&{3s(_YW@rKVm( z#nKWg@|N4i!V;@ETO;k&=L>HfzUn;}wdr<0w{u_BNMf$aM`N?~R`FYoSB$jI+4o6% zf#>V43>7v17M$@V^XsXLSGW3J^{ z8uUyYSb7_>8P7Y2&Hi@hJ*Yzf0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U< z00Izz00bZa0SG_<0uX=z1Rwwb2tWV=5P$##AOHafKmY;|fB*y_009U<00Izz00bcL z-w6cxQE~>mWJ&VB2JEEIKj*+c`&`a1CyS3dx1>@|&hg|@)F{m!%0Eb6%ICS)t?piT zW95G}2!)11q0K{79Uo3cl0b{;dVid3f8??ESMLdbGuFNJ@Rz4Pe%iXOe(#;l`BUd+Z;N$Kv_A90 v2N#?F-}dkS<;eWWq00|G*KxeR;n@A>-h92cZFKLW54?7I+1yk2WY7EoQ9^b? literal 0 HcmV?d00001 From 5c438dc42b203a69b0645dacb0585b9bf4c05990 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 2 Aug 2024 16:51:17 -0400 Subject: [PATCH 2/3] Add MLIR interface --- exla/lib/exla/mlir/value.ex | 21 +++++++++++++++++++++ exla/lib/exla/plugin.ex | 24 +++++++++++++++--------- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b53795055d6..dd54727b3d3 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -815,6 +815,27 @@ defmodule EXLA.MLIR.Value do {q, r} end + def plugin_custom_call(registered_name, [%Value{function: func} | _] = args, result_typespec) do + operand_shapes = + Enum.map(args, fn %Value{function: ^func} = value -> + %{shape: op_shape} = get_typespec(value) + constant(func, Tuple.to_list(op_shape), Typespec.tensor({:s, 64}, {length(op_shape)})) + end) + + operands = + args + |> Enum.zip_with(operand_shapes, fn val, shape -> [val, shape] end) + |> List.flatten() + + # TODO: GPU + attributes = [ + call_target_name: attr_string(registered_name), + backend_config: attr_string("Host") + ] + + op(func, "stablehlo.custom_call", operands, result_typespec, attributes: attributes) + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] diff --git a/exla/lib/exla/plugin.ex b/exla/lib/exla/plugin.ex index 5df3a42d4c1..0a16f1d662a 100644 --- a/exla/lib/exla/plugin.ex +++ b/exla/lib/exla/plugin.ex @@ -8,16 +8,22 @@ defmodule EXLA.Plugin do raise ArgumentError, "#{library_path} does not exist" end - ref = - library_path - |> EXLA.NIF.load_custom_call_plugin_library() - |> unwrap!() + case :persistent_term.get({__MODULE__, library_path}, nil) do + nil -> + ref = + library_path + |> EXLA.NIF.load_custom_call_plugin_library() + |> unwrap!() - # we need to keep the ref from getting garbage collected so - # we can use the symbols within it at anytime - :persistent_term.put({__MODULE__, library_path}, ref) + # we need to keep the ref from getting garbage collected so + # we can use the symbols within it at anytime + :persistent_term.put({__MODULE__, library_path}, ref) + + _ref -> + :ok + end end defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, reason}), do: raise "#{reason}" -end \ No newline at end of file + defp unwrap!({:error, reason}), do: raise("#{reason}") +end From 66618626db280badc4161776fc4505616b4dea65 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 3 Aug 2024 21:16:11 -0400 Subject: [PATCH 3/3] Capture dimensions --- exla/c_src/exla/exla.cc | 67 ++++++++++++++++++------ exla/c_src/exla/exla_nif_util.cc | 19 +++++++ exla/c_src/exla/exla_nif_util.h | 2 + exla/lib/exla/application.ex | 1 + exla/lib/exla/nif.ex | 2 + exla/lib/exla/plugin.ex | 49 +++++++++++++---- exla/test/exla/plugin_test.exs | 16 +----- exla/test/support/c/custom_plugin.c | 13 ++--- exla/test/support/c/libcustom_plugin.so | Bin 33440 -> 16816 bytes 9 files changed, 118 insertions(+), 51 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 924f9482101..eb48c6ce038 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -26,7 +26,7 @@ typedef struct { void * handle; } ExlaPlugin; -typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); typedef struct { const char* name; @@ -962,21 +962,6 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL return exla::nif::error(env, "Unable to open library."); } - const ExlaPluginCustomCall* custom_calls = (ExlaPluginCustomCall*) dlsym(handle, "exla_custom_calls"); - - if(!custom_calls) { - dlclose(handle); - return exla::nif::error(env, "Unable to find exla_custom_calls"); - } - - int i = 0; - ExlaPluginCustomCall func = custom_calls[i]; - while (func.name != NULL) { - // TODO: GPU flags - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(func.name, func.func); - func = custom_calls[++i]; - } - ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin)); plugin->handle = handle; @@ -986,6 +971,53 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL return exla::nif::ok(env, result); } +ERL_NIF_TERM register_custom_call_symbol(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 3) { + return exla::nif::error(env, "Bad argument count."); + } + + ExlaPlugin* plugin; + std::string symbol; + std::vector> dimensions; + + if (!enif_get_resource(env, argv[0], exla_plugin_resource_type, (void **) &plugin)) { + return exla::nif::error(env, "Unable to get plugin."); + } + if (!exla::nif::get(env, argv[1], symbol)) { + return exla::nif::error(env, "Unable to get symbol."); + } + if (!exla::nif::get_list(env, argv[2], dimensions)) { + return exla::nif::error(env, "Unable to get dimensions."); + } + + ExlaCustomCallFunction function = (ExlaCustomCallFunction) dlsym(plugin->handle, symbol.c_str()); + + if (!function) { + return exla::nif::error(env, "Could not find symbol."); + } + + auto lambda = [&dimensions, function](void *in[], const void *out[]) { + std::vector> int_dims(dimensions.size()); + for (size_t i = 0; i < dimensions.size(); ++i) { + int_dims[i].resize(dimensions[i].size()); + std::transform(dimensions[i].begin(), dimensions[i].end(), int_dims[i].begin(), + [](exla::int64 x) { return static_cast(x); }); + } + + std::vector dims_ptrs; + for (auto& d : int_dims) { + dims_ptrs.push_back(d.data()); + } + + function(in, out, dims_ptrs.data()); + }; + + // TODO: GPU/Client flag + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol.c_str(), function); + + return exla::nif::ok(env); +} + static ErlNifFunc exla_funcs[] = { // MLIR Builder {"mlir_new_context", 0, mlir_new_context}, @@ -1024,7 +1056,8 @@ static ErlNifFunc exla_funcs[] = { {"serialize_executable", 1, serialize_executable}, {"deserialize_executable", 2, deserialize_executable}, // Plugins - {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library} + {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library}, + {"register_custom_call_symbol", 3, register_custom_call_symbol} }; ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL); diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index d38785f6ed9..d802f2a55d1 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -248,6 +248,25 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { return 1; } +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) { + return 0; + } + var.reserve(length); + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + std::vector elem; + if (!get_list(env, head, elem)) { + return 0; + } + var.push_back(elem); + list = tail; + } + return 1; +} + int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) { return enif_inspect_binary(env, term, var); } diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 5abf7e3cdaf..82445111741 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -247,6 +247,8 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var); + template int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { unsigned int length; diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 3bdfa30d0c3..03dc4b7e4cc 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -18,6 +18,7 @@ defmodule EXLA.Application do name: EXLA.MLIR.ContextPool, lazy: true}, EXLA.Client, + EXLA.Plugin, EXLA.Defn.Lock, EXLA.Defn.LockedCache, {Task.Supervisor, name: EXLA.Defn.TaskSupervisor} diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 907398d699a..dd90ced0165 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -114,4 +114,6 @@ defmodule EXLA.NIF do def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef) def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef) + + def register_custom_call_symbol(_plugin, _symbol, _dimensions), do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/plugin.ex b/exla/lib/exla/plugin.ex index 0a16f1d662a..b7f682867ee 100644 --- a/exla/lib/exla/plugin.ex +++ b/exla/lib/exla/plugin.ex @@ -2,28 +2,55 @@ defmodule EXLA.Plugin do @moduledoc """ Plugin system for registering custom calls. """ + use GenServer - def register(library_path) do - unless File.exists?(library_path) do - raise ArgumentError, "#{library_path} does not exist" + # TODO: Register and lookup per client + + def start_link(_opts) do + GenServer.start_link(__MODULE__, %{}, name: __MODULE__) + end + + def register(key, library_path) do + GenServer.cast(__MODULE__, {:register, key, library_path}) + end + + def lookup(key) do + GenServer.call(__MODULE__, {:lookup, key}) + end + + def register_symbol(key, symbol, dimensions) do + if ref = lookup(key) do + EXLA.NIF.register_custom_call_symbol(ref, symbol, dimensions) end + end + + @impl true + def init(_opts) do + {:ok, %{}} + end + + @impl true + def handle_cast({:register, key, library_path}, state) do + case state do + %{^key => _ref} -> + {:noreply, state} - case :persistent_term.get({__MODULE__, library_path}, nil) do - nil -> + %{} -> ref = library_path |> EXLA.NIF.load_custom_call_plugin_library() |> unwrap!() - # we need to keep the ref from getting garbage collected so - # we can use the symbols within it at anytime - :persistent_term.put({__MODULE__, library_path}, ref) - - _ref -> - :ok + {:noreply, Map.put(state, key, ref)} end end + @impl true + def handle_call({:lookup, key}, _from, state) do + value = Map.get(state, key) + {:reply, value, state} + end + defp unwrap!({:ok, ref}), do: ref defp unwrap!({:error, reason}), do: raise("#{reason}") end diff --git a/exla/test/exla/plugin_test.exs b/exla/test/exla/plugin_test.exs index e7f7d0d7f35..ca9f9dfce15 100644 --- a/exla/test/exla/plugin_test.exs +++ b/exla/test/exla/plugin_test.exs @@ -2,20 +2,8 @@ defmodule EXLA.PluginTest do use ExUnit.Case describe "register/1" do - test "raises if file does not exist" do - assert_raise ArgumentError, ~r/does not exist/, fn -> - EXLA.Plugin.register("test/support/c/doesnotexist.so") - end - end - - test "does not crash on invalid files" do - assert_raise RuntimeError, ~r/Unable to open/, fn -> - EXLA.Plugin.register(__ENV__.file) - end - end - test "registers a plugin" do - assert :ok = EXLA.Plugin.register("test/support/c/libcustom_plugin.so") + assert :ok = EXLA.Plugin.register(:custom_plugin, "test/support/c/libcustom_plugin.so") end end -end \ No newline at end of file +end diff --git a/exla/test/support/c/custom_plugin.c b/exla/test/support/c/custom_plugin.c index 085b64f6be3..b3c70f09502 100644 --- a/exla/test/support/c/custom_plugin.c +++ b/exla/test/support/c/custom_plugin.c @@ -1,16 +1,16 @@ #include #include -typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); typedef struct { const char* name; ExlaCustomCallFunction func; } ExlaPluginCustomCall; -void custom_increment(void *out[], const void *in[]) { +extern "C" void custom_increment(void *out[], const void *in[], int **dims) { int64_t *operand = (int64_t *)in[0]; - int64_t *dim_sizes = (int64_t *)in[1]; + int64_t *dim_sizes = (int64_t *)dims[0]; int64_t *out_buffer = (int64_t *)out[0]; @@ -19,9 +19,4 @@ void custom_increment(void *out[], const void *in[]) { for (int64_t i = 0; i < n; i++) { out_buffer[i] = operand[i] + 1; } -} - -extern "C" ExlaPluginCustomCall exla_custom_calls[] = { - {"custom_increment", custom_increment}, - {NULL, NULL} -}; \ No newline at end of file +} \ No newline at end of file diff --git a/exla/test/support/c/libcustom_plugin.so b/exla/test/support/c/libcustom_plugin.so index cfc8eb7b93b039e36f53325916d0da668e128752..90e7640043d51487ffbf62f74e8f1f211c470bb5 100755 GIT binary patch delta 694 zcmZWmT}TvB6h8M(n`4g5#IB8|=2Vmw(dPD`5H+)uVDu2y2HT5cESoH-{TXCY7P5ga zvC}oyif}O~-!d!{0#o|Iw_1wqNkn@Xi)DLQ!ZC6?cX#cj1K<75IrqEg9PUX(=}0fx z^IBw+6Tyq}BHXGGJdPRIgp4Ky?|wEX{Jx`0J@YC! zcBSEn#Bgjk`f9C+#QAPvTy2?OEOFB?F5en^Rvmwp-Ux=PpuNrs*M?A?2ETnN+CB)T8H_0L$F{2++ zO|ztlkJU>0A*A-AUn%`b=z#=L6xI)kHHol@oGseU-I*;Z=o`K{=X~ehbM9-bgZMj{ z+-Jd$KMeo`5G#-- z=xF#eMH6_{a#{!&v+@i{k#!^~Nz_yn!w^;DXJi}MSB#Mpg{=T0-Mh2#`o+7Q|MF0#U>)BOKbY?e9f8oy|^XNi*o%YEF1uf|ho> zJf0?q!%J>ei_<$Z(hk~7wa_>Ue>@Vw;iLOcb>ZLQ>b%H5IBPRZ;-38|L^hHboHbxy z(gh55V(E+jZ&3O;j58O1Dej&rv zyR97=6YNPj&%yJR595aQm4$uEo~-lzeLoi79V{3h_dK2cKHw_C~kQhrP4E0QEEiVgLXD