From 33d715d42220ab59da0730ebeb67c0202848fc66 Mon Sep 17 00:00:00 2001 From: tynanbe Date: Fri, 22 Dec 2023 16:54:15 -0600 Subject: [PATCH] Require Gleam v0.33+ --- .github/workflows/ci.yml | 8 +- CHANGELOG.md | 1 + README.md | 12 +-- gleam.toml | 10 +-- manifest.toml | 29 ++++--- src/argamak/space.gleam | 23 +++--- src/argamak/tensor.gleam | 140 +++++++++++++++------------------ src/argamak_ffi.mjs | 5 +- test/argamak/format_test.gleam | 2 +- test/argamak/space_test.gleam | 52 ++++++------ test/argamak/tensor_test.gleam | 28 +++---- test/argamak_test_ffi.mjs | 3 +- 12 files changed, 146 insertions(+), 167 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 92d4838..0ad6625 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,14 +11,14 @@ jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: erlef/setup-beam@v1 with: otp-version: "25" rebar3-version: "3" - elixir-version: "1.14" - gleam-version: "0.30" + elixir-version: "1.16" + gleam-version: "0.33" - id: cache-gleam uses: actions/cache@v3 @@ -40,7 +40,7 @@ jobs: - uses: actions/setup-node@v3 with: - node-version: "18" + node-version: "20" - id: cache-node uses: actions/cache@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 82504fd..ad0a2ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Argamak now requires Gleam v0.33 or later. - The `tensor` module's `concat` function now returns an `AxisNotFound` error when the given find function is `False` for every `Axis`. diff --git a/README.md b/README.md index 98b4c83..070b809 100644 --- a/README.md +++ b/README.md @@ -52,20 +52,21 @@ end The `@tensorflow/tfjs` package is a runtime requirement for `argamak`; however, its import path in the `argamak_ffi.mjs` module might need adjustment, depending on your use case. It can be used as is in your Node.js project after running -`npm install @tensorflow/tfjs-node` or an equivalent command. +`npm install @tensorflow/tfjs-node` or an equivalent command for your package +manager of choice. ## Usage ```gleam // derby.gleam -import argamak/axis.{Axis, Infer} -import argamak/space -import argamak/tensor.{InvalidData, TensorError} import gleam/function import gleam/io import gleam/list import gleam/result import gleam/string +import argamak/axis.{Axis, Infer} +import argamak/space +import argamak/tensor.{type TensorError, InvalidData} pub fn announce_winner( from horses: List(String), @@ -148,7 +149,8 @@ pub fn announce_winner( // Finally, we make our announcement! // - announce(horse <> " wins the day with a mean time of " <> time <> " minutes!") + { horse <> " wins the day with a mean time of " <> time <> " minutes!" } + |> announce |> Ok } ``` diff --git a/gleam.toml b/gleam.toml index 2a709e4..661dedc 100644 --- a/gleam.toml +++ b/gleam.toml @@ -1,8 +1,8 @@ name = "argamak" -version = "0.5.0-dev" +version = "1.0.0-dev" description = "A tensor library for the Gleam programming language" licences = ["Apache-2.0"] -gleam = "~> 0.30" +gleam = ">= 0.33.0" [repository] repo = "argamak" @@ -14,12 +14,12 @@ href = "https://gleam.run/" title = "Website" [dependencies] -gleam_stdlib = "~> 0.30" +gleam_stdlib = "~> 0.34" nx = "~> 0.5" [dev-dependencies] -gleeunit = "~> 0.11" -rad = "~> 0.4" +gleeunit = "~> 1.0" +rad = "~> 1.0" [rad] targets = ["erlang", "javascript"] diff --git a/manifest.toml b/manifest.toml index b71950d..0be1250 100644 --- a/manifest.toml +++ b/manifest.toml @@ -3,25 +3,24 @@ packages = [ { name = "complex", version = "0.5.0", build_tools = ["mix"], requirements = [], otp_app = "complex", source = "hex", outer_checksum = "2683BD3C184466CFB94FAD74CBFDDFAA94B860E27AD4CA1BFFE3BFF169D91EF1" }, - { name = "gleam_bitwise", version = "1.3.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_bitwise", source = "hex", outer_checksum = "E2A46EE42E5E9110DAD67E0F71E7358CBE54D5EC22C526DD48CBBA3223025792" }, - { name = "gleam_community_ansi", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_community_colour", "gleam_stdlib", "gleam_bitwise"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "6E4E0CF2B207C1A7FCD3C21AA43514D67BC7004F21F82045CDCCE6C727A14862" }, - { name = "gleam_community_colour", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_bitwise", "gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "D27CE357ECB343929A8CEC3FBA0B499943A47F0EE1F589EE16AFC2DC21C61E5B" }, - { name = "gleam_http", version = "3.5.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_http", source = "hex", outer_checksum = "FAE9AE3EB1CA90C2194615D20FFFD1E28B630E84DACA670B28D959B37BCBB02C" }, - { name = "gleam_json", version = "0.6.0", build_tools = ["gleam"], requirements = ["thoas", "gleam_stdlib"], otp_app = "gleam_json", source = "hex", outer_checksum = "C6CC5BEECA525117E97D0905013AB3F8836537455645DDDD10FE31A511B195EF" }, - { name = "gleam_stdlib", version = "0.30.2", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "8D8BF3790AA31176B1E1C0B517DD74C86DA8235CF3389EA02043EE4FD82AE3DC" }, - { name = "gleeunit", version = "0.11.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "1397E5C4AC4108769EE979939AC39BF7870659C5AFB714630DEEEE16B8272AD5" }, - { name = "glint", version = "0.11.4", build_tools = ["gleam"], requirements = ["gleam_stdlib", "snag", "gleam_community_ansi", "gleam_community_colour"], otp_app = "glint", source = "hex", outer_checksum = "9508BF037E35F549C51F9F1D2CC4736CEA7F7A49E21CCA9B4540452C7D6CC4C5" }, - { name = "nx", version = "0.6.0", build_tools = ["mix"], requirements = ["telemetry", "complex"], otp_app = "nx", source = "hex", outer_checksum = "E1AD3CC70A5828A1AEDB156B71E90863D9623A2DC9B35A5588F8627A07EE6CB4" }, - { name = "rad", version = "0.4.1", build_tools = ["gleam"], requirements = ["gleam_http", "gleam_stdlib", "thoas", "glint", "shellout", "snag", "gleam_json", "tomerl"], otp_app = "rad", source = "hex", outer_checksum = "1C993A1BF89F46B174ECFBE8A34CB135751517B2F9846D5D9D9A87CF823BFFE0" }, - { name = "shellout", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "shellout", source = "hex", outer_checksum = "995564B69D40146B7A424CA21D32A68D668A882F88BDAD0EFA2C18C7EC412564" }, - { name = "snag", version = "0.2.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "snag", source = "hex", outer_checksum = "8FD70D8FB3728E08AC425283BB509BB0F012BE1AE218424A597CDE001B0EE589" }, + { name = "gleam_community_ansi", version = "1.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "gleam_community_colour"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "AB7C3CCC894653637E02DC455D5890C8CF3064E83E78CFE61145A4C458D02DE6" }, + { name = "gleam_community_colour", version = "1.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "A49A5E3AE8B637A5ACBA80ECB9B1AFE89FD3D5351FF6410A42B84F666D40D7D5" }, + { name = "gleam_http", version = "3.5.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_http", source = "hex", outer_checksum = "AECDA43AFD523D07A8F09068598A6E271C505278A0CB6F9C7A2E4365EAE8D11E" }, + { name = "gleam_json", version = "0.7.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "CB405BD93A8828BCD870463DE29375E7B2D252D9D124C109E5B618AAC00B86FC" }, + { name = "gleam_stdlib", version = "0.34.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "1FB8454D2991E9B4C0C804544D8A9AD0F6184725E20D63C3155F0AEB4230B016" }, + { name = "gleeunit", version = "1.0.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "D364C87AFEB26BDB4FB8A5ABDE67D635DC9FA52D6AB68416044C35B096C6882D" }, + { name = "glint", version = "0.13.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "gleam_community_ansi", "snag", "gleam_community_colour"], otp_app = "glint", source = "hex", outer_checksum = "46E56049CD370D61F720D319D0AB970408C9336EEB918F08B5DCB1DCE9845FA3" }, + { name = "nx", version = "0.6.4", build_tools = ["mix"], requirements = ["complex", "telemetry"], otp_app = "nx", source = "hex", outer_checksum = "BB9C2E2E3545B5EB4739D69046A988DAAA212D127DBA7D97801C291616AFF6D6" }, + { name = "rad", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas", "gleam_json", "glint", "snag", "tomerl", "gleam_http", "shellout"], otp_app = "rad", source = "hex", outer_checksum = "E9EAE1DC9E3F75FFDBB2C8685885646EBE7352D6B2669A993B3F4DAFDF14FF81" }, + { name = "shellout", version = "1.5.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "shellout", source = "hex", outer_checksum = "7B5DE499DBB3DDC25051FC1BB3770DD5466938B6A2AFA91A6FB4A4D49F4CB0D4" }, + { name = "snag", version = "0.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "snag", source = "hex", outer_checksum = "54D32E16E33655346AA3E66CBA7E191DE0A8793D2C05284E3EFB90AD2CE92BCC" }, { name = "telemetry", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "telemetry", source = "hex", outer_checksum = "DAD9CE9D8EFFC621708F99EAC538EF1CBE05D6A874DD741DE2E689C47FEAFED5" }, { name = "thoas", version = "0.4.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "4918D50026C073C4AB1388437132C77A6F6F7C8AC43C60C13758CC0ADCE2134E" }, { name = "tomerl", version = "0.5.0", build_tools = ["rebar3"], requirements = [], otp_app = "tomerl", source = "hex", outer_checksum = "2A7FB62F9EBF0E75561B39255638BC2B805B437C86FEC538657E7C3B576979FA" }, ] [requirements] -gleam_stdlib = { version = "~> 0.30" } -gleeunit = { version = "~> 0.11" } +gleam_stdlib = { version = "~> 0.34" } +gleeunit = { version = "~> 1.0" } nx = { version = "~> 0.5" } -rad = { version = "~> 0.4" } +rad = { version = "~> 1.0" } diff --git a/src/argamak/space.gleam b/src/argamak/space.gleam index 59dda1f..abcfb7c 100644 --- a/src/argamak/space.gleam +++ b/src/argamak/space.gleam @@ -1,9 +1,9 @@ -import argamak/axis.{Axes, Axis, Infer} +import gleam/dict import gleam/int import gleam/list -import gleam/map import gleam/result import gleam/string +import argamak/axis.{type Axes, type Axis, Axis, Infer} /// An n-dimensional `Space` containing `Axes` of various sizes. /// @@ -315,16 +315,16 @@ pub fn merge(a: Space, b: Space) -> SpaceResult { let index = fn(x: Space) { x |> axes - |> list.index_map(with: fn(index, axis) { #(index, axis) }) - |> map.from_list + |> list.index_map(with: fn(axis, index) { #(index, axis) }) + |> dict.from_list } let a_index = index(a) let b_index = index(b) - let a_size = map.size(a_index) - let b_size = map.size(b_index) + let a_size = dict.size(a_index) + let b_size = dict.size(b_index) - let #(x, map) = case a_size < b_size { + let #(x, dict) = case a_size < b_size { True -> #(axes(b), a_index) False -> #(axes(a), b_index) } @@ -332,10 +332,10 @@ pub fn merge(a: Space, b: Space) -> SpaceResult { let #(x, errors) = x - |> list.index_map(with: fn(index, a_axis) { + |> list.index_map(with: fn(a_axis, index) { let b_axis = - map - |> map.get(index - offset) + dict + |> dict.get(index - offset) |> result.unwrap(or: a_axis) let a_name = axis.name(a_axis) let b_name = axis.name(b_axis) @@ -478,7 +478,8 @@ fn validate(space: Space) -> SpaceResult { } ValidateAcc( names: [name, ..acc.names], - inferred: acc.inferred || axis == Infer(name), + inferred: acc.inferred + || axis == Infer(name), results: [result, ..acc.results], ) } diff --git a/src/argamak/tensor.gleam b/src/argamak/tensor.gleam index 41e0a09..575ac7f 100644 --- a/src/argamak/tensor.gleam +++ b/src/argamak/tensor.gleam @@ -1,15 +1,15 @@ -import argamak/axis.{Axes, Axis, Infer} -import argamak/format.{Float32, Format, Int32} -import argamak/space.{Space} import gleam/bool +import gleam/dict import gleam/io import gleam/int import gleam/iterator import gleam/list -import gleam/map import gleam/result import gleam/string -import gleam/string_builder.{StringBuilder} +import gleam/string_builder.{type StringBuilder} +import argamak/axis.{type Axes, type Axis, Axis, Infer} +import argamak/format.{type Float32, type Format, type Int32} +import argamak/space.{type Space} /// A `Tensor` is a generic container for n-dimensional data structures. /// @@ -756,27 +756,29 @@ pub fn broadcast_over( ) -> TensorResult(a) { let new_axes = space.axes(new_space) - use mapped_axes <- result.try(result.all({ - use axis <- list.map(axes(x)) - let name = space_map(axis) - { - use axis <- list.find_map(new_axes) - case axis.name(axis) == name { - True -> - #(name, axis.size(axis)) - |> Ok - False -> Error(Nil) + use mapped_axes <- result.try( + result.all({ + use axis <- list.map(axes(x)) + let name = space_map(axis) + { + use axis <- list.find_map(new_axes) + case axis.name(axis) == name { + True -> + #(name, axis.size(axis)) + |> Ok + False -> Error(Nil) + } } - } - |> result.replace_error(IncompatibleAxes) - })) - let axis_map = map.from_list(mapped_axes) + |> result.replace_error(IncompatibleAxes) + }), + ) + let axis_dict = dict.from_list(mapped_axes) // TODO: use higher level functions? let pre_shape = { use axis <- list.map(new_axes) - axis_map - |> map.get(axis.name(axis)) + axis_dict + |> dict.get(axis.name(axis)) |> result.unwrap(or: 1) } @@ -3411,18 +3413,19 @@ pub fn concat( xs: List(Tensor(a)), with find: fn(Axis) -> Bool, ) -> TensorResult(a) { - use [x, ..rest] <- result.try(case xs { + use xs <- result.try(case xs { [_, ..] -> Ok(xs) _else -> Error(InvalidData) }) + let assert [x, ..rest] = xs let new_axes = axes(x) use index <- result.try( new_axes |> iterator.from_list |> iterator.index - |> iterator.find(one_that: fn(item) { find(item.1) }) - |> result.map(with: fn(x) { x.0 }) + |> iterator.find(one_that: fn(item) { find(item.0) }) + |> result.map(with: fn(x) { x.1 }) |> result.replace_error(AxisNotFound), ) use new_axes <- result.try({ @@ -3437,7 +3440,7 @@ pub fn concat( |> iterator.from_list |> iterator.index use new_axes, pair <- iterator.try_fold(over: pairs, from: []) - let #(i, #(a, b)) = pair + let #(#(a, b), i) = pair case axis.name(a) == axis.name(b) { True if i == index -> [axis.resize(a, axis.size(a) + axis.size(b)), ..new_axes] @@ -3803,69 +3806,61 @@ fn do_to_string(from x: Tensor(a), wrap_at column: Int, with tab: Int) -> String |> string_builder.from_string }) - let [#(_, xs)] = + let assert [#(xs, _)] = iterator.to_list({ use acc, size, i <- list.index_fold(over: shape, from: xs) let should_build = fn(j) { { j + 1 } % size == 0 } let ToStringAcc(built: built, ..) = case i { 0 -> { use acc, item <- iterator.fold(over: acc, from: to_string_acc) - let #(j, x) = item + let #(x, j) = item let builder = string_builder.append_builder(to: acc.builder, suffix: x) let should_build_j = should_build(j) - use <- bool_lazy_guard( - when: should_build_j && rank == 0, + use <- bool.lazy_guard( + when: should_build_j + && rank == 0, return: fn() { ToStringAcc(..acc, built: list.append(acc.built, [builder])) }, ) - use <- bool_lazy_guard( - when: should_build_j, - return: fn() { - let builder = - builder - |> string_builder.prepend(prefix: "[") - |> string_builder.append(suffix: "]") - ToStringAcc( - built: list.append(acc.built, [builder]), - builder: init_builder, - ) - }, - ) - use <- bool_lazy_guard( - when: should_wrap(j), - return: fn() { - let indent = string.repeat(" ", times: tab + rank) - let builder = - builder - |> string_builder.append(suffix: ",\n") - |> string_builder.append(suffix: indent) - ToStringAcc(..acc, builder: builder) - }, - ) + use <- bool.lazy_guard(when: should_build_j, return: fn() { + let builder = + builder + |> string_builder.prepend(prefix: "[") + |> string_builder.append(suffix: "]") + ToStringAcc( + built: list.append(acc.built, [builder]), + builder: init_builder, + ) + }) + use <- bool.lazy_guard(when: should_wrap(j), return: fn() { + let indent = string.repeat(" ", times: tab + rank) + let builder = + builder + |> string_builder.append(suffix: ",\n") + |> string_builder.append(suffix: indent) + ToStringAcc(..acc, builder: builder) + }) // else let builder = string_builder.append(to: builder, suffix: ", ") ToStringAcc(..acc, builder: builder) } _else -> { use acc, item <- iterator.fold(over: acc, from: to_string_acc) - let #(j, x) = item + let #(x, j) = item let builder = string_builder.append_builder(to: acc.builder, suffix: x) - use <- bool_lazy_guard( - when: should_build(j), - return: fn() { - let builder = - builder - |> string_builder.prepend(prefix: "[") - |> string_builder.append(suffix: "]") - ToStringAcc( - built: list.append(acc.built, [builder]), - builder: init_builder, - ) - }, - ) + use <- bool.lazy_guard(when: should_build(j), return: fn() { + let builder = + builder + |> string_builder.prepend(prefix: "[") + |> string_builder.append(suffix: "]") + ToStringAcc( + built: list.append(acc.built, [builder]), + builder: init_builder, + ) + }) // else let indent = string.repeat(" ", times: tab + rank - i) let builder = @@ -4226,14 +4221,3 @@ fn int_to_bool(x) { _else -> True } } - -fn bool_lazy_guard( - when requirement: Bool, - return consequence: fn() -> a, - otherwise alternative: fn() -> a, -) -> a { - case requirement { - True -> consequence() - False -> alternative() - } -} diff --git a/src/argamak_ffi.mjs b/src/argamak_ffi.mjs index 01b5774..890d49e 100644 --- a/src/argamak_ffi.mjs +++ b/src/argamak_ffi.mjs @@ -1,4 +1,5 @@ import * as tf from "@tensorflow/tfjs-node"; +import { inspect } from "../gleam_stdlib/gleam_stdlib.mjs"; import { Error as GleamError, List, Ok, Result, toList } from "./gleam.mjs"; import { CannotBroadcast, @@ -135,7 +136,7 @@ const fn = (f) => new Fn(f); export const tensor = (x, format) => fn(() => { - if (List.isList(x)) { + if (x instanceof List) { x = x.toArray(); if (!x.length) { throw new Error(Nil); @@ -330,7 +331,7 @@ export function columns() { // Format Functions // //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// -export const format_to_native = (x) => x.inspect().toLowerCase(); +export const format_to_native = (x) => inspect(x).toLowerCase(); //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// // Private Functions // diff --git a/test/argamak/format_test.gleam b/test/argamak/format_test.gleam index c82eac7..ec983a6 100644 --- a/test/argamak/format_test.gleam +++ b/test/argamak/format_test.gleam @@ -1,5 +1,5 @@ -import argamak/format import gleeunit/should +import argamak/format pub fn to_string_test() { format.float32() diff --git a/test/argamak/space_test.gleam b/test/argamak/space_test.gleam index 5fbf371..76ce4d5 100644 --- a/test/argamak/space_test.gleam +++ b/test/argamak/space_test.gleam @@ -1,9 +1,9 @@ +import gleam/list +import gleeunit/should import argamak/axis.{A, Axis, B, C, D, E, Infer, Z} import argamak/space.{ CannotInfer, CannotMerge, DuplicateName, InvalidSize, SpaceError, } -import gleam/list -import gleeunit/should //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// // Creation Functions // @@ -63,10 +63,9 @@ pub fn d2_test() { let axis_a = Axis(name: "A", size: 1) space.d2(a, axis_a) - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) space.d2(Infer(name: "A"), infer) |> should.equal(Error([SpaceError(CannotInfer, [infer])])) @@ -94,10 +93,9 @@ pub fn d3_test() { let axis_a = Axis(name: "A", size: 1) space.d3(a, axis_a, axis) - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) space.d3(Infer(name: "A"), infer, axis) |> should.equal(Error([SpaceError(CannotInfer, [infer])])) @@ -126,10 +124,9 @@ pub fn d4_test() { let axis_a = Axis(name: "A", size: 1) space.d4(a, b, axis_a, axis) - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) space.d4(Infer(name: "A"), infer, b, axis) |> should.equal(Error([SpaceError(CannotInfer, [infer])])) @@ -159,10 +156,9 @@ pub fn d5_test() { let axis_a = Axis(name: "A", size: 1) space.d5(a, b, c, axis_a, axis) - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) space.d5(Infer(name: "A"), infer, b, c, axis) |> should.equal(Error([SpaceError(CannotInfer, [infer])])) @@ -193,10 +189,9 @@ pub fn d6_test() { let axis_a = Axis(name: "A", size: 1) space.d6(a, b, c, d, axis_a, axis) - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) space.d6(Infer(name: "A"), infer, b, c, d, axis) |> should.equal(Error([SpaceError(CannotInfer, [infer])])) @@ -233,10 +228,9 @@ pub fn from_list_test() { let axis_a = Axis(name: "A", size: 1) [a, b, c, d, e, axis_a, axis] |> space.from_list - |> should.equal(Error([ - SpaceError(InvalidSize, [a]), - SpaceError(DuplicateName, [axis_a]), - ])) + |> should.equal( + Error([SpaceError(InvalidSize, [a]), SpaceError(DuplicateName, [axis_a])]), + ) [Infer(name: "A"), infer, b, c, d, e, axis] |> space.from_list @@ -441,9 +435,9 @@ pub fn merge_test() { let assert Ok(d1) = space.d1(a) space.merge(d1, d3) - |> should.equal(Error([ - SpaceError(CannotMerge, [Infer(name: "Sparkle"), A(size: 1)]), - ])) + |> should.equal( + Error([SpaceError(CannotMerge, [Infer(name: "Sparkle"), A(size: 1)])]), + ) let assert Ok(d2) = space.d2(Infer(name: "Shine"), axis) space.merge(d3, d2) diff --git a/test/argamak/tensor_test.gleam b/test/argamak/tensor_test.gleam index a4c2c48..572e2bd 100644 --- a/test/argamak/tensor_test.gleam +++ b/test/argamak/tensor_test.gleam @@ -1,13 +1,13 @@ -import argamak/axis.{A, B, C, D, E, F, Infer, Z} -import argamak/format -import argamak/space -import argamak/tensor.{Tensor} -import gleam/dynamic.{Dynamic} +import gleam/dynamic.{type Dynamic} import gleam/float import gleam/int import gleam/list import gleam/order.{Eq} import gleeunit/should +import argamak/axis.{A, B, C, D, E, F, Infer, Z} +import argamak/format +import argamak/space +import argamak/tensor.{type Tensor} //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// // Creation Functions // @@ -525,17 +525,13 @@ pub fn broadcast_over_test() { |> should.equal(list.flat_map(over: xs, with: list.repeat(item: _, times: 2))) let assert Ok(y) = - tensor.broadcast_over( - from: x, - into: d3, - with: fn(axis) { - case axis.name(axis) { - "A" -> "A" - "B" -> "C" - name -> name - } - }, - ) + tensor.broadcast_over(from: x, into: d3, with: fn(axis) { + case axis.name(axis) { + "A" -> "A" + "B" -> "C" + name -> name + } + }) y |> tensor.space |> space.axes diff --git a/test/argamak_test_ffi.mjs b/test/argamak_test_ffi.mjs index ab16370..e5d55a9 100644 --- a/test/argamak_test_ffi.mjs +++ b/test/argamak_test_ffi.mjs @@ -1,6 +1,7 @@ import { tensor as tf_tensor } from "@tensorflow/tfjs-node"; +import { inspect } from "../gleam_stdlib/gleam_stdlib.mjs"; -export const tensor = (x) => tf_tensor(eval(x.inspect())); +export const tensor = (x) => tf_tensor(eval(inspect(x))); export const shape = (x) => x.shape;