Skip to content

Commit

Permalink
fix: fixes Nx.slice with tensor start_indices which affected defn_whi…
Browse files Browse the repository at this point in the history
…le and Nx.Random PRNG (mimiquate#32)
  • Loading branch information
grzuy authored and xabi committed Nov 21, 2023
1 parent 8a93b5e commit cbc1779
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 2 deletions.
6 changes: 5 additions & 1 deletion lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,11 @@ defmodule Candlex.Backend do

defp narrow(t, [start | starts], [length | lengths], axis, shape) do
dim = elem(shape, axis)
start = min(start, dim - length)

start =
start
|> Nx.to_number()
|> min(dim - length)

if start == 0 and length == dim do
# Nothing to narrow at this step
Expand Down
5 changes: 4 additions & 1 deletion lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defmodule Candlex.Native do
mix_config = Mix.Project.config()
version = mix_config[:version]
source_url = mix_config[:package][:links]["GitHub"]
mode = if Mix.env() in [:dev, :test], do: :debug, else: :release
# We can't run on :debug mode until we find a workaround to
# ignore integer overflows when running Nx.Random Threefry PRNG.
# mode = if Mix.env() in [:dev, :test], do: :debug, else: :release
mode = :release

use RustlerPrecompiled,
otp_app: :candlex,
Expand Down
12 changes: 12 additions & 0 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,18 @@ defmodule CandlexTest do
# [5, 6]
# ]
# ))

t([0, 1])
|> Nx.slice([t(0)], [1])
|> assert_equal(t([0]))

t([0, 1])
|> Nx.slice([t(1)], [1])
|> assert_equal(t([1]))

t([[1, 2, 3], [4, 5, 6]])
|> Nx.slice([t(0), t(1)], [1, 1])
|> assert_equal(t([[2]]))
end

test "squeeze" do
Expand Down
86 changes: 86 additions & 0 deletions test/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,90 @@ defmodule Candlex.DefnTest do
|> TG.tanh_grad()
|> assert_close(Nx.tensor(0.41997432708740234))
end

describe "while/3" do
defmodule Mod do
import Nx.Defn

defn upto10(x) do
while x, Nx.less(x, 10) do
x + 1
end
end

defn factorial_tuple(x) do
factorial = Nx.tensor(1, type: Nx.type(x))

{factorial, _} =
while {factorial, x}, Nx.greater(x, 1) do
{factorial * x, x - 1}
end

factorial
end

defn factorial_map(x) do
factorial = Nx.tensor(1, type: Nx.type(x))

%{factorial: factorial} =
while map = %{factorial: factorial, x: x}, Nx.greater(map.x, 1) do
%{map | factorial: map.factorial * map.x, x: map.x - 1}
end

factorial
end

defn factorial_map_input(map) do
%{factorial: factorial} =
while map, Nx.greater(map.x, 1) do
%{map | factorial: map.factorial * map.x, x: map.x - 1}
end

factorial
end

defn tensor_generator_sum() do
while x = 0, r <- Nx.tensor([0, 1, 2]) do
x + r
end
end
end

test "simple" do
Mod.upto10(0)
|> assert_equal(Nx.tensor(10))

Mod.upto10(5)
|> assert_equal(Nx.tensor(10))
end

test "factorial tuple" do
Mod.factorial_tuple(5)
|> assert_equal(Nx.tensor(120))

Mod.factorial_tuple(10.0)
|> assert_equal(Nx.tensor(3_628_800.0))
end

test "factorial map" do
Mod.factorial_map(5)
|> assert_equal(Nx.tensor(120))

Mod.factorial_map(10.0)
|> assert_equal(Nx.tensor(3_628_800.0))
end

test "factorial map input" do
Mod.factorial_map_input(%{factorial: 1, x: 5})
|> assert_equal(Nx.tensor(120))

Mod.factorial_map_input(%{factorial: 1.0, x: 10.0})
|> assert_equal(Nx.tensor(3_628_800.0))
end

test "tensor generator sum" do
Mod.tensor_generator_sum()
|> assert_equal(Nx.tensor(3))
end
end
end
32 changes: 32 additions & 0 deletions test/random_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
defmodule Candlex.RandomTest do
use Nx.Case, async: true

test "key/1" do
Nx.Random.key(42)
|> assert_equal(Nx.tensor([0, 42]))
end

test "uniform/1" do
{normal, new_key} =
Nx.Random.key(42)
|> Nx.Random.uniform()

normal
|> assert_close(Nx.tensor(0.9145736694335938))

new_key
|> assert_equal(Nx.tensor([2_465_931_498, 3_679_230_171]))
end

test "normal/1" do
{normal, new_key} =
Nx.Random.key(42)
|> Nx.Random.normal()

normal
|> assert_close(Nx.tensor(1.3694695234298706))

new_key
|> assert_equal(Nx.tensor([2_465_931_498, 3_679_230_171]))
end
end

0 comments on commit cbc1779

Please sign in to comment.