diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index bbbd3ddcd4..c25e691470 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -112,6 +112,9 @@ defmodule Nx.Defn.Compiler do def __to_backend__(opts) do {compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) compiler.__to_backend__(opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__to_backend__, 1, __STACKTRACE__) end ## JIT/Stream @@ -120,12 +123,30 @@ defmodule Nx.Defn.Compiler do def __compile__(fun, params, opts) do {compiler, runtime_fun, opts} = prepare_options(fun, opts) compiler.__compile__(fun, params, runtime_fun, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__compile__, 4, __STACKTRACE__) end @doc false def __jit__(fun, params, args_list, opts) do {compiler, runtime_fun, opts} = prepare_options(fun, opts) compiler.__jit__(fun, params, runtime_fun, args_list, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__jit__, 5, __STACKTRACE__) + end + + defp raise_missing_callback(exception, name, arity, stacktrace) do + case exception do + %UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} -> + raise ArgumentError, + "the expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + + _ -> + # This is not an error that should've been caught by this function, so we pass the exception along + reraise exception, stacktrace + end end defp prepare_options(fun, opts) do diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index ab0a524bd7..6a4d41ac99 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1359,6 +1359,17 @@ defmodule Nx.Serving do defp serving_partitions(%Nx.Serving{defn_options: defn_options}, true) do compiler = Keyword.get(defn_options, :compiler, Nx.Defn.Evaluator) compiler.__partitions_options__(defn_options) + rescue + e in [UndefinedFunctionError] -> + case e do + %UndefinedFunctionError{module: compiler, function: :__partitions_options__, arity: 1} -> + raise ArgumentError, + "the expected compiler callback __partitions_options__/1 is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + + _ -> + # This is not an error that should've been caught by this function, so we pass the exception along + reraise e, __STACKTRACE__ + end end defp serving_partitions(%Nx.Serving{defn_options: defn_options}, false) do diff --git a/nx/test/nx/defn/compiler_test.exs b/nx/test/nx/defn/compiler_test.exs new file mode 100644 index 0000000000..c35f69b075 --- /dev/null +++ b/nx/test/nx/defn/compiler_test.exs @@ -0,0 +1,49 @@ +defmodule Nx.Defn.CompilerTest do + use ExUnit.Case, async: true + + defmodule SomeInvalidServing do + def init(_, _, _) do + :ok + end + end + + test "raises an error if the __compile__ callback is missing" do + msg = + "the expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.compile(&Function.identity/1, [Nx.template({}, :f32)], + compiler: SomeInvalidCompiler + ) + end + end + + test "raises an error if the __jit__ callback is missing" do + msg = + "the expected compiler callback __jit__/5 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.jit(&Function.identity/1, compiler: SomeInvalidCompiler).(1) + end + end + + test "raises an error if the __partitions_options__ callback is missing" do + msg = + "the expected compiler callback __partitions_options__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + serving = Nx.Serving.new(SomeInvalidServing, [], compiler: SomeInvalidCompiler) + + assert_raise ArgumentError, msg, fn -> + Nx.Serving.init({MyName, serving, true, [1], 10, 1000, nil, 1}) + end + end + + test "raises an error if the __to_backend__ callback is missing" do + msg = + "the expected compiler callback __to_backend__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler." + + assert_raise ArgumentError, msg, fn -> + Nx.Defn.to_backend(compiler: SomeInvalidCompiler) + end + end +end