Skip to content

Commit

Permalink
feat: add better errors for invalid compiler configurations (#1575)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
polvalente and josevalim authored Jan 30, 2025
1 parent c83c4a0 commit aa366da
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
21 changes: 21 additions & 0 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ defmodule Nx.Defn.Compiler do
def __to_backend__(opts) do
{compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator)
compiler.__to_backend__(opts)
rescue
e in [UndefinedFunctionError] ->
raise_missing_callback(e, :__to_backend__, 1, __STACKTRACE__)
end

## JIT/Stream
Expand All @@ -120,12 +123,30 @@ defmodule Nx.Defn.Compiler do
def __compile__(fun, params, opts) do
{compiler, runtime_fun, opts} = prepare_options(fun, opts)
compiler.__compile__(fun, params, runtime_fun, opts)
rescue
e in [UndefinedFunctionError] ->
raise_missing_callback(e, :__compile__, 4, __STACKTRACE__)
end

@doc false
def __jit__(fun, params, args_list, opts) do
{compiler, runtime_fun, opts} = prepare_options(fun, opts)
compiler.__jit__(fun, params, runtime_fun, args_list, opts)
rescue
e in [UndefinedFunctionError] ->
raise_missing_callback(e, :__jit__, 5, __STACKTRACE__)
end

defp raise_missing_callback(exception, name, arity, stacktrace) do
case exception do
%UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} ->
raise ArgumentError,
"the expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler."

_ ->
# This is not an error that should've been caught by this function, so we pass the exception along
reraise exception, stacktrace
end
end

defp prepare_options(fun, opts) do
Expand Down
11 changes: 11 additions & 0 deletions nx/lib/nx/serving.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,17 @@ defmodule Nx.Serving do
defp serving_partitions(%Nx.Serving{defn_options: defn_options}, true) do
compiler = Keyword.get(defn_options, :compiler, Nx.Defn.Evaluator)
compiler.__partitions_options__(defn_options)
rescue
e in [UndefinedFunctionError] ->
case e do
%UndefinedFunctionError{module: compiler, function: :__partitions_options__, arity: 1} ->
raise ArgumentError,
"the expected compiler callback __partitions_options__/1 is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler."

_ ->
# This is not an error that should've been caught by this function, so we pass the exception along
reraise e, __STACKTRACE__
end
end

defp serving_partitions(%Nx.Serving{defn_options: defn_options}, false) do
Expand Down
49 changes: 49 additions & 0 deletions nx/test/nx/defn/compiler_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defmodule Nx.Defn.CompilerTest do
use ExUnit.Case, async: true

defmodule SomeInvalidServing do
def init(_, _, _) do
:ok
end
end

test "raises an error if the __compile__ callback is missing" do
msg =
"the expected compiler callback __compile__/4 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler."

assert_raise ArgumentError, msg, fn ->
Nx.Defn.compile(&Function.identity/1, [Nx.template({}, :f32)],
compiler: SomeInvalidCompiler
)
end
end

test "raises an error if the __jit__ callback is missing" do
msg =
"the expected compiler callback __jit__/5 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler."

assert_raise ArgumentError, msg, fn ->
Nx.Defn.jit(&Function.identity/1, compiler: SomeInvalidCompiler).(1)
end
end

test "raises an error if the __partitions_options__ callback is missing" do
msg =
"the expected compiler callback __partitions_options__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler."

serving = Nx.Serving.new(SomeInvalidServing, [], compiler: SomeInvalidCompiler)

assert_raise ArgumentError, msg, fn ->
Nx.Serving.init({MyName, serving, true, [1], 10, 1000, nil, 1})
end
end

test "raises an error if the __to_backend__ callback is missing" do
msg =
"the expected compiler callback __to_backend__/1 is missing. Please check that the module SomeInvalidCompiler is an Nx.Defn.Compiler."

assert_raise ArgumentError, msg, fn ->
Nx.Defn.to_backend(compiler: SomeInvalidCompiler)
end
end
end

0 comments on commit aa366da

Please sign in to comment.