From 934e06660b04f7bdfc20bb845bb9ae420cd688a7 Mon Sep 17 00:00:00 2001 From: tynanbe Date: Thu, 14 Sep 2023 08:54:12 -0500 Subject: [PATCH] `tensor.concat` shouldn't succeed unless its `find` fn does --- CHANGELOG.md | 5 +++++ gleam.toml | 2 +- src/argamak/tensor.gleam | 16 +++++++--------- test/argamak/tensor_test.gleam | 6 +++++- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1b988a..82504fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +- The `tensor` module's `concat` function now returns an `AxisNotFound` error + when the given find function is `False` for every `Axis`. + ## v0.4.0 - 2023-09-11 - Argamak now requires Gleam v0.30 or later. diff --git a/gleam.toml b/gleam.toml index 2c5ed71..2a709e4 100644 --- a/gleam.toml +++ b/gleam.toml @@ -1,5 +1,5 @@ name = "argamak" -version = "0.4.0" +version = "0.5.0-dev" description = "A tensor library for the Gleam programming language" licences = ["Apache-2.0"] gleam = "~> 0.30" diff --git a/src/argamak/tensor.gleam b/src/argamak/tensor.gleam index 71a1b1c..41e0a09 100644 --- a/src/argamak/tensor.gleam +++ b/src/argamak/tensor.gleam @@ -24,6 +24,7 @@ pub type Native /// When a tensor operation cannot succeed. /// pub type TensorError { + AxisNotFound CannotBroadcast IncompatibleAxes IncompatibleShape @@ -3357,8 +3358,8 @@ pub fn in_situ_mean( /// The first `Axis` for which the given `find` function returns `True` is /// selected for joining. /// -/// If the `find` function returns `False` for every `Axis`, the tensors will -/// be joined along the first `Axis`. +/// If the `find` function returns `False` for every `Axis`, an `AxisNotFound` +/// error is returned. /// /// ## Examples /// @@ -3390,6 +3391,9 @@ pub fn in_situ_mean( /// ) /// Nil /// +/// > concat([a, b], with: fn(_) { False }) +/// Error(AxisNotFound) +/// /// > concat([a, b], with: fn(a) { axis.name(a) == "X" }) /// Error(IncompatibleShape) /// @@ -3419,13 +3423,7 @@ pub fn concat( |> iterator.index |> iterator.find(one_that: fn(item) { find(item.1) }) |> result.map(with: fn(x) { x.0 }) - |> result.lazy_or(fn() { - case new_axes { - [_, ..] -> Ok(0) - _else -> Error(Nil) - } - }) - |> result.replace_error(IncompatibleShape), + |> result.replace_error(AxisNotFound), ) use new_axes <- result.try({ use new_axes, x <- list.try_fold(over: rest, from: new_axes) diff --git a/test/argamak/tensor_test.gleam b/test/argamak/tensor_test.gleam index 8cd79cf..a4c2c48 100644 --- a/test/argamak/tensor_test.gleam +++ b/test/argamak/tensor_test.gleam @@ -2003,7 +2003,7 @@ pub fn concat_test() { let assert Ok(x) = x |> list.repeat(times: 3) - |> tensor.concat(with: fn(_) { False }) + |> tensor.concat(with: fn(_) { True }) x |> should_share_native_format |> tensor.axes @@ -2026,6 +2026,10 @@ pub fn concat_test() { |> tensor.to_floats |> should.equal([0.0, 1.0, 4.0, 2.0, 3.0, 5.0]) + [a, b] + |> tensor.concat(with: fn(_) { False }) + |> should.equal(Error(tensor.AxisNotFound)) + let error = Error(tensor.IncompatibleShape) [a, b]