Skip to content

Commit

Permalink
fix: upcast floats in cond
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Jan 23, 2025
1 parent 32bf7e8 commit 6b9d819
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ defmodule Nx do

for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:f8, :bf16, :f16, :f32, :f64] do
[:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down
10 changes: 7 additions & 3 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,13 @@ defmodule Nx.Defn.Expr do

result =
for expr <- [last | exprs] do
expr
|> Nx.as_type(type)
|> Nx.broadcast(shape, names: names)
typed_expr =
case expr do
%T{data: %Expr{op: :constant}} -> maybe_upcast_float_constant(expr, type)
expr -> Nx.as_type(expr, type)
end

Nx.broadcast(typed_expr, shape, names: names)
end

{result, vectorized_axes}
Expand Down
21 changes: 21 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,27 @@ defmodule Nx.DefnTest do
)
end

defn cond_upcast_float_literals(n) do
cond do
n == 1 -> 1.4
n == 2 -> 2
true -> n
end
end

test "upcasts float literals based on the accumulated clause type" do
for input_type <- [f: 32, f: 64, c: 64, c: 128] do
assert %T{
type: ^input_type,
data: %Expr{op: :cond, args: [[clause1, clause2], _last]}
} =
cond_upcast_float_literals(Nx.tensor(10.0, type: input_type))

assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1
assert {_, %T{type: {:s, 32}, data: %Expr{op: :constant, args: [2]}}} = clause2
end
end

defn cond_list(a) do
if Nx.any(a), do: 1, else: -1
end
Expand Down

0 comments on commit 6b9d819

Please sign in to comment.