Skip to content

Commit

Permalink
tensor.concat shouldn't succeed unless its find fn does
Browse files Browse the repository at this point in the history
  • Loading branch information
tynanbe committed Sep 14, 2023
1 parent 464da52 commit 934e066
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

- The `tensor` module's `concat` function now returns an `AxisNotFound` error
when the given find function is `False` for every `Axis`.

## v0.4.0 - 2023-09-11

- Argamak now requires Gleam v0.30 or later.
Expand Down
2 changes: 1 addition & 1 deletion gleam.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "argamak"
version = "0.4.0"
version = "0.5.0-dev"
description = "A tensor library for the Gleam programming language"
licences = ["Apache-2.0"]
gleam = "~> 0.30"
Expand Down
16 changes: 7 additions & 9 deletions src/argamak/tensor.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub type Native
/// When a tensor operation cannot succeed.
///
pub type TensorError {
AxisNotFound
CannotBroadcast
IncompatibleAxes
IncompatibleShape
Expand Down Expand Up @@ -3357,8 +3358,8 @@ pub fn in_situ_mean(
/// The first `Axis` for which the given `find` function returns `True` is
/// selected for joining.
///
/// If the `find` function returns `False` for every `Axis`, the tensors will
/// be joined along the first `Axis`.
/// If the `find` function returns `False` for every `Axis`, an `AxisNotFound`
/// error is returned.
///
/// ## Examples
///
Expand Down Expand Up @@ -3390,6 +3391,9 @@ pub fn in_situ_mean(
/// )
/// Nil
///
/// > concat([a, b], with: fn(_) { False })
/// Error(AxisNotFound)
///
/// > concat([a, b], with: fn(a) { axis.name(a) == "X" })
/// Error(IncompatibleShape)
///
Expand Down Expand Up @@ -3419,13 +3423,7 @@ pub fn concat(
|> iterator.index
|> iterator.find(one_that: fn(item) { find(item.1) })
|> result.map(with: fn(x) { x.0 })
|> result.lazy_or(fn() {
case new_axes {
[_, ..] -> Ok(0)
_else -> Error(Nil)
}
})
|> result.replace_error(IncompatibleShape),
|> result.replace_error(AxisNotFound),
)
use new_axes <- result.try({
use new_axes, x <- list.try_fold(over: rest, from: new_axes)
Expand Down
6 changes: 5 additions & 1 deletion test/argamak/tensor_test.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,7 @@ pub fn concat_test() {
let assert Ok(x) =
x
|> list.repeat(times: 3)
|> tensor.concat(with: fn(_) { False })
|> tensor.concat(with: fn(_) { True })
x
|> should_share_native_format
|> tensor.axes
Expand All @@ -2026,6 +2026,10 @@ pub fn concat_test() {
|> tensor.to_floats
|> should.equal([0.0, 1.0, 4.0, 2.0, 3.0, 5.0])

[a, b]
|> tensor.concat(with: fn(_) { False })
|> should.equal(Error(tensor.AxisNotFound))

let error = Error(tensor.IncompatibleShape)

[a, b]
Expand Down

0 comments on commit 934e066

Please sign in to comment.