diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index 4e68eba..dde5dc6 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -378,7 +378,7 @@ defmodule Candlex.Backend do # Indexed @impl true - def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices) do + def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, _opts) do tensor |> from_nx() |> Native.gather(from_nx(Nx.flatten(indices)), 0) @@ -391,7 +391,7 @@ defmodule Candlex.Backend do end @impl true - def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates) do + def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates, _opts) do {tensor, updates} = maybe_upcast(tensor, updates) tensor @@ -887,7 +887,6 @@ defmodule Candlex.Backend do end for op <- [ - :indexed_put, :map, :triangular_solve, :window_max, @@ -901,9 +900,14 @@ defmodule Candlex.Backend do end end - @impl true - def reduce(_out, _tensor, _, _, _) do - raise "unsupported Candlex.Backend.reduce function" + for op <- [ + :indexed_put, + :reduce + ] do + @impl true + def unquote(op)(_out, _tensor, _, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end end for op <- [ diff --git a/mix.lock b/mix.lock index 2abdd0f..18d290d 100644 --- a/mix.lock +++ b/mix.lock @@ -1,14 +1,14 @@ %{ "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.38", "b42252eddf63bda05554ba8be93a1262dc0920c721f1aaf989f5de0f73a2e367", [:mix], [], "hexpm", "2cd0907795aaef0c7e8442e376633c5b3bd6edc8dbbdc539b22f095501c1cdb6"}, "ex_doc": {:hex, :ex_doc, "0.30.9", "d691453495c47434c0f2052b08dd91cc32bc4e1a218f86884563448ee2502dd2", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "d7aaaf21e95dc5cddabf89063327e96867d00013963eadf2c6ad135506a8bc10"}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, + "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx", "27e7b5658b6d88ca5e9106ef0f09ad173bb0f154", [sparse: "nx"]}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, + "nx": {:git, "https://github.com/elixir-nx/nx", "e1b776ed2a49498cbf2465862b2fba5a0df6f43b", [sparse: "nx"]}, "rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, diff --git a/test/candlex_test.exs b/test/candlex_test.exs index d07eda4..7a30982 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1639,6 +1639,38 @@ defmodule CandlexTest do # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) # |> Nx.gather(t([[0, 0, 0], [0, 1, 1], [1, 1, 1]])) # |> assert_equal(t([1, 12, 112])) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.gather(t([[1], [0], [2], [1]]), axes: [1]) + # |> assert_equal(t( + # [ + # [2, 5], + # [1, 4], + # [3, 6], + # [2, 5] + # ] + # )) + + # Nx.iota({2, 1, 3}) + # |> Nx.gather(t([[[1], [0], [2]]]), axes: [2]) + # |> assert_equal(t( + # [ + # [ + # [ + # [1], + # [4] + # ], + # [ + # [0], + # [3] + # ], + # [ + # [2], + # [5] + # ] + # ] + # ] + # )) end test "indexed_add" do