Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow complex literals in defn #1572

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ defmodule Nx.Defn.Compiler do
{{{:., dot_meta, [Nx, name]}, meta, args}, state}
end

# We also allow specifically Complex.new so that literal complex numbers
# can be written in defn.
defp normalize({{:., dot_meta, [Complex, :new]}, meta, args}, state) do
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [Complex, :new]}, meta, args}, state}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use Complex.new or should we introduce Nx.c128 and friends exclusively for this? Especially because this will not work in practice:

defn foo(x) do
  Complex.new(0, x)
end

So the usage in practice is quite limited?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally use Complex.new instead of %Complex{re: ..., im: ...} to write a given complex literal, especially because Complex.new ensures that the components will be represented as floats. When structs are properly typed, this won't be an issue anymore.

We could just use %Complex{re: x, im: y} instead of Complex.new(x, y), but then we can end up with an invalid complex with integer (or worse, nil) components.

Nx.c64/1 and Nx.c128/1 would require a raw complex number to be given already, so it doesn't really solve the problem.

And we can't use Nx.complex because it will screw up the typing of the components and return a tensor instead of a constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check that the arguments for Complex.new are valid (either numbers or non-finites) and raise otherwise, as this is specifically intended for literals and constants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add this check to Complex.new directly, we get a better error message, but as it is, it already fails to compile the defn, albeit with a "bad argument in arithmetic expression" error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just use %Complex{re: x, im: y} instead of Complex.new(x, y), but then we can end up with an invalid complex with integer (or worse, nil) components.

To be clear, I am not proposing to use the structs. :D

Nx.c64/1 and Nx.c128/1 would require a raw complex number to be given already, so it doesn't really solve the problem.

I meant we could add a /2 version to them. My concern with this is that now everyone who consumes Nx.Defn.Expr has to deal with Complex as a new member of its AST, no? If we could formalize around tensor literals, it is less for downstream to handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex.new produces a %Complex{} struct, which then becomes the value for a :constant Expr, which is already valid. This isn't too different from getting a complex value via a keyword list.

I don't see how adding this here changes the Nx.Defn.Expr.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have no further objections. :) ANd I assume Complex.new already checks the arguments anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update Complex with a better error message, but it will already fail as it is now

end

defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [mod, name]}, meta, args}, state}
Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,7 @@ defmodule Nx.Defn.Expr do
"value and inline it inside the defn expression. Got: #{inspect(t)}"
end

defp to_expr(number) when is_number(number),
defp to_expr(number) when is_number(number) or is_struct(number, Complex),
do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number)

defp to_expr(other) do
Expand Down
9 changes: 9 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ defmodule Nx.DefnTest do
@tensor [1, 2, 3]
defn(list_constant, do: Nx.tensor(@tensor))

defn complex_constant do
Complex.new(1, :infinity)
end

test "from list" do
assert %T{data: %Expr{op: :tensor}} = list_constant()
end
Expand All @@ -35,6 +39,11 @@ defmodule Nx.DefnTest do
test "from binary" do
assert %T{data: %Expr{op: :tensor}} = binary_constant()
end

test "complex literals" do
assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant()
assert c == Complex.new(1, :infinity)
end
end

describe "Nx.tensor" do
Expand Down
Loading