From fd7b44419d4c129aafeb24330d43e4cebb3e991c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:26:48 -0300 Subject: [PATCH] feat: qr/2 for square matrices --- lib/candlex/backend.ex | 14 +- lib/candlex/native.ex | 1 + native/candlex/Cargo.lock | 347 +++++++++++++++++++++++++++++++++- native/candlex/Cargo.toml | 1 + native/candlex/src/lib.rs | 1 + native/candlex/src/tensors.rs | 34 ++++ test/candlex_test.exs | 170 +++++++++++++++++ 7 files changed, 565 insertions(+), 3 deletions(-) diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index 4e68eba..ab48eb3 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -855,6 +855,19 @@ defmodule Candlex.Backend do |> Stream.map(&to_nx(&1, out)) end + # LinAlg + + @impl true + def qr({out_q, out_r}, %T{shape: {n, n}} = tensor, _opts) do + {native_q, native_r} = + tensor + |> from_nx() + |> Native.qr() + |> unwrap!() + + {to_nx(native_q, out_q), to_nx(native_r, out_r)} + end + for op <- [ :cholesky, :conjugate, @@ -876,7 +889,6 @@ defmodule Candlex.Backend do :ifft, :lu, :product, - :qr, :reverse, :sort ] do diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index b85f0b6..6d0048d 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -75,6 +75,7 @@ defmodule Candlex.Native do :log, :log1p, :negate, + :qr, :round, :rsqrt, :sigmoid, diff --git a/native/candlex/Cargo.lock b/native/candlex/Cargo.lock index 43292a3..3d9ebed 100644 --- a/native/candlex/Cargo.lock +++ b/native/candlex/Cargo.lock @@ -44,6 +44,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "assert2" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf98d1183406dcb8f8b545e1f24829d75c1a9d35eec4b86309a22aa8b6d8e95" +dependencies = [ + "assert2-macros", + "is-terminal", + "yansi", +] + +[[package]] +name = "assert2-macros" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c55bdf3e6f792f8f1c750bb6886b7ca40fa5a354ddb7a4dee550b93985a9235" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -71,6 +94,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + [[package]] name = "bytemuck" version = "1.14.0" @@ -135,6 +164,7 @@ version = "0.1.0" dependencies = [ "anyhow", "candle-core", + "faer", "half", "num-traits", "rustler", @@ -157,6 +187,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "coe-rs" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" + [[package]] name = "crc32fast" version = "1.3.2" @@ -230,6 +266,169 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "errno" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c18ee0ed65a5f1f81cac6b1d213b69c35fa47d4252ad41f1486dbd8226fe36e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "faer" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04b7f7dad8a4aeff63096a8db31ddfa544baaff47b1cf7a5db31b56b51819a8d" +dependencies = [ + "assert2", + "coe-rs", + "dyn-stack", + "faer-cholesky", + "faer-core", + "faer-evd", + "faer-lu", + "faer-qr", + "faer-svd", + "matrixcompare", + "num-complex", + "pulp 0.17.1", + "reborrow", +] + +[[package]] +name = "faer-cholesky" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51b4f662b57a31fb743240936602063637623cb871acc7a86479df81f106e99a" +dependencies = [ + "assert2", + "bytemuck", + "dyn-stack", + "faer-core", + "faer-entity", + "num-complex", + "num-traits", + "pulp 0.17.1", + "reborrow", + "seq-macro", +] + +[[package]] +name = "faer-core" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09e9ad4c2acfd0e0c911cbd7276b8fb6a7e22ccbd998b4080e148cf001f2b267" +dependencies = [ + "assert2", + "bytemuck", + "coe-rs", + "dyn-stack", + "faer-entity", + "gemm", + "matrixcompare-core", + "num-complex", + "num-traits", + "paste", + "pulp 0.17.1", + "rayon", + "reborrow", + "seq-macro", +] + +[[package]] +name = "faer-entity" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5241a0fce43ff8bda51d8ecb4427e152a75b2340edb97b308b63021ed2a1967" +dependencies = [ + "bytemuck", + "coe-rs", + "libm", + "num-complex", + "num-traits", + "pulp 0.17.1", + "reborrow", +] + +[[package]] +name = "faer-evd" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e4f18e2ea44465f15ea22267280d00c9ac1dec8120fff372771ad6de1ef0a1e" +dependencies = [ + "assert2", + "bytemuck", + "coe-rs", + "dyn-stack", + "faer-core", + "faer-entity", + "faer-qr", + "libm", + "num-complex", + "num-traits", + "pulp 0.17.1", + "reborrow", +] + +[[package]] +name = "faer-lu" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14a7d382978305e2016204cee01ef4846c509db85a6631613aefa1ac5bd540d6" +dependencies = [ + "assert2", + "bytemuck", + "coe-rs", + "dyn-stack", + "faer-core", + "faer-entity", + "num-complex", + "num-traits", + "paste", + "pulp 0.17.1", + "reborrow", +] + +[[package]] +name = "faer-qr" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e80b5e11c7cdc417bc124c92e9c47c02cddc13b6e87a1381bb276a0bf5e07e73" +dependencies = [ + "assert2", + "bytemuck", + "coe-rs", + "dyn-stack", + "faer-core", + "faer-entity", + "num-complex", + "num-traits", + "pulp 0.17.1", + "rayon", + "reborrow", +] + +[[package]] +name = "faer-svd" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41303083ad96da55336b11c667a65761060741d6f5d72c90a2437fb79b3e3510" +dependencies = [ + "assert2", + "bytemuck", + "coe-rs", + "dyn-stack", + "faer-core", + "faer-entity", + "faer-qr", + "num-complex", + "num-traits", + "pulp 0.17.1", + "reborrow", +] + [[package]] name = "gemm" version = "0.16.14" @@ -293,7 +492,7 @@ dependencies = [ "num-traits", "once_cell", "paste", - "pulp", + "pulp 0.16.4", "raw-cpuid", "rayon", "seq-macro", @@ -396,6 +595,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "is-terminal" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +dependencies = [ + "hermit-abi", + "rustix", + "windows-sys", +] + [[package]] name = "itoa" version = "1.0.9" @@ -420,6 +630,28 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +[[package]] +name = "linux-raw-sys" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" + +[[package]] +name = "matrixcompare" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37832ba820e47c93d66b4360198dccb004b43c74abc3ac1ce1fed54e65a80445" +dependencies = [ + "matrixcompare-core", + "num-traits", +] + +[[package]] +name = "matrixcompare-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0bdabb30db18805d5290b3da7ceaccbddba795620b86c02145d688e04900a73" + [[package]] name = "matrixmultiply" version = "0.3.8" @@ -591,6 +823,17 @@ dependencies = [ "num-complex", ] +[[package]] +name = "pulp" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3bef241cc27671d8dc6425796134738f2d76fd4ebc878130e069a9c1d74e8b0" +dependencies = [ + "bytemuck", + "libm", + "num-complex", +] + [[package]] name = "quote" version = "1.0.33" @@ -646,7 +889,7 @@ version = "10.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -716,6 +959,28 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b426b0506e5d50a7d8dafcf2e81471400deb602392c7dd110815afb4eaf02a3" +dependencies = [ + "bitflags 2.4.1", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustler" version = "0.30.0" @@ -780,6 +1045,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" + [[package]] name = "seq-macro" version = "0.3.5" @@ -952,6 +1223,78 @@ dependencies = [ "safe_arch", ] +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "yoke" version = "0.7.2" diff --git a/native/candlex/Cargo.toml b/native/candlex/Cargo.toml index 306760b..a4551fc 100644 --- a/native/candlex/Cargo.toml +++ b/native/candlex/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] [dependencies] candle-core = { git = "https://github.com/huggingface/candle" } +faer = "0.14.1" half = "2.3.1" num-traits = "0.2.17" rustler = { version = "0.30.0", default-features = false, features = ["derive", "nif_version_2_16"] } diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index ef72f43..35ebbad 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -112,6 +112,7 @@ rustler::init! { tensors::logical_xor, tensors::left_shift, tensors::right_shift, + tensors::qr, tensors::to_device, devices::is_cuda_available ], diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 78c94b4..dacfbee 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -270,6 +270,40 @@ pub fn reshape(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.reshape(tuple_to_vec(shape).unwrap())?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn qr(tensor: ExTensor) -> Result<(ExTensor, ExTensor), CandlexError> { + use faer::Faer; + + let side = tensor.dims()[0]; + let device = tensor.device(); + + let vec = tensor.to_vec2::()?; + let mat = faer::Mat::from_fn(side, side, |row, col| vec[row][col]); + + let qr = mat.qr(); + let q = qr.compute_q(); + let r = qr.compute_r(); + + let mut q_res = vec![]; + let mut r_res = vec![]; + + let transposed_q = q.transpose().to_owned(); + let transposed_r = r.transpose().to_owned(); + + for i in 0..q.ncols() { + q_res.extend_from_slice(transposed_q.col_ref(i)); + } + + for i in 0..r.ncols() { + r_res.extend_from_slice(transposed_r.col_ref(i)); + } + + Ok(( + ExTensor::new(Tensor::new(q_res, &device)?.reshape((side, side))?), + ExTensor::new(Tensor::new(r_res, &device)?.reshape((side, side))?), + )) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn slice_scatter( t: ExTensor, diff --git a/test/candlex_test.exs b/test/candlex_test.exs index d07eda4..02bce1c 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -2259,6 +2259,176 @@ defmodule CandlexTest do |> assert_equal(t([[1], [1]])) end + test "qr" do + square = t([[-3.0, 2, 1], [0, 1, 1], [0, 0, -1]]) + + {q, r} = + square + |> Nx.LinAlg.qr() + + q + |> assert_equal( + t([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ]) + ) + + r + |> assert_equal( + t([ + [-3.0, 2.0, 1.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, -1.0] + ]) + ) + + Nx.dot(q, r) + |> assert_equal(square) + + # {q, r} = + # t([[3, 2, 1], [0, 1, 1], [0, 0, 1]]) + # |> Nx.LinAlg.qr() + + # q + # |> assert_equal( + # t([ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 1.0] + # ]) + # ) + + # r + # |> assert_equal( + # t([ + # [3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, 1.0] + # ]) + # ) + + # {qs, rs} = + # t([[[-3, 2, 1], [0, 1, 1], [0, 0, -1]], [[3, 2, 1], [0, 1, 1], [0, 0, 1]]]) + # |> Nx.LinAlg.qr() + + # qs + # |> assert_equal( + # t([ + # [ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 1.0] + # ], + # [ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 1.0] + # ] + # ]) + # ) + + # rs + # |> assert_equal( + # t([ + # [ + # [-3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, -1.0] + # ], + # [ + # [3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, 1.0] + # ] + # ]) + # ) + + # {q, r} = + # t([[3, 2, 1], [0, 1, 1], [0, 0, 1], [0, 0, 1]], type: :f32) + # |> Nx.LinAlg.qr(mode: :reduced) + + # q + # |> assert_equal( + # t([ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 0.7071067690849304], + # [0.0, 0.0, 0.7071067690849304] + # ]) + # ) + + # r + # |> assert_equal( + # t([ + # [3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, 1.4142135381698608] + # ]) + # ) + + # t = Nx.tensor([[3, 2, 1], [0, 1, 1], [0, 0, 1], [0, 0, 0]], type: :f32) + # {q, r} = Nx.LinAlg.qr(t, mode: :complete) + # q + ## Nx.Tensor< + # f32[4][4] + # [ + # [1.0, 0.0, 0.0, 0.0], + # [0.0, 1.0, 0.0, 0.0], + # [0.0, 0.0, 1.0, 0.0], + # [0.0, 0.0, 0.0, 1.0] + # ] + # > + # r + ## Nx.Tensor< + # f32[4][3] + # [ + # [3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, 1.0], + # [0.0, 0.0, 0.0] + # ] + # > + + # t = Nx.tensor([[[-3, 2, 1], [0, 1, 1], [0, 0, -1]],[[3, 2, 1], [0, 1, 1], [0, 0, 1]]]) |> Nx.vectorize(x: 2) + # {qs, rs} = Nx.LinAlg.qr(t) + # qs + ## Nx.Tensor< + # vectorized[x: 2] + # f32[3][3] + # [ + # [ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 1.0] + # ], + # [ + # [1.0, 0.0, 0.0], + # [0.0, 1.0, 0.0], + # [0.0, 0.0, 1.0] + # ] + # ] + # > + # rs + ## Nx.Tensor< + # vectorized[x: 2] + # f32[3][3] + # [ + # [ + # [-3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, -1.0] + # ], + # [ + # [3.0, 2.0, 1.0], + # [0.0, 1.0, 1.0], + # [0.0, 0.0, 1.0] + # ] + # ] + # > + end + if Candlex.Backend.cuda_available?() do test "different devices" do t([1, 2, 3], backend: {Candlex.Backend, device: :cpu})