diff --git a/README.md b/README.md index 66d6372..98108c8 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,16 @@ Package `tensor` is a package that provides efficient, generic (by some definiti The main purpose of this package is to support the operations required by [Gorgonia](https://gorgonia.org/gorgonia). ## Introduction ## -In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. +In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. While slices are cool, a large majority of scientific and numeric computing work relies heavily on matrices (two-dimensional arrays), three dimensional arrays and so on. In Go, the typical way of getting multidimensional arrays is to use something like `[][]T`. Applications that are more math heavy may opt to use the very excellent Gonum [`matrix` package](https://github.com/gonum/matrix). What then if we want to go beyond having a `float64` matrix? What if we wanted a 3-dimensional `float32` array? -It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. +It comes to reason then there should be a data structure that handles these things. The `tensor` package fits in that niche. ### Basic Idea: Tensor ### A tensor is a multidimensional array. It's like a slice, but works in multiple dimensions. -With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). +With slices, there are usage patterns that are repeated enough that warrant abstraction - `append`, `len`, `cap`, `range` are abstractions used to manipulate and query slices. Additionally slicing operations (`a[:1]` for example) are also abstractions provided by the language. Andrew Gerrand wrote a very good write up on [Go's slice usage and internals](https://blog.golang.org/go-slices-usage-and-internals). Tensors come with their own set of usage patterns and abstractions. Most of these have analogues in slices, enumerated below (do note that certain slice operation will have more than one tensor analogue - this is due to the number of options available): @@ -26,7 +26,7 @@ Tensors come with their own set of usage patterns and abstractions. Most of thes | `a[0]` | `T.At(x,y)` | | `append(a, ...)`| `T.Stack(...)`, `T.Concat(...)` | | `copy(dest, src)`| `T.CopyTo(dest)`, `tensor.Copy(dest, src)` | -| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | +| `for _, v := range a` | `for i, err := iterator.Next(); err == nil; i, err = iterator.Next()` | Some operations for a tensor does not have direct analogues to slice operations. However, they stem from the same idea, and can be considered a superset of all operations common to slices. They're enumerated below: @@ -77,7 +77,7 @@ fmt.Printf("a:\n%v\n", a) To create a 3-Tensor is just as easy - just put the correct shape and you're good to go: -```go +```go // Creating a (2,3,4) 3-Tensor of float32 b := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) fmt.Printf("b:\n%1.1f\n", b) @@ -133,6 +133,12 @@ fmt.Printf("b:\n%v", b) There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/gorgonia.org/tensor) page +## API Notes ## + +This package has a notion of "layers" in its API. This section clarifies the different patterns seen in the API. + + + ## Design of `*Dense` ## @@ -142,7 +148,7 @@ The design of the `*Dense` tensor is quite simple in concept. However, let's sta The data structure for `*Dense` is similar, but a lot more complex. Much of the complexity comes from the need to do accounting work on the data structure as well as preserving references to memory locations. This is how the `*Dense` is defined: -```go +```go type Dense struct { *AP array @@ -168,7 +174,7 @@ type array struct { } ``` -`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. +`*storage.Header` is the same structure as `reflect.SliceHeader`, except it stores a `unsafe.Pointer` instead of a `uintptr`. This is done so that eventually when more tests are done to determine how the garbage collector marks data, the `v` field may be removed. The `storage.Header` field of the `array` (and hence `*Dense`) is there to provide a quick and easy way to translate back into a slice for operations that use familiar slice semantics, of which much of the operations are dependent upon. @@ -205,17 +211,17 @@ The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https: Example: -```go +```go x := New(WithBacking([]string{"hello", "world", "hello", "world"}), WithShape(2,2)) x = New(WithBacking([]int{1,2,3,4}), WithShape(2,2)) ``` -The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. +The above code will not cause a compile error, because the structure holding the underlying array (of `string`s and then of `int`s) is a `*Dense`. One could argue that this sidesteps the compiler's type checking system, deferring it to runtime (which a number of people consider dangerous). However, tools are being developed to type check these things, and until Go does support typechecked generics, unfortunately this will be the way it has to be. -Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. +Currently, the tensor package supports limited type of genericity - limited to a tensor of any primitive type. # How This Package is Developed # Much of the code in this package is generated. The code to generate them is in the directory `genlib2`. `genlib2` requires [`goimports`](https://godoc.org/golang.org/x/tools/cmd/goimports) binary to be available in the $PATH. @@ -246,7 +252,7 @@ See also: CONTRIBUTING.md ## Contributors and Significant Contributors ## -All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. +All contributions are welcome. However, there is a new class of contributor, called Significant Contributors. A Significant Contributor is one who has shown *deep understanding* of how the library works and/or its environs. Here are examples of what constitutes a Significant Contribution: diff --git a/ap.go b/ap.go index 410ec40..85bf9a7 100644 --- a/ap.go +++ b/ap.go @@ -365,9 +365,9 @@ func (ap *AP) unlock() { ap.fin = false } func (ap *AP) calcStrides() []int { switch { case ap.o.IsRowMajor(): - return ap.shape.CalcStrides() + return CalcStrides(ap.shape) case ap.o.IsColMajor(): - return ap.shape.CalcStridesColMajor() + return CalcStridesColMajor(ap.shape) } panic("unreachable") } diff --git a/ap_test.go b/ap_test.go index 8314546..3a0e5bb 100644 --- a/ap_test.go +++ b/ap_test.go @@ -203,7 +203,7 @@ func TestAccessPatternS(t *testing.T) { var err error for _, sts := range sliceTests { - ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0) + ap = MakeAP(sts.shape, CalcStrides(sts.shape), 0, 0) if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { t.Errorf("%v errored: %v", sts.name, err) continue diff --git a/api_arith.go b/api_arith.go index 4e86ffa..13ccd05 100644 --- a/api_arith.go +++ b/api_arith.go @@ -19,18 +19,18 @@ import ( // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var adder Adder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.Add(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Add(at, bt, opts...) } if adder, ok = at.Engine().(Adder); ok { @@ -55,7 +55,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.AddScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.AddScalar(at, bt, leftTensor, opts...) } if adder, ok = at.Engine().(Adder); ok { @@ -79,7 +79,7 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.AddScalar(bt, at, false, opts...) } if adder, ok = bt.Engine().(Adder); ok { @@ -100,18 +100,18 @@ func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var suber Suber - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction if oe != nil { return oe.Sub(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Sub(at, bt, opts...) } if suber, ok = at.Engine().(Suber); ok { @@ -136,7 +136,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.SubScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.SubScalar(at, bt, leftTensor, opts...) } if suber, ok = at.Engine().(Suber); ok { @@ -160,7 +160,7 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.SubScalar(bt, at, false, opts...) } if suber, ok = bt.Engine().(Suber); ok { @@ -181,18 +181,18 @@ func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var muler Muler - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication if oe != nil { return oe.Mul(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Mul(at, bt, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -217,7 +217,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MulScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(at, bt, leftTensor, opts...) } if muler, ok = at.Engine().(Muler); ok { @@ -242,7 +242,7 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: // b Tensor * a interface - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MulScalar(bt, at, false, opts...) } if muler, ok = bt.Engine().(Muler); ok { @@ -264,18 +264,18 @@ func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var diver Diver - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division if oe != nil { return oe.Div(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Div(at, bt, opts...) } if diver, ok = at.Engine().(Diver); ok { @@ -300,7 +300,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.DivScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.DivScalar(at, bt, leftTensor, opts...) } if diver, ok = at.Engine().(Diver); ok { @@ -324,7 +324,7 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.DivScalar(bt, at, false, opts...) } if diver, ok = bt.Engine().(Diver); ok { @@ -345,18 +345,18 @@ func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var power Power - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation if oe != nil { return oe.Pow(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Pow(at, bt, opts...) } if power, ok = at.Engine().(Power); ok { @@ -381,7 +381,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.PowScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.PowScalar(at, bt, leftTensor, opts...) } if power, ok = at.Engine().(Power); ok { @@ -405,7 +405,7 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.PowScalar(bt, at, false, opts...) } if power, ok = bt.Engine().(Power); ok { @@ -426,18 +426,18 @@ func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { // If the Unsafe flag is passed in, the data of the first tensor will be overwritten func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var moder Moder - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo if oe != nil { return oe.Mod(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.Mod(at, bt, opts...) } if moder, ok = at.Engine().(Moder); ok { @@ -462,7 +462,7 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.ModScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.ModScalar(at, bt, leftTensor, opts...) } if moder, ok = at.Engine().(Moder); ok { @@ -486,7 +486,7 @@ func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.ModScalar(bt, at, false, opts...) } if moder, ok = bt.Engine().(Moder); ok { @@ -526,41 +526,28 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { // FMA performs Y = A * X + Y. func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { - if xTensor, ok := x.(Tensor); ok { - if oe := a.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := xTensor.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMA(a, xTensor, y) - } + var fm FMAer - if e, ok := a.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := xTensor.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMA(a, xTensor, y) + if xTensor, ok := x.(Tensor); ok { + for _, T := range [3]Tensor{a, xTensor, y} { + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) + if ok { + return fm.FMA(ctx, a, xTensor, y) + } } } else { - if oe := a.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - if oe := y.standardEngine(); oe != nil { - return oe.FMAScalar(a, x, y) - } - - if e, ok := a.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) - } - if e, ok := y.Engine().(FMAer); ok { - return e.FMAScalar(a, x, y) + for _, T := range [2]Tensor{a, y} { + e := T.Engine() + ctx := ctxFromEngine(e) + fm, ok = e.(FMAer) + if ok { + return fm.FMAScalar(ctx, a, x, y) + } } } + return Mul(a, x, WithIncr(y)) } @@ -570,13 +557,66 @@ func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return } + ad, aok := a.(*Dense) + _, bok := b.(*Dense) + if aok && bok { + // fast path + return ad.MatMul(b, opts...) + } - switch at := a.(type) { - case *Dense: - bt := b.(*Dense) - return at.MatMul(bt, opts...) + // check that both are matrices + if !a.Shape().IsMatrix() || !b.Shape().IsMatrix() { + err = errors.Errorf("MatMul requires both operands to be matrices. Got t's shape: %v, other's shape: %v", a.Shape(), b.Shape()) + return } - panic("Unreachable") + + // checks that t is mxk matrix + var m, n, k int + m = a.Shape()[0] + k = a.Shape()[1] + n = b.Shape()[1] + + // check shape + if k != b.Shape()[0] { + err = errors.Errorf(shapeMismatch, a.Shape(), b.Shape()) + return + } + + // check whether retVal has the same size as the resulting matrix would be: mxn + expectedShape := Shape{m, n} + + eng := a.Engine() + mm, ok := eng.(MatMuler) + if !ok { + eng = b.Engine() + mm, ok = eng.(MatMuler) + } + if !ok { + return nil, errors.Errorf("Neither a or b have an engine that is a MatMuler. a: %T, b: %T", a.Engine(), b.Engine()) + } + + var reuse Tensor + fo := ParseFuncOpts(opts...) + defer returnOpOpt(fo) + ctx := fo.Context() + reuse = fo.Reuse() + if reuse == nil { + return nil, errors.Errorf("MatMul requires passing in of a reuse Tensor for now.") + } + + if err := checkFixShape(reuse, expectedShape); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") + } + if err = mm.MatMul(ctx, a, b, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "MatMul") + } + + incr := fo.Incr() + if incr != nil { + return Add(incr, reuse, UseUnsafe()) + } + return reuse, nil + } // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector @@ -595,7 +635,7 @@ func MatVecMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // Inner finds the inner products of two vector Tensors. Both arguments to the functions are eexpected to be vectors. -func Inner(a, b Tensor) (retVal interface{}, err error) { +func Inner(a, b Tensor, opts ...FuncOpt) (retVal interface{}, err error) { if a.Dtype() != b.Dtype() { err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) return diff --git a/api_arith_generated_test.go b/api_arith_generated_test.go index ce08af9..f26b7f5 100644 --- a/api_arith_generated_test.go +++ b/api_arith_generated_test.go @@ -1,17 +1,21 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" "testing" "testing/quick" + "time" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestAdd(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +41,7 @@ func TestSub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +68,7 @@ func TestMul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +95,7 @@ func TestDiv(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +122,7 @@ func TestPow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +148,7 @@ func TestAdd_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -163,7 +167,6 @@ func TestAdd_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -175,7 +178,7 @@ func TestSub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -195,7 +198,6 @@ func TestSub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -207,7 +209,7 @@ func TestMul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -226,7 +228,6 @@ func TestMul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -239,7 +240,7 @@ func TestDiv_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -259,7 +260,6 @@ func TestDiv_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -271,7 +271,7 @@ func TestPow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -290,7 +290,6 @@ func TestPow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -303,7 +302,7 @@ func TestAdd_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -335,7 +334,7 @@ func TestSub_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -368,7 +367,7 @@ func TestMul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -401,7 +400,7 @@ func TestDiv_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -434,7 +433,7 @@ func TestPow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -468,7 +467,7 @@ func TestAdd_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -497,7 +496,7 @@ func TestSub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -527,7 +526,7 @@ func TestMul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +556,7 @@ func TestDiv_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -587,7 +586,7 @@ func TestPow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -608,6 +607,204 @@ func TestPow_incr(t *testing.T) { t.Errorf("Identity test for Pow failed: %v", err) } +} +func TestAdd_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add failed: %v", err) + } + +} +func TestSub_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Sub", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub failed: %v", err) + } +} +func TestMul_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul failed: %v", err) + } + +} +func TestDiv_context(t *testing.T) { + inv := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := a.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Div", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div failed: %v", err) + } +} +func TestPow_context(t *testing.T) { + iden := func(a *Dense) bool { + b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) + b.Memset(identityVal(1, a.t)) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := a.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow failed: %v", err) + } + } func TestAddScalar(t *testing.T) { iden1 := func(q *Dense) bool { @@ -615,7 +812,7 @@ func TestAddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -641,7 +838,7 @@ func TestAddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -695,7 +892,7 @@ func TestSubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -721,7 +918,7 @@ func TestSubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -775,7 +972,7 @@ func TestMulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -801,7 +998,7 @@ func TestMulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -855,7 +1052,7 @@ func TestDivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -910,7 +1107,7 @@ func TestPowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -965,7 +1162,7 @@ func TestAddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -984,7 +1181,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -996,7 +1192,7 @@ func TestAddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1015,7 +1211,6 @@ func TestAddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1029,7 +1224,7 @@ func TestSubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1049,7 +1244,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1060,7 +1254,7 @@ func TestSubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1080,7 +1274,6 @@ func TestSubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1093,7 +1286,7 @@ func TestMulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1112,7 +1305,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1124,7 +1316,7 @@ func TestMulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1143,7 +1335,6 @@ func TestMulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1157,7 +1348,7 @@ func TestDivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1177,7 +1368,6 @@ func TestDivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1191,7 +1381,7 @@ func TestPowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1210,7 +1400,6 @@ func TestPowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1226,7 +1415,7 @@ func TestAddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1258,7 +1447,7 @@ func TestAddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1292,7 +1481,7 @@ func TestSubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1324,7 +1513,7 @@ func TestSubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1358,7 +1547,7 @@ func TestMulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1390,7 +1579,7 @@ func TestMulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1424,7 +1613,7 @@ func TestDivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1459,7 +1648,7 @@ func TestPowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1496,7 +1685,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1525,7 +1714,7 @@ func TestAddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1556,7 +1745,7 @@ func TestSubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1588,7 +1777,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1617,7 +1806,7 @@ func TestMulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1648,7 +1837,7 @@ func TestDivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1680,7 +1869,7 @@ func TestPowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1703,3 +1892,327 @@ func TestPowScalar_incr(t *testing.T) { } } +func TestAddScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Adder) + we = we || !ok + + ret, err := Add(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Add", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Add (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestSubScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Add(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (tensor as left, scalar as right) failed: %v", err) + } + + inv2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(0, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) + _, ok := q.Engine().(Suber) + we = we || !ok + + ret, err := Sub(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "SubSV", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Sub(b, ret, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Sub (scalar as left, tensor as right) failed: %v", err) + } +} +func TestMulScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (tensor as left, scalar as right) failed: %v", err) + } + + iden2 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Muler) + we = we || !ok + + ret, err := Mul(b, a, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Mul", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Mul (scalar as left, tensor as right) failed: %v", err) + } + +} +func TestDivScalar_context(t *testing.T) { + inv1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.Number, nilTC) + _, ok := q.Engine().(Diver) + we = we || !ok + + ret, err := Div(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "DivVS", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + ret, err = Mul(ret, b, UseUnsafe()) + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Inv test for Div (tensor as left, scalar as right) failed: %v", err) + } + +} +func TestPowScalar_context(t *testing.T) { + iden1 := func(q *Dense) bool { + a := q.Clone().(*Dense) + b := identityVal(1, q.t) + rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r*100)*time.Second) + } + defer cancel() + + correct := a.Clone().(*Dense) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) + _, ok := q.Engine().(Power) + we = we || !ok + + ret, err := Pow(a, b, WithContext(ctx)) + if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } + if err, retEarly := qcErrCheck(t, "Pow", a, b, we, err); retEarly { + if err != nil { + return false + } + return true + } + + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + return true + } + + if err := quick.Check(iden1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Identity test for Pow (tensor as left, scalar as right) failed: %v", err) + } + +} diff --git a/api_arith_test.go b/api_arith_test.go index 75a4838..3a3cf67 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) // This file contains the tests for API functions that aren't generated by genlib @@ -40,7 +41,7 @@ func TestFMA(t *testing.T) { WithEngine(q.Engine())(y) y2 := y.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok1 := q.Engine().(FMAer) _, ok2 := q.Engine().(Muler) _, ok3 := q.Engine().(Adder) @@ -55,7 +56,7 @@ func TestFMA(t *testing.T) { return true } - we, _ = willerr(a, numberTypes, nil) + we, _ = willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok wi, err := Mul(a, x, WithIncr(y2)) diff --git a/api_cmp.go b/api_cmp.go index ffb602d..b2ac050 100644 --- a/api_cmp.go +++ b/api_cmp.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // public API for comparison ops @@ -295,12 +297,26 @@ func ElNe(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { eleqer, ok = at.Engine().(ElEqer) switch bt := b.(type) { case Tensor: - if !ok { - if eleqer, ok = bt.Engine().(ElEqer); !ok { - return nil, errors.Errorf("Neither operands have engines that support ElEq") + if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison + if !ok { + if eleqer, ok = bt.Engine().(ElEqer); !ok { + return nil, errors.Errorf("Neither operands have engines that support ElEq") + } + } + return eleqer.ElNe(at, bt, opts...) + } else { + var leftTensor bool + if !bt.Shape().IsScalar() { + leftTensor = false + at, bt = bt, at + } else { + leftTensor = true + } + if !ok { + return nil, errors.Errorf("Engine does not support ElNE") } + return eleqer.NeScalar(at, bt, leftTensor, opts...) } - return eleqer.ElNe(at, bt, opts...) default: if !ok { return nil, errors.Errorf("Engine does not support ElEq") diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index 002587b..4a612d8 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -1,16 +1,18 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestGt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -68,7 +70,7 @@ func TestGt(t *testing.T) { } func TestGte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -126,7 +128,7 @@ func TestGte(t *testing.T) { } func TestLt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -184,7 +186,7 @@ func TestLt(t *testing.T) { } func TestLte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -242,7 +244,7 @@ func TestLte(t *testing.T) { } func TestEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -298,7 +300,7 @@ func TestEq(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -333,7 +335,7 @@ func TestEq(t *testing.T) { } func TestNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -368,11 +370,11 @@ func TestNe(t *testing.T) { } func TestGt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -428,11 +430,11 @@ func TestGt_assame(t *testing.T) { } func TestGte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -488,11 +490,11 @@ func TestGte_assame(t *testing.T) { } func TestLt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -548,11 +550,11 @@ func TestLt_assame(t *testing.T) { } func TestLte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -608,11 +610,11 @@ func TestLte_assame(t *testing.T) { } func TestEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -666,11 +668,11 @@ func TestEq_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -704,11 +706,11 @@ func TestEq_assame(t *testing.T) { } func TestNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -742,7 +744,7 @@ func TestNe_assame(t *testing.T) { } func TestGtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -798,7 +800,7 @@ func TestGtScalar(t *testing.T) { } func TestGteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -854,7 +856,7 @@ func TestGteScalar(t *testing.T) { } func TestLtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -910,7 +912,7 @@ func TestLtScalar(t *testing.T) { } func TestLteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -966,7 +968,7 @@ func TestLteScalar(t *testing.T) { } func TestEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1020,7 +1022,7 @@ func TestEqScalar(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1053,7 +1055,7 @@ func TestEqScalar(t *testing.T) { } func TestNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1086,11 +1088,11 @@ func TestNeScalar(t *testing.T) { } func TestGtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1144,11 +1146,11 @@ func TestGtScalar_assame(t *testing.T) { } func TestGteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1202,11 +1204,11 @@ func TestGteScalar_assame(t *testing.T) { } func TestLtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1260,11 +1262,11 @@ func TestLtScalar_assame(t *testing.T) { } func TestLteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1318,11 +1320,11 @@ func TestLteScalar_assame(t *testing.T) { } func TestEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1374,11 +1376,11 @@ func TestEqScalar_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1410,11 +1412,11 @@ func TestEqScalar_assame(t *testing.T) { } func TestNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() diff --git a/api_matop.go b/api_matop.go index a9797c3..4d98479 100644 --- a/api_matop.go +++ b/api_matop.go @@ -19,16 +19,20 @@ func Narrow(t Tensor, dim, start, length int) (View, error) { // Repeat repeats a Tensor along the axis and given the number of repeats. func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) { - if r, ok := t.Engine().(Repeater); ok { - return r.Repeat(t, axis, repeats...) + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.Repeat(ctx, t, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } // RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid. func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) { - if r, ok := t.Engine().(Repeater); ok { - return r.RepeatReuse(t, reuse, axis, repeats...) + e := t.Engine() + ctx := ctxFromEngine(e) + if r, ok := e.(Repeater); ok { + return r.RepeatReuse(ctx, t, reuse, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -38,6 +42,14 @@ func T(t Tensor, axes ...int) (retVal Tensor, err error) { switch tt := t.(type) { case *Dense: return tt.SafeT(axes...) + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, "T() off a DenseView") + } + return DenseView{ret}, nil + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -48,11 +60,20 @@ func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) { case *Dense: var ret *Dense if ret, err = tt.SafeT(axes...); err != nil { - return + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a *Dense") } ret.Transpose() retVal = ret return + case DenseView: + var ret *Dense + if ret, err = tt.SafeT(axes...); err != nil { + return nil, errors.Wrap(err, "Unable to perform .SafeT() on a DenseView") + } + ret.Transpose() + return DenseView{ret}, nil + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -65,15 +86,30 @@ func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) { } switch T := t.(type) { case *Dense: + // IF YOU UPDATE THIS, UPDATE THE DENSE VIEW CASE TOO. + ts := make([]*Dense, len(others)) + for i, o := range others { + ot, err := assertDense(o) + if err == nil { + ts[i] = ot + continue + } + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) + } + return T.Concat(axis, ts...) + case DenseView: ts := make([]*Dense, len(others)) for i, o := range others { - if ot, ok := o.(*Dense); ok { + ot, err := assertDense(o) + if err == nil { ts[i] = ot continue } - return nil, errors.Errorf("Expected all Tensors to be *Dense") + return nil, errors.Wrapf(err, "Expected all Tensors to be *Dense. Got %T instead", o) } return T.Concat(axis, ts...) + default: + return nil, nyierr(typeNYI, t) } panic("Unreachable") } @@ -96,7 +132,7 @@ func Copy(dst, src Tensor) error { copyDense(dt, st) return nil default: - return errors.Errorf("NYI for Copy %T", src) + return nyierr(typeNYI, src) } panic("Unreachable") } @@ -130,8 +166,10 @@ func Materialize(t Tensor) Tensor { } func Diag(t Tensor) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) if d, ok := t.Engine().(Diager); ok { - return d.Diag(t) + return d.Diag(ctx, t) } return nil, errors.Errorf("Unable to perform diagonalization of tensor ") } @@ -194,3 +232,10 @@ func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine()) } + +func Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + if sc, ok := a.Engine().(Scatterer); ok { + return sc.Scatter(a, indices, opts...) + } + return nil, errors.Errorf("Unable to scatter. Engine %T does not support Scattering.", a.Engine()) +} diff --git a/api_minmax.go b/api_minmax.go index 964df7d..e8a7de1 100644 --- a/api_minmax.go +++ b/api_minmax.go @@ -4,18 +4,18 @@ import "github.com/pkg/errors" func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var minbetweener MinBetweener - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.MinBetween(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetween(at, bt, opts...) } if minbetweener, ok = at.Engine().(MinBetweener); ok { @@ -40,7 +40,7 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MinBetweenScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetweenScalar(at, bt, leftTensor, opts...) } if minbetweener, ok = at.Engine().(MinBetweener); ok { @@ -64,7 +64,7 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MinBetweenScalar(bt, at, false, opts...) } if minbetweener, ok = bt.Engine().(MinBetweener); ok { @@ -80,18 +80,18 @@ func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { var maxbetweener MaxBetweener - var oe standardEngine + var oe StandardEngine var ok bool switch at := a.(type) { case Tensor: - oe = at.standardEngine() + oe, _ = at.Engine().(StandardEngine) switch bt := b.(type) { case Tensor: if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition if oe != nil { return oe.MaxBetween(at, bt, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetween(at, bt, opts...) } if maxbetweener, ok = at.Engine().(MaxBetweener); ok { @@ -116,7 +116,7 @@ func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { if oe != nil { return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) } - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetweenScalar(at, bt, leftTensor, opts...) } if maxbetweener, ok = at.Engine().(MaxBetweener); ok { @@ -140,7 +140,7 @@ func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { default: switch bt := b.(type) { case Tensor: - if oe = bt.standardEngine(); oe != nil { + if oe, ok = bt.Engine().(StandardEngine); ok { return oe.MaxBetweenScalar(bt, at, false, opts...) } if maxbetweener, ok = bt.Engine().(MaxBetweener); ok { diff --git a/api_reduction.go b/api_reduction.go index f146972..63c2257 100644 --- a/api_reduction.go +++ b/api_reduction.go @@ -2,26 +2,52 @@ package tensor import "github.com/pkg/errors" -// Sum sums a Tensor along the given axes +// Sum sums a Tensor along the given axes. func Sum(t Tensor, along ...int) (retVal Tensor, err error) { - if sumer, ok := t.Engine().(Sumer); ok { - return sumer.Sum(t, along...) + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Sumer); ok { + return sumer.Sum(ctx, t, along...) } return nil, errors.New("Engine does not support Sum()") } +// Prod sums a Tensor along the given axes. +func Prod(t Tensor, along ...int) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + return sumer.Prod(ctx, t, along...) + } + return nil, errors.New("Engine does not support Prod()") +} + +// Max finds the maximum value along the given axes. +func Max(t Tensor, along ...int) (retVal Tensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if maxer, ok := e.(Maxer); ok { + return maxer.Max(ctx, t, along...) + } + return nil, errors.New("Engine does not support Max()") +} + // Argmax finds the index of the max value along the axis provided func Argmax(t Tensor, axis int) (retVal Tensor, err error) { - if argmaxer, ok := t.Engine().(Argmaxer); ok { - return argmaxer.Argmax(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argmaxer, ok := e.(Argmaxer); ok { + return argmaxer.Argmax(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } // Argmin finds the index of the min value along the axis provided func Argmin(t Tensor, axis int) (retVal Tensor, err error) { - if argminer, ok := t.Engine().(Argminer); ok { - return argminer.Argmin(t, axis) + e := t.Engine() + ctx := ctxFromEngine(e) + if argminer, ok := e.(Argminer); ok { + return argminer.Argmin(ctx, t, axis) } return nil, errors.New("Engine does not support Argmax()") } diff --git a/api_unary.go b/api_unary.go index b1afe71..4c81e33 100644 --- a/api_unary.go +++ b/api_unary.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + func Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { e := a.Engine() if neger, ok := e.(Neger); ok { diff --git a/api_unary_generated_test.go b/api_unary_generated_test.go index 31a23f2..64813ae 100644 --- a/api_unary_generated_test.go +++ b/api_unary_generated_test.go @@ -1,17 +1,19 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestNeg(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -37,7 +39,7 @@ func TestSquare(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -48,7 +50,7 @@ func TestSquare(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -66,7 +68,7 @@ func TestCube(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -77,7 +79,7 @@ func TestCube(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -95,7 +97,7 @@ func TestExp(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -121,7 +123,7 @@ func TestLog(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -147,7 +149,7 @@ func TestSqrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -173,7 +175,7 @@ func TestCbrt(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -199,7 +201,7 @@ func TestNeg_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -218,7 +220,6 @@ func TestNeg_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -230,7 +231,7 @@ func TestSquare_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -241,7 +242,7 @@ func TestSquare_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -252,7 +253,6 @@ func TestSquare_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -264,7 +264,7 @@ func TestCube_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -275,7 +275,7 @@ func TestCube_unsafe(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -286,7 +286,6 @@ func TestCube_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -298,7 +297,7 @@ func TestExp_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -317,7 +316,6 @@ func TestExp_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -329,7 +327,7 @@ func TestLog_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -348,7 +346,6 @@ func TestLog_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -360,7 +357,7 @@ func TestSqrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -379,7 +376,6 @@ func TestSqrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -391,7 +387,7 @@ func TestCbrt_unsafe(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -410,7 +406,6 @@ func TestCbrt_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -423,7 +418,7 @@ func TestNeg_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -455,7 +450,7 @@ func TestSquare_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -466,7 +461,7 @@ func TestSquare_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } Sqrt(ret, UseUnsafe()) @@ -490,7 +485,7 @@ func TestCube_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -501,7 +496,7 @@ func TestCube_reuse(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } Cbrt(ret, UseUnsafe()) @@ -525,7 +520,7 @@ func TestExp_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -557,7 +552,7 @@ func TestLog_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -589,7 +584,7 @@ func TestSqrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -621,7 +616,7 @@ func TestCbrt_reuse(t *testing.T) { a := q.Clone().(*Dense) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok @@ -655,7 +650,7 @@ func TestNeg_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Neger) we = we || !ok @@ -688,7 +683,7 @@ func TestSquare_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Squarer) we = we || !ok @@ -699,7 +694,7 @@ func TestSquare_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatcmplxTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.FloatComplex); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -724,7 +719,7 @@ func TestCube_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Cuber) we = we || !ok @@ -735,7 +730,7 @@ func TestCube_incr(t *testing.T) { } return true } - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true // uninvertible due to type class implementation issues } if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()); err != nil { @@ -760,7 +755,7 @@ func TestExp_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Exper) we = we || !ok @@ -793,7 +788,7 @@ func TestLog_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Loger) we = we || !ok @@ -826,7 +821,7 @@ func TestSqrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, nil) + we, willFailEq := willerr(a, dtype.FloatComplex, nilTC) _, ok := q.Engine().(Sqrter) we = we || !ok @@ -859,7 +854,7 @@ func TestCbrt_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Cbrter) we = we || !ok diff --git a/api_unary_test.go b/api_unary_test.go index 9c735e6..25b68f7 100644 --- a/api_unary_test.go +++ b/api_unary_test.go @@ -1,14 +1,15 @@ package tensor import ( + "math" "math/rand" "testing" "testing/quick" "time" - "math" - "github.com/stretchr/testify/assert" "github.com/chewxy/math32" + "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) /* @@ -354,12 +355,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a) @@ -387,12 +388,12 @@ func TestInvSqrt(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, UseUnsafe()) @@ -426,12 +427,12 @@ func TestInvSqrt(t *testing.T) { reuse := q.Clone().(*Dense) reuse.Zero() correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithReuse(reuse)) @@ -466,12 +467,12 @@ func TestInvSqrt(t *testing.T) { incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(InvSqrter) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := InvSqrt(a, WithIncr(incr)) @@ -509,12 +510,12 @@ func TestInv(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a) @@ -541,12 +542,12 @@ func TestInv(t *testing.T) { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, UseUnsafe()) @@ -577,12 +578,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithReuse(reuse)) @@ -613,12 +614,12 @@ func TestInv(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Inver) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Inv(a, WithIncr(incr)) @@ -654,12 +655,12 @@ func TestLog10(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a) @@ -683,18 +684,17 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests for Log10 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, UseUnsafe()) @@ -720,19 +720,18 @@ func TestLog10(t *testing.T) { t.Errorf("Inv tests using unsafe for Log10 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithReuse(reuse)) @@ -764,12 +763,12 @@ func TestLog10(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log10er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log10(a, WithIncr(incr)) @@ -808,10 +807,10 @@ func TestAbs(t *testing.T) { correct := New(Of(Bool), WithShape(q.Shape().Clone()...)) correct.Memset(true) // we'll exclude everything other than ordtypes because complex numbers cannot be abs'd - if err := typeclassCheck(a.Dtype(), ordTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Ord); err != nil { return true } - we, willFailEq := willerr(a, signedTypes, nil) + we, willFailEq := willerr(a, dtype.Signed, nilTC) _, ok := q.Engine().(Abser) we = we || !ok @@ -836,19 +835,18 @@ func TestAbs(t *testing.T) { } } - func TestTanh(t *testing.T) { var r *rand.Rand // default invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a) @@ -885,12 +883,12 @@ func TestTanh(t *testing.T) { invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, UseUnsafe()) @@ -926,19 +924,18 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithReuse(reuse)) @@ -973,7 +970,6 @@ func TestTanh(t *testing.T) { t.Errorf("Inv tests using unsafe for Tanh failed: %v", err) } - // incr invFn = func(q *Dense) bool { a := q.Clone().(*Dense) @@ -981,12 +977,12 @@ func TestTanh(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Tanher) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Tanh(a, WithIncr(incr)) @@ -1033,12 +1029,12 @@ func TestLog2(t *testing.T) { invFn := func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a) @@ -1062,18 +1058,17 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests for Log2 failed: %v", err) } - // unsafe invFn = func(q *Dense) bool { a := q.Clone().(*Dense) b := q.Clone().(*Dense) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, UseUnsafe()) @@ -1099,19 +1094,18 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } - // reuse invFn = func(q *Dense) bool { a := q.Clone().(*Dense) correct := a.Clone().(*Dense) reuse := a.Clone().(*Dense) reuse.Zero() - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithReuse(reuse)) @@ -1143,12 +1137,12 @@ func TestLog2(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatTypes, nil) + we, willFailEq := willerr(a, dtype.Floats, nilTC) _, ok := q.Engine().(Log2er) we = we || !ok // we'll exclude everything other than floats - if err := typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return true } ret, err := Log2(a, WithIncr(incr)) @@ -1177,4 +1171,4 @@ func TestLog2(t *testing.T) { t.Errorf("Inv tests using unsafe for Log2 failed: %v", err) } -} \ No newline at end of file +} diff --git a/array.go b/array.go index ca948d6..e805405 100644 --- a/array.go +++ b/array.go @@ -7,17 +7,18 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // array is the underlying generic array. type array struct { - storage.Header // the header - the Go representation (a slice) - t Dtype // the element type + storage.Header // the header - the Go representation (a slice) + t dtype.Dtype // the element type } // makeArray makes an array. The memory allocation is handled by Go -func makeArray(t Dtype, length int) array { +func makeArray(t dtype.Dtype, length int) array { v := malloc(t, length) hdr := storage.Header{ Raw: v, @@ -41,7 +42,7 @@ func arrayFromSlice(x interface{}) array { Header: storage.Header{ Raw: storage.AsByteSlice(x), }, - t: Dtype{elT}, + t: dtype.Dtype{elT}, } } @@ -57,7 +58,7 @@ func (a *array) fromSlice(x interface{}) { } elT := xT.Elem() a.Raw = storage.AsByteSlice(x) - a.t = Dtype{elT} + a.t = dtype.Dtype{elT} } // fromSliceOrTensor populates the value from a slice or anything that can form an array @@ -206,13 +207,13 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ // malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory -func malloc(t Dtype, length int) []byte { +func malloc(t dtype.Dtype, length int) []byte { size := int(calcMemSize(t, length)) return make([]byte, size) } // calcMemSize calulates the memory size of an array (given its size) -func calcMemSize(dt Dtype, size int) int64 { +func calcMemSize(dt dtype.Dtype, size int) int64 { return int64(dt.Size()) * int64(size) } diff --git a/array_getset.go b/array_getset.go index c19fe68..69bcf95 100644 --- a/array_getset.go +++ b/array_getset.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Set sets the value of the underlying array at the index i. func (a *array) Set(i int, x interface{}) { switch a.t.Kind() { diff --git a/collections.go b/collections.go index 5f4d075..34e8284 100644 --- a/collections.go +++ b/collections.go @@ -1,30 +1,30 @@ -package tensor - -import "github.com/pkg/errors" - -func densesToTensors(a []*Dense) []Tensor { - retVal := make([]Tensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func densesToDenseTensors(a []*Dense) []DenseTensor { - retVal := make([]DenseTensor, len(a)) - for i, t := range a { - retVal[i] = t - } - return retVal -} - -func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { - retVal := make([]DenseTensor, len(a)) - var ok bool - for i, t := range a { - if retVal[i], ok = t.(DenseTensor); !ok { - return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) - } - } - return retVal, nil -} +package tensor + +import "github.com/pkg/errors" + +func densesToTensors(a []*Dense) []Tensor { + retVal := make([]Tensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func densesToDenseTensors(a []*Dense) []DenseTensor { + retVal := make([]DenseTensor, len(a)) + for i, t := range a { + retVal[i] = t + } + return retVal +} + +func tensorsToDenseTensors(a []Tensor) ([]DenseTensor, error) { + retVal := make([]DenseTensor, len(a)) + var ok bool + for i, t := range a { + if retVal[i], ok = t.(DenseTensor); !ok { + return nil, errors.Errorf("can only convert Tensors of the same type to DenseTensors. Trying to convert %T (#%d in slice)", t, i) + } + } + return retVal, nil +} diff --git a/consopt.go b/consopt.go index 25c157a..4b2de84 100644 --- a/consopt.go +++ b/consopt.go @@ -3,6 +3,7 @@ package tensor import ( "reflect" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -10,8 +11,8 @@ import ( type ConsOpt func(Tensor) // Of is a construction option for a Tensor. -func Of(a Dtype) ConsOpt { - Register(a) +func Of(a dtype.Dtype) ConsOpt { + dtype.Register(a) f := func(t Tensor) { switch tt := t.(type) { case *Dense: @@ -113,7 +114,7 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt { xv0 := xv.Index(0) // xv[0] xv0.Set(reflect.ValueOf(x)) tt.array.Header.Raw = storage.AsByteSlice(xv.Interface()) - tt.t = Dtype{xT} + tt.t = dtype.Dtype{xT} tt.mask = mask default: @@ -163,7 +164,7 @@ func WithEngine(e Engine) ConsOpt { } tt.oe = nil - if oe, ok := e.(standardEngine); ok { + if oe, ok := e.(StandardEngine); ok { tt.oe = oe } case *CS: @@ -234,7 +235,7 @@ func AsDenseDiag(backing interface{}) ConsOpt { sli := reflect.MakeSlice(xT, l*l, l*l) shape := Shape{l, l} - strides := shape.CalcStrides() + strides := CalcStrides(shape) for i := 0; i < l; i++ { idx, err := Ltoi(shape, strides, i, i) if err != nil { diff --git a/defaultengine.go b/defaultengine.go index d9138ae..5338391 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -2,22 +2,29 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" ) +// stdDenseEng is the default execution engine for dense tensor operations. +type stdDenseEng struct { + execution.E +} + // StdEng is the default execution engine that comes with the tensors. To use other execution engines, use the WithEngine construction option. type StdEng struct { - execution.E + stdDenseEng } // makeArray allocates a slice for the array -func (e StdEng) makeArray(arr *array, t Dtype, size int) { +func (e StdEng) makeArray(arr *array, t dtype.Dtype, size int) { arr.Raw = malloc(t, size) arr.t = t } -func (e StdEng) AllocAccessible() bool { return true } -func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } +func (e StdEng) AllocAccessible() bool { return true } +func (e StdEng) Alloc(size int64) (Memory, error) { return nil, noopError{} } + func (e StdEng) Free(mem Memory, size int64) error { return nil } func (e StdEng) Memset(mem Memory, val interface{}) error { if ms, ok := mem.(MemSetter); ok { diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 5632fa6..0bb1707 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -1,19 +1,27 @@ package tensor -import "github.com/pkg/errors" +import ( + "context" -func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +func (e StdEng) Argmax(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { switch tt := t.(type) { case DenseTensor: - return e.argmaxDenseTensor(tt, axis) + return e.argmaxDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t) + return nil, nyierr(typeNYI, t) } } -func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { +func (e StdEng) argmaxDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } @@ -89,18 +97,21 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e return New(WithShape(newShape...), WithBacking(indices)), nil } -func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) { +func (e StdEng) Argmin(ctx context.Context, t Tensor, axis int) (retVal Tensor, err error) { switch tt := t.(type) { case DenseTensor: - return e.argminDenseTensor(tt, axis) + return e.argminDenseTensor(ctx, tt, axis) default: - return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t) + return nil, nyierr(typeNYI, t) } } -func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { - if err = unaryCheck(t, ordTypes); err != nil { +func (e StdEng) argminDenseTensor(ctx context.Context, t DenseTensor, axis int) (retVal *Dense, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, opFail, "Argmin") } diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 131ea33..369475d 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -1,24 +1,32 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Add performs a + b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Add failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -74,15 +82,20 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Sub performs a - b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Sub failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -138,15 +151,20 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Mul performs a × b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mul failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -202,15 +220,20 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Div performs a ÷ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Div failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -266,15 +289,20 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Pow performs a ^ b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Pow failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -330,15 +358,20 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // Mod performs a % b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, numberTypes); err != nil { + if err = binaryCheck(a, b, dtype.Number); err != nil { + return nil, errors.Wrapf(err, "Mod failed") } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -394,7 +427,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err // AddScalar performs t + s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Add failed") } @@ -404,9 +437,13 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -497,7 +534,7 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // SubScalar performs t - s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Sub failed") } @@ -507,9 +544,13 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -600,7 +641,7 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // MulScalar performs t × s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mul failed") } @@ -610,9 +651,13 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -703,7 +748,7 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // DivScalar performs t ÷ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Div failed") } @@ -713,9 +758,13 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -806,7 +855,7 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // PowScalar performs t ^ s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Pow failed") } @@ -816,9 +865,13 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -909,7 +962,7 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // ModScalar performs t % s elementwise. The leftTensor parameter indicates if the tensor is the left operand. Only scalar types are accepted in s. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, numberTypes); err != nil { + if err = unaryCheck(t, dtype.Number); err != nil { return nil, errors.Wrapf(err, "Mod failed") } @@ -919,9 +972,13 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 1d6ff48..6a986d3 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -1,29 +1,37 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + // Gt performs a > b elementwise. Both a and b must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gt failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -90,18 +98,23 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Gte failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -168,18 +181,23 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lt failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -246,18 +264,23 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "Lte failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -324,18 +347,23 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Eq failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -402,18 +430,23 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er //UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, eqTypes); err != nil { + if err = binaryCheck(a, b, dtype.Eq); err != nil { + return nil, errors.Wrapf(err, "Ne failed") } var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -480,7 +513,7 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gt failed") } @@ -490,12 +523,16 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -602,7 +639,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Gte failed") } @@ -612,12 +649,16 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -724,7 +765,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lt failed") } @@ -734,12 +775,16 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -846,7 +891,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO // UseUnsafe() will ensure that the same type is returned. // Tensors used in WithReuse has to have the same Dtype as the return value's Dtype. func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "Lte failed") } @@ -856,12 +901,16 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -964,7 +1013,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Eq failed") } @@ -974,12 +1023,16 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -1082,7 +1135,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, eqTypes); err != nil { + if err = unaryCheck(t, dtype.Eq); err != nil { return nil, errors.Wrapf(err, "Ne failed") } @@ -1092,12 +1145,16 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index d9a16aa..ca114fd 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -1,21 +1,27 @@ package tensor import ( + "context" "reflect" "github.com/pkg/errors" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) -// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error -func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { +// Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error +func (e StdEng) Trace(ctx context.Context, t Tensor) (retVal interface{}, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + if t.Dims() != 2 { err = errors.Errorf(dimMismatch, 2, t.Dims()) return } - if err = typeclassCheck(t.Dtype(), numberTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Number); err != nil { return nil, errors.Wrap(err, "Trace") } @@ -118,6 +124,12 @@ func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { } func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + ctx := fo.Context() + if err = handleCtx(ctx); err != nil { + return nil, err + } + if _, ok := x.(DenseTensor); !ok { err = errors.Errorf("Engine only supports working on x that is a DenseTensor. Got %T instead", x) return @@ -138,8 +150,6 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } - fo := ParseFuncOpts(opts...) - var reuse, incr DenseTensor if reuse, err = getFloatDenseTensor(fo.reuse); err != nil { err = errors.Wrapf(err, opFail, "Dot - reuse") @@ -211,7 +221,7 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { return } var ret interface{} - if ret, err = e.Inner(a, b); err != nil { + if ret, err = e.Inner(ctx, a, b); err != nil { return nil, errors.Wrapf(err, opFail, "Dot") } return New(FromScalar(ret)), nil @@ -308,7 +318,11 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } // TODO: make it take DenseTensor -func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { +func (e StdEng) SVD(ctx context.Context, a Tensor, uv, full bool) (s, u, v Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, nil, nil, err + } + var t *Dense var ok bool if err = e.checkAccessible(a); err != nil { @@ -317,7 +331,7 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { if t, ok = a.(*Dense); !ok { return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) } - if err = typeclassCheck(a.Dtype(), floatTypes); err != nil { + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Floats); err != nil { return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype()) } @@ -371,7 +385,11 @@ func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { // Inner is a thin layer over BLAS's D/Sdot. // It returns a scalar value, wrapped in an interface{}, which is not quite nice. -func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { +func (e StdEng) Inner(ctx context.Context, a, b Tensor) (retVal interface{}, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + var ad, bd DenseTensor if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil { return nil, errors.Wrapf(err, opFail, "StdEng.Inner") @@ -398,7 +416,11 @@ func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { // Because DGEMV computes: // y = αA * x + βy // we set beta to 0, so we don't have to manually zero out the reused/retval tensor data -func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatVecMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { @@ -460,7 +482,7 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { var alpha, beta complex128 = complex(1, 0), complex(0, 0) whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) default: - return errors.Errorf(typeNYI, "matVecMul", bd.Data()) + return nyierr(typeNYI, bd.Data()) } return nil @@ -470,7 +492,11 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { // DGEMM computes: // C = αA * B + βC // To prevent needless zeroing out of the slice, we just set β to 0 -func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { +func (e StdEng) MatMul(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err := handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { @@ -572,13 +598,18 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) } default: - return errors.Errorf(typeNYI, "matMul", ad.Data()) + return nyierr(typeNYI, ad.Data()) + } return } // Outer is a thin wrapper over S/Dger -func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { +func (e StdEng) Outer(ctx context.Context, a, b, prealloc Tensor) (err error) { + if err = handleCtx(ctx); err != nil { + return err + } + // check all are DenseTensors var ad, bd, pd DenseTensor if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { @@ -606,7 +637,7 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { return err } - if err = e.MatMul(a, b, prealloc); err != nil { + if err = e.MatMul(ctx, a, b, prealloc); err != nil { return err } @@ -644,7 +675,7 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { var alpha complex128 = complex(1, 0) whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda) default: - return errors.Errorf(typeNYI, "outer", b.Data()) + return nyierr(typeNYI, b.Data()) } return nil } diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 9c1443c..2c7ccd4 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -1,6 +1,7 @@ package tensor import ( + "context" "reflect" "sort" @@ -11,27 +12,30 @@ import ( ) func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "Failed Map()") return } + if _, ok := a.(DenseTensor); !ok { + return nil, errors.Errorf("StdEng's Map method only supports dense tensors for now. Please put in a Pull Request to support other forms of Tensors. The file is: defaultengine_mapreduce.go") + } var reuse DenseTensor var safe, _, incr bool - if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } switch { case safe && reuse == nil: // create reuse - if v, ok := a.(View); ok { - if v.IsMaterializable() { - reuse = v.Materialize().(DenseTensor) - } else { - reuse = v.Clone().(DenseTensor) - } + if v, ok := a.(View); ok && v.IsMaterializable() { + reuse = v.Materialize().(DenseTensor) } else { - reuse = New(Of(a.Dtype()), WithShape(a.Shape().Clone()...)) + reuse = a.Clone().(DenseTensor) } case reuse != nil: if !reuse.IsNativelyAccessible() { @@ -75,7 +79,7 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e // SET RETVAL switch { case reuse != nil: - if err = reuseCheckShape(reuse, a.Shape()); err != nil { + if err = checkFixShape(reuse, a.Shape()); err != nil { err = errors.Wrapf(err, "Reuse shape check failed") return } @@ -177,7 +181,11 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, return } -func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Sum(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() @@ -185,7 +193,11 @@ func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) { return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...) } -func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Min(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() @@ -193,7 +205,11 @@ func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) { return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...) } -func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) { +func (e StdEng) Max(ctx context.Context, a Tensor, along ...int) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + a2 := a if v, ok := a.(View); ok && v.IsMaterializable() { a2 = v.Materialize() @@ -255,17 +271,21 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe return } - if err = unaryCheck(a, nil); err != nil { + if err = unaryCheck(a, nilTC); err != nil { err = errors.Wrap(err, "prepReduce failed") return } // FUNC PREP var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { err = errors.Wrap(err, "Unable to prep unary tensor") return } + if err = handleCtx(ctx); err != nil { + return + } var newShape Shape for i, s := range a.Shape() { diff --git a/defaultengine_matop_gatherscatter.go b/defaultengine_matop_gatherscatter.go new file mode 100644 index 0000000..b28d4ea --- /dev/null +++ b/defaultengine_matop_gatherscatter.go @@ -0,0 +1,200 @@ +package tensor + +import ( + "sync" + + "github.com/pkg/errors" +) + +func (e StdEng) Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) { + fo := ParseFuncOpts(opts...) + reuse := fo.Reuse() + + maxT, err := Max(indices) + if err != nil { + return nil, errors.Wrapf(err, "Cannot find the max of the indices") + } + max, ok := maxT.Data().(int) + if !ok { + return nil, errors.Errorf("Indices must be of ints. Got %v of %T instead", maxT.Data(), maxT.Data()) + } + + // expected shape + shp := indices.Shape().Clone() + shp[len(shp)-1] = max + 1 + + switch { + case reuse == nil && fo.Safe(): + // create reuse + reuse = New(WithShape(shp...), Of(a.Dtype())) + case reuse == nil && !fo.Safe(): + // check shape of `a` - the last dim of a must be at least max+1 + if a.Shape()[a.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter - the last dim of `a` %v must be at least %v, which is the maximum value of the indices + 1", a.Shape(), max+1) + } + reuse = a + case reuse != nil: + // check shape of `reuse` - last dim of `reuse` must at least be as large as max+1 + if reuse.Shape()[reuse.Dims()-1] < max+1 { + return nil, errors.Errorf("Cannot Scatter. The last dim of `reuse` %v must be at least %v, which is the maximum value off the indices + 1", reuse.Shape(), max+1) + } + } + + oldShape := a.Shape().Clone() + oldIndicesShape := a.Shape().Clone() + reuseOldShape := reuse.Shape().Clone() + defer func() { a.Reshape(oldShape...); indices.Reshape(oldIndicesShape...); reuse.Reshape(reuseOldShape...) }() + + switch { + case indices.Shape().IsVectorLike(): + idx := indices.Data().([]int) + _ = idx + // TODO + default: + // THIS IS ROW MAJOR ONLY + // THIS IS DENSE TENSOR ONLY + + a := a.(DenseTensor) + indices := indices.(DenseTensor) + reuse := reuse.(DenseTensor) + + // reshape everything into a matrix + a.Reshape(asMat(a.Shape(), a.Dims()-1, true)...) + indices.Reshape(asMat(indices.Shape(), indices.Dims()-1, true)...) + reuse.Reshape(asMat(reuse.Shape(), reuse.Dims()-1, true)...) + + // check that indices' shape[0] is <= a.Shape[0] + if indices.Shape()[0] > a.Shape()[0] { + // something is wrong + return nil, errors.Errorf("Cannot scatter") + } + + // now they are all matrices, we can iterate thru them + var ps []iteratorPair + for i := 0; i < indices.Shape()[0]; i++ { + ait := AxialIteratorFromDense(a, 0, i, true) + iit := AxialIteratorFromDense(indices, 0, i, true) + + ps = append(ps, iteratorPair{ait, iit, i}) + } + + errChan := make(chan error, len(ps)) + var wg sync.WaitGroup + for i := range ps { + wg.Add(1) + // note: be careful not to use `for i, p := range ps` + // and then use `go p.coiter`. + // This is because `p` is would not be captured by `go`, + // thus every `p` would be `ps[len(ps)-1]`. + go ps[i].coiter(a, indices, reuse, errChan, &wg) + } + wg.Wait() + close(errChan) + err = <-errChan // maybe get ALL the errors from errChan? + return reuse, err + + } + + panic("unreachable") +} + +type iteratorPair struct { + a *AxialIterator + idx *AxialIterator + axis int +} + +func (it *iteratorPair) coiter(a, indices, reuse DenseTensor, errChan chan error, wg *sync.WaitGroup) { + defer wg.Done() + ii, err := it.idx.Start() + if err != nil { + if err = handleNoOp(err); err != nil { + errChan <- err + } + return + } + + iData := indices.Data().([]int) + retStride := reuse.Strides()[0] + switch { + case a.Dtype() == Float64 && reuse.Dtype() == Float64: + aData := a.Data().([]float64) + rData := reuse.Data().([]float64) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + case a.Dtype() == Float32 && reuse.Dtype() == Float32: + aData := a.Data().([]float32) + rData := reuse.Data().([]float32) + + var ai, ii int + if ai, err = it.a.Start(); err != nil { + goto reterr + } + if ii, err = it.idx.Start(); err != nil { + goto reterr + } + for { + + idx := iData[ii] + v := aData[ai] + + rData[it.axis*retStride+idx] = v + + if it.a.Done() || it.idx.Done() { + break + } + if ai, err = it.a.Next(); err != nil { + break + } + if ii, err = it.idx.Next(); err != nil { + break + } + } + + default: + + // generic + for ai, err := it.a.Start(); err == nil; ai, err = it.a.Next() { + if it.idx.Done() { + break + } + idx := iData[ii] + v := a.arrPtr().Get(ai) + reuse.Set(it.axis*retStride+idx, v) + + if ii, err = it.idx.Next(); err != nil { + break + } + } + } + +reterr: + if err = handleNoOp(err); err != nil { + errChan <- err + return + } + +} diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 0ab392a..56641d3 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,8 +1,13 @@ package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/shapes" ) var ( @@ -14,7 +19,11 @@ type fastcopier interface { } // Repeat ... -func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { +func (e StdEng) Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) @@ -24,12 +33,16 @@ func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{})) return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } // RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t. -func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { +func (e StdEng) RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats) @@ -46,14 +59,16 @@ func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (T } return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) { - if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil { + var newShapelike shapes.Shapelike + if newShapelike, newRepeats, size, err = t.Shape().Repeat(shapes.Axis(axis), repeats...); err != nil { return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape") } + newShape = newShapelike.(Shape) newAxis = axis if axis == AllAxes { newAxis = 0 @@ -198,7 +213,6 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, } // we can straightaway broadcast - continue } @@ -228,7 +242,11 @@ func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, } // Concat tensors -func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { +func (e StdEng) Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + switch tt := t.(type) { case DenseTensor: var denses []DenseTensor @@ -237,7 +255,7 @@ func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err } return e.denseConcat(tt, axis, denses) default: - return nil, errors.Errorf("NYI") + return nil, nyierr(typeNYI, t) } } @@ -252,10 +270,11 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen } } - var newShape Shape - if newShape, err = a.Shape().Concat(axis, ss...); err != nil { + var newShapelike shapes.Shapelike + if newShapelike, err = a.Shape().Concat(shapes.Axis(axis), shapes.ShapesToShapelikes(ss)...); err != nil { return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation") } + newShape := newShapelike.(Shape) retVal := recycledDense(a.Dtype(), newShape, WithEngine(e)) if isMasked { @@ -359,7 +378,11 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen } // Diag ... -func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { +func (e StdEng) Diag(ctx context.Context, t Tensor) (retVal Tensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + a, ok := t.(DenseTensor) if !ok { return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") @@ -370,7 +393,7 @@ func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { return } - if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { + if err = dtype.TypeClassCheck(a.Dtype(), dtype.Number); err != nil { return nil, errors.Wrap(err, "Diagonal") } @@ -412,7 +435,7 @@ func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { bdata[i] = adata[i*stride] } default: - return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t) + return nil, nyierr(typeNYI, "Arbitrary-sized .Diag()", t) } return b, nil } diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 879ca28..33c148d 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -1,12 +1,23 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) // This file contains code for the execution engine to stack tensors -func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { +var ( + // _ Stacker = StdEng{} + _ DenseStacker = StdEng{} +) + +func (e StdEng) StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + opdims := t.Dims() if axis >= opdims+1 { err = errors.Errorf(dimMismatch, opdims+1, axis) @@ -28,9 +39,9 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV info := t.Info() var newStrides []int if info.o.IsColMajor() { - newStrides = newShape.CalcStridesColMajor() + newStrides = CalcStridesColMajor(newShape) } else { - newStrides = newShape.CalcStrides() + newStrides = CalcStrides(newShape) } ap := MakeAP(newShape, newStrides, info.o, info.Δ) diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index cef220e..7ca63cb 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -3,10 +3,16 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) -func (e StdEng) Transpose(a Tensor, expStrides []int) error { +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + if !a.IsNativelyAccessible() { return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") } diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 8627927..8d1d5f3 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -3,10 +3,16 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) -func (e StdEng) Transpose(a Tensor, expStrides []int) error { +func (e StdEng) Transpose(ctx context.Context, a Tensor, expStrides []int) error { + if err := handleCtx(ctx); err != nil { + return err + } + if !a.IsNativelyAccessible() { return errors.Errorf("Cannot Transpose() on non-natively accessible tensor") } diff --git a/defaultengine_minmax.go b/defaultengine_minmax.go index 56ac432..a16cbf0 100644 --- a/defaultengine_minmax.go +++ b/defaultengine_minmax.go @@ -1,27 +1,35 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + var ( _ MinBetweener = StdEng{} _ MaxBetweener = StdEng{} ) func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MinBetween failed") } var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -75,15 +83,20 @@ func (e StdEng) MinBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, } func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = binaryCheck(a, b, ordTypes); err != nil { + if err = binaryCheck(a, b, dtype.Ord); err != nil { + return nil, errors.Wrapf(err, "MaxBetween failed") } var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } typ := a.Dtype().Type var dataA, dataB, dataReuse *storage.Header var ait, bit, iit Iterator @@ -137,7 +150,7 @@ func (e StdEng) MaxBetween(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, } func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "MinBetween failed") } @@ -147,9 +160,13 @@ func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator @@ -243,7 +260,7 @@ func (e StdEng) MinBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts } func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(t, ordTypes); err != nil { + if err = unaryCheck(t, dtype.Ord); err != nil { return nil, errors.Wrapf(err, "MaxBetween failed") } @@ -253,9 +270,13 @@ func (e StdEng) MaxBetweenScalar(t Tensor, s interface{}, leftTensor bool, opts var reuse DenseTensor var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } a := t typ := t.Dtype().Type var ait, bit, iit Iterator diff --git a/defaultengine_misc.go b/defaultengine_misc.go index bb70e57..6a0c570 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -1,81 +1,94 @@ -package tensor - -import ( - "github.com/pkg/errors" - "gorgonia.org/tensor/internal/storage" -) - -func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, nonComplexNumberTypes); err != nil { - return nil, errors.Wrap(err, "Clamp failed") - } - - var reuse DenseTensor - var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { - return nil, errors.Wrap(err, "Unable to handle funcOpts") - } - - typ := a.Dtype().Type - var ait, rit Iterator - var dataA, dataReuse *storage.Header - var useIter bool - - if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { - return nil, errors.Wrapf(err, opFail, "StdEng.Neg") - } - - if useIter { - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - ait.Reset() - err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) - retVal = reuse - case toReuse: - storage.CopyIter(typ, dataReuse, dataA, rit, ait) - rit.Reset() - err = e.E.ClampIter(typ, dataReuse, rit, min, max) - retVal = reuse - case !safe: - err = e.E.ClampIter(typ, dataA, ait, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) - retVal = cloned - } - return - } - switch { - case incr: - cloned := a.Clone().(Tensor) - if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { - return nil, errors.Wrapf(err, "Unable to perform Clamp") - } - err = e.E.Add(typ, dataReuse, cloned.hdr()) - retVal = reuse - case toReuse: - storage.Copy(typ, dataReuse, dataA) - err = e.E.Clamp(typ, dataReuse, min, max) - retVal = reuse - case !safe: - err = e.E.Clamp(typ, dataA, min, max) - retVal = a - default: - cloned := a.Clone().(Tensor) - err = e.E.Clamp(typ, cloned.hdr(), min, max) - retVal = cloned - } - return -} - -func (e StdEng) FMA(a, x, y Tensor) (Tensor, error) { - return e.Mul(a, x, WithIncr(y)) -} -func (e StdEng) FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) { - return e.MulScalar(a, x, true, WithIncr(y)) -} +package tensor + +import ( + "context" + + "github.com/pkg/errors" + "gorgonia.org/dtype" + "gorgonia.org/tensor/internal/storage" +) + +func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal Tensor, err error) { + if err = unaryCheck(a, dtype.NonComplexNumber); err != nil { + return nil, errors.Wrap(err, "Clamp failed") + } + + var reuse DenseTensor + var safe, toReuse, incr bool + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { + return nil, errors.Wrap(err, "Unable to handle funcOpts") + } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap.s + } + + typ := a.Dtype().Type + var ait, rit Iterator + var dataA, dataReuse *storage.Header + var useIter bool + + if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil { + return nil, errors.Wrapf(err, opFail, "StdEng.Neg") + } + + if useIter { + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + ait.Reset() + err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait) + retVal = reuse + case toReuse: + storage.CopyIter(typ, dataReuse, dataA, rit, ait) + rit.Reset() + err = e.E.ClampIter(typ, dataReuse, rit, min, max) + retVal = reuse + case !safe: + err = e.E.ClampIter(typ, dataA, ait, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.ClampIter(typ, cloned.hdr(), ait, min, max) + retVal = cloned + } + return + } + switch { + case incr: + cloned := a.Clone().(Tensor) + if err = e.E.Clamp(typ, cloned.hdr(), min, max); err != nil { + return nil, errors.Wrapf(err, "Unable to perform Clamp") + } + err = e.E.Add(typ, dataReuse, cloned.hdr()) + retVal = reuse + case toReuse: + storage.Copy(typ, dataReuse, dataA) + err = e.E.Clamp(typ, dataReuse, min, max) + retVal = reuse + case !safe: + err = e.E.Clamp(typ, dataA, min, max) + retVal = a + default: + cloned := a.Clone().(Tensor) + err = e.E.Clamp(typ, cloned.hdr(), min, max) + retVal = cloned + } + return +} + +func (e StdEng) FMA(ctx context.Context, a, x, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + return e.Mul(a, x, WithIncr(y)) +} +func (e StdEng) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (Tensor, error) { + if err := handleCtx(ctx); err != nil { + return nil, err + } + return e.MulScalar(a, x, true, WithIncr(y)) +} diff --git a/defaultengine_prep.go b/defaultengine_prep.go index 261367a..6f6927a 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -1,15 +1,17 @@ package tensor import ( + "context" "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" - // "log" ) -func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType dtype.Dtype, o DataOrder, strict bool, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() @@ -61,7 +63,16 @@ func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opt return } -func binaryCheck(a, b Tensor, tc *typeclass) (err error) { +func handleCtx(ctx context.Context) error { + select { + case <-ctx.Done(): + return noopError{} + default: + } + return nil +} + +func binaryCheck(a, b Tensor, tc dtype.TypeClass) (err error) { // check if the tensors are accessible if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) @@ -73,11 +84,11 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { at := a.Dtype() bt := b.Dtype() - if tc != nil { - if err = typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err = dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } - if err = typeclassCheck(bt, tc); err != nil { + if err = dtype.TypeClassCheck(bt, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "b") } } @@ -91,13 +102,13 @@ func binaryCheck(a, b Tensor, tc *typeclass) (err error) { return nil } -func unaryCheck(a Tensor, tc *typeclass) error { +func unaryCheck(a Tensor, tc dtype.TypeClass) error { if !a.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, a) } at := a.Dtype() - if tc != nil { - if err := typeclassCheck(at, tc); err != nil { + if tc != nilTC { + if err := dtype.TypeClassCheck(at, tc); err != nil { return errors.Wrapf(err, typeclassMismatch, "a") } } @@ -106,13 +117,13 @@ func unaryCheck(a Tensor, tc *typeclass) error { // scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor. func scalarDtypeCheck(a Tensor, b interface{}) error { - var dt Dtype + var dt dtype.Dtype switch bt := b.(type) { case Dtyper: dt = bt.Dtype() default: t := reflect.TypeOf(b) - dt = Dtype{t} + dt = dtype.Dtype{t} } if a.Dtype() != dt { diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index e0564e6..58e3e42 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" @@ -18,7 +20,6 @@ func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (r if indices.Dtype() != Int { return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) } - // if b is a scalar, then use Slice if a.Shape().IsScalarEquiv() { slices := make([]Slice, a.Shape().Dims()) @@ -31,9 +32,13 @@ func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (r var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(a.Dtype())) @@ -105,7 +110,6 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da for o := 0; o < outer; o++ { end := start + axStride dstEnd := dstStart + retStride - storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) start += prevStride @@ -157,9 +161,13 @@ func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts var reuse DenseTensor var _, toReuse, _ bool - if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // will be noopError{}, no need to wrap. + } if !toReuse && reuse == nil { // create reuse reuse = New(WithShape(expectedShape...), Of(input.Dtype())) diff --git a/defaultengine_softmax.go b/defaultengine_softmax.go index ffc5a06..8a7dc3e 100644 --- a/defaultengine_softmax.go +++ b/defaultengine_softmax.go @@ -1,6 +1,7 @@ package tensor import ( + "context" "fmt" "math" "sync" @@ -30,9 +31,14 @@ func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(x.Dtype())) @@ -74,9 +80,14 @@ func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(output.Dtype())) @@ -112,9 +123,14 @@ func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(x.Dtype())) @@ -157,9 +173,14 @@ func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (ret var reuse DenseTensor var safe, toReuse, _ bool - if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if safe || !toReuse && reuse == nil && safe { // create reuse reuse = New(WithShape(expectedShape...), Of(output.Dtype())) diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 986e246..8efe589 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -1,22 +1,29 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Neg failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -76,15 +83,19 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Inv failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -144,15 +155,19 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Square failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -212,15 +227,19 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, numberTypes); err != nil { + if err = unaryCheck(a, dtype.Number); err != nil { err = errors.Wrapf(err, "Cube failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -280,15 +299,19 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Exp failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -348,15 +371,19 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Tanh failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -416,15 +443,19 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -484,15 +515,19 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Log2 failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -552,15 +587,19 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Log10 failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -620,15 +659,19 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatcmplxTypes); err != nil { + if err = unaryCheck(a, dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "Sqrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -688,15 +731,19 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "Cbrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -756,15 +803,19 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, floatTypes); err != nil { + if err = unaryCheck(a, dtype.Floats); err != nil { err = errors.Wrapf(err, "InvSqrt failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -824,15 +875,19 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Abs failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator @@ -892,15 +947,19 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { - if err = unaryCheck(a, signedTypes); err != nil { + if err = unaryCheck(a, dtype.Signed); err != nil { err = errors.Wrapf(err, "Sign failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index 45859a4..2b78aad 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -1,16 +1,20 @@ package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" "gorgonia.org/vecf32" ) -func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -112,7 +116,7 @@ type Float32Engine struct { } // makeArray allocates a slice for the array -func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float32Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float32 { panic("Float32Engine only creates float32s") } @@ -123,7 +127,11 @@ func (e Float32Engine) makeArray(arr *array, t Dtype, size int) { arr.t = t } -func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -146,7 +154,11 @@ func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float32Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -178,9 +190,14 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -209,14 +226,21 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf32.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf32.Add(ret.hdr().Float32s(), dataB) + ret, ok := a.Clone().(float32ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float32 data from `a`, of %T", a) + } + vecf32.Add(ret.Float32s(), dataB) retVal = ret.(Tensor) } return } -func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) { +func (e Float32Engine) Inner(ctx context.Context, a, b Tensor) (retVal float32, err error) { + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float32 var AD, BD *Dense var ok bool diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 21bba43..85c59b2 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -1,16 +1,19 @@ package tensor import ( + "context" + "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" "gorgonia.org/tensor/internal/storage" "gorgonia.org/vecf64" ) -func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (ctx context.Context, reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) - + ctx = fo.Context() reuseT, incr := fo.IncrReuse() safe = fo.Safe() toReuse = reuseT != nil @@ -112,7 +115,7 @@ type Float64Engine struct { } // makeArray allocates a slice for the array -func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { +func (e Float64Engine) makeArray(arr *array, t dtype.Dtype, size int) { if t != Float64 { panic("Float64Engine only creates float64s") } @@ -120,7 +123,11 @@ func (e Float64Engine) makeArray(arr *array, t Dtype, size int) { arr.t = t } -func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMA(ctx context.Context, a, x, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } + reuse := y if err = e.checkThree(a, x, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -143,7 +150,10 @@ func (e Float64Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) { return } -func (e Float64Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { +func (e Float64Engine) FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { + if err = handleCtx(ctx); err != nil { + return nil, err + } reuse := y if err = e.checkTwo(a, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") @@ -175,9 +185,14 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } + if err = e.checkThree(a, b, reuse); err != nil { return nil, errors.Wrap(err, "Failed checks") } @@ -206,14 +221,21 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, vecf64.Add(dataA, dataB) retVal = a default: - ret := a.Clone().(headerer) - vecf64.Add(ret.hdr().Float64s(), dataB) + ret, ok := a.Clone().(float64ser) + if !ok { + return nil, errors.Errorf("Unable to get the Float64 data from `a`, of %T", a) + } + vecf64.Add(ret.Float64s(), dataB) retVal = ret.(Tensor) } return } -func (e Float64Engine) Inner(a, b Tensor) (retVal float64, err error) { +func (e Float64Engine) Inner(ctx context.Context, a, b Tensor, opts ...FuncOpt) (retVal float64, err error) { + if err = handleCtx(ctx); err != nil { + return 0, err // this err will be noopError{}, no need to wrap. + } + var A, B []float64 var AD, BD *Dense var ok bool diff --git a/dense.go b/dense.go index 0e8e684..1623eee 100644 --- a/dense.go +++ b/dense.go @@ -6,6 +6,7 @@ import ( "unsafe" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -20,7 +21,7 @@ type Dense struct { flag MemoryFlag e Engine // execution engine for the *Dense - oe standardEngine // optimized engine + oe StandardEngine // optimized engine // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes old AP @@ -34,11 +35,11 @@ type Dense struct { } // NewDense creates a new *Dense. It tries its best to get from the tensor pool. -func NewDense(dt Dtype, shape Shape, opts ...ConsOpt) *Dense { +func NewDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) *Dense { return recycledDense(dt, shape, opts...) } -func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { +func recycledDense(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { retVal = recycledDenseNoFix(dt, shape, opts...) retVal.fix() if err := retVal.sanity(); err != nil { @@ -47,7 +48,7 @@ func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { return } -func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { +func recycledDenseNoFix(dt dtype.Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { // size := shape.TotalSize() //if shape.IsScalar() { // size = 1 @@ -83,7 +84,9 @@ func (t *Dense) makeArray(size int) { case arrayMaker: te.makeArray(&t.array, t.t, size) return + case StandardEngine2: default: + } memsize := calcMemSize(t.t, size) @@ -100,7 +103,7 @@ func (t *Dense) makeArray(size int) { func (t *Dense) Info() *AP { return &t.AP } // Dtype returns the data type of the *Dense tensor. -func (t *Dense) Dtype() Dtype { return t.t } +func (t *Dense) Dtype() dtype.Dtype { return t.t } // Data returns the underlying array. If the *Dense represents a scalar value, the scalar value is returned instead func (t *Dense) Data() interface{} { @@ -138,7 +141,7 @@ func (t *Dense) Reshape(dims ...int) error { } if t.viewOf != 0 && t.o.IsNotContiguous() { - return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") + return nyierr(methodNYI, "non-contiguous views") } if !t.old.IsZero() { @@ -180,16 +183,6 @@ func (t *Dense) ScalarValue() interface{} { return t.Get(0) } -// IsView indicates if the Tensor is a view of another (typically from slicing) -func (t *Dense) IsView() bool { - return t.viewOf != 0 -} - -// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing -func (t *Dense) IsMaterializable() bool { - return t.viewOf != 0 || !t.old.IsZero() -} - // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) func (t *Dense) IsManuallyManaged() bool { return t.flag.manuallyManaged() } @@ -213,7 +206,6 @@ func (t *Dense) Clone() interface{} { } copyDense(retVal, t) retVal.lock() - return retVal } panic("Unreachable: No engine") @@ -282,11 +274,32 @@ func (t *Dense) fix() { t.e = StdEng{} } - if oe, ok := t.e.(standardEngine); ok { + if oe, ok := t.e.(StandardEngine); ok { t.oe = oe } + _, isNonStdEng := t.e.(NonStdEngine) + switch { + case isNonStdEng && t.Shape() != nil: + // if there is already data in the array, we should back it up now + raw := t.array.Header.Raw + + // make the array + size := t.Shape().TotalSize() + if t.Shape().IsScalar() { + size = 1 + } + t.makeArray(size) + + if len(raw) != 0 { + // copy over if natively accessible + if t.IsNativelyAccessible() { + bs := t.byteSlice() + copy(bs, raw) + } + } + case t.IsScalar() && t.array.Header.Raw == nil: t.makeArray(1) case t.Shape() == nil && t.array.Header.Raw != nil: @@ -296,7 +309,7 @@ func (t *Dense) fix() { } else { t.SetShape(size) // vector } - case t.array.Header.Raw == nil && t.t != Dtype{}: + case t.array.Header.Raw == nil && t.t != dtype.Dtype{}: size := t.Shape().TotalSize() t.makeArray(size) @@ -573,7 +586,7 @@ func (t *Dense) Memset(x interface{}) error { if !t.IsNativelyAccessible() { return errors.Errorf(inaccessibleData, t) } - if t.IsMaterializable() { + if t.RequiresIterator() { it := newFlatIterator(&t.AP) return t.array.memsetIter(x, it) } @@ -592,11 +605,19 @@ func (t *Dense) Eq(other interface{}) bool { return t.array.Eq(&ot.array) } + if ot, ok := other.(DenseTensor); ok { + if !t.Shape().Eq(ot.Shape()) { + return false + } + + return t.array.Eq(ot.arrPtr()) + } + return false } func (t *Dense) Zero() { - if t.IsMaterializable() { + if t.RequiresIterator() { it := newFlatIterator(&t.AP) if err := t.zeroIter(it); err != nil { panic(err) @@ -635,4 +656,4 @@ func (t *Dense) RequiresIterator() bool { func (t *Dense) Iterator() Iterator { return IteratorFromDense(t) } -func (t *Dense) standardEngine() standardEngine { return t.oe } +func (t *Dense) standardEngine() StandardEngine { return t.oe } diff --git a/dense_apply_test.go b/dense_apply_test.go index 5e8c23d..793f2c5 100644 --- a/dense_apply_test.go +++ b/dense_apply_test.go @@ -1,222 +1,224 @@ -package tensor - -import ( - "math/rand" - "testing" - "testing/quick" - "time" - "unsafe" -) - -func getMutateVal(dt Dtype) interface{} { - switch dt { - case Int: - return int(1) - case Int8: - return int8(1) - case Int16: - return int16(1) - case Int32: - return int32(1) - case Int64: - return int64(1) - case Uint: - return uint(1) - case Uint8: - return uint8(1) - case Uint16: - return uint16(1) - case Uint32: - return uint32(1) - case Uint64: - return uint64(1) - case Float32: - return float32(1) - case Float64: - return float64(1) - case Complex64: - var c complex64 = 1 - return c - case Complex128: - var c complex128 = 1 - return c - case Bool: - return true - case String: - return "Hello World" - case Uintptr: - return uintptr(0xdeadbeef) - case UnsafePointer: - return unsafe.Pointer(uintptr(0xdeadbeef)) - } - return nil -} - -func getMutateFn(dt Dtype) interface{} { - switch dt { - case Int: - return mutateI - case Int8: - return mutateI8 - case Int16: - return mutateI16 - case Int32: - return mutateI32 - case Int64: - return mutateI64 - case Uint: - return mutateU - case Uint8: - return mutateU8 - case Uint16: - return mutateU16 - case Uint32: - return mutateU32 - case Uint64: - return mutateU64 - case Float32: - return mutateF32 - case Float64: - return mutateF64 - case Complex64: - return mutateC64 - case Complex128: - return mutateC128 - case Bool: - return mutateB - case String: - return mutateStr - case Uintptr: - return mutateUintptr - case UnsafePointer: - return mutateUnsafePointer - } - return nil -} - -func TestDense_Apply(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - - // wrong fn type/illogical values - if _, err = a.Apply(getMutateFn); err == nil { - t.Error("Expected an error") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_unsafe(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, UseUnsafe()) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != a { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} - -func TestDense_Apply_reuse(t *testing.T) { - var r *rand.Rand - mut := func(q *Dense) bool { - var mutVal interface{} - if mutVal = getMutateVal(q.Dtype()); mutVal == nil { - return true // we'll temporarily skip those we cannot mutate/get a mutation value - } - var fn interface{} - if fn = getMutateFn(q.Dtype()); fn == nil { - return true // we'll skip those that we cannot mutate - } - - we, eqFail := willerr(q, nil, nil) - _, ok := q.Engine().(Mapper) - we = we || !ok - - a := q.Clone().(*Dense) - reuse := q.Clone().(*Dense) - reuse.Zero() - correct := q.Clone().(*Dense) - correct.Memset(mutVal) - ret, err := a.Apply(fn, WithReuse(reuse)) - if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { - if err != nil { - return false - } - return true - } - if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { - return false - } - if ret != reuse { - t.Error("Expected ret == correct (Unsafe option was used)") - return false - } - return true - } - r = rand.New(rand.NewSource(time.Now().UnixNano())) - if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { - t.Errorf("Applying mutation function failed %v", err) - } -} +package tensor + +import ( + "math/rand" + "testing" + "testing/quick" + "time" + "unsafe" + + "gorgonia.org/dtype" +) + +func getMutateVal(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return int(1) + case Int8: + return int8(1) + case Int16: + return int16(1) + case Int32: + return int32(1) + case Int64: + return int64(1) + case Uint: + return uint(1) + case Uint8: + return uint8(1) + case Uint16: + return uint16(1) + case Uint32: + return uint32(1) + case Uint64: + return uint64(1) + case Float32: + return float32(1) + case Float64: + return float64(1) + case Complex64: + var c complex64 = 1 + return c + case Complex128: + var c complex128 = 1 + return c + case Bool: + return true + case String: + return "Hello World" + case Uintptr: + return uintptr(0xdeadbeef) + case UnsafePointer: + return unsafe.Pointer(uintptr(0xdeadbeef)) + } + return nil +} + +func getMutateFn(dt dtype.Dtype) interface{} { + switch dt { + case Int: + return mutateI + case Int8: + return mutateI8 + case Int16: + return mutateI16 + case Int32: + return mutateI32 + case Int64: + return mutateI64 + case Uint: + return mutateU + case Uint8: + return mutateU8 + case Uint16: + return mutateU16 + case Uint32: + return mutateU32 + case Uint64: + return mutateU64 + case Float32: + return mutateF32 + case Float64: + return mutateF64 + case Complex64: + return mutateC64 + case Complex128: + return mutateC128 + case Bool: + return mutateB + case String: + return mutateStr + case Uintptr: + return mutateUintptr + case UnsafePointer: + return mutateUnsafePointer + } + return nil +} + +func TestDense_Apply(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + + // wrong fn type/illogical values + if _, err = a.Apply(getMutateFn); err == nil { + t.Error("Expected an error") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_unsafe(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, UseUnsafe()) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != a { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} + +func TestDense_Apply_reuse(t *testing.T) { + var r *rand.Rand + mut := func(q *Dense) bool { + var mutVal interface{} + if mutVal = getMutateVal(q.Dtype()); mutVal == nil { + return true // we'll temporarily skip those we cannot mutate/get a mutation value + } + var fn interface{} + if fn = getMutateFn(q.Dtype()); fn == nil { + return true // we'll skip those that we cannot mutate + } + + we, eqFail := willerr(q, nilTC, nilTC) + _, ok := q.Engine().(Mapper) + we = we || !ok + + a := q.Clone().(*Dense) + reuse := q.Clone().(*Dense) + reuse.Zero() + correct := q.Clone().(*Dense) + correct.Memset(mutVal) + ret, err := a.Apply(fn, WithReuse(reuse)) + if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { + if err != nil { + return false + } + return true + } + if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { + return false + } + if ret != reuse { + t.Error("Expected ret == correct (Unsafe option was used)") + return false + } + return true + } + r = rand.New(rand.NewSource(time.Now().UnixNano())) + if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { + t.Errorf("Applying mutation function failed %v", err) + } +} diff --git a/dense_argmethods.go b/dense_argmethods.go index bfdc0d7..fdace5f 100644 --- a/dense_argmethods.go +++ b/dense_argmethods.go @@ -7,13 +7,14 @@ import "github.com/pkg/errors" // Argmax finds the index of the max value along the axis provided func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgmaxer: - return am.argmaxDenseTensor(t, axis) + return am.argmaxDenseTensor(ctx, t, axis) case Argmaxer: var ret Tensor var ok bool - if ret, err = am.Argmax(t, axis); err != nil { + if ret, err = am.Argmax(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { @@ -29,13 +30,14 @@ func (t *Dense) Argmax(axis int) (retVal *Dense, err error) { // Argmin finds the index of the min value along the axis provided func (t *Dense) Argmin(axis int) (retVal *Dense, err error) { e := t.e + ctx := ctxFromEngine(e) switch am := e.(type) { case denseArgminer: - return am.argminDenseTensor(t, axis) + return am.argminDenseTensor(ctx, t, axis) case Argminer: var ret Tensor var ok bool - if ret, err = am.Argmin(t, axis); err != nil { + if ret, err = am.Argmin(ctx, t, axis); err != nil { return nil, errors.Wrapf(err, opFail, "Argmax") } if retVal, ok = ret.(*Dense); !ok { diff --git a/dense_argmethods_test.go b/dense_argmethods_test.go index a4b03bd..a90b957 100644 --- a/dense_argmethods_test.go +++ b/dense_argmethods_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -10,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* Test data */ var basicDenseI = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) diff --git a/dense_arith.go b/dense_arith.go index 7218d37..5c4eba9 100644 --- a/dense_arith.go +++ b/dense_arith.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Add performs t + other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T) func (t *Dense) Add(other *Dense, opts ...FuncOpt) (retVal *Dense, err error) { diff --git a/dense_arith_test.go b/dense_arith_test.go index 8d791fd..423fc85 100644 --- a/dense_arith_test.go +++ b/dense_arith_test.go @@ -1,17 +1,19 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Add(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -37,7 +39,7 @@ func TestDense_Sub(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -64,7 +66,7 @@ func TestDense_Mul(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -91,7 +93,7 @@ func TestDense_Div(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -118,7 +120,7 @@ func TestDense_Pow(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -144,7 +146,7 @@ func TestDense_Add_unsafe(t *testing.T) { iden := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -163,7 +165,6 @@ func TestDense_Add_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -175,7 +176,7 @@ func TestDense_Sub_unsafe(t *testing.T) { inv := func(a *Dense) bool { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -195,7 +196,6 @@ func TestDense_Sub_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -207,7 +207,7 @@ func TestDense_Mul_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -226,7 +226,6 @@ func TestDense_Mul_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -239,7 +238,7 @@ func TestDense_Div_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -259,7 +258,6 @@ func TestDense_Div_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -271,7 +269,7 @@ func TestDense_Pow_unsafe(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) b.Memset(identityVal(1, a.t)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -290,7 +288,6 @@ func TestDense_Pow_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -303,7 +300,7 @@ func TestDense_Add_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -346,7 +343,7 @@ func TestDense_Add_reuse(t *testing.T) { } correct, err := a.Add(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -387,7 +384,7 @@ func TestDense_Sub_reuse(t *testing.T) { b := New(Of(a.t), WithShape(a.Shape().Clone()...), WithEngine(a.Engine())) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -430,7 +427,7 @@ func TestDense_Sub_reuse(t *testing.T) { } correct, err := a.Sub(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -472,7 +469,7 @@ func TestDense_Mul_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -515,7 +512,7 @@ func TestDense_Mul_reuse(t *testing.T) { } correct, err := a.Mul(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -557,7 +554,7 @@ func TestDense_Div_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -600,7 +597,7 @@ func TestDense_Div_reuse(t *testing.T) { } correct, err := a.Div(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -642,7 +639,7 @@ func TestDense_Pow_reuse(t *testing.T) { b.Memset(identityVal(1, a.t)) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -676,7 +673,7 @@ func TestDense_Add_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -705,7 +702,7 @@ func TestDense_Sub_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Suber) we = we || !ok @@ -735,7 +732,7 @@ func TestDense_Mul_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -765,7 +762,7 @@ func TestDense_Div_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -795,7 +792,7 @@ func TestDense_Pow_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := a.Engine().(Power) we = we || !ok @@ -823,7 +820,7 @@ func TestDense_AddScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -849,7 +846,7 @@ func TestDense_AddScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -903,7 +900,7 @@ func TestDense_SubScalar(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -929,7 +926,7 @@ func TestDense_SubScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -983,7 +980,7 @@ func TestDense_MulScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1009,7 +1006,7 @@ func TestDense_MulScalar(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1063,7 +1060,7 @@ func TestDense_DivScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1118,7 +1115,7 @@ func TestDense_PowScalar(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1173,7 +1170,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1192,7 +1189,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1204,7 +1200,7 @@ func TestDense_AddScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1223,7 +1219,6 @@ func TestDense_AddScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1237,7 +1232,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1257,7 +1252,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1268,7 +1262,7 @@ func TestDense_SubScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(0, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1288,7 +1282,6 @@ func TestDense_SubScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1301,7 +1294,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1320,7 +1313,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1332,7 +1324,7 @@ func TestDense_MulScalar_unsafe(t *testing.T) { a := q.Clone().(*Dense) b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1351,7 +1343,6 @@ func TestDense_MulScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(iden2, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1365,7 +1356,7 @@ func TestDense_DivScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1385,7 +1376,6 @@ func TestDense_DivScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } if err := quick.Check(inv1, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { @@ -1399,7 +1389,7 @@ func TestDense_PowScalar_unsafe(t *testing.T) { b := identityVal(1, q.t) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1418,7 +1408,6 @@ func TestDense_PowScalar_unsafe(t *testing.T) { t.Errorf("Expected ret to be the same as a") return false } - return true } @@ -1434,7 +1423,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1466,7 +1455,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1509,7 +1498,7 @@ func TestDense_AddScalar_reuse(t *testing.T) { } correct, err := a.Add(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Adder) we = we || !ok @@ -1552,7 +1541,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1584,7 +1573,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -1627,7 +1616,7 @@ func TestDense_SubScalar_reuse(t *testing.T) { } correct, err := a.Sub(b) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := a.Engine().(Suber) we = we || !ok @@ -1670,7 +1659,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1702,7 +1691,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { b := identityVal(1, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -1745,7 +1734,7 @@ func TestDense_MulScalar_reuse(t *testing.T) { } correct, err := a.Mul(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Muler) we = we || !ok @@ -1788,7 +1777,7 @@ func TestDense_DivScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -1832,7 +1821,7 @@ func TestDense_DivScalar_reuse(t *testing.T) { } correct, err := a.Div(b) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := a.Engine().(Diver) we = we || !ok @@ -1875,7 +1864,7 @@ func TestDense_PowScalar_reuse(t *testing.T) { reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok @@ -1912,7 +1901,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1941,7 +1930,7 @@ func TestDense_AddScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Adder) we = we || !ok @@ -1972,7 +1961,7 @@ func TestDense_SubScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok @@ -2004,7 +1993,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -2033,7 +2022,7 @@ func TestDense_MulScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Muler) we = we || !ok @@ -2064,7 +2053,7 @@ func TestDense_DivScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, numberTypes, nil) + we, willFailEq := willerr(a, dtype.Number, nilTC) _, ok := q.Engine().(Diver) we = we || !ok @@ -2096,7 +2085,7 @@ func TestDense_PowScalar_incr(t *testing.T) { correct := a.Clone().(*Dense) incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) - we, willFailEq := willerr(a, floatcmplxTypes, complexTypes) + we, willFailEq := willerr(a, dtype.FloatComplex, dtype.Complexes) _, ok := q.Engine().(Power) we = we || !ok diff --git a/dense_cmp.go b/dense_cmp.go index 4ffaadf..d7770ac 100644 --- a/dense_cmp.go +++ b/dense_cmp.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import "github.com/pkg/errors" +// Code generated by genlib2. DO NOT EDIT. + // Gt performs t > other elementwise. Both t and other must have the same shape. // Acceptable FuncOpts are: UseUnsafe(), AsSameType(), WithReuse(). //UseUnsafe() will ensure that the same type is returned. diff --git a/dense_cmp_test.go b/dense_cmp_test.go index a0bc5b6..82e8518 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -1,16 +1,18 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "reflect" "testing" "testing/quick" + + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func TestDense_Gt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -68,7 +70,7 @@ func TestDense_Gt(t *testing.T) { } func TestDense_Gte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -126,7 +128,7 @@ func TestDense_Gte(t *testing.T) { } func TestDense_Lt(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -184,7 +186,7 @@ func TestDense_Lt(t *testing.T) { } func TestDense_Lte(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -242,7 +244,7 @@ func TestDense_Lte(t *testing.T) { } func TestDense_ElEq(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -298,7 +300,7 @@ func TestDense_ElEq(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -333,7 +335,7 @@ func TestDense_ElEq(t *testing.T) { } func TestDense_ElNe(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -368,11 +370,11 @@ func TestDense_ElNe(t *testing.T) { } func TestDense_Gt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -428,11 +430,11 @@ func TestDense_Gt_assame(t *testing.T) { } func TestDense_Gte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -488,11 +490,11 @@ func TestDense_Gte_assame(t *testing.T) { } func TestDense_Lt_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -548,11 +550,11 @@ func TestDense_Lt_assame(t *testing.T) { } func TestDense_Lte_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -608,11 +610,11 @@ func TestDense_Lte_assame(t *testing.T) { } func TestDense_ElEq_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -666,11 +668,11 @@ func TestDense_ElEq_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -704,11 +706,11 @@ func TestDense_ElEq_assame(t *testing.T) { } func TestDense_ElNe_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -742,7 +744,7 @@ func TestDense_ElNe_assame(t *testing.T) { } func TestDense_GtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gter) we = we || !ok @@ -798,7 +800,7 @@ func TestDense_GtScalar(t *testing.T) { } func TestDense_GteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok @@ -854,7 +856,7 @@ func TestDense_GteScalar(t *testing.T) { } func TestDense_LtScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lter) we = we || !ok @@ -910,7 +912,7 @@ func TestDense_LtScalar(t *testing.T) { } func TestDense_LteScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, ordTypes, nil) + we, _ := willerr(q, dtype.Ord, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok @@ -966,7 +968,7 @@ func TestDense_LteScalar(t *testing.T) { } func TestDense_ElEqScalar(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1020,7 +1022,7 @@ func TestDense_ElEqScalar(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1053,7 +1055,7 @@ func TestDense_ElEqScalar(t *testing.T) { } func TestDense_ElNeScalar(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, eqTypes, nil) + we, _ := willerr(q, dtype.Eq, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok @@ -1086,11 +1088,11 @@ func TestDense_ElNeScalar(t *testing.T) { } func TestDense_GtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1144,11 +1146,11 @@ func TestDense_GtScalar_assame(t *testing.T) { } func TestDense_GteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Gteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1202,11 +1204,11 @@ func TestDense_GteScalar_assame(t *testing.T) { } func TestDense_LtScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lter) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1260,11 +1262,11 @@ func TestDense_LtScalar_assame(t *testing.T) { } func TestDense_LteScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(Lteer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1318,11 +1320,11 @@ func TestDense_LteScalar_assame(t *testing.T) { } func TestDense_ElEqScalar_assame(t *testing.T) { transFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1374,11 +1376,11 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() @@ -1410,11 +1412,11 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } func TestDense_ElNeScalar_assame(t *testing.T) { symFn := func(q *Dense) bool { - we, _ := willerr(q, nonComplexNumberTypes, nil) + we, _ := willerr(q, dtype.NonComplexNumber, nilTC) _, ok := q.Engine().(ElEqer) we = we || !ok - if err := typeclassCheck(q.Dtype(), nonComplexNumberTypes); err != nil { + if err := dtype.TypeClassCheck(q.Dtype(), dtype.NonComplexNumber); err != nil { return true // we exit early if the generated type is not something we can handle } r := newRand() diff --git a/dense_compat.go b/dense_compat.go index dcbefa2..cf6764d 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "fmt" "math" @@ -15,9 +15,10 @@ import ( "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) -func convFromFloat64s(to Dtype, data []float64) interface{} { +func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { case Int: retVal := make([]int, len(data)) @@ -416,10 +417,10 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: it := newFlatIterator(&t.AP) diff --git a/dense_compat_test.go b/dense_compat_test.go index c641203..442b7e7 100644 --- a/dense_compat_test.go +++ b/dense_compat_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" @@ -11,13 +11,14 @@ import ( arrowTensor "github.com/apache/arrow/go/arrow/tensor" "github.com/stretchr/testify/assert" "gonum.org/v1/gonum/mat" + "gorgonia.org/dtype" ) var toMat64Tests = []struct { data interface{} sliced interface{} shape Shape - dt Dtype + dt dtype.Dtype }{ {Range(Int, 0, 6), []int{0, 1, 3, 4}, Shape{2, 3}, Int}, {Range(Int8, 0, 6), []int8{0, 1, 3, 4}, Shape{2, 3}, Int8}, diff --git a/dense_generated.go b/dense_generated.go index 6349bfb..c9158fa 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -1,11 +1,15 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor -import "reflect" +import ( + "reflect" + + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. // Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { case reflect.Int: @@ -68,7 +72,7 @@ func Ones(dt Dtype, shape ...int) *Dense { // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense { +func I(dt dtype.Dtype, r, c, k int) *Dense { ret := New(Of(dt), WithShape(r, c)) i := k if k < 0 { diff --git a/dense_generated_test.go b/dense_generated_test.go index e87baa0..edd0850 100644 --- a/dense_generated_test.go +++ b/dense_generated_test.go @@ -1,15 +1,16 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var onesTests = []struct { - of Dtype + of dtype.Dtype shape Shape correct interface{} }{ @@ -56,7 +57,7 @@ func TestOnes(t *testing.T) { // yes, it's a pun on eye tests, stop asking and go see your optometrist var eyeTests = []struct { - E Dtype + E dtype.Dtype R, C, K int correct interface{} diff --git a/dense_getset_test.go b/dense_getset_test.go index 8ab8e44..899e855 100644 --- a/dense_getset_test.go +++ b/dense_getset_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -8,10 +6,13 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + var denseSetGetTests = []struct { - of Dtype + of dtype.Dtype data interface{} set interface{} @@ -48,7 +49,7 @@ func TestDense_setget(t *testing.T) { } var denseMemsetTests = []struct { - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -88,7 +89,7 @@ func TestDense_memset(t *testing.T) { } var denseZeroTests = []struct { - of Dtype + of dtype.Dtype data interface{} correct interface{} diff --git a/dense_io.go b/dense_io.go index 7bb9608..374daf0 100644 --- a/dense_io.go +++ b/dense_io.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -16,10 +14,13 @@ import ( flatbuffers "github.com/google/flatbuffers/go" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/serialization/fb" "gorgonia.org/tensor/internal/serialization/pb" ) +// Code generated by genlib2. DO NOT EDIT. + /* GOB SERIALIZATION */ // GobEncode implements gob.GobEncoder @@ -163,7 +164,7 @@ func (r *binaryReader) Err() error { // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil { + if npdt, err = t.t.NumpyDtype(); err != nil { return } @@ -242,7 +243,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error) { } // TODO: check for endianness. For now we assume everything is little endian - if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } @@ -423,7 +424,7 @@ func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { // convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. // If into is nil, then a backing slice will be created. -func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { case reflect.Int: @@ -793,12 +794,11 @@ func (t *Dense) FBDecode(buf []byte) error { t.strides[i] = int(serialized.Strides(i)) } typ := string(serialized.Type()) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") } + t.t = dt if t.e == nil { t.e = StdEng{} @@ -870,12 +870,11 @@ func (t *Dense) PBDecode(buf []byte) error { } t.Δ = Triangle(toSerialize.T) typ := string(toSerialize.Type) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") } + t.t = dt if t.e == nil { t.e = StdEng{} diff --git a/dense_linalg.go b/dense_linalg.go index 3caa8e7..10eb936 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -2,14 +2,15 @@ package tensor import ( "github.com/pkg/errors" + "gorgonia.org/dtype" ) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices func (t *Dense) Trace() (retVal interface{}, err error) { e := t.e - + ctx := ctxFromEngine(e) if tracer, ok := e.(Tracer); ok { - return tracer.Trace(t) + return tracer.Trace(ctx, t) } return nil, errors.Errorf("Engine %T does not support Trace", e) } @@ -17,7 +18,7 @@ func (t *Dense) Trace() (retVal interface{}, err error) { // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { // check that the data is a float - if err = typeclassCheck(t.t, floatcmplxTypes); err != nil { + if err = dtype.TypeClassCheck(t.t, dtype.FloatComplex); err != nil { return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") } @@ -33,13 +34,14 @@ func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { } e := t.e + ctx := ctxFromEngine(e) switch ip := e.(type) { case InnerProderF32: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) case InnerProderF64: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) case InnerProder: - return ip.Inner(t, other) + return ip.Inner(ctx, t, other) } return nil, errors.Errorf("Engine does not support Inner()") @@ -93,11 +95,11 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e - if mvm, ok := e.(MatVecMuler); ok { - if err = mvm.MatVecMul(t, other, retVal); err != nil { + if err = mvm.MatVecMul(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "MatVecMul") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -142,10 +144,11 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e if mm, ok := e.(MatMuler); ok { - if err = mm.MatMul(t, other, retVal); err != nil { + if err = mm.MatMul(ctx, t, other, retVal); err != nil { return } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -156,12 +159,6 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) // Outer finds the outer product of two vectors func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { - // check both are vectors - if !t.Shape().IsVector() || !other.Shape().IsVector() { - err = errors.Errorf("Outer only works when there are two vectors. t's shape: %v. other's shape: %v", t.Shape(), other.Shape()) - return - } - m := t.Size() n := other.Size() @@ -181,13 +178,14 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) AsFortran(nil)(retVal) } } + ctx := fo.Context() e := t.e // DGER does not have any beta. So the values have to be zeroed first if the tensor is to be reused retVal.Zero() if op, ok := e.(OuterProder); ok { - if err = op.Outer(t, other, retVal); err != nil { + if err = op.Outer(ctx, t, other, retVal); err != nil { return nil, errors.Wrapf(err, opFail, "engine.uter") } return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) @@ -355,10 +353,10 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err // In the future, when gonum/lapack fully supports float32, we'll look into rewriting this func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { e := t.Engine() - + ctx := ctxFromEngine(e) if svder, ok := e.(SVDer); ok { var sT, uT, vT Tensor - if sT, uT, vT, err = svder.SVD(t, uv, full); err != nil { + if sT, uT, vT, err = svder.SVD(ctx, t, uv, full); err != nil { return nil, nil, nil, errors.Wrap(err, "Error while performing *Dense.SVD") } if s, err = assertDense(sT); err != nil { @@ -389,13 +387,13 @@ func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, e if !safe { return } - if err = reuseCheckShape(retVal, expectedShape); err != nil { + if err = checkFixShape(retVal, expectedShape); err != nil { err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.") return } return } - return + return nil, nil } // handleIncr is the cleanup step for when there is an Tensor to increment. If the result tensor is the same as the reuse Tensor, the result tensor gets returned to the pool @@ -413,7 +411,7 @@ func handleIncr(res *Dense, reuse, incr Tensor, expectedShape Shape) (retVal *De return } - if err = typeclassCheck(incrD.t, numberTypes); err != nil { + if err = dtype.TypeClassCheck(incrD.t, dtype.Number); err != nil { err = errors.Wrapf(err, "handleIncr only handles Number types. Got %v instead", incrD.t) return } diff --git a/dense_linalg_test.go b/dense_linalg_test.go index a9a24dc..17b6fcd 100644 --- a/dense_linalg_test.go +++ b/dense_linalg_test.go @@ -408,11 +408,13 @@ var outerTests = []linalgTest{ []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, + /* TODO: this test is no longer valid with the new impl of outer // stupids - a or b not vector {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, + */ // stupids - bad incr shape {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, diff --git a/dense_mask_filling.go b/dense_mask_filling.go index f5d45c7..a31b5aa 100644 --- a/dense_mask_filling.go +++ b/dense_mask_filling.go @@ -72,7 +72,7 @@ func (t *Dense) Filled(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := tc.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } @@ -107,7 +107,7 @@ func (t *Dense) FilledInplace(val ...interface{}) (interface{}, error) { for i := range sliceList { tt, err := t.Slice(nil, sliceList[i]) if err != nil { - ts := tt.(*Dense) + ts := tt.(DenseView) ts.Memset(fillval) } } diff --git a/dense_mask_inspection.go b/dense_mask_inspection.go index d2e7843..7e1c30c 100644 --- a/dense_mask_inspection.go +++ b/dense_mask_inspection.go @@ -1,10 +1,12 @@ package tensor +import "gorgonia.org/dtype" + type maskedReduceFn func(Tensor) interface{} // MaskedReduce applies a reduction function of type maskedReduceFn to mask, and returns // either an int, or another array -func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) interface{} { +func MaskedReduce(t *Dense, retType dtype.Dtype, fn maskedReduceFn, axis ...int) interface{} { if len(axis) == 0 || t.IsVector() { return fn(t) } @@ -18,7 +20,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter // calculate shape of tensor to be returned slices[ax] = makeRS(0, 0) tt, _ := t.Slice(slices...) - ts := tt.(*Dense) + ts := MustGetDense(tt) retVal := NewDense(retType, ts.shape) //retVal is array to be returned it := NewIterator(retVal.Info()) @@ -37,7 +39,7 @@ func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) inter } } tt, _ = t.Slice(slices...) - ts = tt.(*Dense) + ts = MustGetDense(tt) retVal.SetAt(fn(ts), coord...) } diff --git a/dense_mask_inspection_test.go b/dense_mask_inspection_test.go index 7bd118f..ea3574f 100644 --- a/dense_mask_inspection_test.go +++ b/dense_mask_inspection_test.go @@ -124,7 +124,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } retSL = T.FlatNotMaskedContiguous() @@ -137,7 +137,7 @@ func TestMaskedFindContiguous(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } retSL = T.FlatMaskedContiguous() @@ -158,7 +158,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(false) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(true) } start, end := T.FlatNotMaskedEdges() @@ -169,7 +169,7 @@ func TestMaskedFindEdges(t *testing.T) { T.ResetMask(true) for i := range sliceList { tt, _ := T.Slice(nil, sliceList[i]) - ts := tt.(*Dense) + ts := MustGetDense(tt) ts.ResetMask(false) } start, end = T.FlatMaskedEdges() diff --git a/dense_maskcmp_methods.go b/dense_maskcmp_methods.go index 4cc3d95..d4b415a 100644 --- a/dense_maskcmp_methods.go +++ b/dense_maskcmp_methods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "github.com/pkg/errors" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ // MaskedEqual sets the mask to true where the corresponding data is equal to val diff --git a/dense_maskcmp_methods_test.go b/dense_maskcmp_methods_test.go index d16a78d..e48e89c 100644 --- a/dense_maskcmp_methods_test.go +++ b/dense_maskcmp_methods_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,6 +7,8 @@ import ( "github.com/stretchr/testify/assert" ) +// Code generated by genlib2. DO NOT EDIT. + /* MaskedEqual */ func TestDense_MaskedEqual_I(t *testing.T) { diff --git a/dense_matop.go b/dense_matop.go index b059f35..3a9f005 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -131,7 +131,9 @@ func (t *Dense) SetAt(v interface{}, coords ...int) error { return errors.Errorf(inaccessibleData, t) } - if len(coords) != t.Dims() { + switch { + case t.IsScalar() && len(coords) == 1: + case len(coords) != t.Dims(): return errors.Errorf(dimMismatch, t.Dims(), len(coords)) } @@ -195,7 +197,7 @@ func (t *Dense) CopyTo(other *Dense) error { } // TODO: use copyDenseIter - return errors.Errorf(methodNYI, "CopyTo", "views") + return nyierr(methodNYI, "views") } // Narrow narrows the tensor. @@ -238,18 +240,39 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { view.mask = t.mask[ndStart:ndEnd] } - return view, err + return DenseView{view}, err } // SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view. // The underlying data is the same. // This method will override ALL the metadata in view. -func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { +func (t *Dense) SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) { + switch view := view.(type) { + case nil: + return t.Slice(slices...) + case DenseView: + v := view.Dense + if v, err = t.sliceIntoDense(v, slices...); err != nil { + return nil, err + } + return DenseView{v}, nil + + case *Dense: + if view, err = t.sliceIntoDense(view, slices...); err != nil { + return nil, err + } + return DenseView{view}, nil + default: + return nil, nyierr(typeNYI, view) + } +} + +func (t *Dense) sliceIntoDense(view *Dense, slices ...Slice) (retVal *Dense, err error) { var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { - return + return nil, err } view.AP.zero() @@ -265,9 +288,7 @@ func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) if t.IsMasked() { view.mask = t.mask[ndStart:ndEnd] } - - return view, err - + return view, nil } // RollAxis rolls the axis backwards until it lies in the given position. diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index fe05f2a..9d63082 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -27,9 +27,9 @@ func (t *Dense) Transpose() error { // important! because the strides would have changed once the underlying data changed var expStrides []int if t.AP.o.IsColMajor() { - expStrides = expShape.CalcStridesColMajor() + expStrides = CalcStridesColMajor(expShape) } else { - expStrides = expShape.CalcStrides() + expStrides = CalcStrides(expShape) } defer ReturnInts(expStrides) defer func() { @@ -43,13 +43,14 @@ func (t *Dense) Transpose() error { } // actually move data - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) transposer, ok := e.(Transposer) if !ok { return errors.Errorf("Engine does not support Transpose()") } - return transposer.Transpose(t, expStrides) + return transposer.Transpose(ctx, t, expStrides) } // Repeat is like Numpy's repeat. It repeats the elements of an array. @@ -57,9 +58,10 @@ func (t *Dense) Transpose() error { // Just like NumPy, the repeats param is broadcasted to fit the size of the given axis. func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if rp, ok := e.(Repeater); ok { - return rp.Repeat(t, axis, repeats...) + return rp.Repeat(ctx, t, axis, repeats...) } return nil, errors.New("Engine does not support Repeat") } @@ -67,11 +69,12 @@ func (t *Dense) Repeat(axis int, repeats ...int) (retVal Tensor, err error) { // Concat concatenates the other tensors along the given axis. It is like Numpy's concatenate() function. func (t *Dense) Concat(axis int, Ts ...*Dense) (retVal *Dense, err error) { e := t.Engine() + ctx := ctxFromEngine(e) if c, ok := e.(Concater); ok { var ret Tensor others := densesToTensors(Ts) - if ret, err = c.Concat(t, axis, others...); err != nil { + if ret, err = c.Concat(ctx, t, axis, others...); err != nil { return nil, errors.Wrapf(err, opFail, "Concat") } return ret.(*Dense), nil @@ -127,8 +130,10 @@ func (t *Dense) Stack(axis int, others ...*Dense) (retVal *Dense, err error) { } func (t *Dense) stackDense(axis int, others ...DenseTensor) (retVal DenseTensor, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) if ds, ok := t.Engine().(DenseStacker); ok { - return ds.StackDense(t, axis, others...) + return ds.StackDense(ctx, t, axis, others...) } return nil, errors.Errorf("Engine does not support DenseStacker") } diff --git a/dense_matop_test.go b/dense_matop_test.go index 652c71d..2e3c9bb 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/vecf64" ) @@ -42,7 +43,7 @@ func cloneArray(a interface{}) interface{} { return nil } -func castToDt(val float64, dt Dtype) interface{} { +func castToDt(val float64, dt dtype.Dtype) interface{} { switch dt { case Bool: return false @@ -504,7 +505,7 @@ func TestDense_CopyTo(t *testing.T) { T = New(Of(Byte), WithShape(3, 3)) T2 = New(Of(Byte), WithShape(2, 2)) T3, _ = T.Slice(makeRS(0, 2), makeRS(0, 2)) // T[0:2, 0:2], shape == (2,2) - if err = T2.CopyTo(T3.(*Dense)); err != nil { + if err = T2.CopyTo(MustGetDense(T3)); err != nil { t.Log(err) // for now it's a not yet implemented error. TODO: FIX THIS } @@ -610,7 +611,7 @@ func TestDense_Slice(t *testing.T) { assert.True(Shape{2}.Eq(V.Shape())) assert.Equal([]int{3}, V.Strides()) assert.Equal([]float32{0, 1, 2, 3}, V.Data()) - assert.True(V.(*Dense).old.IsZero()) + assert.True(MustGetDense(V).old.IsZero()) // slice a sliced t.Logf("%v", V) @@ -775,7 +776,7 @@ func TestDense_RollAxis(t *testing.T) { var concatTests = []struct { name string - dt Dtype + dt dtype.Dtype a interface{} b interface{} shape Shape @@ -933,7 +934,7 @@ func TestDense_Concat_sliced(t *testing.T) { var simpleStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape axis int stackCount int @@ -984,7 +985,7 @@ var simpleStackTests = []struct { var viewStackTests = []struct { name string - dt Dtype + dt dtype.Dtype shape Shape transform []int slices []Slice @@ -1041,12 +1042,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -1057,12 +1058,12 @@ func TestDense_Stack(t *testing.T) { T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -1108,12 +1109,12 @@ func TestDense_Stack(t *testing.T) { T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize()))) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T.Slice(sts.slices...); err != nil { t.Error(err) continue } - T = sliced.(*Dense) + T = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T.T(sts.transform...) } @@ -1125,12 +1126,12 @@ func TestDense_Stack(t *testing.T) { T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt)) switch { case sts.slices != nil && sts.transform == nil: - var sliced Tensor + var sliced View if sliced, err = T1.Slice(sts.slices...); err != nil { t.Error(err) continue } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) case sts.transform != nil && sts.slices == nil: T1.T(sts.transform...) } @@ -1158,12 +1159,12 @@ func TestDense_Stack(t *testing.T) { var stacked []*Dense for i := 0; i < 1; i++ { T1 := New(WithShape(2, 2), WithBacking([]string{"blah1", "blah2", "blah3", "blah4"})) - var sliced Tensor + var sliced View if sliced, err = T1.Slice(nil, nil); err != nil { t.Error(err) break } - T1 = sliced.(*Dense) + T1 = MustGetDense(sliced) stacked = append(stacked, T1) } T2, err := T.Stack(0, stacked...) diff --git a/dense_norms_test.go b/dense_norms_test.go index 316b32a..69879ee 100644 --- a/dense_norms_test.go +++ b/dense_norms_test.go @@ -120,12 +120,13 @@ func TestTensor_Norm(t *testing.T) { t.Error(err) } } + } func TestTensor_Norm_Axis(t *testing.T) { assert := assert.New(t) var T, s, expected, retVal *Dense - var sliced Tensor + var sliced View var err error var backing []float64 var ords []NormOrder @@ -149,7 +150,7 @@ func TestTensor_Norm_Axis(t *testing.T) { var expecteds []*Dense for k := 0; k < T.Shape()[1]; k++ { sliced, _ = T.Slice(nil, ss(k)) - s = sliced.(View).Materialize().(*Dense) + s = sliced.Materialize().(*Dense) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -162,8 +163,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize() - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 0; Ord = %v; Expected %v. Got %v instead. ret %v, i: %d", ord, e.Data(), sliced.Data(), retVal, i) } } @@ -173,7 +174,7 @@ func TestTensor_Norm_Axis(t *testing.T) { expecteds = expecteds[:0] for k := 0; k < T.Shape()[0]; k++ { sliced, _ = T.Slice(ss(k)) - s = sliced.(*Dense) + s = MustGetDense(sliced) expected, _ = s.Norm(ord) expecteds = append(expecteds, expected) } @@ -185,8 +186,8 @@ func TestTensor_Norm_Axis(t *testing.T) { assert.Equal(len(expecteds), retVal.Shape()[0]) for i, e := range expecteds { sliced, _ = retVal.Slice(ss(i)) - sliced = sliced.(View).Materialize().(*Dense) - if !allClose(e.Data(), sliced.Data()) { + mat := sliced.Materialize() + if !allClose(e.Data(), mat.Data()) { t.Errorf("Axis = 1; Ord = %v; Expected %v. Got %v instead", ord, e.Data(), sliced.Data()) } } @@ -249,9 +250,8 @@ func TestTensor_Norm_Axis(t *testing.T) { if rowAxis > colAxis { sliced.T() } - sliced = sliced.(View).Materialize().(*Dense) - s = sliced.(*Dense) - expected, _ = s.Norm(ord) + mat := sliced.Materialize().(*Dense) + expected, _ = mat.Norm(ord) expecteds = append(expecteds, expected) } diff --git a/dense_reduction_methods.go b/dense_reduction_methods.go index cb744b5..28058a2 100644 --- a/dense_reduction_methods.go +++ b/dense_reduction_methods.go @@ -3,37 +3,65 @@ package tensor import "github.com/pkg/errors" func (t *Dense) Sum(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if sumer, ok := e.(Sumer); ok { var ret Tensor - if ret, err = sumer.Sum(t, along...); err != nil { + if ret, err = sumer.Sum(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Sum") + } + return } return nil, errors.Errorf("Engine does not support Sum") } +func (t *Dense) Prod(along ...int) (retVal *Dense, err error) { + e := t.Engine() + ctx := ctxFromEngine(e) + if sumer, ok := e.(Proder); ok { + var ret Tensor + if ret, err = sumer.Prod(ctx, t, along...); err != nil { + return + } + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Prod") + } + return + } + return nil, errors.Errorf("Engine does not support Prod") +} + func (t *Dense) Max(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if maxer, ok := e.(Maxer); ok { var ret Tensor - if ret, err = maxer.Max(t, along...); err != nil { + if ret, err = maxer.Max(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Max") + } + return } return nil, errors.Errorf("Engine does not support Max") } func (t *Dense) Min(along ...int) (retVal *Dense, err error) { - var e Engine = t.e + e := t.Engine() + ctx := ctxFromEngine(e) if miner, ok := e.(Miner); ok { var ret Tensor - if ret, err = miner.Min(t, along...); err != nil { + if ret, err = miner.Min(ctx, t, along...); err != nil { return } - return ret.(*Dense), nil + if retVal, err = assertDense(ret); err != nil { + return nil, errors.Wrapf(err, opFail, "Min") + } + return } return nil, errors.Errorf("Engine does not support Min") } diff --git a/dense_reduction_test.go b/dense_reduction_test.go index b10e3ac..e4ef5ec 100644 --- a/dense_reduction_test.go +++ b/dense_reduction_test.go @@ -1,16 +1,17 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( "testing" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/execution" ) +// Code generated by genlib2. DO NOT EDIT. + var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int @@ -116,7 +117,7 @@ func TestDense_Reduce(t *testing.T) { var sumTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -273,7 +274,7 @@ func TestDense_Sum(t *testing.T) { var maxTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int @@ -411,7 +412,7 @@ func TestDense_Max(t *testing.T) { var minTests = []struct { name string - of Dtype + of dtype.Dtype shape Shape along []int diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index e542133..98d309a 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -19,28 +19,28 @@ type selByIndicesTest struct { } var selByIndicesTests = []selByIndicesTest{ - {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, - Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, - }, - {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, - Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, + // {Name: "Basic", Data: Range(Float64, 0, 4), Shape: Shape{2, 2}, Indices: []int{0, 1}, Axis: 0, WillErr: false, + // Correct: []float64{0, 1, 2, 3}, CorrectShape: Shape{2, 2}, + // }, + // {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, - {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, - Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, + // {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, + // Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, - {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, - Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, + // {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, + // Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, - {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, - Correct: []int{1, 1}, CorrectShape: Shape{2}}, + // {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + // Correct: []int{1, 1}, CorrectShape: Shape{2}}, {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, Correct: []int{1, 1}, CorrectShape: Shape{2}}, - {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, - Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, - {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, - Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, - }, + // {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, + // Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, + // {Name: "(2,1) Matrx (colvec) with (10) indices", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + // Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, + // }, } func TestDense_SelectByIndices(t *testing.T) { @@ -98,10 +98,10 @@ var selByIndicesBTests = []struct { } func init() { - for i := range selByIndicesBTests { - selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] - selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape - } + // for i := range selByIndicesBTests { + // selByIndicesBTests[i].selByIndicesTest = selByIndicesTests[i] + // selByIndicesBTests[i].CorrectGradShape = selByIndicesTests[i].Shape + // } } func TestDense_SelectByIndicesB(t *testing.T) { diff --git a/dense_views.go b/dense_views.go index 201ff20..ab3c537 100644 --- a/dense_views.go +++ b/dense_views.go @@ -3,10 +3,39 @@ package tensor // a View is a *Tensor with customized strides. The reason for not splitting them up into different types is complicated // this file contains all the methods that deals with Views +var _ View = DenseView{} + +// Dense +type DenseView struct { + *Dense +} + +// RequiresIterator returns true if an iterator is required to read the data in the correct fashion. +func (t DenseView) RequiresIterator() bool { + if t.len() == 1 { + return false + } + // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { + return true + } + return false +} + +// IsView indicates if the Tensor is a view of another (typically from slicing) +func (t DenseView) IsView() bool { + return t.viewOf != 0 +} + +// IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing +func (t DenseView) IsMaterializable() bool { + return t.viewOf != 0 || !t.old.IsZero() +} + // Materialize takes a view, copies its data and puts it in a new *Tensor. -func (t *Dense) Materialize() Tensor { +func (t DenseView) Materialize() Tensor { if !t.IsMaterializable() { - return t + return t.Dense } retVal := recycledDense(t.t, t.shape.Clone(), WithEngine(t.e)) diff --git a/engine.go b/engine.go index 39e3f04..2a115e7 100644 --- a/engine.go +++ b/engine.go @@ -1,5 +1,11 @@ package tensor +import ( + "context" + + "gorgonia.org/dtype" +) + // Memory is a representation of memory of the value. // // The main reason for requiring both Uintptr() and Pointer() methods is because while Go currently does not have a compacting @@ -24,7 +30,16 @@ type Engine interface { WorksWith(order DataOrder) bool // WorksWith returns true if the data order can be directly worked with } -type standardEngine interface { +// StandardEngine is any engine that wraps a StdEng{}. +type StandardEngine interface { + StandardEngine2 + + // anything that wraps StdEng will contain the following interfaces: + arrayMaker +} + +// StandardEngine2 is any engine that implements the basic operations of a standard engine. +type StandardEngine2 interface { Engine Adder @@ -53,7 +68,12 @@ type standardEngine interface { } type arrayMaker interface { - makeArray(arr *array, t Dtype, size int) + makeArray(arr *array, t dtype.Dtype, size int) +} + +// contexter is any engine (or type) that returns the current context. +type contexter interface { + Context() context.Context } // NonStdEngine are any engines that do not allocate using the default built in allocator @@ -65,33 +85,33 @@ type NonStdEngine interface { // Transposer is any engine that can perform an unsafe transpose of a tensor. type Transposer interface { - Transpose(t Tensor, expStrides []int) error + Transpose(ctx context.Context, t Tensor, expStrides []int) error } // Concater is any enegine that can concatenate multiple Tensors together type Concater interface { - Concat(t Tensor, axis int, others ...Tensor) (Tensor, error) + Concat(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // Stacker is any engine that can stack multiple Tenosrs along an axis type Stacker interface { - Stack(t Tensor, axis int, others ...Tensor) (Tensor, error) + Stack(ctx context.Context, t Tensor, axis int, others ...Tensor) (Tensor, error) } // DenseStacker is any engine that can stack DenseTensors along an axis. This is a specialization of Stacker. type DenseStacker interface { - StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) + StackDense(ctx context.Context, t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) } // Repeater is any engine that can repeat values along the given axis. type Repeater interface { - Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) - RepeatReuse(t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) + Repeat(ctx context.Context, t Tensor, axis int, repeats ...int) (Tensor, error) + RepeatReuse(ctx context.Context, t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error) } // Diager is any engine that can return a tensor that only contains the diagonal values of the input type Diager interface { - Diag(a Tensor) (Tensor, error) + Diag(ctx context.Context, a Tensor) (Tensor, error) } /* NUMBER INTERFACES @@ -177,43 +197,43 @@ type MaxBetweener interface { // Tracer is any engine that can return the trace (aka the sum of the diagonal elements). type Tracer interface { - Trace(a Tensor) (interface{}, error) + Trace(ctx context.Context, a Tensor) (interface{}, error) } // FMAer is any engine that can perform fused multiply add functions: A * X + Y. Also known as Axpy. type FMAer interface { - FMA(a, x, y Tensor) (Tensor, error) - FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error) + FMA(ctx context.Context, a, x, y Tensor) (Tensor, error) + FMAScalar(ctx context.Context, a Tensor, x interface{}, y Tensor) (Tensor, error) } // MatMuler is any engine that can perform matrix multiplication type MatMuler interface { - MatMul(a, b, preallocated Tensor) error + MatMul(ctx context.Context, a, b, preallocated Tensor) error } // MatVecMuler is any engine that can perform matrix vector multiplication type MatVecMuler interface { - MatVecMul(a, b, preallocated Tensor) error + MatVecMul(ctx context.Context, a, b, preallocated Tensor) error } // InnerProder is any engine that can perform inner product multiplication type InnerProder interface { - Inner(a, b Tensor) (interface{}, error) // Inner always returns a scalar value + Inner(ctx context.Context, a, b Tensor) (interface{}, error) // Inner always returns a scalar value } // InnerProderF32 is an optimization for float32 - results are returned as float32. type InnerProderF32 interface { - Inner(a, b Tensor) (float32, error) + Inner(ctx context.Context, a, b Tensor) (float32, error) } // InnerProderF64 is an optimization for float64 - results are returned as float64 type InnerProderF64 interface { - Inner(a, b Tensor) (float64, error) + Inner(ctx context.Context, a, b Tensor) (float64, error) } // OuterProder is any engine that can perform outer product (kronecker) multiplication type OuterProder interface { - Outer(a, b, preallocated Tensor) error + Outer(ctx context.Context, a, b, preallocated Tensor) error } // Dotter is used to implement sparse matrices @@ -223,7 +243,7 @@ type Dotter interface { // SVDer is any engine that can perform SVD type SVDer interface { - SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) + SVD(ctx context.Context, a Tensor, uv, full bool) (s, u, v Tensor, err error) } /* ORD INTERFACES */ @@ -330,6 +350,16 @@ type InvSqrter interface { InvSqrt(a Tensor, opts ...FuncOpt) (Tensor, error) } +// Expm1er is any engine that can perform expm1 on the values of a Tensor. +type Expm1er interface { + Expm1(a Tensor, opts ...FuncOpt) (Tensor, error) +} + +// Log1per is any engine that can perform log1p on the values of a Tensor. +type Log1per interface { + Log1p(a Tensor, opts ...FuncOpt) (Tensor, error) +} + // Signer is any engine that can perform a sign function on the values of a Tensor. type Signer interface { Sign(a Tensor, opts ...FuncOpt) (Tensor, error) @@ -359,22 +389,22 @@ type OptimizedReducer interface { // Sumer is any engine that can perform summation along an axis of a Tensor. type Sumer interface { - Sum(a Tensor, along ...int) (Tensor, error) + Sum(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Proder is any engine that can perform product along an axis of a Tensor. type Proder interface { - Prod(a Tensor, along ...int) (Tensor, error) + Prod(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Miner is any engine that can find the minimum value along an axis of a Tensor. type Miner interface { - Min(a Tensor, along ...int) (Tensor, error) + Min(ctx context.Context, a Tensor, along ...int) (Tensor, error) } // Maxer is any engine that can find the maximum value along an axis of a Tensor. type Maxer interface { - Max(a Tensor, along ...int) (Tensor, error) + Max(ctx context.Context, a Tensor, along ...int) (Tensor, error) } /* Arg methods */ @@ -382,27 +412,27 @@ type Maxer interface { // Argmaxer is any engine that can find the indices of the maximum values along an axis. // By convention the returned Tensor has Dtype of Int. type Argmaxer interface { - Argmax(t Tensor, axis int) (Tensor, error) + Argmax(ctx context.Context, t Tensor, axis int) (Tensor, error) } // Argmaxer is any engine that can find the indices of the minimum values along an axis. // By convention the returned Tensor has Dtype of Int. type Argminer interface { - Argmin(t Tensor, axis int) (Tensor, error) + Argmin(ctx context.Context, t Tensor, axis int) (Tensor, error) } // NaNChecker checks that the tensor contains a NaN // Errors are to be returned if the concept of NaN does not apply to the data type. // Other errors may also occur. See specific implementations for details type NaNChecker interface { - HasNaN(t Tensor) (bool, error) + HasNaN(ctx context.Context, t Tensor) (bool, error) } // InfChecker checks that the tensor contains a Inf. // Errors are to be returned if the concept of Inf does not apply to the data type. // Other errors may also occur. See specific implementations for details type InfChecker interface { - HasInf(t Tensor) (bool, error) + HasInf(ctx context.Context, t Tensor) (bool, error) } /* Advanced Indexing */ @@ -413,14 +443,18 @@ type ByIndiceser interface { SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) } +type Scatterer interface { + Scatter(a, indices Tensor, opts ...FuncOpt) (retVal Tensor, err error) +} + /* Internal interfaces for faster shit */ type denseArgmaxer interface { - argmaxDenseTensor(t DenseTensor, axis int) (*Dense, error) + argmaxDenseTensor(ctx context.Context, t DenseTensor, axis int) (*Dense, error) } type denseArgminer interface { - argminDenseTensor(t DenseTensor, axis int) (*Dense, error) + argminDenseTensor(ctx context.Context, t DenseTensor, axis int) (*Dense, error) } type SoftMaxer interface { diff --git a/errors.go b/errors.go index 314c91c..5806faf 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,11 @@ package tensor -import "fmt" +import ( + "fmt" + "runtime" + + "github.com/pkg/errors" +) // NoOpError is a useful for operations that have no op. type NoOpError interface { @@ -60,6 +65,58 @@ const ( maskRequired = "Masked array type required for %v" inaccessibleData = "Data in %p inaccessible" - methodNYI = "%q not yet implemented for %v" - typeNYI = "%q not yet implemented for interactions with %T" + // NYI errors + + methodNYI = "%q not yet implemented for %v." + typeNYI = "%q not yet implemented for interactions with %T." + typeNYI2 = "%q (%v) not yet implemented for interactions with %T." + prmsg = "Please make a pull request at github.com/gorgonia/tensor if you wish to contribute a solution" ) + +// nyierr is a convenience function that decorates a NYI error message with additional information. +// +// It assumes that `msg` is either `typeNYI` or `methodNYI`. +func nyierr(msg string, args ...interface{}) error { + var fnName string = "UNKNOWN FUNCTION" + pc, _, _, ok := runtime.Caller(1) + if ok { + fnName = runtime.FuncForPC(pc).Name() + } + + switch len(args) { + case 0: + // no args + case 1: + // the usual + switch msg { + case methodNYI: + // do nothing + case typeNYI: + // do nothing + case typeNYI2: + // this is the wrong message to use, so we revert to typeNYI. + msg = typeNYI + default: + // do nothing + } + case 2: + switch msg { + case methodNYI: + // do nothing + case typeNYI: + // we assume that args[0] is an additional descriptive string. + msg = typeNYI2 + case typeNYI2: + // do nothing + default: + // do nothing + } + default: + } + + // prepend fnName + args = append(args, fnName) + copy(args[1:], args[0:]) + args[0] = fnName + return errors.Errorf(msg, args...) +} diff --git a/example_batched_nativeselect_test.go b/example_batched_nativeselect_test.go new file mode 100644 index 0000000..cfa128e --- /dev/null +++ b/example_batched_nativeselect_test.go @@ -0,0 +1,106 @@ +package tensor + +import ( + "fmt" +) + +func ExampleBatchedNativeSelectF64() { + T := New(WithShape(50, 5), WithBacking(Range(Float64, 1, 251))) + + // now let's iterate this using a lazy native select, selecting 10 rows at time + + fmt.Println("Batchsize of 10") + it := BatchSelectF64(T, 0, 10) + var batchNo int + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ + } + fmt.Printf("Is Truncated? %t\n", it.IsTruncated()) + + fmt.Println("Reusing the same iterator for another loop") + batchNo = 0 + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ + } + + fmt.Println("Batchsize of 3") + it = BatchSelectF64(T, 0, 3) + batchNo = 0 + for cur, hasRem := it.Start(); hasRem; cur, hasRem = it.Next() { + fmt.Printf("%d: %v\n", batchNo, cur) + batchNo++ + } + fmt.Printf("Is Truncated? %t\n", it.IsTruncated()) + + // Output: + // Batchsize of 10 + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] + // 1: [[51 52 53 54 55] [56 57 58 59 60] [61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75] [76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90] [91 92 93 94 95] [96 97 98 99 100]] + // 2: [[101 102 103 104 105] [106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120] [121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135] [136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 3: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165] [166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180] [181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195] [196 197 198 199 200]] + // 4: [[201 202 203 204 205] [206 207 208 209 210] [211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225] [226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240] [241 242 243 244 245] [246 247 248 249 250]] + // Is Truncated? false + // Reusing the same iterator for another loop + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15] [16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30] [31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45] [46 47 48 49 50]] + // 1: [[51 52 53 54 55] [56 57 58 59 60] [61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75] [76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90] [91 92 93 94 95] [96 97 98 99 100]] + // 2: [[101 102 103 104 105] [106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120] [121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135] [136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 3: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165] [166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180] [181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195] [196 197 198 199 200]] + // 4: [[201 202 203 204 205] [206 207 208 209 210] [211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225] [226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240] [241 242 243 244 245] [246 247 248 249 250]] + // Batchsize of 3 + // 0: [[1 2 3 4 5] [6 7 8 9 10] [11 12 13 14 15]] + // 1: [[16 17 18 19 20] [21 22 23 24 25] [26 27 28 29 30]] + // 2: [[31 32 33 34 35] [36 37 38 39 40] [41 42 43 44 45]] + // 3: [[46 47 48 49 50] [51 52 53 54 55] [56 57 58 59 60]] + // 4: [[61 62 63 64 65] [66 67 68 69 70] [71 72 73 74 75]] + // 5: [[76 77 78 79 80] [81 82 83 84 85] [86 87 88 89 90]] + // 6: [[91 92 93 94 95] [96 97 98 99 100] [101 102 103 104 105]] + // 7: [[106 107 108 109 110] [111 112 113 114 115] [116 117 118 119 120]] + // 8: [[121 122 123 124 125] [126 127 128 129 130] [131 132 133 134 135]] + // 9: [[136 137 138 139 140] [141 142 143 144 145] [146 147 148 149 150]] + // 10: [[151 152 153 154 155] [156 157 158 159 160] [161 162 163 164 165]] + // 11: [[166 167 168 169 170] [171 172 173 174 175] [176 177 178 179 180]] + // 12: [[181 182 183 184 185] [186 187 188 189 190] [191 192 193 194 195]] + // 13: [[196 197 198 199 200] [201 202 203 204 205] [206 207 208 209 210]] + // 14: [[211 212 213 214 215] [216 217 218 219 220] [221 222 223 224 225]] + // 15: [[226 227 228 229 230] [231 232 233 234 235] [236 237 238 239 240]] + // 16: [[241 242 243 244 245] [246 247 248 249 250]] + // Is Truncated? true + +} + +func ExampleIterSelect() { + T := New(WithShape(20, 5), WithBacking(Range(Float64, 1, 101))) + it := NewIterSelect(T, 0) + data := T.Float64s() + var rowNo int + for start, end, hasRem := it.Start(); hasRem; start, end, hasRem = it.Next() { + sl := data[start:end] + fmt.Printf("%d: %v\n", rowNo, sl) + rowNo++ + } + + // Output: + // 0: [1 2 3 4 5] + // 1: [6 7 8 9 10] + // 2: [11 12 13 14 15] + // 3: [16 17 18 19 20] + // 4: [21 22 23 24 25] + // 5: [26 27 28 29 30] + // 6: [31 32 33 34 35] + // 7: [36 37 38 39 40] + // 8: [41 42 43 44 45] + // 9: [46 47 48 49 50] + // 10: [51 52 53 54 55] + // 11: [56 57 58 59 60] + // 12: [61 62 63 64 65] + // 13: [66 67 68 69 70] + // 14: [71 72 73 74 75] + // 15: [76 77 78 79 80] + // 16: [81 82 83 84 85] + // 17: [86 87 88 89 90] + // 18: [91 92 93 94 95] + // 19: [96 97 98 99 100] + +} diff --git a/example_dense_arith_test.go b/example_dense_arith_test.go index a78fd21..4a8ec10 100644 --- a/example_dense_arith_test.go +++ b/example_dense_arith_test.go @@ -13,7 +13,7 @@ func ExampleDense_Add_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Add(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] + T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -57,7 +57,7 @@ func ExampleDense_Add_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Add(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -100,7 +100,7 @@ func ExampleDense_Add_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Add(T2, WithReuse(Reuse)) @@ -171,7 +171,7 @@ func ExampleDense_Add_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Add(T2, WithIncr(Incr)) @@ -209,7 +209,7 @@ func ExampleDense_Sub_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Sub(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] + T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -253,7 +253,7 @@ func ExampleDense_Sub_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Sub(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -296,7 +296,7 @@ func ExampleDense_Sub_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Sub(T2, WithReuse(Reuse)) @@ -365,7 +365,7 @@ func ExampleDense_Sub_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Sub(T2, WithIncr(Incr)) @@ -403,7 +403,7 @@ func ExampleDense_Mul_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Mul(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] × T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -447,7 +447,7 @@ func ExampleDense_Mul_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Mul(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -490,7 +490,7 @@ func ExampleDense_Mul_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Mul(T2, WithReuse(Reuse)) @@ -560,7 +560,7 @@ func ExampleDense_Mul_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Mul(T2, WithIncr(Incr)) @@ -598,7 +598,7 @@ func ExampleDense_Div_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Div(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] ÷ T2\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) @@ -642,7 +642,7 @@ func ExampleDense_Div_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Div(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -685,7 +685,7 @@ func ExampleDense_Div_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Div(T2, WithReuse(Reuse)) @@ -754,7 +754,7 @@ func ExampleDense_Div_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Div(T2, WithIncr(Incr)) @@ -792,7 +792,7 @@ func ExampleDense_Pow_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Pow(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] ^ T2\nT3:\n%1.1v\nT1 is unchanged:\n%v\n", T3, T1) @@ -836,7 +836,7 @@ func ExampleDense_Pow_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Pow(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -880,7 +880,7 @@ func ExampleDense_Pow_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Pow(T2, WithReuse(Reuse)) @@ -917,7 +917,7 @@ func ExampleDense_Pow_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Pow(T2, WithIncr(Incr)) @@ -955,7 +955,7 @@ func ExampleDense_Mod_basic() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) T3, _ = V.Mod(T2) fmt.Printf("Default operation is safe (sliced operations)\n=============================================\nT3 = T1[0:2, 0:2] %% T2\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -999,7 +999,7 @@ func ExampleDense_Mod_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) V.Mod(T2, UseUnsafe()) // unsafe overwrites the data in T1 @@ -1043,7 +1043,7 @@ func ExampleDense_Mod_reuse() { // You can also use it on operations on sliced tensors - note your reuse tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Reuse = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.Mod(T2, WithReuse(Reuse)) @@ -1080,7 +1080,7 @@ func ExampleDense_Mod_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 10, 14)), WithShape(2, 2)) Incr = New(WithBacking([]float64{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.Mod(T2, WithIncr(Incr)) @@ -1122,13 +1122,13 @@ func ExampleDense_AddScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1198,15 +1198,15 @@ func ExampleDense_AddScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.AddScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 + T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1245,7 +1245,7 @@ func ExampleDense_AddScalar_unsafe() { // ⎢ 8 9⎥ // ⎣11 12⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 6 2⎤ // ⎢ 8 9 5⎥ @@ -1259,7 +1259,7 @@ func ExampleDense_AddScalar_unsafe() { // ⎢ 8 9⎥ // ⎣11 12⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 6 2⎤ // ⎢ 8 9 5⎥ @@ -1286,7 +1286,7 @@ func ExampleDense_AddScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.AddScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1294,7 +1294,7 @@ func ExampleDense_AddScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.AddScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1344,7 +1344,7 @@ func ExampleDense_AddScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.AddScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 + T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1381,13 +1381,13 @@ func ExampleDense_SubScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1458,15 +1458,15 @@ func ExampleDense_SubScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.SubScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 - T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1505,7 +1505,7 @@ func ExampleDense_SubScalar_unsafe() { // ⎢-2 -1⎥ // ⎣ 1 2⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡-5 -4 2⎤ // ⎢-2 -1 5⎥ @@ -1519,7 +1519,7 @@ func ExampleDense_SubScalar_unsafe() { // ⎢ 2 1⎥ // ⎣-1 -2⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 5 4 2⎤ // ⎢ 2 1 5⎥ @@ -1546,7 +1546,7 @@ func ExampleDense_SubScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.SubScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1554,7 +1554,7 @@ func ExampleDense_SubScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.SubScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1604,7 +1604,7 @@ func ExampleDense_SubScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.SubScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 - T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1643,13 +1643,13 @@ func ExampleDense_MulScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 1:3]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -1719,15 +1719,15 @@ func ExampleDense_MulScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.MulScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 0:2]\nT3:\n%v\nsliced == T3: %t\nT1 is changed:\n%v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 * T1[:, 0:2]\nT3:\n%v\nV == T3: %t\nT1 is changed:\n%v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -1766,7 +1766,7 @@ func ExampleDense_MulScalar_unsafe() { // ⎢15 20⎥ // ⎣30 35⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 5 2⎤ // ⎢15 20 5⎥ @@ -1780,7 +1780,7 @@ func ExampleDense_MulScalar_unsafe() { // ⎢15 20⎥ // ⎣30 35⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 5 2⎤ // ⎢15 20 5⎥ @@ -1807,7 +1807,7 @@ func ExampleDense_MulScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.MulScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v\n", T3 == Reuse, T3) @@ -1815,7 +1815,7 @@ func ExampleDense_MulScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.MulScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%v", T3 == Reuse, T3) @@ -1865,7 +1865,7 @@ func ExampleDense_MulScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.MulScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 * T2\nIncr == T3: %t\nT3:\n%v\n", Incr == T3, T3) @@ -1902,13 +1902,13 @@ func ExampleDense_DivScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 1:3] + 5\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(1, 3)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 1:3]\nT3:\n%1.1v\nT1 is unchanged:\n%1.1v\n", T3, T1) @@ -1978,15 +1978,15 @@ func ExampleDense_DivScalar_unsafe() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), true, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%1.1v\nsliced == T3: %t\nT1 is changed:\n%1.1v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[:, 0:2] + 5\nT3:\n%1.1v\nV == T3: %t\nT1 is changed:\n%1.1v\n", T3, V == T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(nil, makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.DivScalar(float32(5), false, UseUnsafe()) - fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 0:2]\nT3:\n%1.1v\nsliced == T3: %t\nT1 is changed:\n%1.1v\n", T3, sliced == T3, T1) + fmt.Printf("Operation is unsafe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 / T1[:, 0:2]\nT3:\n%1.1v\nV == T3: %t\nT1 is changed:\n%1.1v\n", T3, V == T3, T1) // Output: // Operation is unsafe (tensor is left operand) @@ -2025,7 +2025,7 @@ func ExampleDense_DivScalar_unsafe() { // ⎢0.6 0.8⎥ // ⎣ 1 1⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡ 0 0.2 2⎤ // ⎢0.6 0.8 5⎥ @@ -2039,7 +2039,7 @@ func ExampleDense_DivScalar_unsafe() { // ⎢ 2 1⎥ // ⎣ 0.8 0.7⎦ // - // sliced == T3: true + // V == T3: true // T1 is changed: // ⎡+Inf 5 2⎤ // ⎢ 2 1 5⎥ @@ -2066,7 +2066,7 @@ func ExampleDense_DivScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.DivScalar(float32(5), true, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%1.1v\n", T3 == Reuse, T3) @@ -2074,7 +2074,7 @@ func ExampleDense_DivScalar_reuse() { // Tensor is left operand T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Reuse = New(WithBacking(Range(Float32, 100, 104)), WithShape(2, 2)) // same shape as result T3, _ = V.DivScalar(float32(5), false, WithReuse(Reuse)) fmt.Printf("Reuse tensor passed in (sliced tensor - Tensor is left operand)\n======================================\nT3 == Reuse: %t\nT3:\n%1.1v", T3 == Reuse, T3) @@ -2124,7 +2124,7 @@ func ExampleDense_DivScalar_incr() { // Operations on sliced tensor is also allowed. Note that your Incr tensor has to be the same shape as the result T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) Incr = New(WithBacking([]float32{100, 100, 100, 100}), WithShape(2, 2)) T3, _ = V.DivScalar(float32(5), true, WithIncr(Incr)) fmt.Printf("Incr tensor passed in (sliced tensor)\n======================================\nIncr += T1 / T2\nIncr == T3: %t\nT3:\n%3.1v\n", Incr == T3, T3) @@ -2161,13 +2161,13 @@ func ExampleDense_PowScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.PowScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[0:2, 0:2] ^ 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.PowScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 ^ T1[0:2, 0:2]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) @@ -2236,13 +2236,13 @@ func ExampleDense_ModScalar_basic() { T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.ModScalar(float32(5), true) fmt.Printf("Default operation is safe (sliced operations - tensor is left operand)\n=============================================\nT3 = T1[0:2, 0:2] %% 5\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) T1 = New(WithBacking(Range(Float32, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T3, _ = V.ModScalar(float32(5), false) fmt.Printf("Default operation is safe (sliced operations - tensor is right operand)\n=============================================\nT3 = 5 %% T1[0:2, 0:2]\nT3:\n%v\nT1 is unchanged:\n%v\n", T3, T1) diff --git a/example_dense_cmp_test.go b/example_dense_cmp_test.go index 6d72c4d..9166821 100644 --- a/example_dense_cmp_test.go +++ b/example_dense_cmp_test.go @@ -20,7 +20,7 @@ func ExampleDense_Gt_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -28,7 +28,7 @@ func ExampleDense_Gt_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gt(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -83,7 +83,7 @@ func ExampleDense_Gt_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Gt(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -129,7 +129,7 @@ func ExampleDense_Gt_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Gt(T2, WithReuse(T3)) @@ -138,7 +138,7 @@ func ExampleDense_Gt_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Gt(T2, WithReuse(T3), AsSameType()) @@ -192,7 +192,7 @@ func ExampleDense_Gte_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -200,7 +200,7 @@ func ExampleDense_Gte_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Gte(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -255,7 +255,7 @@ func ExampleDense_Gte_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Gte(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -301,7 +301,7 @@ func ExampleDense_Gte_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Gte(T2, WithReuse(T3)) @@ -310,7 +310,7 @@ func ExampleDense_Gte_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Gte(T2, WithReuse(T3), AsSameType()) @@ -364,7 +364,7 @@ func ExampleDense_Lt_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lt(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -372,7 +372,7 @@ func ExampleDense_Lt_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lt(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -427,7 +427,7 @@ func ExampleDense_Lt_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Lt(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -473,7 +473,7 @@ func ExampleDense_Lt_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Lt(T2, WithReuse(T3)) @@ -482,7 +482,7 @@ func ExampleDense_Lt_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Lt(T2, WithReuse(T3), AsSameType()) @@ -535,7 +535,7 @@ func ExampleDense_Lte_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lte(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -543,7 +543,7 @@ func ExampleDense_Lte_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.Lte(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -598,7 +598,7 @@ func ExampleDense_Lte_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.Lte(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -644,7 +644,7 @@ func ExampleDense_Lte_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.Lte(T2, WithReuse(T3)) @@ -653,7 +653,7 @@ func ExampleDense_Lte_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.Lte(T2, WithReuse(T3), AsSameType()) @@ -707,7 +707,7 @@ func ExampleDense_ElEq_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElEq(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -715,7 +715,7 @@ func ExampleDense_ElEq_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElEq(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -770,7 +770,7 @@ func ExampleDense_ElEq_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.ElEq(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -817,7 +817,7 @@ func ExampleDense_ElEq_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.ElEq(T2, WithReuse(T3)) @@ -826,7 +826,7 @@ func ExampleDense_ElEq_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.ElEq(T2, WithReuse(T3), AsSameType()) @@ -880,7 +880,7 @@ func ExampleDense_ElNe_basic() { // Sliced tensors are safe too T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElNe(T2) fmt.Printf("Safe slicing\n============\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -888,7 +888,7 @@ func ExampleDense_ElNe_basic() { // Similarly for tensors that return the same type T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) T3, _ = V.ElNe(T2, AsSameType()) // AsSameType returns a tensor of the same type fmt.Printf("Safe slicing (Same type)\n========================\nT3:\n%v\nT1 remains unchanged:\n%v\n", T3, T1) @@ -943,7 +943,7 @@ func ExampleDense_ElNe_unsafe() { T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 1, 5)), WithShape(2, 2)) V.ElNe(T2, UseUnsafe()) fmt.Printf("Unsafe operation, with a sliced Tensor\n======================================\nT1:\n%v", T1) @@ -990,7 +990,7 @@ func ExampleDense_ElNe_reuse() { // Slicing is similar: T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking([]bool{true, true, true, true}), WithShape(2, 2)) V.ElNe(T2, WithReuse(T3)) @@ -999,7 +999,7 @@ func ExampleDense_ElNe_reuse() { // Again, bear in mind same types T1 = New(WithBacking(Range(Float64, 0, 9)), WithShape(3, 3)) sliced, _ = T1.Slice(makeRS(0, 2), makeRS(0, 2)) - V = sliced.(*Dense) + V = MustGetDense(sliced) T2 = New(WithBacking(Range(Float64, 0, 4)), WithShape(2, 2)) T3 = New(WithBacking(Range(Float64, 100, 104)), WithShape(2, 2)) V.ElNe(T2, WithReuse(T3), AsSameType()) diff --git a/example_dense_linalg_test.go b/example_dense_linalg_test.go index d558481..13c9dcf 100644 --- a/example_dense_linalg_test.go +++ b/example_dense_linalg_test.go @@ -76,7 +76,7 @@ func ExampleDense_MatVecMul_rowMajorSliced() { // here we print the underlying slice of T3 just to show that it's actually a much larger slice fmt.Printf("Underlying Slice: %v\n", T3.Data()) - T4, err := T2.(*Dense).MatVecMul(T3) + T4, err := MustGetDense(T2).MatVecMul(T3) handleErr(err) fmt.Printf("T4:\n%v\n", T4) @@ -120,7 +120,7 @@ func ExampleDense_MatMul_sliced() { handleErr(err) fmt.Printf("T4:\n%v", T4) - T5, err := T3.(*Dense).MatMul(T4) + T5, err := MustGetDense(T3).MatMul(T4) handleErr(err) fmt.Printf("T3xT4:\n%v", T5) diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 497e9d1..c6f50a5 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -31,6 +31,71 @@ func ExampleDense_Slice() { // [1 4] } +func ExampleDense_SliceInto() { + var v Tensor + var err error + T := New(WithBacking(Range(Int, 0, 9)), WithShape(3, 3)) + fmt.Println("SliceInto works with nil values. It simply creates a View.\n==========================================================") + fmt.Printf("T:\n%v\n", T) + + if v, err = T.SliceInto(v, makeRS(0, 2), makeRS(0, 2)); err != nil { + fmt.Printf("ERR %v\n", err) + return + } + fmt.Printf("T[0:2, 0:2]:\n%v\n", v) + + v.Zero() + fmt.Printf("When v is zeroed, T is zeroed too.\n==================================\nv:\n%v\nT:\n%v\n", v, T) + + fmt.Println("Primary use case of SliceInto.\n==============================") + T = New(WithBacking(Range(Int, 0, 9)), WithShape(3, 3)) + fmt.Printf("T:\n%v\nv:\n%v\n", T, v) + if v, err = T.SliceInto(v, makeRS(0, 2), makeRS(0, 2)); err != nil { + fmt.Printf("ERR %v\n", err) + return + } + fmt.Printf("v = T[0:2, 0:2]:\n%v\n", v) + + // Output: + // SliceInto works with nil values. It simply creates a View. + // ========================================================== + // T: + // ⎡0 1 2⎤ + // ⎢3 4 5⎥ + // ⎣6 7 8⎦ + // + // T[0:2, 0:2]: + // ⎡0 1⎤ + // ⎣3 4⎦ + // + // When v is zeroed, T is zeroed too. + // ================================== + // v: + // ⎡0 0⎤ + // ⎣0 0⎦ + // + // T: + // ⎡0 0 0⎤ + // ⎢0 0 5⎥ + // ⎣6 7 8⎦ + // + // Primary use case of SliceInto. + // ============================== + // T: + // ⎡0 1 2⎤ + // ⎢3 4 5⎥ + // ⎣6 7 8⎦ + // + // v: + // ⎡0 0⎤ + // ⎣0 0⎦ + // + // v = T[0:2, 0:2]: + // ⎡0 1⎤ + // ⎣3 4⎦ + +} + // Slicing works on one dimensional arrays too: func ExampleDense_Slice_oneDimension() { var T Tensor @@ -58,7 +123,7 @@ func ExampleDense_Slice_viewMutation() { fmt.Printf("V:\n%v\n", V) // Now we modify V's 0th value - V.(*Dense).Set(0, 1000) + MustGetDense(V).Set(0, 1000) fmt.Printf("V[0] = 1000:\n%v\n", V) fmt.Printf("T is also mutated:\n%v", T) @@ -93,7 +158,7 @@ func ExampleView() { fmt.Printf("V:\n%v\n", V) // Now we modify V's 0th value - V.(*Dense).Set(0, 1000) + MustGetDense(V).Set(0, 1000) fmt.Printf("V[0] = 1000:\n%v\n", V) fmt.Printf("T is also mutated:\n%v\n", T) diff --git a/example_dense_scatter_test.go b/example_dense_scatter_test.go new file mode 100644 index 0000000..d71d4ff --- /dev/null +++ b/example_dense_scatter_test.go @@ -0,0 +1,79 @@ +package tensor + +import "fmt" + +func ExampleScatter() { + T := New(WithShape(2, 3, 4), WithBacking([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + })) + + indices := New(WithShape(2, 3, 4), WithBacking([]int{ + 3, 2, 1, 0, + 3, 2, 1, 0, + 4, 3, 2, 1, + + 0, 4, 1, 2, + 4, 4, 4, 4, + 3, 3, 3, 3, + })) + + s, err := Scatter(T, indices) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("%v\n", s) + + // Output: + // ⎡ 3 2 1 0 0⎤ + // ⎢ 7 6 5 4 0⎥ + // ⎣ 0 11 10 9 8⎦ + // + // ⎡ 0 2 3 0 1⎤ + // ⎢ 0 0 0 0 7⎥ + // ⎣ 0 0 0 11 0⎦ + +} + +func ExampleScatter_matrixIndices() { + T := New(WithShape(2, 3, 4), WithBacking([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + })) + + indices := New(WithShape(5, 4), WithBacking([]int{ + 3, 2, 1, 0, + 3, 2, 1, 0, + 4, 3, 2, 1, + 0, 4, 1, 2, + 4, 4, 4, 4, + })) + + s, err := Scatter(T, indices) + if err != nil { + fmt.Println(err) + return + } + + fmt.Printf("%v\n", s) + + // Output: + // ⎡ 3 2 1 0 0⎤ + // ⎢ 7 6 5 4 0⎥ + // ⎢ 0 11 10 9 8⎥ + // ⎢ 0 2 3 0 1⎥ + // ⎣ 0 0 0 0 7⎦ + +} diff --git a/example_extension_test.go b/example_extension_test.go index e5c2b22..23be0f7 100644 --- a/example_extension_test.go +++ b/example_extension_test.go @@ -6,6 +6,7 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/tensor" ) @@ -21,7 +22,7 @@ type MyType struct { func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) } // MyDtype this the dtype of MyType. This value is populated in the init() function below -var MyDtype tensor.Dtype +var MyDtype dtype.Dtype // MyEngine supports additions of MyType, as well as other Dtypes type MyEngine struct { @@ -73,7 +74,7 @@ func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor } func init() { - MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})} + MyDtype = dtype.Dtype{reflect.TypeOf(&MyType{})} } func Example_extension() { diff --git a/example_iterator_test.go b/example_iterator_test.go index aff34e3..0fa3025 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -1,6 +1,9 @@ package tensor -import "fmt" +import ( + "fmt" + "sync" +) // This is an example of how to use `IteratorFromDense` from a row-major Dense tensor func Example_iteratorRowmajor() { @@ -58,8 +61,8 @@ func ExampleSliceIter() { fmt.Printf("Err %v\n", err) return } - fmt.Printf("S (requires iterator? %t)\n%v\n", S.(*Dense).RequiresIterator(), S) - it := IteratorFromDense(S.(*Dense)) + fmt.Printf("S (requires iterator? %t)\n%v\n", S.(DenseView).RequiresIterator(), S) + it := IteratorFromDense(S.(DenseView)) for i, err := it.Start(); err == nil; i, err = it.Next() { fmt.Printf("i %d, coord %v\n", i, it.Coord()) } @@ -75,3 +78,136 @@ func ExampleSliceIter() { // i 4, coord [0 0] } + +func ExampleAxialIterator() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + it := AxialIteratorFromDense(T, 1, 0, false) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i %d coord %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // i 0 coord [0 0 1] + // i 1 coord [0 0 2] + // i 2 coord [0 0 3] + // i 3 coord [1 0 0] + // i 12 coord [1 0 1] + // i 13 coord [1 0 2] + // i 14 coord [1 0 3] + // i 15 coord [0 1 0] + // i 4 coord [0 1 1] + // i 5 coord [0 1 2] + // i 6 coord [0 1 3] + // i 7 coord [1 1 0] + // i 16 coord [1 1 1] + // i 17 coord [1 1 2] + // i 18 coord [1 1 3] + // i 19 coord [0 2 0] + // i 8 coord [0 2 1] + // i 9 coord [0 2 2] + // i 10 coord [0 2 3] + // i 11 coord [1 2 0] + // i 20 coord [1 2 1] + // i 21 coord [1 2 2] + // i 22 coord [1 2 3] + // i 23 coord [0 0 0] +} + +func ExampleAxialIterator_2() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + it := AxialIteratorFromDense(T, 1, 1, true) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i %d coord %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // i 4 coord [0 1 1] + // i 5 coord [0 1 2] + // i 6 coord [0 1 3] + // i 7 coord [1 1 0] + // i 16 coord [1 1 1] + // i 17 coord [1 1 2] + // i 18 coord [1 1 3] + // i 19 coord [0 0 0] +} + +func ExampleAxialIterator_concurrent() { + T := New(WithShape(2, 3, 4), WithBacking([]float64{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + })) + fmt.Printf("T:\n%v", T) + + axis := 1 + var its []Iterator + for i := 0; i < T.Shape()[axis]; i++ { + it := AxialIteratorFromDense(T, axis, i, true) + its = append(its, it) + } + + done := make(chan float64, T.Shape()[axis]) + var wg sync.WaitGroup + for _, it := range its { + wg.Add(1) + go func(it Iterator, t *Dense, done chan float64, wg *sync.WaitGroup) { + data := t.Data().([]float64) + var sum float64 + for i, err := it.Start(); err == nil; i, err = it.Next() { + sum += data[i] + } + done <- sum + wg.Done() + }(it, T, done, &wg) + } + + wg.Wait() + close(done) + + var total float64 + for v := range done { + total += v + } + + fmt.Printf("Total: %v", total) + + // Output: + // T: + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // ⎡ 0 1 2 3⎤ + // ⎢ 4 5 6 7⎥ + // ⎣ 8 9 10 11⎦ + // + // Total: 132 + +} diff --git a/example_mapreduce_test.go b/example_mapreduce_test.go index 47bd2ce..51c1dfe 100644 --- a/example_mapreduce_test.go +++ b/example_mapreduce_test.go @@ -89,7 +89,6 @@ func ExampleArgmax_sliced() { // // Argmax: 0 // Argmax is *tensor.Dense of int - } func ExampleArgmin() { @@ -109,3 +108,22 @@ func ExampleArgmin() { // Argmin: [0 1] // Argmin is *tensor.Dense of int } + +func ExampleMax() { + T := New(WithBacking([]int{1, 2, 5, 3, 4, 1}), WithShape(2, 3)) + fmt.Printf("T\n%v\n", T) + + // Max along all axes + m, _ := Max(T) + fmt.Printf("Max: %v\n", m) + fmt.Printf("Max is %T of %v", m, m.Dtype()) + + // Output: + // T + // ⎡1 2 5⎤ + // ⎣3 4 1⎦ + // + // Max: 5 + // Max is *tensor.Dense of int + +} diff --git a/example_matop_test.go b/example_matop_test.go new file mode 100644 index 0000000..4c0d4da --- /dev/null +++ b/example_matop_test.go @@ -0,0 +1,59 @@ +package tensor_test + +import ( + "fmt" + + "gorgonia.org/tensor" +) + +func ExampleTranspose() { + t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]int{1, 2, 3, 4, 5, 6})) + t2, err := tensor.Transpose(t) + if err != nil { + fmt.Printf("ERR: %v\n", err) + } + fmt.Printf("Transpose is a safe operation.\nT:\n%v\nT':\n%v\n", t, t2) + fmt.Printf("The data is changed:\nT : %v\nT': %v", t.Data(), t2.Data()) + + // Output: + // Transpose is a safe operation. + // T: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // T': + // ⎡1 4⎤ + // ⎢2 5⎥ + // ⎣3 6⎦ + // + // The data is changed: + // T : [1 2 3 4 5 6] + // T': [1 4 2 5 3 6] + +} + +func ExampleT() { + t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]int{1, 2, 3, 4, 5, 6})) + t2, err := tensor.T(t) + if err != nil { + fmt.Printf("ERR: %v\n", err) + } + fmt.Printf("T is a safe version of the .T() method\nT:\n%v\nT':\n%v\n", t, t2) + fmt.Printf("The data is unchanged:\nT : %v\nT': %v\n", t.Data(), t2.Data()) + + // Output: + // T is a safe version of the .T() method + // T: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // + // T': + // ⎡1 4⎤ + // ⎢2 5⎥ + // ⎣3 6⎦ + // + // The data is unchanged: + // T : [1 2 3 4 5 6] + // T': [1 2 3 4 5 6] + +} diff --git a/flags.go b/flags.go index 22fed67..5cc0bae 100644 --- a/flags.go +++ b/flags.go @@ -116,52 +116,3 @@ func MakeMemoryFlag(fs ...MemoryFlag) (retVal MemoryFlag) { func (f MemoryFlag) nativelyAccessible() bool { return !((f & NativelyInaccessible) != 0) } func (f MemoryFlag) manuallyManaged() bool { return (f & ManuallyManaged) != 0 } func (f MemoryFlag) isOverallocated() bool { return (f & IsOverallocated) != 0 } - -// OpOpt are the options used to call ops -type OpOpt struct { - reuse Tensor - incr Tensor - unsafe bool - same bool - t Dtype -} - -// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. -func ParseFuncOpts(opts ...FuncOpt) *OpOpt { - retVal := borrowOpOpt() - for _, opt := range opts { - opt(retVal) - } - return retVal -} - -// Incr returns the tensor to be incremented in the call. Can be nil. -func (fo *OpOpt) Incr() Tensor { return fo.incr } - -// Reuse returns the tensor to be reused in the call. Can be nil. -func (fo *OpOpt) Reuse() Tensor { return fo.reuse } - -// IncReuse returns whether a reuse tensor is to be used as the incr Tensor -func (fo *OpOpt) IncrReuse() (Tensor, bool) { - if fo.incr != nil { - return fo.incr, true - } - return fo.reuse, false -} - -// Safe signals if the op is to be done safely -func (fo *OpOpt) Safe() bool { return !fo.unsafe } - -// Same signals if the op is to return the same type as its inputs -func (fo *OpOpt) Same() bool { return fo.same } - -// As returns the dtype of the return value of the method call. -// For example: -// a.Lt(b, As(Bool)) -// indicates that the result of the `Lt()` should be a Tensor of Bool. -// -// Another example: -// a.Add(b, As(Int)) -// indicates that the result of `Add()` should be converted to a Tensor of Int. -// Note that this function is not yet supported in most operations. -func (fo *OpOpt) As() Dtype { return fo.t } diff --git a/flags_test.go b/flags_test.go index 83dd3be..26d10e8 100644 --- a/flags_test.go +++ b/flags_test.go @@ -1,90 +1,90 @@ -package tensor - -import "testing" - -func TestMemoryFlag(t *testing.T) { - var defaultFlag MemoryFlag - if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { - t.Errorf("Something went wrong with the creation of flags") - } - - a := ManuallyManaged - if !a.manuallyManaged() { - t.Errorf("Expected ManuallyManaged to be true") - } - if !a.nativelyAccessible() { - t.Errorf("Expected ManuallyManaged to be nativelyAccessible") - } - - b := NativelyInaccessible - if b.manuallyManaged() { - t.Errorf("Expected NativelyInaccessible to not be manually managed") - } - if b.nativelyAccessible() { - t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) - } - - c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) - if !c.manuallyManaged() { - t.Errorf("Expected c to be manually managed") - } - if c.nativelyAccessible() { - t.Errorf("Expected c to be natively inaccessible") - } -} - -func TestDataOrder(t *testing.T) { - var defaultFlag DataOrder - if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { - t.Error("Expected default flag to be row major and contiguous and not transposed") - } - if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { - t.Error("Expected default flag to be row major and contiguous") - } - if defaultFlag.String() != "Contiguous, RowMajor" { - t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) - } - - ncrm := MakeDataOrder(NonContiguous) - if ncrm.IsColMajor() || ncrm.IsContiguous() { - t.Error("Expected noncontiguous row major.") - } - if ncrm.String() != "NonContiguous, RowMajor" { - t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) - } - - cm := ColMajor - if cm.IsRowMajor() { - t.Error("colMajor cannot be rowMajor") - } - if cm.IsNotContiguous() { - t.Error("ColMajor by default is contiguous") - } - if cm.String() != "Contiguous, ColMajor" { - t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) - } - - // check toggle - rm := cm.toggleColMajor() - if rm.IsColMajor() { - t.Errorf("toggled cm should be rm") - } - - cm = rm.toggleColMajor() - if cm.IsRowMajor() { - t.Errorf("toggled rm should be cm") - } - - transposed := MakeDataOrder(Transposed) - if !transposed.IsTransposed() { - t.Error("Expected transposed flag to be set") - } - if transposed.String() != "Contiguous, RowMajorᵀ" { - t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) - } - untransposed := transposed.clearTransposed() - if untransposed != defaultFlag { - t.Error("Expected default flag after untransposing") - } - -} +package tensor + +import "testing" + +func TestMemoryFlag(t *testing.T) { + var defaultFlag MemoryFlag + if defaultFlag.manuallyManaged() || !defaultFlag.nativelyAccessible() { + t.Errorf("Something went wrong with the creation of flags") + } + + a := ManuallyManaged + if !a.manuallyManaged() { + t.Errorf("Expected ManuallyManaged to be true") + } + if !a.nativelyAccessible() { + t.Errorf("Expected ManuallyManaged to be nativelyAccessible") + } + + b := NativelyInaccessible + if b.manuallyManaged() { + t.Errorf("Expected NativelyInaccessible to not be manually managed") + } + if b.nativelyAccessible() { + t.Errorf("Expected NativelyInaccessible to be false %v", b.nativelyAccessible()) + } + + c := MakeMemoryFlag(ManuallyManaged, NativelyInaccessible) + if !c.manuallyManaged() { + t.Errorf("Expected c to be manually managed") + } + if c.nativelyAccessible() { + t.Errorf("Expected c to be natively inaccessible") + } +} + +func TestDataOrder(t *testing.T) { + var defaultFlag DataOrder + if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { + t.Error("Expected default flag to be row major and contiguous and not transposed") + } + if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { + t.Error("Expected default flag to be row major and contiguous") + } + if defaultFlag.String() != "Contiguous, RowMajor" { + t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + ncrm := MakeDataOrder(NonContiguous) + if ncrm.IsColMajor() || ncrm.IsContiguous() { + t.Error("Expected noncontiguous row major.") + } + if ncrm.String() != "NonContiguous, RowMajor" { + t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + cm := ColMajor + if cm.IsRowMajor() { + t.Error("colMajor cannot be rowMajor") + } + if cm.IsNotContiguous() { + t.Error("ColMajor by default is contiguous") + } + if cm.String() != "Contiguous, ColMajor" { + t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) + } + + // check toggle + rm := cm.toggleColMajor() + if rm.IsColMajor() { + t.Errorf("toggled cm should be rm") + } + + cm = rm.toggleColMajor() + if cm.IsRowMajor() { + t.Errorf("toggled rm should be cm") + } + + transposed := MakeDataOrder(Transposed) + if !transposed.IsTransposed() { + t.Error("Expected transposed flag to be set") + } + if transposed.String() != "Contiguous, RowMajorᵀ" { + t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) + } + untransposed := transposed.clearTransposed() + if untransposed != defaultFlag { + t.Error("Expected default flag after untransposing") + } + +} diff --git a/funcopts.go b/funcopts.go new file mode 100644 index 0000000..65e3d4b --- /dev/null +++ b/funcopts.go @@ -0,0 +1,153 @@ +package tensor + +import ( + "context" + + "gorgonia.org/dtype" +) + +// FuncOpt are optionals for calling Tensor functions. +// The `*opOpt` type is unexported, but it's methods are exported. +// This is intentional as use of the `*opOpt` is very specialized. +// See funcopts.go for more information. +type FuncOpt func(*opOpt) + +// WithIncr passes in a Tensor to be incremented. +func WithIncr(incr Tensor) FuncOpt { + f := func(opt *opOpt) { + opt.incr = incr + } + return f +} + +// WithReuse passes in a Tensor to be reused. +func WithReuse(reuse Tensor) FuncOpt { + f := func(opt *opOpt) { + opt.reuse = reuse + } + return f +} + +// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions +func UseSafe() FuncOpt { + f := func(opt *opOpt) { + opt.unsafe = false + } + return f +} + +// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace +func UseUnsafe() FuncOpt { + f := func(opt *opOpt) { + opt.unsafe = true + } + return f +} + +// AsSameType makes sure that the return Tensor is the same type as input Tensors. +func AsSameType() FuncOpt { + f := func(opt *opOpt) { + opt.same = true + } + return f +} + +// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 +func As(t dtype.Dtype) FuncOpt { + f := func(opt *opOpt) { + opt.t = t + } + return f +} + +// WithContext allows a function to be called with a given context +func WithContext(ctx context.Context) FuncOpt { + f := func(opt *opOpt) { + opt.ctx = ctx + } + return f +} + +// opOpt are the options used to call ops +type opOpt struct { + reuse Tensor + incr Tensor + unsafe bool + same bool + t dtype.Dtype + ctx context.Context +} + +// ParseFuncOpts parses a list of FuncOpt into a single unified method call structure. +func ParseFuncOpts(opts ...FuncOpt) *opOpt { + retVal := borrowOpOpt() + + for _, opt := range opts { + opt(retVal) + } + if retVal.ctx == nil { + retVal.ctx = context.Background() // default context - required for no panics. + } + return retVal +} + +// Incr returns the tensor to be incremented in the call. Can be nil. +func (fo *opOpt) Incr() Tensor { return fo.incr } + +// Reuse returns the tensor to be reused in the call. Can be nil. +func (fo *opOpt) Reuse() Tensor { return fo.reuse } + +// IncrReuse returns whether a reuse tensor is to be used as the incr Tensor +func (fo *opOpt) IncrReuse() (Tensor, bool) { + if fo.incr != nil { + return fo.incr, true + } + return fo.reuse, false +} + +// Safe signals if the op is to be done safely +func (fo *opOpt) Safe() bool { return !fo.unsafe } + +// Same signals if the op is to return the same type as its inputs +func (fo *opOpt) Same() bool { return fo.same } + +// As returns the dtype of the return value of the method call. +// For example: +// a.Lt(b, As(Bool)) +// indicates that the result of the `Lt()` should be a Tensor of Bool. +// +// Another example: +// a.Add(b, As(Int)) +// indicates that the result of `Add()` should be converted to a Tensor of Int. +// Note that this function is not yet supported in most operations. +func (fo *opOpt) As() dtype.Dtype { return fo.t } + +// Context returns a context.Context that may have been passed in as a function option. +func (fo *opOpt) Context() context.Context { return fo.ctx } + +// SetReuse allows the reuse parameter to be set. +func (fo *opOpt) SetReuse(reuse Tensor) { fo.reuse = reuse } + +// SetIncr allows the incr parameter to be set. +func (fo *opOpt) SetIncr(incr Tensor) { fo.incr = incr } + +// FuncOpts is the inverse of ParseFuncOpts. +func (fo *opOpt) FuncOpts() []FuncOpt { + retVal := make([]FuncOpt, 0, 4) + if fo.reuse != nil { + retVal = append(retVal, WithReuse(fo.reuse)) + } + if fo.incr != nil { + retVal = append(retVal, WithIncr(fo.incr)) + } + if fo.unsafe { + retVal = append(retVal, UseUnsafe()) + } + if fo.same { + retVal = append(retVal, AsSameType()) + } + if fo.t != (Dtype{}) { + retVal = append(retVal, As(fo.t)) + } + return retVal +} diff --git a/generic_utils.go b/generic_utils.go index 24310b5..9a44263 100644 --- a/generic_utils.go +++ b/generic_utils.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -7,14 +5,17 @@ import ( "reflect" "github.com/pkg/errors" + "gorgonia.org/dtype" "gorgonia.org/vecf32" "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + // Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -172,7 +173,7 @@ func Range(dt Dtype, start, end int) interface{} { // For complex Dtypes, the imaginary component will be 0. // // This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { case reflect.Int: diff --git a/genlib2/agg2_body.go b/genlib2/agg2_body.go index 54dd1f2..6f85f90 100644 --- a/genlib2/agg2_body.go +++ b/genlib2/agg2_body.go @@ -5,27 +5,40 @@ import "text/template" // level 2 aggregation (tensor.StdEng) templates const cmpPrepRaw = `var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(),false, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { same = true } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } ` const arithPrepRaw = `var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil { + return nil, err // this err will be noopError{}, no need to wrap. + } ` const minmaxPrepRaw = `var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ + var ctx context.Context + if ctx, reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{ return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err !=nil{ + return nil, err // this err will be noopError{}, no need to wrap. + } ` -const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); err != nil { +const prepVVRaw = `if err = binaryCheck(a, b, dtype.{{.TypeClassCheck}}); err != nil { + return nil, errors.Wrapf(err, "{{.Name}} failed") } @@ -42,7 +55,7 @@ const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); } ` -const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); err != nil { +const prepMixedRaw = `if err = unaryCheck(t, dtype.{{.TypeClassCheck}}); err != nil { return nil, errors.Wrapf(err, "{{.Name}} failed") } @@ -73,15 +86,19 @@ const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); ` -const prepUnaryRaw = `if err = unaryCheck(a, {{.TypeClassCheck | lower}}Types); err != nil { +const prepUnaryRaw = `if err = unaryCheck(a, dtype.{{.TypeClassCheck}}); err != nil { err = errors.Wrapf(err, "{{.Name}} failed") return } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { + var ctx context.Context + if ctx, reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } + if err = handleCtx(ctx); err != nil{ + return nil, err // this err will be a noopError{}, no need to wrap. + } typ := a.Dtype().Type var ait, rit Iterator diff --git a/genlib2/agg3_body.go b/genlib2/agg3_body.go index c780e90..024204a 100644 --- a/genlib2/agg3_body.go +++ b/genlib2/agg3_body.go @@ -66,6 +66,9 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -76,7 +79,9 @@ const denseIdentityArithTestBodyRaw = `iden := func(a *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -96,6 +101,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -106,7 +114,9 @@ const denseIdentityArithScalarTestRaw = `iden1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -127,6 +137,9 @@ iden2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -137,7 +150,9 @@ iden2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -160,6 +175,9 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { _, ok := a.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}", a, b, we, err); retEarly{ if err != nil { return false @@ -171,7 +189,10 @@ const denseInvArithTestBodyRaw = `inv := func(a *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} + return true } @@ -191,6 +212,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call0" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}VS", a, b, we, err); retEarly{ if err != nil { return false @@ -202,7 +226,9 @@ const denseInvArithScalarTestRaw = `inv1 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } @@ -224,6 +250,9 @@ inv2 := func(q *Dense) bool { _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok {{template "call1" . }} + {{if eq .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} if err, retEarly := qcErrCheck(t, "{{.Name}}SV", a, b, we, err); retEarly{ if err != nil { return false @@ -235,7 +264,9 @@ inv2 := func(q *Dense) bool { if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { return false } - {{template "funcoptcheck" -}} + {{if ne .FuncOpt "context" -}} + {{template "funcoptcheck" -}} + {{end -}} return true } diff --git a/genlib2/arith_tests.go b/genlib2/arith_tests.go index c65a97f..77b1647 100644 --- a/genlib2/arith_tests.go +++ b/genlib2/arith_tests.go @@ -205,10 +205,10 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -234,6 +234,13 @@ func generateAPIArithTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) @@ -248,13 +255,13 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: API, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } @@ -280,6 +287,13 @@ func generateAPIArithScalarTests(f io.Writer, ak Kinds) { fn.FuncOpt = "incr" } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "context" + } + for _, fn := range tests { if fn.canWrite() { fn.Write(f) @@ -293,10 +307,10 @@ func generateDenseMethodArithTests(f io.Writer, ak Kinds) { t := &ArithTest{ arithOp: op, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } if t.name == "Pow" { - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" } tests = append(tests, t) } @@ -336,13 +350,13 @@ func generateDenseMethodScalarTests(f io.Writer, ak Kinds) { arithOp: op, scalars: true, lvl: Dense, - EqFailTypeClassName: "nil", + EqFailTypeClassName: "nilTC", } switch t.name { case "Pow": - t.EqFailTypeClassName = "complexTypes" + t.EqFailTypeClassName = "dtype.Complexes" case "Sub": - t.EqFailTypeClassName = "unsignedTypes" + t.EqFailTypeClassName = "dtype.Unsigned" } tests = append(tests, t) } diff --git a/genlib2/cmp_tests.go b/genlib2/cmp_tests.go index 8d3d8f6..1110e6e 100644 --- a/genlib2/cmp_tests.go +++ b/genlib2/cmp_tests.go @@ -1,471 +1,471 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const ( - APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` - APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` - APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` - APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` - - DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` - DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` - DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` - DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` -) - -const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} - if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ - t.Errorf("a: %-v", a) - t.Errorf("b: %-v", b) - t.Errorf("c: %-v", c) - t.Errorf("axb.Data() %v", axb.Data()) - t.Errorf("bxc.Data() %v", bxc.Data()) - t.Errorf("axc.Data() %v", axc.Data()) - return false - } -{{else -}} - {{if eq .Level "API" -}} - ab := axb.(*Dense).Bools() - bc := bxc.(*Dense).Bools() - ac := axc.(*Dense).Bools() - {{else -}} - ab := axb.Bools() - bc := bxc.Bools() - ac := axc.Bools() - {{end -}} - for i, vab := range ab { - if vab && bc[i] { - if !ac[i]{ - return false - } - } - } -{{end -}} -` - -const transitivityBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - c := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - cv, _ := quick.Value(c.Dtype().Type, r) - b.Memset(bv.Interface()) - c.Memset(cv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - c := q.Clone().(*Dense) - cv, _ := quick.Value(c.Dtype().Type, r) - c.Memset(cv.Interface()) - - {{template "axb" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "axc" . }} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "transitivityCheck" .}} - return true -} -if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - b := q.Clone().(*Dense) - - bv, _ := quick.Value(b.Dtype().Type, r) - b.Memset(bv.Interface()) - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Transitivity test for {{.Name}} failed: %v", err) -} -` - -const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { - we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - {{template "funcoptdecl" . -}} - - r := newRand() - a := q.Clone().(*Dense) - bv, _ := quick.Value(a.Dtype().Type, r) - b := bv.Interface() - - {{template "axb" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - - {{template "bxa" .}} - if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ - if err != nil { - return false - } - return true - } - return reflect.DeepEqual(axb.Data(), bxa.Data()) - -} -if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Errorf("Symmetry test for {{.Name}} failed: %v", err) -} -` - -type CmpTest struct { - cmpOp - scalars bool - lvl Level - FuncOpt string - EqFailTypeClassName string -} - -func (fn *CmpTest) Name() string { - if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { - return "El" + fn.cmpOp.Name() - } - return fn.cmpOp.Name() -} - -func (fn *CmpTest) Level() string { - switch fn.lvl { - case API: - return "API" - case Dense: - return "Dense" - } - return "" -} - -func (fn *CmpTest) Signature() *Signature { - var name string - switch fn.lvl { - case API: - name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) - case Dense: - name = fmt.Sprintf("TestDense_%s", fn.Name()) - } - if fn.scalars { - name += "Scalar" - } - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *CmpTest) canWrite() bool { - return fn.IsTransitive || fn.IsSymmetric -} - -func (fn *CmpTest) WriteBody(w io.Writer) { - if fn.IsTransitive { - fn.writeTransitivity(w) - fmt.Fprintf(w, "\n") - } - if fn.IsSymmetric { - fn.writeSymmetry(w) - } -} - -func (fn *CmpTest) writeTransitivity(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) - template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) - template.Must(t.New("axc").Parse(APICallVVaxcRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) - template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) - } - } - template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) writeSymmetry(w io.Writer) { - var t *template.Template - if fn.scalars { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) - } else { - t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) - } - - switch fn.lvl { - case API: - if fn.scalars { - template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(APICallVVaxbRaw)) - template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) - } - case Dense: - if fn.scalars { - template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) - } else { - template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) - template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) - } - } - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - - t.Execute(w, fn) -} - -func (fn *CmpTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPICmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } - -} - -func generateAPICmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: API, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} - -func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { - var tests []*CmpTest - - for _, op := range cmpBinOps { - t := &CmpTest{ - cmpOp: op, - lvl: Dense, - scalars: true, - EqFailTypeClassName: "nil", - } - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "assame" - fn.TypeClassName = "nonComplexNumberTypes" - } - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const ( + APICallVVaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallVVbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallVVaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallVVbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + APICallMixedaxbRaw = `axb, err := {{.Name}}(a, b {{template "funcoptuse" . -}})` + APICallMixedbxcRaw = `bxc, err := {{.Name}}(b, c {{template "funcoptuse" . -}})` + APICallMixedaxcRaw = `axc, err := {{.Name}}(a, c {{template "funcoptuse" . -}})` + APICallMixedbxaRaw = `bxa, err := {{.Name}}(b, a {{template "funcoptuse" . -}})` + + DenseMethodCallVVaxbRaw = `axb, err := a.{{.Name}}(b {{template "funcoptuse" . -}})` + DenseMethodCallVVbxcRaw = `bxc, err := b.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallVVbxaRaw = `bxa, err := b.{{.Name}}(a {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxbRaw = `axb, err := a.{{.Name}}Scalar(b, true {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxcRaw = `bxc, err := c.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` + DenseMethodCallMixedaxcRaw = `axc, err := a.{{.Name}}(c {{template "funcoptuse" . -}})` + DenseMethodCallMixedbxaRaw = `bxa, err := a.{{.Name}}Scalar(b, false {{template "funcoptuse" . -}})` +) + +const transitivityCheckRaw = `{{if eq .FuncOpt "assame" -}} + if !threewayEq(axb.Data(), bxc.Data(), axc.Data()){ + t.Errorf("a: %-v", a) + t.Errorf("b: %-v", b) + t.Errorf("c: %-v", c) + t.Errorf("axb.Data() %v", axb.Data()) + t.Errorf("bxc.Data() %v", bxc.Data()) + t.Errorf("axc.Data() %v", axc.Data()) + return false + } +{{else -}} + {{if eq .Level "API" -}} + ab := axb.(*Dense).Bools() + bc := bxc.(*Dense).Bools() + ac := axc.(*Dense).Bools() + {{else -}} + ab := axb.Bools() + bc := bxc.Bools() + ac := axc.Bools() + {{end -}} + for i, vab := range ab { + if vab && bc[i] { + if !ac[i]{ + return false + } + } + } +{{end -}} +` + +const transitivityBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + c := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + cv, _ := quick.Value(c.Dtype().Type, r) + b.Memset(bv.Interface()) + c.Memset(cv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", b, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const transitivityMixedBodyRaw = `transFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + c := q.Clone().(*Dense) + cv, _ := quick.Value(c.Dtype().Type, r) + c.Memset(cv.Interface()) + + {{template "axb" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙c", c, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "axc" . }} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙c", a, c, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "transitivityCheck" .}} + return true +} +if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + b := q.Clone().(*Dense) + + bv, _ := quick.Value(b.Dtype().Type, r) + b.Memset(bv.Interface()) + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Transitivity test for {{.Name}} failed: %v", err) +} +` + +const symmetryMixedBodyRaw = `symFn := func(q *Dense) bool { + we, _ := willerr(q, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + {{template "funcoptdecl" . -}} + + r := newRand() + a := q.Clone().(*Dense) + bv, _ := quick.Value(a.Dtype().Type, r) + b := bv.Interface() + + {{template "axb" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - a∙b", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + + {{template "bxa" .}} + if err, retEarly := qcErrCheck(t, "{{.Name}} - b∙a", a, b, we, err); retEarly{ + if err != nil { + return false + } + return true + } + return reflect.DeepEqual(axb.Data(), bxa.Data()) + +} +if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { + t.Errorf("Symmetry test for {{.Name}} failed: %v", err) +} +` + +type CmpTest struct { + cmpOp + scalars bool + lvl Level + FuncOpt string + EqFailTypeClassName string +} + +func (fn *CmpTest) Name() string { + if fn.cmpOp.Name() == "Eq" || fn.cmpOp.Name() == "Ne" { + return "El" + fn.cmpOp.Name() + } + return fn.cmpOp.Name() +} + +func (fn *CmpTest) Level() string { + switch fn.lvl { + case API: + return "API" + case Dense: + return "Dense" + } + return "" +} + +func (fn *CmpTest) Signature() *Signature { + var name string + switch fn.lvl { + case API: + name = fmt.Sprintf("Test%s", fn.cmpOp.Name()) + case Dense: + name = fmt.Sprintf("TestDense_%s", fn.Name()) + } + if fn.scalars { + name += "Scalar" + } + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *CmpTest) canWrite() bool { + return fn.IsTransitive || fn.IsSymmetric +} + +func (fn *CmpTest) WriteBody(w io.Writer) { + if fn.IsTransitive { + fn.writeTransitivity(w) + fmt.Fprintf(w, "\n") + } + if fn.IsSymmetric { + fn.writeSymmetry(w) + } +} + +func (fn *CmpTest) writeTransitivity(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp transitivity test").Funcs(funcs).Parse(transitivityBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(APICallMixedbxcRaw)) + template.Must(t.New("axc").Parse(APICallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxc").Parse(APICallVVbxcRaw)) + template.Must(t.New("axc").Parse(APICallVVaxcRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallMixedbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallMixedaxcRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxc").Parse(DenseMethodCallVVbxcRaw)) + template.Must(t.New("axc").Parse(DenseMethodCallVVaxcRaw)) + } + } + template.Must(t.New("transitivityCheck").Parse(transitivityCheckRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) writeSymmetry(w io.Writer) { + var t *template.Template + if fn.scalars { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryMixedBodyRaw)) + } else { + t = template.Must(template.New("dense cmp symmetry test").Funcs(funcs).Parse(symmetryBodyRaw)) + } + + switch fn.lvl { + case API: + if fn.scalars { + template.Must(t.New("axb").Parse(APICallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(APICallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(APICallVVaxbRaw)) + template.Must(t.New("bxa").Parse(APICallVVbxaRaw)) + } + case Dense: + if fn.scalars { + template.Must(t.New("axb").Parse(DenseMethodCallMixedaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallMixedbxaRaw)) + } else { + template.Must(t.New("axb").Parse(DenseMethodCallVVaxbRaw)) + template.Must(t.New("bxa").Parse(DenseMethodCallVVbxaRaw)) + } + } + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + + t.Execute(w, fn) +} + +func (fn *CmpTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPICmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } + +} + +func generateAPICmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: API, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} + +func generateDenseMethodCmpMixedTests(f io.Writer, ak Kinds) { + var tests []*CmpTest + + for _, op := range cmpBinOps { + t := &CmpTest{ + cmpOp: op, + lvl: Dense, + scalars: true, + EqFailTypeClassName: "nilTC", + } + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "assame" + fn.TypeClassName = "dtype.NonComplexNumber" + } + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/genlib2/declarations.go b/genlib2/declarations.go index 7bcd6bc..9d70630 100644 --- a/genlib2/declarations.go +++ b/genlib2/declarations.go @@ -25,7 +25,7 @@ var cmpSymbolTemplates = [...]string{ } var nonFloatConditionalUnarySymbolTemplates = [...]string{ - `{{if isFloat .Kind -}} + `{{if isFloat .Kind -}} {{.Range}}[{{.Index0}}] = {{mathPkg .Kind}}Abs({{.Range}}[{{.Index0}}]) {{else -}} if {{.Range}}[{{.Index0}}] < 0 { {{.Range}}[{{.Index0}}] = -{{.Range}}[{{.Index0}}] @@ -57,10 +57,11 @@ var unconditionalFloatUnarySymbolTemplates = [...]string{ } var funcOptUse = map[string]string{ - "reuse": ",WithReuse(reuse)", - "incr": ",WithIncr(incr)", - "unsafe": ",UseUnsafe()", - "assame": ", AsSameType()", + "reuse": ",WithReuse(reuse)", + "incr": ",WithIncr(incr)", + "unsafe": ",UseUnsafe()", + "assame": ", AsSameType()", + "context": ", WithContext(ctx)", } var funcOptCheck = map[string]string{ @@ -77,7 +78,10 @@ var funcOptCheck = map[string]string{ t.Errorf("Expected ret to be the same as a") return false } - + `, + "context": `if _, ok := err.(NoOpError); ok && r < 5 { + return true // short circuit + } `, } @@ -85,10 +89,21 @@ var funcOptDecl = map[string]string{ "reuse": "reuse := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "incr": "incr := New(Of(a.t), WithShape(a.Shape().Clone()...))\n", "unsafe": "", - "assame": `if err := typeclassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { + "assame": `if err := dtype.TypeClassCheck(q.Dtype(), {{.TypeClassName}}); err != nil { return true // we exit early if the generated type is not something we can handle } `, + "context": `rng := newRand() + r := rng.Intn(10) + var ctx context.Context + var cancel context.CancelFunc + if r < 5 { + ctx, cancel = context.WithTimeout(context.Background(), 1 * time.Microsecond) + } else { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(r * 100)*time.Second) + } + defer cancel() +`, } var funcOptCorrect = map[string]string{ @@ -96,7 +111,8 @@ var funcOptCorrect = map[string]string{ "incr": `incr.Memset(identityVal(100, a.t)) correct.Add(incr, UseUnsafe()) `, - "unsafe": "", + "unsafe": "", + "context": "", } var stdTypes = [...]string{ @@ -427,51 +443,51 @@ func init() { // ops arithBinOps = []arithOp{ - {basicBinOp{"", "Add", false, isAddable}, "numberTypes", true, 0, false, "", true, false}, - {basicBinOp{"", "Sub", false, isNumber}, "numberTypes", false, 0, true, "Add", false, true}, - {basicBinOp{"", "Mul", false, isNumber}, "numberTypes", true, 1, false, "", true, false}, - {basicBinOp{"", "Div", false, isNumber}, "numberTypes", false, 1, true, "Mul", false, false}, - {basicBinOp{"", "Pow", true, isFloatCmplx}, "floatcmplxTypes", true, 1, false, "", false, false}, - {basicBinOp{"", "Mod", false, isNonComplexNumber}, "nonComplexNumberTypes", false, 0, false, "", false, false}, + {basicBinOp{"", "Add", false, isAddable}, "dtype.Number", true, 0, false, "", true, false}, + {basicBinOp{"", "Sub", false, isNumber}, "dtype.Number", false, 0, true, "Add", false, true}, + {basicBinOp{"", "Mul", false, isNumber}, "dtype.Number", true, 1, false, "", true, false}, + {basicBinOp{"", "Div", false, isNumber}, "dtype.Number", false, 1, true, "Mul", false, false}, + {basicBinOp{"", "Pow", true, isFloatCmplx}, "dtype.FloatComplex", true, 1, false, "", false, false}, + {basicBinOp{"", "Mod", false, isNonComplexNumber}, "dtype.NonComplexNumber", false, 0, false, "", false, false}, } for i := range arithBinOps { arithBinOps[i].symbol = arithSymbolTemplates[i] } cmpBinOps = []cmpOp{ - {basicBinOp{"", "Gt", false, isOrd}, "ordTypes", "Lt", true, false}, - {basicBinOp{"", "Gte", false, isOrd}, "ordTypes", "Lte", true, false}, - {basicBinOp{"", "Lt", false, isOrd}, "ordTypes", "Gt", true, false}, - {basicBinOp{"", "Lte", false, isOrd}, "ordTypes", "Gte", true, false}, - {basicBinOp{"", "Eq", false, isEq}, "eqTypes", "Eq", true, true}, - {basicBinOp{"", "Ne", false, isEq}, "eqTypes", "Ne", false, true}, + {basicBinOp{"", "Gt", false, isOrd}, "dtype.Ord", "Lt", true, false}, + {basicBinOp{"", "Gte", false, isOrd}, "dtype.Ord", "Lte", true, false}, + {basicBinOp{"", "Lt", false, isOrd}, "dtype.Ord", "Gt", true, false}, + {basicBinOp{"", "Lte", false, isOrd}, "dtype.Ord", "Gte", true, false}, + {basicBinOp{"", "Eq", false, isEq}, "dtype.Eq", "Eq", true, true}, + {basicBinOp{"", "Ne", false, isEq}, "dtype.Eq", "Ne", false, true}, } for i := range cmpBinOps { cmpBinOps[i].symbol = cmpSymbolTemplates[i] } conditionalUnaries = []unaryOp{ - {"", "Abs", false, isSignedNumber, "signedTypes", ""}, - {"", "Sign", false, isSignedNumber, "signedTypes", ""}, + {"", "Abs", false, isSignedNumber, "dtype.Signed", ""}, + {"", "Sign", false, isSignedNumber, "dtype.Signed", ""}, } for i := range conditionalUnaries { conditionalUnaries[i].symbol = nonFloatConditionalUnarySymbolTemplates[i] } unconditionalUnaries = []unaryOp{ - {"", "Neg", false, isNumber, "numberTypes", "Neg"}, - {"", "Inv", false, isNumber, "numberTypes", ""}, - {"", "Square", false, isNumber, "numberTypes", "Sqrt"}, - {"", "Cube", false, isNumber, "numberTypes", "Cbrt"}, - - {"", "Exp", true, isFloatCmplx, "floatcmplxTypes", "Log"}, - {"", "Tanh", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Log", true, isFloatCmplx, "floatcmplxTypes", "Exp"}, - {"", "Log2", true, isFloat, "floatTypes", ""}, - {"", "Log10", true, isFloatCmplx, "floatcmplxTypes", ""}, - {"", "Sqrt", true, isFloatCmplx, "floatcmplxTypes", "Square"}, - {"", "Cbrt", true, isFloat, "floatTypes", "Cube"}, - {"", "InvSqrt", true, isFloat, "floatTypes", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later + {"", "Neg", false, isNumber, "dtype.Number", "Neg"}, + {"", "Inv", false, isNumber, "dtype.Number", ""}, + {"", "Square", false, isNumber, "dtype.Number", "Sqrt"}, + {"", "Cube", false, isNumber, "dtype.Number", "Cbrt"}, + + {"", "Exp", true, isFloatCmplx, "dtype.FloatComplex", "Log"}, + {"", "Tanh", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Log", true, isFloatCmplx, "dtype.FloatComplex", "Exp"}, + {"", "Log2", true, isFloat, "dtype.Floats", ""}, + {"", "Log10", true, isFloatCmplx, "dtype.FloatComplex", ""}, + {"", "Sqrt", true, isFloatCmplx, "dtype.FloatComplex", "Square"}, + {"", "Cbrt", true, isFloat, "dtype.Floats", "Cube"}, + {"", "InvSqrt", true, isFloat, "dtype.Floats", ""}, // TODO: cmplx requires to much finagling to the template. Come back to it later } nonF := len(unconditionalNumUnarySymbolTemplates) for i := range unconditionalNumUnarySymbolTemplates { @@ -482,7 +498,7 @@ func init() { } specialUnaries = []UnaryOp{ - specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "nonComplexNumberTypes", ""}, []string{"min", "max"}}, + specialUnaryOp{unaryOp{clampBody, "Clamp", false, isNonComplexNumber, "dtype.NonComplexNumber", ""}, []string{"min", "max"}}, } // typed operations diff --git a/genlib2/dense_compat.go b/genlib2/dense_compat.go index e3b5b52..afb3fda 100644 --- a/genlib2/dense_compat.go +++ b/genlib2/dense_compat.go @@ -13,7 +13,7 @@ const importsArrowRaw = `import ( ) ` -const conversionsRaw = `func convFromFloat64s(to Dtype, data []float64) interface{} { +const conversionsRaw = `func convFromFloat64s(to dtype.Dtype, data []float64) interface{} { switch to { {{range .Kinds -}} {{if isNumber . -}} @@ -220,10 +220,10 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { var data []float64 switch { - case t.t == Float64 && toCopy && !t.IsMaterializable(): + case t.t == Float64 && toCopy && !t.RequiresIterator() && t.viewOf == 0: data = make([]float64, t.len()) copy(data, t.Float64s()) - case !t.IsMaterializable(): + case !t.RequiresIterator() && t.viewOf == 0: data = convToFloat64s(t) default: it := newFlatIterator(&t.AP) @@ -235,7 +235,7 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { data = append(data, convToFloat64(t.Get(next))) } err = nil - + } retVal = mat.NewDense(r, c, data) diff --git a/genlib2/dense_compat_tests.go b/genlib2/dense_compat_tests.go index d21831a..334f2a8 100644 --- a/genlib2/dense_compat_tests.go +++ b/genlib2/dense_compat_tests.go @@ -1,287 +1,287 @@ -package main - -import ( - "io" - "text/template" -) - -const compatTestsRaw = `var toMat64Tests = []struct{ - data interface{} - sliced interface{} - shape Shape - dt Dtype -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, - {{end -}} - {{end -}} -} -func TestToMat64(t *testing.T){ - assert := assert.New(t) - for i, tmt := range toMat64Tests { - T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) - var m *mat.Dense - var err error - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat basic test %d failed : %v", i, err) - continue - } - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) - - if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ - t.Errorf("Slice failed %v", err) - continue - } - if m, err = ToMat64(T); err != nil { - t.Errorf("ToMat of slice test %d failed : %v", i, err) - continue - } - conv = anyToFloat64s(tmt.sliced) - assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) - t.Logf("Done") - - if tmt.dt == Float64 { - T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) - if m, err = ToMat64(T, UseUnsafe()); err != nil { - t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) - } - conv = anyToFloat64s(tmt.data) - assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 1000 - assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) - conv[0] = 0 // reset for future tests that use the same backing - } - } - // idiocy test - T := New(Of(Float64), WithShape(2,3,4)) - _, err := ToMat64(T) - if err == nil { - t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") - } -} - -func TestFromMat64(t *testing.T){ - assert := assert.New(t) - var m *mat.Dense - var T *Dense - var backing []float64 - - - for i, tmt := range toMat64Tests { - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m) - conv := anyToFloat64s(tmt.data) - assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) - assert.True(T.Shape().Eq(tmt.shape)) - - T = FromMat64(m, As(tmt.dt)) - assert.Equal(tmt.data, T.Data()) - assert.True(T.Shape().Eq(tmt.shape)) - - if tmt.dt == Float64{ - backing = Range(Float64, 0, 6).([]float64) - m = mat.NewDense(2, 3, backing) - T = FromMat64(m, UseUnsafe()) - assert.Equal(backing, T.Float64s()) - assert.True(T.Shape().Eq(tmt.shape)) - backing[0] = 1000 - assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) - } - } -} -` - -const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ - data interface{} - valid []bool - dt arrow.DataType - shape Shape -}{ - {{range .PrimitiveTypes -}} - { - data: Range({{.}}, 0, 6), - valid: []bool{true, true, true, false, true, true}, - dt: arrow.PrimitiveTypes.{{ . }}, - shape: Shape{6,1}, - }, - {{end -}} -} -func TestFromArrowArray(t *testing.T){ - assert := assert.New(t) - var T *Dense - pool := memory.NewGoAllocator() - - for i, taat := range toArrowArrayTests { - var m arrowArray.Interface - - switch taat.dt { - {{range .BinaryTypes -}} - case arrow.BinaryTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - {{if eq . "String" -}} - []string{"0", "1", "2", "3", "4", "5"}, - {{else -}} - Range({{ . }}, 0, 6).([]{{lower . }}), - {{end -}} - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - {{range .FixedWidthTypes -}} - case arrow.FixedWidthTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - {{if eq . "Boolean" -}} - []bool{true, false, true, false, true, false}, - {{else -}} - Range({{ . }}, 0, 6).([]{{lower . }}), - {{end -}} - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - Range({{ . }}, 0, 6).([]{{lower . }}), - taat.valid, - ) - m = b.NewArray() - defer m.Release() - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - - T = FromArrowArray(m) - switch taat.dt { - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - conv := taat.data.([]{{lower . }}) - assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - for i, invalid := range T.Mask() { - assert.Equal(taat.valid[i], !invalid) - } - assert.True(T.Shape().Eq(taat.shape)) - } -} -` - -const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ - rowMajorData interface{} - colMajorData interface{} - rowMajorValid []bool - colMajorValid []bool - dt arrow.DataType - shape Shape -}{ - {{range .PrimitiveTypes -}} - { - rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, - rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, - colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, - dt: arrow.PrimitiveTypes.{{ . }}, - shape: Shape{2,5}, - }, - {{end -}} -} -func TestFromArrowTensor(t *testing.T){ - assert := assert.New(t) - var rowMajorT *Dense - var colMajorT *Dense - pool := memory.NewGoAllocator() - - for i, taat := range toArrowTensorTests { - var rowMajorArr arrowArray.Interface - var colMajorArr arrowArray.Interface - var rowMajor arrowTensor.Interface - var colMajor arrowTensor.Interface - - switch taat.dt { - {{range .PrimitiveTypes -}} - case arrow.PrimitiveTypes.{{ . }}: - b := arrowArray.New{{ . }}Builder(pool) - defer b.Release() - b.AppendValues( - []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - taat.rowMajorValid, - ) - rowMajorArr = b.NewArray() - defer rowMajorArr.Release() - - b.AppendValues( - []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - taat.rowMajorValid, - ) - colMajorArr = b.NewArray() - defer colMajorArr.Release() - - rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) - defer rowMajor.Release() - colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) - defer colMajor.Release() - {{end -}} - default: - t.Errorf("DataType not supported in tests: %v", taat.dt) - } - - rowMajorT = FromArrowTensor(rowMajor) - colMajorT = FromArrowTensor(colMajor) - - assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) - assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) - for i, invalid := range rowMajorT.Mask() { - assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) - } - assert.True(colMajorT.Shape().Eq(taat.shape)) - - assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) - assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) - for i, invalid := range colMajorT.Mask() { - assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) - } - assert.True(rowMajorT.Shape().Eq(taat.shape)) - } -} -` - -var ( - compatTests *template.Template - compatArrowArrayTests *template.Template - compatArrowTensorTests *template.Template -) - -func init() { - compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) - compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) - compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) -} - -func generateDenseCompatTests(f io.Writer, generic Kinds) { - // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming - // collisions - importsArrow.Execute(f, generic) - compatTests.Execute(f, generic) - arrowData := ArrowData{ - BinaryTypes: arrowBinaryTypes, - FixedWidthTypes: arrowFixedWidthTypes, - PrimitiveTypes: arrowPrimitiveTypes, - } - compatArrowArrayTests.Execute(f, arrowData) - compatArrowTensorTests.Execute(f, arrowData) -} +package main + +import ( + "io" + "text/template" +) + +const compatTestsRaw = `var toMat64Tests = []struct{ + data interface{} + sliced interface{} + shape Shape + dt dtype.Dtype +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { Range({{asType . | title | strip}}, 0, 6), []{{asType .}}{0,1,3,4}, Shape{2,3}, {{asType . | title | strip}} }, + {{end -}} + {{end -}} +} +func TestToMat64(t *testing.T){ + assert := assert.New(t) + for i, tmt := range toMat64Tests { + T := New(WithBacking(tmt.data), WithShape(tmt.shape...)) + var m *mat.Dense + var err error + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat basic test %d failed : %v", i, err) + continue + } + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "i %d from %v", i, tmt.dt) + + if T, err = sliceDense(T, nil, makeRS(0, 2)); err != nil{ + t.Errorf("Slice failed %v", err) + continue + } + if m, err = ToMat64(T); err != nil { + t.Errorf("ToMat of slice test %d failed : %v", i, err) + continue + } + conv = anyToFloat64s(tmt.sliced) + assert.Equal(conv, m.RawMatrix().Data, "sliced test %d from %v", i, tmt.dt) + t.Logf("Done") + + if tmt.dt == Float64 { + T = New(WithBacking(tmt.data), WithShape(tmt.shape...)) + if m, err = ToMat64(T, UseUnsafe()); err != nil { + t.Errorf("ToMat64 unsafe test %d failed: %v", i, err) + } + conv = anyToFloat64s(tmt.data) + assert.Equal(conv, m.RawMatrix().Data, "float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 1000 + assert.Equal(conv, m.RawMatrix().Data,"float64 unsafe i %d from %v", i, tmt.dt) + conv[0] = 0 // reset for future tests that use the same backing + } + } + // idiocy test + T := New(Of(Float64), WithShape(2,3,4)) + _, err := ToMat64(T) + if err == nil { + t.Error("Expected an error when trying to convert a 3-T to *mat.Dense") + } +} + +func TestFromMat64(t *testing.T){ + assert := assert.New(t) + var m *mat.Dense + var T *Dense + var backing []float64 + + + for i, tmt := range toMat64Tests { + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m) + conv := anyToFloat64s(tmt.data) + assert.Equal(conv, T.Float64s(), "test %d: []float64 from %v", i, tmt.dt) + assert.True(T.Shape().Eq(tmt.shape)) + + T = FromMat64(m, As(tmt.dt)) + assert.Equal(tmt.data, T.Data()) + assert.True(T.Shape().Eq(tmt.shape)) + + if tmt.dt == Float64{ + backing = Range(Float64, 0, 6).([]float64) + m = mat.NewDense(2, 3, backing) + T = FromMat64(m, UseUnsafe()) + assert.Equal(backing, T.Float64s()) + assert.True(T.Shape().Eq(tmt.shape)) + backing[0] = 1000 + assert.Equal(backing, T.Float64s(), "test %d - unsafe float64", i) + } + } +} +` + +const compatArrowArrayTestsRaw = `var toArrowArrayTests = []struct{ + data interface{} + valid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + data: Range({{.}}, 0, 6), + valid: []bool{true, true, true, false, true, true}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{6,1}, + }, + {{end -}} +} +func TestFromArrowArray(t *testing.T){ + assert := assert.New(t) + var T *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowArrayTests { + var m arrowArray.Interface + + switch taat.dt { + {{range .BinaryTypes -}} + case arrow.BinaryTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "String" -}} + []string{"0", "1", "2", "3", "4", "5"}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .FixedWidthTypes -}} + case arrow.FixedWidthTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + {{if eq . "Boolean" -}} + []bool{true, false, true, false, true, false}, + {{else -}} + Range({{ . }}, 0, 6).([]{{lower . }}), + {{end -}} + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + Range({{ . }}, 0, 6).([]{{lower . }}), + taat.valid, + ) + m = b.NewArray() + defer m.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + T = FromArrowArray(m) + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + conv := taat.data.([]{{lower . }}) + assert.Equal(conv, T.{{ . }}s(), "test %d: []{{lower . }} from %v", i, taat.dt) + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + for i, invalid := range T.Mask() { + assert.Equal(taat.valid[i], !invalid) + } + assert.True(T.Shape().Eq(taat.shape)) + } +} +` + +const compatArrowTensorTestsRaw = `var toArrowTensorTests = []struct{ + rowMajorData interface{} + colMajorData interface{} + rowMajorValid []bool + colMajorValid []bool + dt arrow.DataType + shape Shape +}{ + {{range .PrimitiveTypes -}} + { + rowMajorData: []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + colMajorData: []{{lower .}}{1, 6, 2, 7, 3, 8, 4, 9, 5, 10}, + rowMajorValid: []bool{true, false, true, false, true, false, true, false, true, false}, + colMajorValid: []bool{true, false, false, true, true, false, false, true, true, false}, + dt: arrow.PrimitiveTypes.{{ . }}, + shape: Shape{2,5}, + }, + {{end -}} +} +func TestFromArrowTensor(t *testing.T){ + assert := assert.New(t) + var rowMajorT *Dense + var colMajorT *Dense + pool := memory.NewGoAllocator() + + for i, taat := range toArrowTensorTests { + var rowMajorArr arrowArray.Interface + var colMajorArr arrowArray.Interface + var rowMajor arrowTensor.Interface + var colMajor arrowTensor.Interface + + switch taat.dt { + {{range .PrimitiveTypes -}} + case arrow.PrimitiveTypes.{{ . }}: + b := arrowArray.New{{ . }}Builder(pool) + defer b.Release() + b.AppendValues( + []{{lower . }}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + rowMajorArr = b.NewArray() + defer rowMajorArr.Release() + + b.AppendValues( + []{{lower .}}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + taat.rowMajorValid, + ) + colMajorArr = b.NewArray() + defer colMajorArr.Release() + + rowMajor = arrowTensor.New{{.}}(rowMajorArr.Data(), []int64{2, 5}, nil, []string{"x", "y"}) + defer rowMajor.Release() + colMajor = arrowTensor.New{{.}}(colMajorArr.Data(), []int64{2, 5}, []int64{int64(arrow.{{ . }}SizeBytes), int64(arrow.{{ . }}SizeBytes * 2)}, []string{"x", "y"}) + defer colMajor.Release() + {{end -}} + default: + t.Errorf("DataType not supported in tests: %v", taat.dt) + } + + rowMajorT = FromArrowTensor(rowMajor) + colMajorT = FromArrowTensor(colMajor) + + assert.Equal(taat.rowMajorData, rowMajorT.Data(), "test %d: row major %v", i, taat.dt) + assert.Equal(len(taat.rowMajorValid), len(rowMajorT.Mask()), "test %d: row major %v mask length incorrect", i, taat.dt) + for i, invalid := range rowMajorT.Mask() { + assert.Equal(taat.rowMajorValid[i], !invalid, "test %d: row major %v mask value incorrect", i, taat.dt) + } + assert.True(colMajorT.Shape().Eq(taat.shape)) + + assert.Equal(taat.colMajorData, colMajorT.Data(), "test %d: column major %v", i, taat.dt) + assert.Equal(len(taat.colMajorValid), len(colMajorT.Mask()), "test %d: column major %v mask length incorrect", i, taat.dt) + for i, invalid := range colMajorT.Mask() { + assert.Equal(taat.colMajorValid[i], !invalid, "test %d: column major %v mask value incorrect", i, taat.dt) + } + assert.True(rowMajorT.Shape().Eq(taat.shape)) + } +} +` + +var ( + compatTests *template.Template + compatArrowArrayTests *template.Template + compatArrowTensorTests *template.Template +) + +func init() { + compatTests = template.Must(template.New("testCompat").Funcs(funcs).Parse(compatTestsRaw)) + compatArrowArrayTests = template.Must(template.New("testArrowArrayCompat").Funcs(funcs).Parse(compatArrowArrayTestsRaw)) + compatArrowTensorTests = template.Must(template.New("testArrowTensorCompat").Funcs(funcs).Parse(compatArrowTensorTestsRaw)) +} + +func generateDenseCompatTests(f io.Writer, generic Kinds) { + // NOTE(poopoothegorilla): an alias is needed for the Arrow Array pkg to prevent naming + // collisions + importsArrow.Execute(f, generic) + compatTests.Execute(f, generic) + arrowData := ArrowData{ + BinaryTypes: arrowBinaryTypes, + FixedWidthTypes: arrowFixedWidthTypes, + PrimitiveTypes: arrowPrimitiveTypes, + } + compatArrowArrayTests.Execute(f, arrowData) + compatArrowTensorTests.Execute(f, arrowData) +} diff --git a/genlib2/dense_cons.go b/genlib2/dense_cons.go index fee0df5..aa6bab8 100644 --- a/genlib2/dense_cons.go +++ b/genlib2/dense_cons.go @@ -6,7 +6,7 @@ import ( ) const onesRaw = `// Ones creates a *Dense with the provided shape and type -func Ones(dt Dtype, shape ...int) *Dense { +func Ones(dt dtype.Dtype, shape ...int) *Dense { d := recycledDense(dt, shape) switch d.t.Kind() { {{range .Kinds -}} @@ -48,7 +48,7 @@ const Iraw = `// I creates the identity matrix (usually a square) matrix with 1s // ⎢1 0 0 0⎥ // ⎢0 1 0 0⎥ // ⎣0 0 1 0⎦ -func I(dt Dtype, r, c, k int) *Dense{ +func I(dt dtype.Dtype, r, c, k int) *Dense{ ret := New(Of(dt), WithShape(r,c)) i := k if k < 0 { diff --git a/genlib2/dense_cons_tests.go b/genlib2/dense_cons_tests.go index 938d6fa..29d1366 100644 --- a/genlib2/dense_cons_tests.go +++ b/genlib2/dense_cons_tests.go @@ -1,85 +1,85 @@ -package main - -import ( - "io" - "text/template" -) - -const onesTestsRaw = `var onesTests = []struct { - of Dtype - shape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, - { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, - {{end -}} - {{end -}} - {Bool, ScalarShape(), true}, - {Bool, Shape{2,2}, []bool{true, true, true, true}}, -} - -func TestOnes(t *testing.T){ - assert := assert.New(t) - for _, ot := range onesTests{ - T := Ones(ot.of, ot.shape...) - assert.True(ot.shape.Eq(T.Shape())) - assert.Equal(ot.correct, T.Data()) - } -} -` - -const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist -var eyeTests = []struct{ - E Dtype - R, C, K int - - - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, - { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, - { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, - { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, - {{end -}} - {{end -}} -} - -func TestI(t *testing.T){ - assert := assert.New(t) - var T Tensor - - for i, it := range eyeTests { - T = I(it.E, it.R, it.C, it.K) - assert.True(Shape{it.R, it.C}.Eq(T.Shape())) - assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) - } - -} -` - -var ( - onesTests *template.Template - eyeTests *template.Template -) - -func init() { - onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) - eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) -} - -func generateDenseConsTests(f io.Writer, generic Kinds) { - onesTests.Execute(f, generic) - eyeTests.Execute(f, generic) -} +package main + +import ( + "io" + "text/template" +) + +const onesTestsRaw = `var onesTests = []struct { + of dtype.Dtype + shape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, ScalarShape(), {{asType .}}(1)}, + { {{asType . | title | strip}}, Shape{2,2}, []{{asType .}}{1,1,1,1}}, + {{end -}} + {{end -}} + {Bool, ScalarShape(), true}, + {Bool, Shape{2,2}, []bool{true, true, true, true}}, +} + +func TestOnes(t *testing.T){ + assert := assert.New(t) + for _, ot := range onesTests{ + T := Ones(ot.of, ot.shape...) + assert.True(ot.shape.Eq(T.Shape())) + assert.Equal(ot.correct, T.Data()) + } +} +` + +const eyeTestsRaw = `// yes, it's a pun on eye tests, stop asking and go see your optometrist +var eyeTests = []struct{ + E dtype.Dtype + R, C, K int + + + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + { {{asType . | title | strip}}, 4,4, 0, []{{asType .}}{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,4, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 2, []{{asType .}}{0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 3, []{{asType .}}{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, 4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -1, []{{asType .}}{0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,4, -2, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -3, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,4, -4, []{{asType .}}{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + { {{asType . | title | strip}}, 4,5, 0, []{{asType .}}{1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}}, + { {{asType . | title | strip}}, 4,5, 1, []{{asType .}}{0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1}}, + { {{asType . | title | strip}}, 4,5, -1, []{{asType .}}{0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}}, + {{end -}} + {{end -}} +} + +func TestI(t *testing.T){ + assert := assert.New(t) + var T Tensor + + for i, it := range eyeTests { + T = I(it.E, it.R, it.C, it.K) + assert.True(Shape{it.R, it.C}.Eq(T.Shape())) + assert.Equal(it.correct, T.Data(), "Test %d-R: %d, C: %d K: %d", i, it.R, it.C, it.K) + } + +} +` + +var ( + onesTests *template.Template + eyeTests *template.Template +) + +func init() { + onesTests = template.Must(template.New("onesTest").Funcs(funcs).Parse(onesTestsRaw)) + eyeTests = template.Must(template.New("eyeTest").Funcs(funcs).Parse(eyeTestsRaw)) +} + +func generateDenseConsTests(f io.Writer, generic Kinds) { + onesTests.Execute(f, generic) + eyeTests.Execute(f, generic) +} diff --git a/genlib2/dense_getset_tests.go b/genlib2/dense_getset_tests.go index 15cc820..50bafb3 100644 --- a/genlib2/dense_getset_tests.go +++ b/genlib2/dense_getset_tests.go @@ -102,8 +102,8 @@ func makeZeroTests(generic Kinds) []testData { } const getTestRaw = `var denseSetGetTests = []struct { - of Dtype - data interface{} + of dtype.Dtype + data interface{} set interface{} correct []interface{} @@ -129,7 +129,7 @@ func TestDense_setget(t *testing.T) { ` const memsetTestRaw = `var denseMemsetTests = []struct{ - of Dtype + of dtype.Dtype data interface{} val interface{} shape Shape @@ -139,7 +139,7 @@ const memsetTestRaw = `var denseMemsetTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, {{asType .Kind}}({{$val}}), Shape{2,3}, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } @@ -159,7 +159,7 @@ func TestDense_memset(t *testing.T){ ` const zeroTestRaw = `var denseZeroTests = []struct{ - of Dtype + of dtype.Dtype data interface{} correct interface{} @@ -167,18 +167,18 @@ const zeroTestRaw = `var denseZeroTests = []struct{ {{range . -}} {{$val := .Set -}} {{$k := .Kind -}} - { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, + { {{title .Kind.String | strip}}, []{{asType .Kind}}{ {{range .TestData0 -}}{{printf "%v" .}}, {{end -}} }, []{{asType .Kind}}{ {{range .Correct}} {{printf "%v" .}}, {{end -}} } }, {{end -}} } func TestDense_Zero(t *testing.T) { assert := assert.New(t) for _, mts := range denseZeroTests { - + typ := reflect.TypeOf(mts.data) val := reflect.ValueOf(mts.data) data := reflect.MakeSlice(typ, val.Len(), val.Cap()) - reflect.Copy(data, val) + reflect.Copy(data, val) T := New(Of(mts.of), WithBacking(data.Interface())) T.Zero() @@ -188,7 +188,7 @@ func TestDense_Zero(t *testing.T) { T2, _ := T.Slice(nil) T2.Zero() assert.Equal(mts.correct, T2.Data()) - } + } } ` diff --git a/genlib2/dense_io.go b/genlib2/dense_io.go index 0fe010a..9b971ac 100644 --- a/genlib2/dense_io.go +++ b/genlib2/dense_io.go @@ -63,7 +63,7 @@ func (r *binaryReader) Err() error { // If tensor is masked, invalid values are replaced by the default fill value. func (t *Dense) WriteNpy(w io.Writer) (err error) { var npdt string - if npdt, err = t.t.numpyDtype(); err != nil{ + if npdt, err = t.t.NumpyDtype(); err != nil{ return } @@ -290,7 +290,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ } // TODO: check for endianness. For now we assume everything is little endian - if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = dtype.FromNumpyDtype(string(match[1][1:])); err != nil { return } @@ -348,7 +348,7 @@ func (t *Dense) ReadNpy(r io.Reader) (err error){ const readCSVRaw = `// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. // If into is nil, then a backing slice will be created. -func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { +func convFromStrs(to dtype.Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { {{range .Kinds -}} @@ -545,12 +545,11 @@ func (t *Dense) FBDecode(buf []byte) error { t.strides[i] = int(serialized.Strides(i)) } typ := string(serialized.Type()) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode FlatBuffers") } + t.t = dt if t.e == nil { t.e = StdEng{} @@ -621,12 +620,11 @@ func (t *Dense) PBDecode(buf []byte) error { } t.Δ = Triangle(toSerialize.T) typ := string(toSerialize.Type) - for _, dt := range allTypes.set { - if dt.String() == typ { - t.t = dt - break - } + dt, err := dtype.FindByName(typ) + if err != nil { + return errors.Wrap(err, "Failed to decode ProtoBuf") } + t.t = dt if t.e == nil { t.e = StdEng{} diff --git a/genlib2/dense_maskedmethods.go b/genlib2/dense_maskedmethods.go index ce1133c..644e37a 100644 --- a/genlib2/dense_maskedmethods.go +++ b/genlib2/dense_maskedmethods.go @@ -1,103 +1,103 @@ -package main - -import ( - "fmt" - "io" - "reflect" - "text/template" -) - -var maskcmpMethods = []struct { - Name string - Desc string - NumArgs int - CmpFn string - ReqFloat bool - Kinds []reflect.Kind -}{ - {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, - {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, - {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, - {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, - {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, - {"MaskedLess", " less than ", 1, "a < x", false, nil}, - {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, - {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, - {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, -} - -const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val -// Any values must be the same type as the tensor -func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ - {{if .ReqFloat}} - if !isFloat(t.t) { - err = errors.Errorf("Can only do {{.Name}} with floating point types") - return - } - {{end}} - - if !t.IsMasked() { - t.makeMask() - } - - {{$numargs := .NumArgs}} - {{$name := .Name}} - {{$fn := .CmpFn}} - {{$reqFloat := .ReqFloat}} - switch t.t.Kind(){ - {{range .Kinds -}} - {{if isParameterized . -}} - {{else -}} - {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} - {{else -}} - case reflect.{{reflectKind .}}: - data := t.{{sliceOf .}} - mask := t.mask - {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} - {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} - {{if ge $numargs 3 -}} - {{if eq $name "MaskedValues"}} - delta := float64(1.0e-8) - if len(val3) > 0 { - delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) - } - {{else}} - z := val3.({{asType .}}) - {{end}} - {{end}} - if t.maskIsSoft{ - for i := range data { - a := data[i] - mask[i] = ({{$fn}}) - } - } else { - for i := range data { - a := data[i] - mask[i] = mask[i] || ({{$fn}}) - } - } - - {{end}} - {{end}} - {{end}} -} -return nil -} -` - -var ( - maskCmpMethod *template.Template -) - -func init() { - maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) -} - -func generateDenseMaskedMethods(f io.Writer, generic Kinds) { - for _, mm := range maskcmpMethods { - mm.Kinds = generic.Kinds - fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) - maskCmpMethod.Execute(f, mm) - - } -} +package main + +import ( + "fmt" + "io" + "reflect" + "text/template" +) + +var maskcmpMethods = []struct { + Name string + Desc string + NumArgs int + CmpFn string + ReqFloat bool + Kinds []reflect.Kind +}{ + {"MaskedEqual", "equal to ", 1, "a == x", false, nil}, + {"MaskedNotEqual", "not equal to ", 1, "a != x", false, nil}, + {"MaskedValues", " equal to ", 3, "math.Abs(float64(a-x)) <= delta", true, nil}, + {"MaskedGreater", " greater than ", 1, "a > x", false, nil}, + {"MaskedGreaterEqual", " greater than or equal to ", 1, "a >= x", false, nil}, + {"MaskedLess", " less than ", 1, "a < x", false, nil}, + {"MaskedLessEqual", " less than or equal to ", 1, "a <= x", false, nil}, + {"MaskedInside", " inside range of ", 2, "(a >= x) && (a <= y)", false, nil}, + {"MaskedOutside", " outside range of ", 2, "(a < x) || (a > y)", false, nil}, +} + +const maskCmpMethodRaw = `// {{.Name}} sets the mask to true where the corresponding data is {{.Desc}} val +// Any values must be the same type as the tensor +func (t *Dense) {{.Name}}({{if ge .NumArgs 1 -}} val1 interface{} {{end}} {{if ge .NumArgs 2 -}} , val2 interface{} {{end}} {{if ge .NumArgs 3 -}} , val3 ...interface{}{{end}})(err error){ + {{if .ReqFloat}} + if !isFloat(t.t) { + err = errors.Errorf("Can only do {{.Name}} with floating point types") + return + } + {{end}} + + if !t.IsMasked() { + t.makeMask() + } + + {{$numargs := .NumArgs}} + {{$name := .Name}} + {{$fn := .CmpFn}} + {{$reqFloat := .ReqFloat}} + switch t.t.Kind(){ + {{range .Kinds -}} + {{if isParameterized . -}} + {{else -}} + {{if or (not (isOrd .)) (and $reqFloat (isntFloat .)) -}} + {{else -}} + case reflect.{{reflectKind .}}: + data := t.{{sliceOf .}} + mask := t.mask + {{if ge $numargs 1 -}} x := val1.({{asType .}}) {{end}} + {{if ge $numargs 2 -}} y := val2.({{asType .}}){{end}} + {{if ge $numargs 3 -}} + {{if eq $name "MaskedValues"}} + delta := float64(1.0e-8) + if len(val3) > 0 { + delta = float64(val3[0].({{asType .}})) + float64(y)*math.Abs(float64(x)) + } + {{else}} + z := val3.({{asType .}}) + {{end}} + {{end}} + if t.maskIsSoft{ + for i := range data { + a := data[i] + mask[i] = ({{$fn}}) + } + } else { + for i := range data { + a := data[i] + mask[i] = mask[i] || ({{$fn}}) + } + } + + {{end}} + {{end}} + {{end}} +} +return nil +} +` + +var ( + maskCmpMethod *template.Template +) + +func init() { + maskCmpMethod = template.Must(template.New("maskcmpmethod").Funcs(funcs).Parse(maskCmpMethodRaw)) +} + +func generateDenseMaskedMethods(f io.Writer, generic Kinds) { + for _, mm := range maskcmpMethods { + mm.Kinds = generic.Kinds + fmt.Fprintf(f, "/* %s */ \n\n", mm.Name) + maskCmpMethod.Execute(f, mm) + + } +} diff --git a/genlib2/dense_reduction_methods_tests.go b/genlib2/dense_reduction_methods_tests.go index 30defc9..342b1d4 100644 --- a/genlib2/dense_reduction_methods_tests.go +++ b/genlib2/dense_reduction_methods_tests.go @@ -1,164 +1,164 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const testDenseSumRaw = `var sumTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, - {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, - {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, - {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, - {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, - {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, - {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, - {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, - {{end -}} - {{end -}} -} -func TestDense_Sum(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, sts := range sumTests { - T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) - if T2, err = T.Sum(sts.along ...); err != nil { - t.Error(err) - continue - } - assert.True(sts.correctShape.Eq(T2.Shape())) - assert.Equal(sts.correct, T2.Data()) - } - - // idiots - _,err =T.Sum(1000) - assert.NotNil(err) -} -` - -const testDenseMaxRaw = `var maxTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, - {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, - {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, - {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, - {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, - {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, - {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, - {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Max(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range maxTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Max(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - /* IDIOT TESTING TIME */ - _, err = T.Max(1000) - assert.NotNil(err) -} -` - -const testDenseMinRaw = `var minTests = []struct { - name string - of Dtype - shape Shape - along []int - - correctShape Shape - correct interface{} -}{ - {{range .Kinds -}} - {{if isNumber . -}} - {{if isOrd . -}} - {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, - {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, - {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, - {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, - {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, - {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, - {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, - {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, - {{end -}} - {{end -}} - {{end -}} -} - -func TestDense_Min(t *testing.T){ - assert := assert.New(t) - var T, T2 *Dense - var err error - - for _, mts := range minTests { - T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) - if T2, err = T.Min(mts.along...); err != nil{ - t.Error(err) - continue - } - assert.True(mts.correctShape.Eq(T2.Shape())) - assert.Equal(mts.correct, T2.Data()) - } - - /* IDIOT TESTING TIME */ - _, err = T.Min(1000) - assert.NotNil(err) -} -` - -var ( - testDenseSum *template.Template - testDenseMax *template.Template - testDenseMin *template.Template -) - -func init() { - testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) - testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) - testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) -} - -func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { - testDenseSum.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMax.Execute(f, generic) - fmt.Fprint(f, "\n") - testDenseMin.Execute(f, generic) -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const testDenseSumRaw = `var sumTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {"common case: T.Sum() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(0) for {{.}}", {{asType . | title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{3, 5, 7}}, + {"A.Sum(1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType .}}{3, 12}}, + {"A.Sum(0,1) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(15)}, + {"A.Sum(1,0) for {{.}}", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(15)}, + {"3T.Sum(1,2) for {{.}}", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{66, {{if eq .String "int8"}}-46{{else}}210{{end}} }}, + {"4T.Sum() for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(120)}, + {"4T.Sum(1,3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{10, 18, 42, 50}}, + {"4T.Sum(0, 2, 3) for {{.}}", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{44, 76}}, + {{end -}} + {{end -}} +} +func TestDense_Sum(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, sts := range sumTests { + T = New(WithShape(sts.shape...), WithBacking(Range(sts.of, 0, sts.shape.TotalSize()))) + if T2, err = T.Sum(sts.along ...); err != nil { + t.Error(err) + continue + } + assert.True(sts.correctShape.Eq(T2.Shape())) + assert.Equal(sts.correct, T2.Data()) + } + + // idiots + _,err =T.Sum(1000) + assert.NotNil(err) +} +` + +const testDenseMaxRaw = `var maxTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Max() for {{.}}", {{asType . | title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(0)", {{asType . | title}}, Shape{2,3},[]int{0}, Shape{3}, []{{asType . }}{3, 4, 5}}, + {"A.Max(1)", {{asType . | title}}, Shape{2,3},[]int{1}, Shape{2}, []{{asType . }}{2,5}}, + {"A.Max(0,1)", {{asType . | title}}, Shape{2,3},[]int{0, 1}, ScalarShape(), {{asType .}}(5)}, + {"A.Max(1,0)", {{asType . | title}}, Shape{2,3},[]int{1, 0}, ScalarShape(), {{asType .}}(5)}, + {"3T.Max(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{11, 23} }, + {"4T.Max()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(15)}, + {"4T.Max(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{5, 7, 13, 15}}, + {"4T.Max(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{11, 15}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Max(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range maxTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Max(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + /* IDIOT TESTING TIME */ + _, err = T.Max(1000) + assert.NotNil(err) +} +` + +const testDenseMinRaw = `var minTests = []struct { + name string + of dtype.Dtype + shape Shape + along []int + + correctShape Shape + correct interface{} +}{ + {{range .Kinds -}} + {{if isNumber . -}} + {{if isOrd . -}} + {"common case: T.Min() for {{.}}", {{asType .|title}}, Shape{2,3}, []int{}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(0)", {{asType .|title}}, Shape{2,3}, []int{0}, Shape{3}, []{{asType .}}{0, 1, 2}}, + {"A.Min(1)", {{asType .|title}}, Shape{2,3}, []int{1}, Shape{2}, []{{asType .}}{0, 3}}, + {"A.Min(0,1)", {{asType .|title}}, Shape{2,3}, []int{0, 1}, ScalarShape(), {{asType .}}(0)}, + {"A.Min(1,0)", {{asType .|title}}, Shape{2,3}, []int{1, 0}, ScalarShape(), {{asType .}}(0)}, + {"3T.Min(1,2)", {{asType . | title}}, Shape{2,3,4}, []int{1,2}, Shape{2}, []{{asType .}}{0,12} }, + {"4T.Min()", {{asType . | title}}, Shape{2, 2, 2, 2},[]int{}, ScalarShape(), {{asType .}}(0)}, + {"4T.Min(1,3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{1, 3}, Shape{2, 2}, []{{asType .}}{0, 2, 8, 10}}, + {"4T.Min(0, 2, 3)", {{asType . | title}}, Shape{2, 2, 2, 2}, []int{0, 2, 3}, Shape{2}, []{{asType .}}{0, 4}}, + {{end -}} + {{end -}} + {{end -}} +} + +func TestDense_Min(t *testing.T){ + assert := assert.New(t) + var T, T2 *Dense + var err error + + for _, mts := range minTests { + T = New(WithShape(mts.shape...), WithBacking(Range(mts.of, 0, mts.shape.TotalSize()))) + if T2, err = T.Min(mts.along...); err != nil{ + t.Error(err) + continue + } + assert.True(mts.correctShape.Eq(T2.Shape())) + assert.Equal(mts.correct, T2.Data()) + } + + /* IDIOT TESTING TIME */ + _, err = T.Min(1000) + assert.NotNil(err) +} +` + +var ( + testDenseSum *template.Template + testDenseMax *template.Template + testDenseMin *template.Template +) + +func init() { + testDenseSum = template.Must(template.New("testDenseSum").Funcs(funcs).Parse(testDenseSumRaw)) + testDenseMax = template.Must(template.New("testDenseMax").Funcs(funcs).Parse(testDenseMaxRaw)) + testDenseMin = template.Must(template.New("testDenseMin").Funcs(funcs).Parse(testDenseMinRaw)) +} + +func generateDenseReductionMethodsTests(f io.Writer, generic Kinds) { + testDenseSum.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMax.Execute(f, generic) + fmt.Fprint(f, "\n") + testDenseMin.Execute(f, generic) +} diff --git a/genlib2/dense_reduction_tests.go b/genlib2/dense_reduction_tests.go index 2c35efa..06f78c0 100644 --- a/genlib2/dense_reduction_tests.go +++ b/genlib2/dense_reduction_tests.go @@ -6,7 +6,7 @@ import ( ) const testDenseReduceRaw = `var denseReductionTests = []struct { - of Dtype + of dtype.Dtype fn interface{} def interface{} axis int diff --git a/genlib2/engine.go b/genlib2/engine.go index f48a0eb..6f551fc 100644 --- a/genlib2/engine.go +++ b/genlib2/engine.go @@ -7,6 +7,7 @@ import ( ) type EngineArith struct { + isStdDenseEng bool Name string VecVar string PrepData string @@ -34,9 +35,11 @@ func (fn *EngineArith) Signature() *Signature { case fn.VV: paramNames = []string{"a", "b", "opts"} paramTemplates = []*template.Template{tensorType, tensorType, splatFuncOptType} + default: paramNames = []string{"t", "s", "leftTensor", "opts"} paramTemplates = []*template.Template{tensorType, interfaceType, boolType, splatFuncOptType} + } return &Signature{ Name: fn.methName(), @@ -388,18 +391,18 @@ func (fn *EngineUnary) Write(w io.Writer) { func generateStdEngUncondUnary(f io.Writer, ak Kinds) { tcc := []string{ - "Number", // Neg - "Number", // Inv - "Number", // Square - "Number", // Cube - "FloatCmplx", // Exp - "FloatCmplx", // Tanhh - "FloatCmplx", // Log - "Float", // Log2 - "FloatCmplx", // Log10 - "FloatCmplx", // Sqrt - "Float", // Cbrt - "Float", // InvSqrt + "Number", // Neg + "Number", // Inv + "Number", // Square + "Number", // Cube + "FloatComplex", // Exp + "FloatComplex", // Tanhh + "FloatComplex", // Log + "Floats", // Log2 + "FloatComplex", // Log10 + "FloatComplex", // Sqrt + "Floats", // Cbrt + "Floats", // InvSqrt } var gen []*EngineUnary for i, u := range unconditionalUnaries { diff --git a/genlib2/generic_utils.go b/genlib2/generic_utils.go index 7c207fa..8d5f87b 100644 --- a/genlib2/generic_utils.go +++ b/genlib2/generic_utils.go @@ -8,7 +8,7 @@ import ( const rangeRaw = `// Range creates a ranged array with a given type. It panics if the Dtype is not supported or does not represent a naturally orderable type (strings, pointers etc) // Do note that the range algorithm is very simple, and simply does increments or decrements of 1. This means for floating point types // you're not able to create a range with a 0.1 increment step, and for complex number types, the imaginary part will always be 0i -func Range(dt Dtype, start, end int) interface{} { +func Range(dt dtype.Dtype, start, end int) interface{} { size := end - start incr := true if start > end { @@ -58,8 +58,8 @@ func Range(dt Dtype, start, end int) interface{} { const randomRaw = `// Random creates an array of random numbers of the given type. // For complex Dtypes, the imaginary component will be 0. // -// This function is only useful in cases where the randomness is not vital. -func Random(dt Dtype, size int) interface{} { +// This function is only useful in cases where the randomness is not vital. +func Random(dt dtype.Dtype, size int) interface{} { r := rand.New(rand.NewSource(1337)) switch dt.Kind() { {{range .Kinds -}} diff --git a/genlib2/main.go b/genlib2/main.go index 46327c4..f9923dc 100644 --- a/genlib2/main.go +++ b/genlib2/main.go @@ -116,15 +116,25 @@ func main() { pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests) pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests) - // native iterators - pipeline(nativePkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators) - pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests) - pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect) - pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests) + // native iterators - the ones in the tensor package + pipeline(tensorPkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators(false)) + pipeline(tensorPkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(false)) + pipeline(tensorPkgLoc, "select_native.go", Kinds{allKinds}, generateNativeSelect(false)) + pipeline(tensorPkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(false)) + + // native iterators - exported into gorgonia.org/tensor/native + pipeline(nativePkgLoc+"_unsafe", "iterator_native.go", Kinds{allKinds}, generateNativeIteratorStubs) + pipeline(nativePkgLoc+"_purego", "iterator_native_purego.go", Kinds{allKinds}, generateNativeIterators(true)) + pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests(true)) + pipeline(nativePkgLoc+"_unsafe", "select_native.go", Kinds{allKinds}, generateNativeSelectStubs) + pipeline(nativePkgLoc+"_purego", "select_native_purego.go", Kinds{allKinds}, generateNativeSelect(true)) + pipeline(nativePkgLoc, "select_native_test.go", Kinds{allKinds}, generateNativeSelectTests(true)) + pipeline(nativePkgLoc, "utils.go", Kinds{allKinds}, generateNativeIterChecks, generateNativeSelChecks) } func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { - fullpath := path.Join(pkg, filename) + pkgpath := strings.Replace(strings.Replace(pkg, "_unsafe", "", -1), "_purego", "", -1) + fullpath := path.Join(pkgpath, filename) f, err := os.Create(fullpath) if err != nil { log.Printf("fullpath %q", fullpath) diff --git a/genlib2/native_iterator.go b/genlib2/native_iterator.go index 565d9e9..1d7a85c 100644 --- a/genlib2/native_iterator.go +++ b/genlib2/native_iterator.go @@ -3,10 +3,11 @@ package main import ( "fmt" "io" + "reflect" "text/template" ) -const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dtype) error { +const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { // checks: if !t.IsNativelyAccessible() { return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") @@ -28,35 +29,45 @@ const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dty } ` -const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}} +const nativeIterRaw = ` +{{- $vecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $matName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $T3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- if .N -}} + {{- $vecName = ( printf "Vector%s" (short .K) ) -}} + {{- $matName = ( printf "Matrix%s" (short .K) ) -}} + {{- $T3Name = ( printf "Tensor3%s" (short .K) ) -}} +{{- end -}} + +// {{$vecName}} converts a *Dense into a []{{asType .K}} // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { - if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil { +func {{$vecName}}(t *Dense) (retVal []{{asType .K}}, err error) { + if err = checkNativeIterable(t, 1, {{reflectKind .K}}); err != nil { return nil, err } - return t.{{sliceOf .}}, nil + return t.{{sliceOf .K}}, nil } -// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// {{$matName}} converts a *Dense into a [][]{{asType .K}} // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { - if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil { +func {{$matName}}(t *Dense) (retVal [][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 2, {{reflectKind .K}}); err != nil { return nil, err } - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} shape := t.Shape() strides := t.Strides() rows := shape[0] cols := shape[1] rowStride := strides[0] - retVal = make([][]{{asType .}}, rows) + retVal = make([][]{{asType .K}}, rows) for i := range retVal { start := i * rowStride - retVal[i] = make([]{{asType .}}, 0) + retVal[i] = make([]{{asType .K}}, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) hdr.Cap = cols @@ -65,14 +76,14 @@ func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { return } -// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// {{$T3Name}} converts a *Dense into a [][][]{{asType .K}}. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { - if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { +func {{$T3Name}}(t *Dense) (retVal [][][]{{asType .K}}, err error) { + if err = checkNativeIterable(t, 3, {{reflectKind .K}}); err != nil { return nil, err } - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} shape := t.Shape() strides := t.Strides() @@ -81,11 +92,11 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { cols := shape[2] layerStride := strides[0] rowStride := strides[1] - retVal = make([][][]{{asType .}}, layers) + retVal = make([][][]{{asType .K}}, layers) for i := range retVal { - retVal[i] = make([][]{{asType .}}, rows) + retVal[i] = make([][]{{asType .K}}, rows) for j := range retVal[i] { - retVal[i][j] = make([]{{asType .}}, 0) + retVal[i][j] = make([]{{asType .K}}, 0) start := i*layerStride + j*rowStride hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) hdr.Data = uintptr(unsafe.Pointer(&data[start])) @@ -97,15 +108,57 @@ func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { } ` -const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { +const nativeIterStubsRaw = `//go:linkname Vector{{short .}} gorgonia.org/tensor.nativeDenseVector{{short .}} + +// Vector{{short .}} converts a *Dense into a []{{asType .}} +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func Vector{{short .}}(t *tensor.Dense) (retVal []{{asType .}}, err error) + +//go:linkname Matrix{{short .}} gorgonia.org/tensor.nativeDenseMatrix{{short .}} + +// Matrix{{short .}} converts a *Dense into a [][]{{asType .}} +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func Matrix{{short .}}(t *tensor.Dense) (retVal [][]{{asType .}}, err error) + +//go:linkname Tensor3{{short .}} gorgonia.org/tensor.nativeDenseTensor3{{short .}} + +// Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3{{short .}}(t *tensor.Dense) (retVal [][][]{{asType .}}, err error) +` + +const nativeIterTestRaw = ` +{{- $pkgTVecName := ( printf "nativeDenseVector%s" (short .K) ) -}} +{{- $pkgTMatName := ( printf "nativeDenseMatrix%s" (short .K) ) -}} +{{- $pkgTT3Name := ( printf "nativeDenseTensor3%s" (short .K) ) -}} +{{- $pkgNVecName := ( printf "Vector%s" (short .K) ) -}} +{{- $pkgNMatName := ( printf "Matrix%s" (short .K) ) -}} +{{- $pkgNT3Name := ( printf "Tensor3%s" (short .K) ) -}} +{{- $vecName := "" -}} +{{- $matName := "" -}} +{{- $T3Name := "" -}} +{{- if .N -}} + {{- $vecName = $pkgNVecName -}} + {{- $matName = $pkgNMatName -}} + {{- $T3Name = $pkgNT3Name -}} +{{- else -}} + {{- $vecName = $pkgTVecName -}} + {{- $matName = $pkgTMatName -}} + {{- $T3Name = $pkgTT3Name -}} +{{end -}} + + +func Test_{{$vecName}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(6)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(6)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(6)) + T = New(Of({{reflectKind .K}}), WithShape(6)) {{end -}} - it, err := Vector{{short .}}(T) + it, err := {{$vecName}}(T) if err != nil { t.Fatal(err) } @@ -113,15 +166,15 @@ const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { assert.Equal(6, len(it)) } -func Test_Matrix{{short .}}(t *testing.T) { +func Test_{{$matName}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(2, 3)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 6)), WithShape(2, 3)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(2, 3)) + T = New(Of({{reflectKind .K}}), WithShape(2, 3)) {{end -}} - it, err := Matrix{{short .}}(T) + it, err := {{$matName}}(T) if err != nil { t.Fatal(err) } @@ -130,15 +183,15 @@ func Test_Matrix{{short .}}(t *testing.T) { assert.Equal(3, len(it[0])) } -func Test_Tensor3{{short .}}(t *testing.T) { +func Test_{{$T3Name}}(t *testing.T) { assert := assert.New(t) var T *Dense - {{if isRangeable . -}} - T = New(WithBacking(Range({{reflectKind .}}, 0, 24)), WithShape(2, 3, 4)) + {{if isRangeable .K -}} + T = New(WithBacking(Range({{reflectKind .K}}, 0, 24)), WithShape(2, 3, 4)) {{else -}} - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4)) + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4)) {{end -}} - it, err := Tensor3{{short .}}(T) + it, err := {{$T3Name}}(T) if err != nil { t.Fatal(err) } @@ -150,31 +203,68 @@ func Test_Tensor3{{short .}}(t *testing.T) { ` var ( - NativeIter *template.Template - NativeIterTest *template.Template + NativeIter *template.Template + NativeIterTest *template.Template + NativeIterStubs *template.Template ) func init() { NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw)) NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw)) + NativeIterStubs = template.Must(template.New("NativeStubs").Funcs(funcs).Parse(nativeIterStubsRaw)) } -func generateNativeIterators(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) - fmt.Fprintf(f, "%v\n", checkNativeiterable) - ks := filter(ak.Kinds, isSpecialized) - for _, k := range ks { - fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) - NativeIter.Execute(f, k) - fmt.Fprint(f, "\n\n") +// generateNativeIterators generates the code for native iterators. `isNative` represents whether the code is generated for the `native` package or not. +// isNative will only be true for the `purego` build tag. +func generateNativeIterators(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + // checkNativeIteratble is separately generated and placed into util.go in the `native` package + // so there is no need to generate that here. + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeiterable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) + NativeIter.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } } } -func generateNativeIteratorTests(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) +func generateNativeIteratorTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeIterTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeIteratorStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work ks := filter(ak.Kinds, isSpecialized) for _, k := range ks { - NativeIterTest.Execute(f, k) + NativeIterStubs.Execute(f, k) fmt.Fprint(f, "\n\n") } } + +func generateNativeIterChecks(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnqualifiedTensor) + fmt.Fprintf(f, "%v\n", checkNativeiterable) +} diff --git a/genlib2/native_select.go b/genlib2/native_select.go index 6b1e277..1095668 100644 --- a/genlib2/native_select.go +++ b/genlib2/native_select.go @@ -3,10 +3,11 @@ package main import ( "fmt" "io" + "reflect" "text/template" ) -const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { +const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { if !t.IsNativelyAccessible() { return errors.New("Cannot select on non-natively accessible data") } @@ -22,29 +23,35 @@ const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt return nil } ` -const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64. -func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) { - if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil { +const nativeSelectRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} + +// {{$selName}} creates a slice of flat data types. See Example of NativeSelectF64. +func {{$selName}}(t *Dense, axis int) (retVal [][]{{asType .K}}, err error) { + if err := checkNativeSelectable(t, axis, {{reflectKind .K}}); err != nil { return nil, err } switch t.Shape().Dims() { case 0, 1: - retVal = make([][]{{asType .}}, 1) - retVal[0] = t.{{sliceOf .}} + retVal = make([][]{{asType .K}}, 1) + retVal[0] = t.{{sliceOf .K}} case 2: if axis == 0 { - return Matrix{{short .}}(t) + return {{if .N}}Matrix{{short .K}}{{else}}nativeDenseMatrix{{short .K}}{{end}}(t) } fallthrough default: // size := t.Shape()[axis] - data := t.{{sliceOf .}} + data := t.{{sliceOf .K}} stride := t.Strides()[axis] upper := ProdInts(t.Shape()[:axis+1]) - retVal = make([][]{{asType .}}, 0, upper) + retVal = make([][]{{asType .K}}, 0, upper) for i, r := 0, 0; r < upper; i += stride { - s := make([]{{asType .}}, 0) + s := make([]{{asType .K}}, 0) hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) hdr.Data = uintptr(unsafe.Pointer(&data[i])) hdr.Len = stride @@ -58,85 +65,132 @@ func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) return } ` -const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) { +const nativeSelectTestRaw = ` +{{- $selName := ( printf "nativeSelect%s" (short .K) ) -}} +{{- if .N -}} + {{- $selName = ( printf "Select%s" (short .K) ) -}} +{{- end -}} +func Test{{$selName}}(t *testing.T) { assert := assert.New(t) var T *Dense var err error - var x [][]{{asType .}} - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 1); err != nil { + var x [][]{{asType .K}} + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 1); err != nil { t.Fatal(err) } assert.Equal(6, len(x)) assert.Equal(20, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(2, len(x)) assert.Equal(60, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) - if x, err = Select{{short .}}(T, 3); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3, 4, 5), ) + if x, err = {{$selName}}(T, 3); err != nil { t.Fatal(err) } assert.Equal(120, len(x)) assert.Equal(1, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3), ) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(2, len(x)) assert.Equal(3, len(x[0])) - T = New(Of({{reflectKind .}}), WithShape(2, 3), ) - if x, err = Select{{short .}}(T, 1); err != nil { + T = New(Of({{reflectKind .K}}), WithShape(2, 3), ) + if x, err = {{$selName}}(T, 1); err != nil { t.Fatal(err) } assert.Equal(6, len(x)) assert.Equal(1, len(x[0])) - T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} )) - if x, err = Select{{short .}}(T, 0); err != nil { + T = New(FromScalar({{if eq .K.String "bool" -}}false{{else if eq .K.String "string" -}}""{{else -}}{{asType .K}}(0) {{end -}} )) + if x, err = {{$selName}}(T, 0); err != nil { t.Fatal(err) } assert.Equal(1, len(x)) assert.Equal(1, len(x[0])) - if _, err = Select{{short .}}(T, 10); err == nil{ + if _, err = {{$selName}}(T, 10); err == nil{ t.Fatal("Expected errors") } } ` +const nativeSelectStubsRaw = `//go:linkname Select{{short .}} gorgonia.org/tensor.nativeSelect{{short .}} + +// Select{{short .}} creates a slice of {{asType .}}s. See Example of NativeSelectF64. +func Select{{short .}}(t *tensor.Dense, axis int) (retVal [][]{{asType .}}, err error) +` + var ( - NativeSelect *template.Template - NativeSelectTest *template.Template + NativeSelect *template.Template + NativeSelectTest *template.Template + NativeSelectStubs *template.Template ) func init() { NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw)) NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw)) + NativeSelectStubs = template.Must(template.New("NativeSelectStub").Funcs(funcs).Parse(nativeSelectStubsRaw)) } -func generateNativeSelect(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) - fmt.Fprintf(f, "%v\n", checkNativeSelectable) - ks := filter(ak.Kinds, isSpecialized) - for _, k := range ks { - fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) - NativeSelect.Execute(f, k) - fmt.Fprint(f, "\n\n") +// generateNativeSelect generates code for the native selection. `isNative` indicates if the +// code is meant to be generated for the native package. The code is generated for the native package +// only for the purposes of the `purego` build tag. +func generateNativeSelect(isNative bool) func(io.Writer, Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } else { + fmt.Fprintf(f, "%v\n", checkNativeSelectable) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) + NativeSelect.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } } } -func generateNativeSelectTests(f io.Writer, ak Kinds) { - fmt.Fprintf(f, importUnqualifiedTensor) +func generateNativeSelectTests(isNative bool) func(f io.Writer, ak Kinds) { + type IterTup struct { + N bool + K reflect.Kind + } + return func(f io.Writer, ak Kinds) { + if isNative { + fmt.Fprintf(f, importUnqualifiedTensor) + } + ks := filter(ak.Kinds, isSpecialized) + for _, k := range ks { + NativeSelectTest.Execute(f, IterTup{N: isNative, K: k}) + fmt.Fprint(f, "\n\n") + } + } +} + +func generateNativeSelectStubs(f io.Writer, ak Kinds) { + fmt.Fprintf(f, importUnsafe) // this is required for go:linkname to work ks := filter(ak.Kinds, isSpecialized) for _, k := range ks { - NativeSelectTest.Execute(f, k) - fmt.Fprint(f, "\n\n") + NativeSelectStubs.Execute(f, k) + fmt.Fprintf(f, "\n\n") } } + +func generateNativeSelChecks(f io.Writer, ak Kinds) { + // fmt.Fprintf(f, importUnqualifiedTensor) // already generated by generateNativeIterChecks + fmt.Fprintf(f, "%v\n", checkNativeSelectable) +} diff --git a/genlib2/package.go b/genlib2/package.go index 4380b6b..8ffcf79 100644 --- a/genlib2/package.go +++ b/genlib2/package.go @@ -8,17 +8,27 @@ import ( func writePkgName(f io.Writer, pkg string) { switch pkg { case tensorPkgLoc: - fmt.Fprintf(f, "// %s\n\npackage tensor\n\n", genmsg) + fmt.Fprintf(f, "package tensor\n\n // %s\n\n", genmsg) case nativePkgLoc: - fmt.Fprintf(f, "// %s\n\npackage native\n\n", genmsg) + fmt.Fprintf(f, "package native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_unsafe": + fmt.Fprintf(f, "// +build !purego \n\npackage native\n\n // %s\n\n", genmsg) + case nativePkgLoc + "_purego": + fmt.Fprintf(f, "// +build purego \n\npackage native\n\n // %s\n\n", genmsg) case execLoc: - fmt.Fprintf(f, "// %s\n\npackage execution\n\n", genmsg) + fmt.Fprintf(f, "package execution\n\n // %s\n\n", genmsg) case storageLoc: - fmt.Fprintf(f, "// %s\n\npackage storage\n\n", genmsg) + fmt.Fprintf(f, "package storage\n\n // %s\n\n", genmsg) default: - fmt.Fprintf(f, "// %s\n\npackage unknown\n\n", genmsg) + fmt.Fprintf(f, "package unknown\n\n %s\n\n", genmsg) } } const importUnqualifiedTensor = `import . "gorgonia.org/tensor" ` + +const importInternalNative = `import inative "gorgonia.org/tensor/internal/native" +` + +const importUnsafe = `import _ "unsafe" +` diff --git a/genlib2/testutils.go b/genlib2/testutils.go index 177333f..c7dbe81 100644 --- a/genlib2/testutils.go +++ b/genlib2/testutils.go @@ -90,7 +90,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { s[i] = randomString() {{else if eq .String "unsafe.Pointer" -}} s[i] = nil - {{end -}} + {{end -}} } {{end -}} {{end -}} @@ -99,7 +99,7 @@ const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { ` const testQCRaw = `type QCDense{{short .}} struct { - *Dense + *Dense } func (*QCDense{{short .}}) Generate(r *rand.Rand, size int) reflect.Value { s := make([]{{asType .}}, size) @@ -137,11 +137,11 @@ const mutateFnsRaw = `func mutate{{short .}}(a {{asType . }}){{asType .}} { {{if {{else if eq .String "bool" -}}return true } {{else if eq .String "string" -}}return "Hello World"} {{else if eq .String "uintptr" -}}return 0xdeadbeef} -{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} -{{end -}} +{{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} +{{end -}} ` -const identityValsRaw = `func identityVal(x int, dt Dtype) interface{} { +const identityValsRaw = `func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { {{range .Kinds -}} case {{reflectKind .}}: diff --git a/genlib2/unary_tests.go b/genlib2/unary_tests.go index dedd02d..5153f2b 100644 --- a/genlib2/unary_tests.go +++ b/genlib2/unary_tests.go @@ -1,151 +1,151 @@ -package main - -import ( - "fmt" - "io" - "text/template" -) - -const unaryTestBodyRaw = `invFn := func(q *Dense) bool { - a := q.Clone().(*Dense) - {{template "funcoptdecl" -}} - correct := a.Clone().(*Dense) - {{template "funcoptcorrect" -}} - - - we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) - _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok - - ret, err := {{.Name}}(a {{template "funcoptuse"}}) - if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ - if err != nil { - return false - } - return true - } - {{if ne .InvTypeClass "" -}} - if err := typeclassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { - return true // uninvertible due to type class implementation issues - } - {{end -}} - {{if eq .FuncOpt "incr" -}} - if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { - t.Errorf("err while subtracting incr: %v", err) - return false - } - {{end -}} - {{.Inv}}(ret, UseUnsafe()) - if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { - return false - } - {{template "funcoptcheck" -}} - return true -} - -if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ - t.Errorf("Inv tests for {{.Name}} failed: %v", err) -} -` - -type unaryTest struct { - unaryOp - FuncOpt string - EqFailTypeClassName string - InvTypeClass string -} - -func (fn *unaryTest) Name() string { - if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { - return "El" + fn.unaryOp.Name() - } - return fn.unaryOp.Name() -} - -func (fn *unaryTest) Signature() *Signature { - name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) - if fn.FuncOpt != "" { - name += "_" + fn.FuncOpt - } - return &Signature{ - Name: name, - NameTemplate: plainName, - ParamNames: []string{"t"}, - ParamTemplates: []*template.Template{testingType}, - } -} - -func (fn *unaryTest) WriteBody(w io.Writer) { - t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) - template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) - template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) - template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) - template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) - t.Execute(w, fn) -} - -func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } - -func (fn *unaryTest) Write(w io.Writer) { - sig := fn.Signature() - w.Write([]byte("func ")) - sig.Write(w) - w.Write([]byte("{\n")) - fn.WriteBody(w) - w.Write([]byte("}\n")) -} - -func generateAPIUnaryTests(f io.Writer, ak Kinds) { - var tests []*unaryTest - for _, op := range conditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - - tests = append(tests, t) - } - - for _, op := range unconditionalUnaries { - t := &unaryTest{ - unaryOp: op, - EqFailTypeClassName: "nil", - } - switch op.name { - case "Square": - t.InvTypeClass = "floatcmplxTypes" - case "Cube": - t.InvTypeClass = "floatTypes" - } - - tests = append(tests, t) - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "unsafe" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "reuse" - } - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - fn.FuncOpt = "incr" - } - - // for now incr cannot be quickchecked - - for _, fn := range tests { - if fn.canWrite() { - fn.Write(f) - } - } -} +package main + +import ( + "fmt" + "io" + "text/template" +) + +const unaryTestBodyRaw = `invFn := func(q *Dense) bool { + a := q.Clone().(*Dense) + {{template "funcoptdecl" -}} + correct := a.Clone().(*Dense) + {{template "funcoptcorrect" -}} + + + we, willFailEq := willerr(a, {{.TypeClassName}}, {{.EqFailTypeClassName}}) + _, ok := q.Engine().({{interfaceName .Name}}); we = we || !ok + + ret, err := {{.Name}}(a {{template "funcoptuse"}}) + if err, retEarly := qcErrCheck(t, "{{.Name}}", a, nil, we, err); retEarly{ + if err != nil { + return false + } + return true + } + {{if ne .InvTypeClass "" -}} + if err := dtype.TypeClassCheck(a.Dtype(), {{.InvTypeClass}}); err != nil { + return true // uninvertible due to type class implementation issues + } + {{end -}} + {{if eq .FuncOpt "incr" -}} + if ret, err = Sub(ret, identityVal(100, a.Dtype()), UseUnsafe()) ; err != nil { + t.Errorf("err while subtracting incr: %v", err) + return false + } + {{end -}} + {{.Inv}}(ret, UseUnsafe()) + if !qcEqCheck(t, a.Dtype(), willFailEq, correct.Data(), ret.Data()) { + return false + } + {{template "funcoptcheck" -}} + return true +} + +if err := quick.Check(invFn, &quick.Config{Rand:newRand(), MaxCount: quickchecks}); err != nil{ + t.Errorf("Inv tests for {{.Name}} failed: %v", err) +} +` + +type unaryTest struct { + unaryOp + FuncOpt string + EqFailTypeClassName string + InvTypeClass string +} + +func (fn *unaryTest) Name() string { + if fn.unaryOp.Name() == "Eq" || fn.unaryOp.Name() == "Ne" { + return "El" + fn.unaryOp.Name() + } + return fn.unaryOp.Name() +} + +func (fn *unaryTest) Signature() *Signature { + name := fmt.Sprintf("Test%s", fn.unaryOp.Name()) + if fn.FuncOpt != "" { + name += "_" + fn.FuncOpt + } + return &Signature{ + Name: name, + NameTemplate: plainName, + ParamNames: []string{"t"}, + ParamTemplates: []*template.Template{testingType}, + } +} + +func (fn *unaryTest) WriteBody(w io.Writer) { + t := template.Must(template.New("unary test body").Funcs(funcs).Parse(unaryTestBodyRaw)) + template.Must(t.New("funcoptdecl").Parse(funcOptDecl[fn.FuncOpt])) + template.Must(t.New("funcoptcorrect").Parse(funcOptCorrect[fn.FuncOpt])) + template.Must(t.New("funcoptuse").Parse(funcOptUse[fn.FuncOpt])) + template.Must(t.New("funcoptcheck").Parse(funcOptCheck[fn.FuncOpt])) + t.Execute(w, fn) +} + +func (fn *unaryTest) canWrite() bool { return fn.Inv != "" } + +func (fn *unaryTest) Write(w io.Writer) { + sig := fn.Signature() + w.Write([]byte("func ")) + sig.Write(w) + w.Write([]byte("{\n")) + fn.WriteBody(w) + w.Write([]byte("}\n")) +} + +func generateAPIUnaryTests(f io.Writer, ak Kinds) { + var tests []*unaryTest + for _, op := range conditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + + tests = append(tests, t) + } + + for _, op := range unconditionalUnaries { + t := &unaryTest{ + unaryOp: op, + EqFailTypeClassName: "nilTC", + } + switch op.name { + case "Square": + t.InvTypeClass = "dtype.FloatComplex" + case "Cube": + t.InvTypeClass = "dtype.Floats" + } + + tests = append(tests, t) + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "unsafe" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "reuse" + } + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + fn.FuncOpt = "incr" + } + + // for now incr cannot be quickchecked + + for _, fn := range tests { + if fn.canWrite() { + fn.Write(f) + } + } +} diff --git a/go.mod b/go.mod index f43f495..bcb9359 100644 --- a/go.mod +++ b/go.mod @@ -2,23 +2,30 @@ module gorgonia.org/tensor go 1.18 +replace gorgonia.org/dtype => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/dtype + +replace gorgonia.org/shapes => /home/chewxy/workspace/gorgoniaws/src/gorgonia.org/shapes + require ( github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc - github.com/chewxy/hm v1.0.0 + github.com/chewxy/hm v1.0.0 // indirect github.com/chewxy/math32 v1.0.8 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.4.3 github.com/google/flatbuffers v1.12.0 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 gonum.org/v1/gonum v0.8.2 + gorgonia.org/dtype v0.0.0-00010101000000-000000000000 + gorgonia.org/shapes v0.0.0-00010101000000-000000000000 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) require ( github.com/davecgh/go-spew v1.1.0 // indirect + github.com/google/gofuzz v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index 524d845..60fa6ad 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -55,8 +57,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/interfaces.go b/interfaces.go index e33502f..7061997 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,12 +3,13 @@ package tensor import ( "reflect" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) // Dtyper is any type that has a Dtype type Dtyper interface { - Dtype() Dtype + Dtype() dtype.Dtype } // Eq is any type where you can perform an equality test @@ -71,6 +72,22 @@ type Slicer interface { Slice(...Slice) (View, error) } +// SlicerInto is any tensor that can slice into another tensor. +// The other tensor may already have data allocated in it. +// If that is the case then the slice will be a copy operation. +type SlicerInto interface { + SliceInto(view Tensor, slices ...Slice) (retVal Tensor, err error) +} + +// Reslicer is any tensor that can reslice. +// To reslice is to reuse the container (*Dense, *CS) etc, but with new `Slice`s applied to it. +// +// e.g: A is a (3,3) matrix that has been sliced at [1:3, 1:3]. Call it B. So now B's shape is (2,2). +// B.Reslice(S(0,2), S(0,2)) would reslice the original tensor (A) with the new slices. +type Reslicer interface { + Reslice(...Slice) (View, error) +} + // DenseTensor is the interface for any Dense tensor. type DenseTensor interface { Tensor @@ -131,6 +148,11 @@ type Kinder interface { Kind() reflect.Kind } +// MakeAliker is any Tensor that can make more like itself. +type MakeAliker interface { + MakeAike(opts ...ConsOpt) Tensor +} + type headerer interface { hdr() *storage.Header } @@ -150,3 +172,11 @@ type unsafeMem interface { Complex64s() []complex64 Complex128s() []complex128 } + +type float64ser interface { + Float64s() []float64 +} + +type float32ser interface { + Float32s() []float32 +} diff --git a/internal/execution/eng_argmethods.go b/internal/execution/eng_argmethods.go index 05ed725..9adc173 100644 --- a/internal/execution/eng_argmethods.go +++ b/internal/execution/eng_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ArgmaxIter(t reflect.Type, a *storage.Header, it Iterator, lastSize int) (indices []int, err error) { var next int switch t { diff --git a/internal/execution/eng_arith.go b/internal/execution/eng_arith.go index bc0af43..9b193ba 100644 --- a/internal/execution/eng_arith.go +++ b/internal/execution/eng_arith.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Add(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) diff --git a/internal/execution/eng_cmp.go b/internal/execution/eng_cmp.go index b2c4ece..e5d3dd5 100644 --- a/internal/execution/eng_cmp.go +++ b/internal/execution/eng_cmp.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Gt(t reflect.Type, a *storage.Header, b *storage.Header, retVal *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) diff --git a/internal/execution/eng_map.go b/internal/execution/eng_map.go index 81cb2c4..ecd2b64 100644 --- a/internal/execution/eng_map.go +++ b/internal/execution/eng_map.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Map(t reflect.Type, fn interface{}, a *storage.Header, incr bool) (err error) { as := isScalar(a, t) switch t { diff --git a/internal/execution/eng_minmaxbetween.go b/internal/execution/eng_minmaxbetween.go index 5d31706..8c41606 100644 --- a/internal/execution/eng_minmaxbetween.go +++ b/internal/execution/eng_minmaxbetween.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) MaxBetween(t reflect.Type, a *storage.Header, b *storage.Header) (err error) { as := isScalar(a, t) bs := isScalar(b, t) diff --git a/internal/execution/eng_reduce.go b/internal/execution/eng_reduce.go index 88c7ae5..bebe52f 100644 --- a/internal/execution/eng_reduce.go +++ b/internal/execution/eng_reduce.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -10,6 +8,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) ReduceFirst(t reflect.Type, data *storage.Header, retVal *storage.Header, split int, size int, fn interface{}) (err error) { switch t { case Bool: diff --git a/internal/execution/eng_unary.go b/internal/execution/eng_unary.go index bd9bd81..4038190 100644 --- a/internal/execution/eng_unary.go +++ b/internal/execution/eng_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func (e E) Neg(t reflect.Type, a *storage.Header) (err error) { switch t { case Int: diff --git a/internal/execution/generic_argmethods.go b/internal/execution/generic_argmethods.go index 3edb606..cdf4b7d 100644 --- a/internal/execution/generic_argmethods.go +++ b/internal/execution/generic_argmethods.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -8,6 +6,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func ArgmaxI(a []int) int { var set bool var f int diff --git a/internal/execution/generic_arith_mixed.go b/internal/execution/generic_arith_mixed.go index 6e8aa72..94f5e8b 100644 --- a/internal/execution/generic_arith_mixed.go +++ b/internal/execution/generic_arith_mixed.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func AddSVI(a int, b []int) { for i := range b { b[i] = a + b[i] diff --git a/internal/execution/generic_arith_vv.go b/internal/execution/generic_arith_vv.go index e2a3c46..a9f3a7a 100644 --- a/internal/execution/generic_arith_vv.go +++ b/internal/execution/generic_arith_vv.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -11,6 +9,8 @@ import ( "gorgonia.org/vecf64" ) +// Code generated by genlib2. DO NOT EDIT. + func VecAddI(a []int, b []int) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_cmp_mixed.go b/internal/execution/generic_cmp_mixed.go index b9a1154..1c53747 100644 --- a/internal/execution/generic_cmp_mixed.go +++ b/internal/execution/generic_cmp_mixed.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtSVI(a int, b []int, retVal []bool) { for i := range retVal { retVal[i] = a > b[i] diff --git a/internal/execution/generic_cmp_vv.go b/internal/execution/generic_cmp_vv.go index 7d528c4..a501f93 100644 --- a/internal/execution/generic_cmp_vv.go +++ b/internal/execution/generic_cmp_vv.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func GtI(a []int, b []int, retVal []bool) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_map.go b/internal/execution/generic_map.go index 41c7de8..f054239 100644 --- a/internal/execution/generic_map.go +++ b/internal/execution/generic_map.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func MapB(fn func(bool) bool, a []bool) { for i := range a { a[i] = fn(a[i]) diff --git a/internal/execution/generic_minmax.go b/internal/execution/generic_minmax.go index 8398d5f..011645b 100644 --- a/internal/execution/generic_minmax.go +++ b/internal/execution/generic_minmax.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution +// Code generated by genlib2. DO NOT EDIT. + func VecMinI(a, b []int) { a = a[:] b = b[:len(a)] diff --git a/internal/execution/generic_reduce.go b/internal/execution/generic_reduce.go index a489f1c..ef94057 100644 --- a/internal/execution/generic_reduce.go +++ b/internal/execution/generic_reduce.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + func ReduceB(f func(a, b bool) bool, def bool, l ...bool) (retVal bool) { retVal = def if len(l) == 0 { diff --git a/internal/execution/generic_unary.go b/internal/execution/generic_unary.go index cb3f87f..7c05acd 100644 --- a/internal/execution/generic_unary.go +++ b/internal/execution/generic_unary.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "github.com/chewxy/math32" ) +// Code generated by genlib2. DO NOT EDIT. + func NegI(a []int) { for i := range a { a[i] = -a[i] diff --git a/internal/execution/reduction_specialization.go b/internal/execution/reduction_specialization.go index e83e67e..90cfe69 100644 --- a/internal/execution/reduction_specialization.go +++ b/internal/execution/reduction_specialization.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package execution import ( @@ -9,6 +7,8 @@ import ( "gorgonia.org/tensor/internal/storage" ) +// Code generated by genlib2. DO NOT EDIT. + func MonotonicSum(t reflect.Type, a *storage.Header) (retVal interface{}, err error) { switch t { case Int: diff --git a/internal/storage/consts.go b/internal/storage/consts.go index 7304ac5..b6e03cc 100644 --- a/internal/storage/consts.go +++ b/internal/storage/consts.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package storage import ( @@ -7,6 +5,8 @@ import ( "unsafe" ) +// Code generated by genlib2. DO NOT EDIT. + var ( bType = reflect.TypeOf(bool(false)) iType = reflect.TypeOf(int(0)) diff --git a/internal/storage/getset.go b/internal/storage/getset.go index c60d61c..89421f0 100644 --- a/internal/storage/getset.go +++ b/internal/storage/getset.go @@ -1,9 +1,9 @@ -// Code generated by genlib2. DO NOT EDIT. - package storage import "unsafe" +// Code generated by genlib2. DO NOT EDIT. + /* bool */ func (h *Header) Bools() []bool { diff --git a/internal/storage/header.go b/internal/storage/header.go index 93f67e7..65f1069 100644 --- a/internal/storage/header.go +++ b/internal/storage/header.go @@ -3,6 +3,8 @@ package storage // import "gorgonia.org/tensor/internal/storage" import ( "reflect" "unsafe" + + _ "go4.org/unsafe/assume-no-moving-gc" ) // Header is runtime representation of a slice. It's a cleaner version of reflect.SliceHeader. diff --git a/iterator_axial.go b/iterator_axial.go new file mode 100644 index 0000000..493ac21 --- /dev/null +++ b/iterator_axial.go @@ -0,0 +1,170 @@ +package tensor + +// AxialIterator iterates based on a given axis +type AxialIterator struct { + *AP + axis int // the axis to iterate along + + // state + axisSz int // if an axis is of size N, then axisSz indicates the current num (0 - N). + nextIndex int + lastIndex int + track []int + isReverse bool + done bool + fixed bool +} + +// AxialIteratorFromDense creates and axial iterator that will iterate along the given axis. `fixedAxis` defines if the axisSz is fixed. +func AxialIteratorFromDense(t DenseTensor, axis, axisSz int, fixedAxis bool) *AxialIterator { + ap := t.Info() + return &AxialIterator{ + AP: ap, + track: make([]int, len(ap.shape)), + axis: axis, + axisSz: axisSz, + fixed: fixedAxis, + } +} + +// Start returns the first index +func (it *AxialIterator) Start() (retVal int, err error) { + it.Reset() + + // compute the nextIndex + if it.fixed { + it.track[it.axis] = it.axisSz + it.nextIndex, err = Ltoi(it.shape, it.strides, it.track...) + } + + return it.Next() +} + +// Next returns the next index. +// Example: let's say we're iterating on a tensor with the following +// shape: (2, 3, 4); axis: 1 +// At the start, the coordinates are: +// coordinates: (0, 0, 0) +// Next() will yield: +// coordinates: (0, 0, 1) +// But when the coordinates are: +// coordinates: (0, 0, 4) +// Next() will yield: +// coordinates: (1, 0, 0). +// Note that axis 1 is frozen at 0. +func (it *AxialIterator) Next() (int, error) { + if it.done { + return -1, noopError{} + } + + switch { + case it.isReverse: + return it.ndPrevious() + default: + return it.ndNext() + } + +} + +func (it *AxialIterator) ndNext() (int, error) { + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + track := it.track[:v+1] // force bounds check + coord := it.shape[:v+1] // force bounds check + strides := it.strides[:v+1] // fource bounds check + sz := it.axisSz + track[it.axis] = sz + + for i := v; i >= 0; i-- { + if i == it.axis { + if i == 0 { + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + it.axisSz++ + track[it.axis] = it.axisSz + } + continue // we're iterating along an axis. + } + track[i]++ + shapeI := coord[i] + strideI := strides[i] + if track[i] == shapeI { + track[i] = 0 + nextIndex -= (shapeI - 1) * strideI + if i == 0 { + it.axisSz++ + track[it.axis] = it.axisSz + + if it.fixed || track[it.axis] == coord[it.axis] || it.axisSz >= coord[it.axis] { + track[it.axis] = 0 + it.done = true + break + } + + nextIndex = track[it.axis] * strides[it.axis] + } + + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil +} + +func (it *AxialIterator) ndPrevious() (int, error) { + panic("Not yet implemented") +} + +// NextValidity is like Next, but returns the validity of the value at the index as well. +func (it *AxialIterator) NextValidity() (int, bool, error) { + i, err := it.Next() + return i, true, err +} + +// NextValid returns the next valid index, as well as a skip count. +func (it *AxialIterator) NextValid() (int, int, error) { + if it.done { + return -1, 1, noopError{} + } + + switch { + case it.isReverse: + a, err := it.ndPrevious() + return a, -1, err + default: + a, err := it.ndNext() + return a, 1, err + } +} + +// NextInvalid returns the next invalid index, as well as a skip count. +func (it *AxialIterator) NextInvalid() (int, int, error) { + panic("not implemented") // TODO: Implement +} + +// Reset resets the iterator +func (it *AxialIterator) Reset() { + it.nextIndex = 0 + for i := range it.track { + it.track[i] = 0 + } +} + +// SetReverse tells the iterator to iterate in reverse +func (it *AxialIterator) SetReverse() { it.isReverse = true } + +// SetForward tells the iterator to iterate forwards +func (it *AxialIterator) SetForward() { it.isReverse = false } + +// Coord returns the coordinates +func (it *AxialIterator) Coord() []int { return it.track } + +// Done returns true when the iterator is done iterating. +func (it *AxialIterator) Done() bool { return it.done } diff --git a/iterator_native.go b/iterator_native.go new file mode 100644 index 0000000..470891d --- /dev/null +++ b/iterator_native.go @@ -0,0 +1,1152 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +/* Native Iterables for bool */ + +// nativeDenseVectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// nativeDenseMatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// nativeDenseVectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// nativeDenseMatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// nativeDenseVectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// nativeDenseMatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// nativeDenseVectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// nativeDenseMatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// nativeDenseVectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// nativeDenseMatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// nativeDenseVectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// nativeDenseMatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// nativeDenseVectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// nativeDenseMatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// nativeDenseVectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// nativeDenseMatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// nativeDenseVectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// nativeDenseMatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// nativeDenseVectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// nativeDenseMatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// nativeDenseVectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// nativeDenseMatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// nativeDenseVectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// nativeDenseMatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// nativeDenseVectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// nativeDenseMatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// nativeDenseVectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// nativeDenseMatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// nativeDenseVectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// nativeDenseMatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// nativeDenseVectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func nativeDenseVectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// nativeDenseMatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func nativeDenseMatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// nativeDenseTensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func nativeDenseTensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/iterator_native_test.go b/iterator_native_test.go new file mode 100644 index 0000000..afcd14d --- /dev/null +++ b/iterator_native_test.go @@ -0,0 +1,633 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func Test_nativeDenseVectorB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(6)) + it, err := nativeDenseVectorB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixB(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3)) + it, err := nativeDenseMatrixB(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3B(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(Bool), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3B(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixI64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3I64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Int64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3I64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U8(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint8, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U8(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U16(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint16, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U16(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixU64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3U64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Uint64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3U64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F32(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F32(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixF64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3F64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Float64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3F64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C64(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex64, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C64(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(6)) + it, err := nativeDenseVectorC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 6)), WithShape(2, 3)) + it, err := nativeDenseMatrixC128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3C128(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(WithBacking(Range(Complex128, 0, 24)), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3C128(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} + +func Test_nativeDenseVectorStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(6)) + it, err := nativeDenseVectorStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(6, len(it)) +} + +func Test_nativeDenseMatrixStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3)) + it, err := nativeDenseMatrixStr(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) +} + +func Test_nativeDenseTensor3Str(t *testing.T) { + assert := assert.New(t) + var T *Dense + T = New(Of(String), WithShape(2, 3, 4)) + it, err := nativeDenseTensor3Str(T) + if err != nil { + t.Fatal(err) + } + + assert.Equal(2, len(it)) + assert.Equal(3, len(it[0])) + assert.Equal(4, len(it[0][0])) +} diff --git a/junkyard_test.go b/junkyard_test.go index 428178a..6d4b43a 100644 --- a/junkyard_test.go +++ b/junkyard_test.go @@ -9,7 +9,7 @@ import ( func TestRandom(t *testing.T) { const size = 50 - for _, typ := range numberTypes.set { + for _, typ := range numberTypes { r := Random(typ, size) typR := reflect.TypeOf(r).Elem() diff --git a/known_issues_test.go b/known_issues_test.go index 36d4125..3175ce7 100644 --- a/known_issues_test.go +++ b/known_issues_test.go @@ -5,6 +5,7 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "gorgonia.org/dtype" ) func TestIssue70(t *testing.T) { @@ -43,7 +44,7 @@ func TestIssue72(t *testing.T) { b := identityVal(0, q.t) reuse := New(Of(a.t), WithShape(a.Shape().Clone()...)) correct := a.Clone().(*Dense) - we, willFailEq := willerr(a, numberTypes, unsignedTypes) + we, willFailEq := willerr(a, dtype.Number, dtype.Unsigned) _, ok := q.Engine().(Suber) we = we || !ok //log.Printf("b-a(r) | b:%v, a %v, r %v", b, a.Shape(), reuse.Shape()) diff --git a/native/iterator_native.go b/native/iterator_native.go index d9727fe..1ad0573 100644 --- a/native/iterator_native.go +++ b/native/iterator_native.go @@ -1,1152 +1,332 @@ -// Code generated by genlib2. DO NOT EDIT. +//go:build !purego +// +build !purego package native +// Code generated by genlib2. DO NOT EDIT. + import ( - "reflect" - "unsafe" + _ "unsafe" - "github.com/pkg/errors" - . "gorgonia.org/tensor" + "gorgonia.org/tensor" ) -func checkNativeIterable(t *Dense, dims int, dt Dtype) error { - // checks: - if !t.IsNativelyAccessible() { - return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - } - - if t.Shape().Dims() != dims { - return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) - } - - if t.F() || t.RequiresIterator() { - return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") - } - - if t.Dtype() != dt { - return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) - } - - return nil -} - -/* Native Iterables for bool */ +//go:linkname VectorB gorgonia.org/tensor.nativeDenseVectorB // VectorB converts a *Dense into a []bool // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorB(t *Dense) (retVal []bool, err error) { - if err = checkNativeIterable(t, 1, Bool); err != nil { - return nil, err - } - return t.Bools(), nil -} +func VectorB(t *tensor.Dense) (retVal []bool, err error) + +//go:linkname MatrixB gorgonia.org/tensor.nativeDenseMatrixB // MatrixB converts a *Dense into a [][]bool // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixB(t *Dense) (retVal [][]bool, err error) { - if err = checkNativeIterable(t, 2, Bool); err != nil { - return nil, err - } - - data := t.Bools() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]bool, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]bool, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixB(t *tensor.Dense) (retVal [][]bool, err error) + +//go:linkname Tensor3B gorgonia.org/tensor.nativeDenseTensor3B // Tensor3B converts a *Dense into a [][][]bool. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3B(t *Dense) (retVal [][][]bool, err error) { - if err = checkNativeIterable(t, 3, Bool); err != nil { - return nil, err - } - - data := t.Bools() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]bool, layers) - for i := range retVal { - retVal[i] = make([][]bool, rows) - for j := range retVal[i] { - retVal[i][j] = make([]bool, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int */ +func Tensor3B(t *tensor.Dense) (retVal [][][]bool, err error) + +//go:linkname VectorI gorgonia.org/tensor.nativeDenseVectorI // VectorI converts a *Dense into a []int // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI(t *Dense) (retVal []int, err error) { - if err = checkNativeIterable(t, 1, Int); err != nil { - return nil, err - } - return t.Ints(), nil -} +func VectorI(t *tensor.Dense) (retVal []int, err error) + +//go:linkname MatrixI gorgonia.org/tensor.nativeDenseMatrixI // MatrixI converts a *Dense into a [][]int // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI(t *Dense) (retVal [][]int, err error) { - if err = checkNativeIterable(t, 2, Int); err != nil { - return nil, err - } - - data := t.Ints() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI(t *tensor.Dense) (retVal [][]int, err error) + +//go:linkname Tensor3I gorgonia.org/tensor.nativeDenseTensor3I // Tensor3I converts a *Dense into a [][][]int. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I(t *Dense) (retVal [][][]int, err error) { - if err = checkNativeIterable(t, 3, Int); err != nil { - return nil, err - } - - data := t.Ints() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int, layers) - for i := range retVal { - retVal[i] = make([][]int, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int8 */ +func Tensor3I(t *tensor.Dense) (retVal [][][]int, err error) + +//go:linkname VectorI8 gorgonia.org/tensor.nativeDenseVectorI8 // VectorI8 converts a *Dense into a []int8 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI8(t *Dense) (retVal []int8, err error) { - if err = checkNativeIterable(t, 1, Int8); err != nil { - return nil, err - } - return t.Int8s(), nil -} +func VectorI8(t *tensor.Dense) (retVal []int8, err error) + +//go:linkname MatrixI8 gorgonia.org/tensor.nativeDenseMatrixI8 // MatrixI8 converts a *Dense into a [][]int8 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI8(t *Dense) (retVal [][]int8, err error) { - if err = checkNativeIterable(t, 2, Int8); err != nil { - return nil, err - } - - data := t.Int8s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int8, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI8(t *tensor.Dense) (retVal [][]int8, err error) + +//go:linkname Tensor3I8 gorgonia.org/tensor.nativeDenseTensor3I8 // Tensor3I8 converts a *Dense into a [][][]int8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { - if err = checkNativeIterable(t, 3, Int8); err != nil { - return nil, err - } - - data := t.Int8s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int8, layers) - for i := range retVal { - retVal[i] = make([][]int8, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int8, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int16 */ +func Tensor3I8(t *tensor.Dense) (retVal [][][]int8, err error) + +//go:linkname VectorI16 gorgonia.org/tensor.nativeDenseVectorI16 // VectorI16 converts a *Dense into a []int16 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI16(t *Dense) (retVal []int16, err error) { - if err = checkNativeIterable(t, 1, Int16); err != nil { - return nil, err - } - return t.Int16s(), nil -} +func VectorI16(t *tensor.Dense) (retVal []int16, err error) + +//go:linkname MatrixI16 gorgonia.org/tensor.nativeDenseMatrixI16 // MatrixI16 converts a *Dense into a [][]int16 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI16(t *Dense) (retVal [][]int16, err error) { - if err = checkNativeIterable(t, 2, Int16); err != nil { - return nil, err - } - - data := t.Int16s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int16, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI16(t *tensor.Dense) (retVal [][]int16, err error) + +//go:linkname Tensor3I16 gorgonia.org/tensor.nativeDenseTensor3I16 // Tensor3I16 converts a *Dense into a [][][]int16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { - if err = checkNativeIterable(t, 3, Int16); err != nil { - return nil, err - } - - data := t.Int16s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int16, layers) - for i := range retVal { - retVal[i] = make([][]int16, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int16, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int32 */ +func Tensor3I16(t *tensor.Dense) (retVal [][][]int16, err error) + +//go:linkname VectorI32 gorgonia.org/tensor.nativeDenseVectorI32 // VectorI32 converts a *Dense into a []int32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI32(t *Dense) (retVal []int32, err error) { - if err = checkNativeIterable(t, 1, Int32); err != nil { - return nil, err - } - return t.Int32s(), nil -} +func VectorI32(t *tensor.Dense) (retVal []int32, err error) + +//go:linkname MatrixI32 gorgonia.org/tensor.nativeDenseMatrixI32 // MatrixI32 converts a *Dense into a [][]int32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI32(t *Dense) (retVal [][]int32, err error) { - if err = checkNativeIterable(t, 2, Int32); err != nil { - return nil, err - } - - data := t.Int32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI32(t *tensor.Dense) (retVal [][]int32, err error) + +//go:linkname Tensor3I32 gorgonia.org/tensor.nativeDenseTensor3I32 // Tensor3I32 converts a *Dense into a [][][]int32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { - if err = checkNativeIterable(t, 3, Int32); err != nil { - return nil, err - } - - data := t.Int32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int32, layers) - for i := range retVal { - retVal[i] = make([][]int32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for int64 */ +func Tensor3I32(t *tensor.Dense) (retVal [][][]int32, err error) + +//go:linkname VectorI64 gorgonia.org/tensor.nativeDenseVectorI64 // VectorI64 converts a *Dense into a []int64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorI64(t *Dense) (retVal []int64, err error) { - if err = checkNativeIterable(t, 1, Int64); err != nil { - return nil, err - } - return t.Int64s(), nil -} +func VectorI64(t *tensor.Dense) (retVal []int64, err error) + +//go:linkname MatrixI64 gorgonia.org/tensor.nativeDenseMatrixI64 // MatrixI64 converts a *Dense into a [][]int64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixI64(t *Dense) (retVal [][]int64, err error) { - if err = checkNativeIterable(t, 2, Int64); err != nil { - return nil, err - } - - data := t.Int64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]int64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]int64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixI64(t *tensor.Dense) (retVal [][]int64, err error) + +//go:linkname Tensor3I64 gorgonia.org/tensor.nativeDenseTensor3I64 // Tensor3I64 converts a *Dense into a [][][]int64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { - if err = checkNativeIterable(t, 3, Int64); err != nil { - return nil, err - } - - data := t.Int64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]int64, layers) - for i := range retVal { - retVal[i] = make([][]int64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]int64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint */ +func Tensor3I64(t *tensor.Dense) (retVal [][][]int64, err error) + +//go:linkname VectorU gorgonia.org/tensor.nativeDenseVectorU // VectorU converts a *Dense into a []uint // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU(t *Dense) (retVal []uint, err error) { - if err = checkNativeIterable(t, 1, Uint); err != nil { - return nil, err - } - return t.Uints(), nil -} +func VectorU(t *tensor.Dense) (retVal []uint, err error) + +//go:linkname MatrixU gorgonia.org/tensor.nativeDenseMatrixU // MatrixU converts a *Dense into a [][]uint // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU(t *Dense) (retVal [][]uint, err error) { - if err = checkNativeIterable(t, 2, Uint); err != nil { - return nil, err - } - - data := t.Uints() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU(t *tensor.Dense) (retVal [][]uint, err error) + +//go:linkname Tensor3U gorgonia.org/tensor.nativeDenseTensor3U // Tensor3U converts a *Dense into a [][][]uint. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U(t *Dense) (retVal [][][]uint, err error) { - if err = checkNativeIterable(t, 3, Uint); err != nil { - return nil, err - } - - data := t.Uints() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint, layers) - for i := range retVal { - retVal[i] = make([][]uint, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint8 */ +func Tensor3U(t *tensor.Dense) (retVal [][][]uint, err error) + +//go:linkname VectorU8 gorgonia.org/tensor.nativeDenseVectorU8 // VectorU8 converts a *Dense into a []uint8 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU8(t *Dense) (retVal []uint8, err error) { - if err = checkNativeIterable(t, 1, Uint8); err != nil { - return nil, err - } - return t.Uint8s(), nil -} +func VectorU8(t *tensor.Dense) (retVal []uint8, err error) + +//go:linkname MatrixU8 gorgonia.org/tensor.nativeDenseMatrixU8 // MatrixU8 converts a *Dense into a [][]uint8 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU8(t *Dense) (retVal [][]uint8, err error) { - if err = checkNativeIterable(t, 2, Uint8); err != nil { - return nil, err - } - - data := t.Uint8s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint8, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint8, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU8(t *tensor.Dense) (retVal [][]uint8, err error) + +//go:linkname Tensor3U8 gorgonia.org/tensor.nativeDenseTensor3U8 // Tensor3U8 converts a *Dense into a [][][]uint8. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { - if err = checkNativeIterable(t, 3, Uint8); err != nil { - return nil, err - } - - data := t.Uint8s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint8, layers) - for i := range retVal { - retVal[i] = make([][]uint8, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint8, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint16 */ +func Tensor3U8(t *tensor.Dense) (retVal [][][]uint8, err error) + +//go:linkname VectorU16 gorgonia.org/tensor.nativeDenseVectorU16 // VectorU16 converts a *Dense into a []uint16 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU16(t *Dense) (retVal []uint16, err error) { - if err = checkNativeIterable(t, 1, Uint16); err != nil { - return nil, err - } - return t.Uint16s(), nil -} +func VectorU16(t *tensor.Dense) (retVal []uint16, err error) + +//go:linkname MatrixU16 gorgonia.org/tensor.nativeDenseMatrixU16 // MatrixU16 converts a *Dense into a [][]uint16 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU16(t *Dense) (retVal [][]uint16, err error) { - if err = checkNativeIterable(t, 2, Uint16); err != nil { - return nil, err - } - - data := t.Uint16s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint16, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint16, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU16(t *tensor.Dense) (retVal [][]uint16, err error) + +//go:linkname Tensor3U16 gorgonia.org/tensor.nativeDenseTensor3U16 // Tensor3U16 converts a *Dense into a [][][]uint16. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { - if err = checkNativeIterable(t, 3, Uint16); err != nil { - return nil, err - } - - data := t.Uint16s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint16, layers) - for i := range retVal { - retVal[i] = make([][]uint16, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint16, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint32 */ +func Tensor3U16(t *tensor.Dense) (retVal [][][]uint16, err error) + +//go:linkname VectorU32 gorgonia.org/tensor.nativeDenseVectorU32 // VectorU32 converts a *Dense into a []uint32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU32(t *Dense) (retVal []uint32, err error) { - if err = checkNativeIterable(t, 1, Uint32); err != nil { - return nil, err - } - return t.Uint32s(), nil -} +func VectorU32(t *tensor.Dense) (retVal []uint32, err error) + +//go:linkname MatrixU32 gorgonia.org/tensor.nativeDenseMatrixU32 // MatrixU32 converts a *Dense into a [][]uint32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU32(t *Dense) (retVal [][]uint32, err error) { - if err = checkNativeIterable(t, 2, Uint32); err != nil { - return nil, err - } - - data := t.Uint32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU32(t *tensor.Dense) (retVal [][]uint32, err error) + +//go:linkname Tensor3U32 gorgonia.org/tensor.nativeDenseTensor3U32 // Tensor3U32 converts a *Dense into a [][][]uint32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { - if err = checkNativeIterable(t, 3, Uint32); err != nil { - return nil, err - } - - data := t.Uint32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint32, layers) - for i := range retVal { - retVal[i] = make([][]uint32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for uint64 */ +func Tensor3U32(t *tensor.Dense) (retVal [][][]uint32, err error) + +//go:linkname VectorU64 gorgonia.org/tensor.nativeDenseVectorU64 // VectorU64 converts a *Dense into a []uint64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorU64(t *Dense) (retVal []uint64, err error) { - if err = checkNativeIterable(t, 1, Uint64); err != nil { - return nil, err - } - return t.Uint64s(), nil -} +func VectorU64(t *tensor.Dense) (retVal []uint64, err error) + +//go:linkname MatrixU64 gorgonia.org/tensor.nativeDenseMatrixU64 // MatrixU64 converts a *Dense into a [][]uint64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixU64(t *Dense) (retVal [][]uint64, err error) { - if err = checkNativeIterable(t, 2, Uint64); err != nil { - return nil, err - } - - data := t.Uint64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]uint64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]uint64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixU64(t *tensor.Dense) (retVal [][]uint64, err error) + +//go:linkname Tensor3U64 gorgonia.org/tensor.nativeDenseTensor3U64 // Tensor3U64 converts a *Dense into a [][][]uint64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { - if err = checkNativeIterable(t, 3, Uint64); err != nil { - return nil, err - } - - data := t.Uint64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]uint64, layers) - for i := range retVal { - retVal[i] = make([][]uint64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]uint64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for float32 */ +func Tensor3U64(t *tensor.Dense) (retVal [][][]uint64, err error) + +//go:linkname VectorF32 gorgonia.org/tensor.nativeDenseVectorF32 // VectorF32 converts a *Dense into a []float32 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorF32(t *Dense) (retVal []float32, err error) { - if err = checkNativeIterable(t, 1, Float32); err != nil { - return nil, err - } - return t.Float32s(), nil -} +func VectorF32(t *tensor.Dense) (retVal []float32, err error) + +//go:linkname MatrixF32 gorgonia.org/tensor.nativeDenseMatrixF32 // MatrixF32 converts a *Dense into a [][]float32 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixF32(t *Dense) (retVal [][]float32, err error) { - if err = checkNativeIterable(t, 2, Float32); err != nil { - return nil, err - } - - data := t.Float32s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]float32, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]float32, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixF32(t *tensor.Dense) (retVal [][]float32, err error) + +//go:linkname Tensor3F32 gorgonia.org/tensor.nativeDenseTensor3F32 // Tensor3F32 converts a *Dense into a [][][]float32. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { - if err = checkNativeIterable(t, 3, Float32); err != nil { - return nil, err - } - - data := t.Float32s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]float32, layers) - for i := range retVal { - retVal[i] = make([][]float32, rows) - for j := range retVal[i] { - retVal[i][j] = make([]float32, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for float64 */ +func Tensor3F32(t *tensor.Dense) (retVal [][][]float32, err error) + +//go:linkname VectorF64 gorgonia.org/tensor.nativeDenseVectorF64 // VectorF64 converts a *Dense into a []float64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorF64(t *Dense) (retVal []float64, err error) { - if err = checkNativeIterable(t, 1, Float64); err != nil { - return nil, err - } - return t.Float64s(), nil -} +func VectorF64(t *tensor.Dense) (retVal []float64, err error) + +//go:linkname MatrixF64 gorgonia.org/tensor.nativeDenseMatrixF64 // MatrixF64 converts a *Dense into a [][]float64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixF64(t *Dense) (retVal [][]float64, err error) { - if err = checkNativeIterable(t, 2, Float64); err != nil { - return nil, err - } - - data := t.Float64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]float64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]float64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixF64(t *tensor.Dense) (retVal [][]float64, err error) + +//go:linkname Tensor3F64 gorgonia.org/tensor.nativeDenseTensor3F64 // Tensor3F64 converts a *Dense into a [][][]float64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { - if err = checkNativeIterable(t, 3, Float64); err != nil { - return nil, err - } - - data := t.Float64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]float64, layers) - for i := range retVal { - retVal[i] = make([][]float64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]float64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for complex64 */ +func Tensor3F64(t *tensor.Dense) (retVal [][][]float64, err error) + +//go:linkname VectorC64 gorgonia.org/tensor.nativeDenseVectorC64 // VectorC64 converts a *Dense into a []complex64 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorC64(t *Dense) (retVal []complex64, err error) { - if err = checkNativeIterable(t, 1, Complex64); err != nil { - return nil, err - } - return t.Complex64s(), nil -} +func VectorC64(t *tensor.Dense) (retVal []complex64, err error) + +//go:linkname MatrixC64 gorgonia.org/tensor.nativeDenseMatrixC64 // MatrixC64 converts a *Dense into a [][]complex64 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixC64(t *Dense) (retVal [][]complex64, err error) { - if err = checkNativeIterable(t, 2, Complex64); err != nil { - return nil, err - } - - data := t.Complex64s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]complex64, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]complex64, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixC64(t *tensor.Dense) (retVal [][]complex64, err error) + +//go:linkname Tensor3C64 gorgonia.org/tensor.nativeDenseTensor3C64 // Tensor3C64 converts a *Dense into a [][][]complex64. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { - if err = checkNativeIterable(t, 3, Complex64); err != nil { - return nil, err - } - - data := t.Complex64s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]complex64, layers) - for i := range retVal { - retVal[i] = make([][]complex64, rows) - for j := range retVal[i] { - retVal[i][j] = make([]complex64, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for complex128 */ +func Tensor3C64(t *tensor.Dense) (retVal [][][]complex64, err error) + +//go:linkname VectorC128 gorgonia.org/tensor.nativeDenseVectorC128 // VectorC128 converts a *Dense into a []complex128 // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorC128(t *Dense) (retVal []complex128, err error) { - if err = checkNativeIterable(t, 1, Complex128); err != nil { - return nil, err - } - return t.Complex128s(), nil -} +func VectorC128(t *tensor.Dense) (retVal []complex128, err error) + +//go:linkname MatrixC128 gorgonia.org/tensor.nativeDenseMatrixC128 // MatrixC128 converts a *Dense into a [][]complex128 // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixC128(t *Dense) (retVal [][]complex128, err error) { - if err = checkNativeIterable(t, 2, Complex128); err != nil { - return nil, err - } - - data := t.Complex128s() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]complex128, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]complex128, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixC128(t *tensor.Dense) (retVal [][]complex128, err error) + +//go:linkname Tensor3C128 gorgonia.org/tensor.nativeDenseTensor3C128 // Tensor3C128 converts a *Dense into a [][][]complex128. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { - if err = checkNativeIterable(t, 3, Complex128); err != nil { - return nil, err - } - - data := t.Complex128s() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]complex128, layers) - for i := range retVal { - retVal[i] = make([][]complex128, rows) - for j := range retVal[i] { - retVal[i][j] = make([]complex128, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} - -/* Native Iterables for string */ +func Tensor3C128(t *tensor.Dense) (retVal [][][]complex128, err error) + +//go:linkname VectorStr gorgonia.org/tensor.nativeDenseVectorStr // VectorStr converts a *Dense into a []string // If the *Dense does not represent a vector of the wanted type, it will return // an error. -func VectorStr(t *Dense) (retVal []string, err error) { - if err = checkNativeIterable(t, 1, String); err != nil { - return nil, err - } - return t.Strings(), nil -} +func VectorStr(t *tensor.Dense) (retVal []string, err error) + +//go:linkname MatrixStr gorgonia.org/tensor.nativeDenseMatrixStr // MatrixStr converts a *Dense into a [][]string // If the *Dense does not represent a matrix of the wanted type, it // will return an error. -func MatrixStr(t *Dense) (retVal [][]string, err error) { - if err = checkNativeIterable(t, 2, String); err != nil { - return nil, err - } - - data := t.Strings() - shape := t.Shape() - strides := t.Strides() - - rows := shape[0] - cols := shape[1] - rowStride := strides[0] - retVal = make([][]string, rows) - for i := range retVal { - start := i * rowStride - retVal[i] = make([]string, 0) - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - return -} +func MatrixStr(t *tensor.Dense) (retVal [][]string, err error) + +//go:linkname Tensor3Str gorgonia.org/tensor.nativeDenseTensor3Str // Tensor3Str converts a *Dense into a [][][]string. // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. -func Tensor3Str(t *Dense) (retVal [][][]string, err error) { - if err = checkNativeIterable(t, 3, String); err != nil { - return nil, err - } - - data := t.Strings() - shape := t.Shape() - strides := t.Strides() - - layers := shape[0] - rows := shape[1] - cols := shape[2] - layerStride := strides[0] - rowStride := strides[1] - retVal = make([][][]string, layers) - for i := range retVal { - retVal[i] = make([][]string, rows) - for j := range retVal[i] { - retVal[i][j] = make([]string, 0) - start := i*layerStride + j*rowStride - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) - hdr.Data = uintptr(unsafe.Pointer(&data[start])) - hdr.Cap = cols - hdr.Len = cols - } - } - return -} +func Tensor3Str(t *tensor.Dense) (retVal [][][]string, err error) diff --git a/native/iterator_native_purego.go b/native/iterator_native_purego.go new file mode 100644 index 0000000..aba1b50 --- /dev/null +++ b/native/iterator_native_purego.go @@ -0,0 +1,1133 @@ +//go:build purego +// +build purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +/* Native Iterables for bool */ + +// VectorB converts a *Dense into a []bool +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorB(t *Dense) (retVal []bool, err error) { + if err = checkNativeIterable(t, 1, Bool); err != nil { + return nil, err + } + return t.Bools(), nil +} + +// MatrixB converts a *Dense into a [][]bool +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixB(t *Dense) (retVal [][]bool, err error) { + if err = checkNativeIterable(t, 2, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]bool, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3B converts a *Dense into a [][][]bool. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3B(t *Dense) (retVal [][][]bool, err error) { + if err = checkNativeIterable(t, 3, Bool); err != nil { + return nil, err + } + + data := t.Bools() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]bool, layers) + for i := range retVal { + retVal[i] = make([][]bool, rows) + for j := range retVal[i] { + retVal[i][j] = make([]bool, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int */ + +// VectorI converts a *Dense into a []int +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI(t *Dense) (retVal []int, err error) { + if err = checkNativeIterable(t, 1, Int); err != nil { + return nil, err + } + return t.Ints(), nil +} + +// MatrixI converts a *Dense into a [][]int +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI(t *Dense) (retVal [][]int, err error) { + if err = checkNativeIterable(t, 2, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I converts a *Dense into a [][][]int. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I(t *Dense) (retVal [][][]int, err error) { + if err = checkNativeIterable(t, 3, Int); err != nil { + return nil, err + } + + data := t.Ints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int, layers) + for i := range retVal { + retVal[i] = make([][]int, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int8 */ + +// VectorI8 converts a *Dense into a []int8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI8(t *Dense) (retVal []int8, err error) { + if err = checkNativeIterable(t, 1, Int8); err != nil { + return nil, err + } + return t.Int8s(), nil +} + +// MatrixI8 converts a *Dense into a [][]int8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI8(t *Dense) (retVal [][]int8, err error) { + if err = checkNativeIterable(t, 2, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I8 converts a *Dense into a [][][]int8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I8(t *Dense) (retVal [][][]int8, err error) { + if err = checkNativeIterable(t, 3, Int8); err != nil { + return nil, err + } + + data := t.Int8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int8, layers) + for i := range retVal { + retVal[i] = make([][]int8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int16 */ + +// VectorI16 converts a *Dense into a []int16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI16(t *Dense) (retVal []int16, err error) { + if err = checkNativeIterable(t, 1, Int16); err != nil { + return nil, err + } + return t.Int16s(), nil +} + +// MatrixI16 converts a *Dense into a [][]int16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI16(t *Dense) (retVal [][]int16, err error) { + if err = checkNativeIterable(t, 2, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I16 converts a *Dense into a [][][]int16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I16(t *Dense) (retVal [][][]int16, err error) { + if err = checkNativeIterable(t, 3, Int16); err != nil { + return nil, err + } + + data := t.Int16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int16, layers) + for i := range retVal { + retVal[i] = make([][]int16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int32 */ + +// VectorI32 converts a *Dense into a []int32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI32(t *Dense) (retVal []int32, err error) { + if err = checkNativeIterable(t, 1, Int32); err != nil { + return nil, err + } + return t.Int32s(), nil +} + +// MatrixI32 converts a *Dense into a [][]int32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI32(t *Dense) (retVal [][]int32, err error) { + if err = checkNativeIterable(t, 2, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I32 converts a *Dense into a [][][]int32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I32(t *Dense) (retVal [][][]int32, err error) { + if err = checkNativeIterable(t, 3, Int32); err != nil { + return nil, err + } + + data := t.Int32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int32, layers) + for i := range retVal { + retVal[i] = make([][]int32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for int64 */ + +// VectorI64 converts a *Dense into a []int64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorI64(t *Dense) (retVal []int64, err error) { + if err = checkNativeIterable(t, 1, Int64); err != nil { + return nil, err + } + return t.Int64s(), nil +} + +// MatrixI64 converts a *Dense into a [][]int64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixI64(t *Dense) (retVal [][]int64, err error) { + if err = checkNativeIterable(t, 2, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]int64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3I64 converts a *Dense into a [][][]int64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3I64(t *Dense) (retVal [][][]int64, err error) { + if err = checkNativeIterable(t, 3, Int64); err != nil { + return nil, err + } + + data := t.Int64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]int64, layers) + for i := range retVal { + retVal[i] = make([][]int64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]int64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint */ + +// VectorU converts a *Dense into a []uint +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU(t *Dense) (retVal []uint, err error) { + if err = checkNativeIterable(t, 1, Uint); err != nil { + return nil, err + } + return t.Uints(), nil +} + +// MatrixU converts a *Dense into a [][]uint +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU(t *Dense) (retVal [][]uint, err error) { + if err = checkNativeIterable(t, 2, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U converts a *Dense into a [][][]uint. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U(t *Dense) (retVal [][][]uint, err error) { + if err = checkNativeIterable(t, 3, Uint); err != nil { + return nil, err + } + + data := t.Uints() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint, layers) + for i := range retVal { + retVal[i] = make([][]uint, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint8 */ + +// VectorU8 converts a *Dense into a []uint8 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU8(t *Dense) (retVal []uint8, err error) { + if err = checkNativeIterable(t, 1, Uint8); err != nil { + return nil, err + } + return t.Uint8s(), nil +} + +// MatrixU8 converts a *Dense into a [][]uint8 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU8(t *Dense) (retVal [][]uint8, err error) { + if err = checkNativeIterable(t, 2, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint8, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U8 converts a *Dense into a [][][]uint8. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U8(t *Dense) (retVal [][][]uint8, err error) { + if err = checkNativeIterable(t, 3, Uint8); err != nil { + return nil, err + } + + data := t.Uint8s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint8, layers) + for i := range retVal { + retVal[i] = make([][]uint8, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint8, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint16 */ + +// VectorU16 converts a *Dense into a []uint16 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU16(t *Dense) (retVal []uint16, err error) { + if err = checkNativeIterable(t, 1, Uint16); err != nil { + return nil, err + } + return t.Uint16s(), nil +} + +// MatrixU16 converts a *Dense into a [][]uint16 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU16(t *Dense) (retVal [][]uint16, err error) { + if err = checkNativeIterable(t, 2, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint16, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U16 converts a *Dense into a [][][]uint16. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U16(t *Dense) (retVal [][][]uint16, err error) { + if err = checkNativeIterable(t, 3, Uint16); err != nil { + return nil, err + } + + data := t.Uint16s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint16, layers) + for i := range retVal { + retVal[i] = make([][]uint16, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint16, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint32 */ + +// VectorU32 converts a *Dense into a []uint32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU32(t *Dense) (retVal []uint32, err error) { + if err = checkNativeIterable(t, 1, Uint32); err != nil { + return nil, err + } + return t.Uint32s(), nil +} + +// MatrixU32 converts a *Dense into a [][]uint32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU32(t *Dense) (retVal [][]uint32, err error) { + if err = checkNativeIterable(t, 2, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U32 converts a *Dense into a [][][]uint32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U32(t *Dense) (retVal [][][]uint32, err error) { + if err = checkNativeIterable(t, 3, Uint32); err != nil { + return nil, err + } + + data := t.Uint32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint32, layers) + for i := range retVal { + retVal[i] = make([][]uint32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for uint64 */ + +// VectorU64 converts a *Dense into a []uint64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorU64(t *Dense) (retVal []uint64, err error) { + if err = checkNativeIterable(t, 1, Uint64); err != nil { + return nil, err + } + return t.Uint64s(), nil +} + +// MatrixU64 converts a *Dense into a [][]uint64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixU64(t *Dense) (retVal [][]uint64, err error) { + if err = checkNativeIterable(t, 2, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]uint64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3U64 converts a *Dense into a [][][]uint64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3U64(t *Dense) (retVal [][][]uint64, err error) { + if err = checkNativeIterable(t, 3, Uint64); err != nil { + return nil, err + } + + data := t.Uint64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]uint64, layers) + for i := range retVal { + retVal[i] = make([][]uint64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]uint64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float32 */ + +// VectorF32 converts a *Dense into a []float32 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF32(t *Dense) (retVal []float32, err error) { + if err = checkNativeIterable(t, 1, Float32); err != nil { + return nil, err + } + return t.Float32s(), nil +} + +// MatrixF32 converts a *Dense into a [][]float32 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF32(t *Dense) (retVal [][]float32, err error) { + if err = checkNativeIterable(t, 2, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float32, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F32 converts a *Dense into a [][][]float32. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F32(t *Dense) (retVal [][][]float32, err error) { + if err = checkNativeIterable(t, 3, Float32); err != nil { + return nil, err + } + + data := t.Float32s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float32, layers) + for i := range retVal { + retVal[i] = make([][]float32, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float32, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for float64 */ + +// VectorF64 converts a *Dense into a []float64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorF64(t *Dense) (retVal []float64, err error) { + if err = checkNativeIterable(t, 1, Float64); err != nil { + return nil, err + } + return t.Float64s(), nil +} + +// MatrixF64 converts a *Dense into a [][]float64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixF64(t *Dense) (retVal [][]float64, err error) { + if err = checkNativeIterable(t, 2, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]float64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3F64 converts a *Dense into a [][][]float64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3F64(t *Dense) (retVal [][][]float64, err error) { + if err = checkNativeIterable(t, 3, Float64); err != nil { + return nil, err + } + + data := t.Float64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]float64, layers) + for i := range retVal { + retVal[i] = make([][]float64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]float64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex64 */ + +// VectorC64 converts a *Dense into a []complex64 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC64(t *Dense) (retVal []complex64, err error) { + if err = checkNativeIterable(t, 1, Complex64); err != nil { + return nil, err + } + return t.Complex64s(), nil +} + +// MatrixC64 converts a *Dense into a [][]complex64 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC64(t *Dense) (retVal [][]complex64, err error) { + if err = checkNativeIterable(t, 2, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex64, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C64 converts a *Dense into a [][][]complex64. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C64(t *Dense) (retVal [][][]complex64, err error) { + if err = checkNativeIterable(t, 3, Complex64); err != nil { + return nil, err + } + + data := t.Complex64s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex64, layers) + for i := range retVal { + retVal[i] = make([][]complex64, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex64, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for complex128 */ + +// VectorC128 converts a *Dense into a []complex128 +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorC128(t *Dense) (retVal []complex128, err error) { + if err = checkNativeIterable(t, 1, Complex128); err != nil { + return nil, err + } + return t.Complex128s(), nil +} + +// MatrixC128 converts a *Dense into a [][]complex128 +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixC128(t *Dense) (retVal [][]complex128, err error) { + if err = checkNativeIterable(t, 2, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]complex128, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3C128 converts a *Dense into a [][][]complex128. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3C128(t *Dense) (retVal [][][]complex128, err error) { + if err = checkNativeIterable(t, 3, Complex128); err != nil { + return nil, err + } + + data := t.Complex128s() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]complex128, layers) + for i := range retVal { + retVal[i] = make([][]complex128, rows) + for j := range retVal[i] { + retVal[i][j] = make([]complex128, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} + +/* Native Iterables for string */ + +// VectorStr converts a *Dense into a []string +// If the *Dense does not represent a vector of the wanted type, it will return +// an error. +func VectorStr(t *Dense) (retVal []string, err error) { + if err = checkNativeIterable(t, 1, String); err != nil { + return nil, err + } + return t.Strings(), nil +} + +// MatrixStr converts a *Dense into a [][]string +// If the *Dense does not represent a matrix of the wanted type, it +// will return an error. +func MatrixStr(t *Dense) (retVal [][]string, err error) { + if err = checkNativeIterable(t, 2, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + retVal = make([][]string, rows) + for i := range retVal { + start := i * rowStride + retVal[i] = make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + return +} + +// Tensor3Str converts a *Dense into a [][][]string. +// If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. +func Tensor3Str(t *Dense) (retVal [][][]string, err error) { + if err = checkNativeIterable(t, 3, String); err != nil { + return nil, err + } + + data := t.Strings() + shape := t.Shape() + strides := t.Strides() + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal = make([][][]string, layers) + for i := range retVal { + retVal[i] = make([][]string, rows) + for j := range retVal[i] { + retVal[i][j] = make([]string, 0) + start := i*layerStride + j*rowStride + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) + hdr.Data = uintptr(unsafe.Pointer(&data[start])) + hdr.Cap = cols + hdr.Len = cols + } + } + return +} diff --git a/native/iterator_native_test.go b/native/iterator_native_test.go index 09236a0..2e99966 100644 --- a/native/iterator_native_test.go +++ b/native/iterator_native_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package native +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" diff --git a/native/select_native.go b/native/select_native.go new file mode 100644 index 0000000..b048ae9 --- /dev/null +++ b/native/select_native.go @@ -0,0 +1,92 @@ +//go:build !purego +// +build !purego + +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + _ "unsafe" + + "gorgonia.org/tensor" +) + +//go:linkname SelectB gorgonia.org/tensor.nativeSelectB + +// SelectB creates a slice of bools. See Example of NativeSelectF64. +func SelectB(t *tensor.Dense, axis int) (retVal [][]bool, err error) + +//go:linkname SelectI gorgonia.org/tensor.nativeSelectI + +// SelectI creates a slice of ints. See Example of NativeSelectF64. +func SelectI(t *tensor.Dense, axis int) (retVal [][]int, err error) + +//go:linkname SelectI8 gorgonia.org/tensor.nativeSelectI8 + +// SelectI8 creates a slice of int8s. See Example of NativeSelectF64. +func SelectI8(t *tensor.Dense, axis int) (retVal [][]int8, err error) + +//go:linkname SelectI16 gorgonia.org/tensor.nativeSelectI16 + +// SelectI16 creates a slice of int16s. See Example of NativeSelectF64. +func SelectI16(t *tensor.Dense, axis int) (retVal [][]int16, err error) + +//go:linkname SelectI32 gorgonia.org/tensor.nativeSelectI32 + +// SelectI32 creates a slice of int32s. See Example of NativeSelectF64. +func SelectI32(t *tensor.Dense, axis int) (retVal [][]int32, err error) + +//go:linkname SelectI64 gorgonia.org/tensor.nativeSelectI64 + +// SelectI64 creates a slice of int64s. See Example of NativeSelectF64. +func SelectI64(t *tensor.Dense, axis int) (retVal [][]int64, err error) + +//go:linkname SelectU gorgonia.org/tensor.nativeSelectU + +// SelectU creates a slice of uints. See Example of NativeSelectF64. +func SelectU(t *tensor.Dense, axis int) (retVal [][]uint, err error) + +//go:linkname SelectU8 gorgonia.org/tensor.nativeSelectU8 + +// SelectU8 creates a slice of uint8s. See Example of NativeSelectF64. +func SelectU8(t *tensor.Dense, axis int) (retVal [][]uint8, err error) + +//go:linkname SelectU16 gorgonia.org/tensor.nativeSelectU16 + +// SelectU16 creates a slice of uint16s. See Example of NativeSelectF64. +func SelectU16(t *tensor.Dense, axis int) (retVal [][]uint16, err error) + +//go:linkname SelectU32 gorgonia.org/tensor.nativeSelectU32 + +// SelectU32 creates a slice of uint32s. See Example of NativeSelectF64. +func SelectU32(t *tensor.Dense, axis int) (retVal [][]uint32, err error) + +//go:linkname SelectU64 gorgonia.org/tensor.nativeSelectU64 + +// SelectU64 creates a slice of uint64s. See Example of NativeSelectF64. +func SelectU64(t *tensor.Dense, axis int) (retVal [][]uint64, err error) + +//go:linkname SelectF32 gorgonia.org/tensor.nativeSelectF32 + +// SelectF32 creates a slice of float32s. See Example of NativeSelectF64. +func SelectF32(t *tensor.Dense, axis int) (retVal [][]float32, err error) + +//go:linkname SelectF64 gorgonia.org/tensor.nativeSelectF64 + +// SelectF64 creates a slice of float64s. See Example of NativeSelectF64. +func SelectF64(t *tensor.Dense, axis int) (retVal [][]float64, err error) + +//go:linkname SelectC64 gorgonia.org/tensor.nativeSelectC64 + +// SelectC64 creates a slice of complex64s. See Example of NativeSelectF64. +func SelectC64(t *tensor.Dense, axis int) (retVal [][]complex64, err error) + +//go:linkname SelectC128 gorgonia.org/tensor.nativeSelectC128 + +// SelectC128 creates a slice of complex128s. See Example of NativeSelectF64. +func SelectC128(t *tensor.Dense, axis int) (retVal [][]complex128, err error) + +//go:linkname SelectStr gorgonia.org/tensor.nativeSelectStr + +// SelectStr creates a slice of strings. See Example of NativeSelectF64. +func SelectStr(t *tensor.Dense, axis int) (retVal [][]string, err error) diff --git a/native/iterator_native2.go b/native/select_native_purego.go similarity index 95% rename from native/iterator_native2.go rename to native/select_native_purego.go index 934863d..6285fe0 100644 --- a/native/iterator_native2.go +++ b/native/select_native_purego.go @@ -1,31 +1,17 @@ -// Code generated by genlib2. DO NOT EDIT. +//go:build purego +// +build purego package native +// Code generated by genlib2. DO NOT EDIT. + import ( "reflect" "unsafe" - "github.com/pkg/errors" . "gorgonia.org/tensor" ) -func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { - if !t.IsNativelyAccessible() { - return errors.New("Cannot select on non-natively accessible data") - } - if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { - return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) - } - if t.F() || t.RequiresIterator() { - return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") - } - if t.Dtype() != dt { - return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) - } - return nil -} - /* Native Select for bool */ // SelectB creates a slice of flat data types. See Example of NativeSelectF64. diff --git a/native/iterator_native2_test.go b/native/select_native_test.go similarity index 100% rename from native/iterator_native2_test.go rename to native/select_native_test.go index df56b5e..a6f247f 100644 --- a/native/iterator_native2_test.go +++ b/native/select_native_test.go @@ -1,7 +1,7 @@ -// Code generated by genlib2. DO NOT EDIT. - package native +// Code generated by genlib2. DO NOT EDIT. + import ( "testing" diff --git a/native/utils.go b/native/utils.go new file mode 100644 index 0000000..341388e --- /dev/null +++ b/native/utils.go @@ -0,0 +1,46 @@ +package native + +// Code generated by genlib2. DO NOT EDIT. + +import ( + "github.com/pkg/errors" + "gorgonia.org/dtype" + . "gorgonia.org/tensor" +) + +func checkNativeIterable(t *Dense, dims int, dt dtype.Dtype) error { + // checks: + if !t.IsNativelyAccessible() { + return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") + } + + if t.Shape().Dims() != dims { + return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) + } + + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") + } + + if t.Dtype() != dt { + return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) + } + + return nil +} + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} diff --git a/optimizations_test.go b/optimizations_test.go index 18bb677..9b8afcb 100644 --- a/optimizations_test.go +++ b/optimizations_test.go @@ -1,15 +1,15 @@ -package tensor - -import ( - "testing" -) - -// this file contains tests to make sure certain algorithms/optimizations aren't crazy - -func TestRequiresIterator(t *testing.T) { - T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) - sliced, _ := T.Slice(makeRS(1, 3)) - if sliced.RequiresIterator() { - t.Errorf("Slicing on rows should not require Iterator") - } -} +package tensor + +import ( + "testing" +) + +// this file contains tests to make sure certain algorithms/optimizations aren't crazy + +func TestRequiresIterator(t *testing.T) { + T := New(Of(Int), WithBacking([]int{1, 2, 3, 4})) + sliced, _ := T.Slice(makeRS(1, 3)) + if sliced.RequiresIterator() { + t.Errorf("Slicing on rows should not require Iterator") + } +} diff --git a/perf.go b/perf.go index bc5c3aa..a37c610 100644 --- a/perf.go +++ b/perf.go @@ -4,6 +4,7 @@ import ( "runtime" "sync" + "gorgonia.org/dtype" "gorgonia.org/tensor/internal/storage" ) @@ -89,7 +90,7 @@ func ReturnTensor(t Tensor) { } // array reset - tt.t = Dtype{} + tt.t = dtype.Dtype{} tt.array.Header.Raw = nil // engine and flag reset @@ -238,10 +239,10 @@ func ReturnBools(is []bool) { // var optPool = make(chan *OpOpt, PoolSize) // var optPool = newRingbuffer(PoolSize) var optPool = &sync.Pool{ - New: func() interface{} { return new(OpOpt) }, + New: func() interface{} { return new(opOpt) }, } -func borrowOpOpt() *OpOpt { +func borrowOpOpt() *opOpt { // select { // case fo := <-optPool: // return fo @@ -249,7 +250,7 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) // } - return optPool.Get().(*OpOpt) + return optPool.Get().(*opOpt) // if fo, err := optPool.Get(); err == nil { // return (*OpOpt)(fo) @@ -257,12 +258,13 @@ func borrowOpOpt() *OpOpt { // return new(OpOpt) } -func returnOpOpt(oo *OpOpt) { +func returnOpOpt(oo *opOpt) { oo.reuse = nil oo.incr = nil oo.unsafe = false oo.same = false - oo.t = Dtype{} + oo.t = dtype.Dtype{} + oo.ctx = nil // if len(optPool) < cap(optPool) { // optPool <- oo // } diff --git a/scalar.go b/scalar.go new file mode 100644 index 0000000..ee37ba0 --- /dev/null +++ b/scalar.go @@ -0,0 +1,89 @@ +// +build ignore + +package tensor + +import ( + "fmt" + "io" + "reflect" + "unsafe" + + "gorgonia.org/dtype" + + "github.com/pkg/errors" + "gorgonia.org/tensor/internal/storage" +) + +var _ Tensor = Scalar{} +var _ ScalarRep = Scalar{} +var _ ScalarRep = ScalarDense{} + +// ScalarDense wraps a *Dense to provide a typesafe alternative for a scalar to be represented in a *Dense. +type ScalarDense struct { + *Dense +} + +func (s ScalarDense) IsScalar() bool { return true } + +func (s ScalarDense) ScalarValue() interface{} { return s.Dense.Data() } + +// Scalar is a representation of a scalar value on the CPU. +type Scalar struct{ v interface{} } + +func MakeScalar(v interface{}) Scalar { + if s, ok := v.(Scalar); ok { + return s + } + if s, ok := v.(*Scalar); ok { + return Scalar{s.v} + } + return Scalar{v} +} + +func (s Scalar) Shape() Shape { return ScalarShape() } +func (s Scalar) Strides() []int { return nil } +func (s Scalar) Dtype() dtype.Dtype { return dtype.Dtype{reflect.TypeOf(s.v)} } +func (s Scalar) Dims() int { return 0 } +func (s Scalar) Size() int { return 0 } // TODO +func (s Scalar) DataSize() int { return 0 } +func (s Scalar) RequiresIterator() bool { return false } +func (s Scalar) Iterator() Iterator { return nil } +func (s Scalar) DataOrder() DataOrder { return 0 } // TODO + +func (s Scalar) Slice(...Slice) (View, error) { return nil, errors.New("Cannot slice a scalar") } +func (s Scalar) At(at ...int) (interface{}, error) { return nil, errors.New("Get a value of a scalar") } +func (s Scalar) SetAt(_ interface{}, _ ...int) error { return errors.New("Cannot set value of scalar") } +func (s Scalar) Reshape(_ ...int) error { return errors.New("Cannot reshape a scalar") } +func (s Scalar) T(_ ...int) error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) UT() {} +func (s Scalar) Transpose() error { return errors.New("Cannot transpose a scalar") } +func (s Scalar) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nyierr(typeNYI, s) } + +func (s Scalar) Zero() {} //TODO +func (s Scalar) Memset(interface{}) error { return errors.New("Cannot Memset") } +func (s Scalar) Data() interface{} { return s.v } +func (s Scalar) Eq(other interface{}) bool { return s == other } +func (s Scalar) Clone() interface{} { return s } + +func (s Scalar) IsScalar() bool { return true } +func (s Scalar) ScalarValue() interface{} { return s.v } + +func (s Scalar) Engine() Engine { return nil } +func (s Scalar) MemSize() uintptr { return 0 } +func (s Scalar) Uintptr() uintptr { return 0 } +func (s Scalar) Pointer() unsafe.Pointer { return nil } +func (s Scalar) IsNativelyAccessible() bool { return true } +func (s Scalar) IsManuallyManaged() bool { return false } + +func (s Scalar) Format(t fmt.State, c rune) {} // TODO +func (s Scalar) String() string { return fmt.Sprintf("%v", s) } + +func (s Scalar) WriteNpy(io.Writer) error { return nyierr(typeNYI, s) } +func (s Scalar) ReadNpy(io.Reader) error { return nyierr(typeNYI, s) } +func (s Scalar) GobEncode() ([]byte, error) { return nil, nyierr(typeNYI, s) } +func (s Scalar) GobDecode([]byte) error { return nyierr(typeNYI, s) } + +func (s Scalar) standardEngine() StandardEngine { return StdEng{} } +func (s Scalar) hdr() *storage.Header { return nil } +func (s Scalar) arr() array { return array{} } +func (s Scalar) arrPtr() *array { return nil } diff --git a/select_native.go b/select_native.go new file mode 100644 index 0000000..d3cf1f2 --- /dev/null +++ b/select_native.go @@ -0,0 +1,635 @@ +package tensor + +import ( + "reflect" + "unsafe" + + "github.com/pkg/errors" + "gorgonia.org/dtype" +) + +// Code generated by genlib2. DO NOT EDIT. + +func checkNativeSelectable(t *Dense, axis int, dt dtype.Dtype) error { + if !t.IsNativelyAccessible() { + return errors.New("Cannot select on non-natively accessible data") + } + if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { + return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) + } + if t.F() || t.RequiresIterator() { + return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") + } + if t.Dtype() != dt { + return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) + } + return nil +} + +/* Native Select for bool */ + +// nativeSelectB creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectB(t *Dense, axis int) (retVal [][]bool, err error) { + if err := checkNativeSelectable(t, axis, Bool); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]bool, 1) + retVal[0] = t.Bools() + case 2: + if axis == 0 { + return nativeDenseMatrixB(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Bools() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]bool, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]bool, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int */ + +// nativeSelectI creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI(t *Dense, axis int) (retVal [][]int, err error) { + if err := checkNativeSelectable(t, axis, Int); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int, 1) + retVal[0] = t.Ints() + case 2: + if axis == 0 { + return nativeDenseMatrixI(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Ints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int8 */ + +// nativeSelectI8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI8(t *Dense, axis int) (retVal [][]int8, err error) { + if err := checkNativeSelectable(t, axis, Int8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int8, 1) + retVal[0] = t.Int8s() + case 2: + if axis == 0 { + return nativeDenseMatrixI8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int16 */ + +// nativeSelectI16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI16(t *Dense, axis int) (retVal [][]int16, err error) { + if err := checkNativeSelectable(t, axis, Int16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int16, 1) + retVal[0] = t.Int16s() + case 2: + if axis == 0 { + return nativeDenseMatrixI16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int32 */ + +// nativeSelectI32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI32(t *Dense, axis int) (retVal [][]int32, err error) { + if err := checkNativeSelectable(t, axis, Int32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int32, 1) + retVal[0] = t.Int32s() + case 2: + if axis == 0 { + return nativeDenseMatrixI32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for int64 */ + +// nativeSelectI64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectI64(t *Dense, axis int) (retVal [][]int64, err error) { + if err := checkNativeSelectable(t, axis, Int64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]int64, 1) + retVal[0] = t.Int64s() + case 2: + if axis == 0 { + return nativeDenseMatrixI64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Int64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]int64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]int64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint */ + +// nativeSelectU creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU(t *Dense, axis int) (retVal [][]uint, err error) { + if err := checkNativeSelectable(t, axis, Uint); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint, 1) + retVal[0] = t.Uints() + case 2: + if axis == 0 { + return nativeDenseMatrixU(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uints() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint8 */ + +// nativeSelectU8 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU8(t *Dense, axis int) (retVal [][]uint8, err error) { + if err := checkNativeSelectable(t, axis, Uint8); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint8, 1) + retVal[0] = t.Uint8s() + case 2: + if axis == 0 { + return nativeDenseMatrixU8(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint8s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint8, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint8, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint16 */ + +// nativeSelectU16 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU16(t *Dense, axis int) (retVal [][]uint16, err error) { + if err := checkNativeSelectable(t, axis, Uint16); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint16, 1) + retVal[0] = t.Uint16s() + case 2: + if axis == 0 { + return nativeDenseMatrixU16(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint16s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint16, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint16, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint32 */ + +// nativeSelectU32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU32(t *Dense, axis int) (retVal [][]uint32, err error) { + if err := checkNativeSelectable(t, axis, Uint32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint32, 1) + retVal[0] = t.Uint32s() + case 2: + if axis == 0 { + return nativeDenseMatrixU32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for uint64 */ + +// nativeSelectU64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectU64(t *Dense, axis int) (retVal [][]uint64, err error) { + if err := checkNativeSelectable(t, axis, Uint64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]uint64, 1) + retVal[0] = t.Uint64s() + case 2: + if axis == 0 { + return nativeDenseMatrixU64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Uint64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]uint64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]uint64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float32 */ + +// nativeSelectF32 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF32(t *Dense, axis int) (retVal [][]float32, err error) { + if err := checkNativeSelectable(t, axis, Float32); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float32, 1) + retVal[0] = t.Float32s() + case 2: + if axis == 0 { + return nativeDenseMatrixF32(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float32s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float32, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float32, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for float64 */ + +// nativeSelectF64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectF64(t *Dense, axis int) (retVal [][]float64, err error) { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]float64, 1) + retVal[0] = t.Float64s() + case 2: + if axis == 0 { + return nativeDenseMatrixF64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Float64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]float64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]float64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex64 */ + +// nativeSelectC64 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC64(t *Dense, axis int) (retVal [][]complex64, err error) { + if err := checkNativeSelectable(t, axis, Complex64); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex64, 1) + retVal[0] = t.Complex64s() + case 2: + if axis == 0 { + return nativeDenseMatrixC64(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex64s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex64, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex64, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for complex128 */ + +// nativeSelectC128 creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectC128(t *Dense, axis int) (retVal [][]complex128, err error) { + if err := checkNativeSelectable(t, axis, Complex128); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]complex128, 1) + retVal[0] = t.Complex128s() + case 2: + if axis == 0 { + return nativeDenseMatrixC128(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Complex128s() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]complex128, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]complex128, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} + +/* Native Select for string */ + +// nativeSelectStr creates a slice of flat data types. See Example of NativeSelectF64. +func nativeSelectStr(t *Dense, axis int) (retVal [][]string, err error) { + if err := checkNativeSelectable(t, axis, String); err != nil { + return nil, err + } + + switch t.Shape().Dims() { + case 0, 1: + retVal = make([][]string, 1) + retVal[0] = t.Strings() + case 2: + if axis == 0 { + return nativeDenseMatrixStr(t) + } + fallthrough + default: + // size := t.Shape()[axis] + data := t.Strings() + stride := t.Strides()[axis] + upper := ProdInts(t.Shape()[:axis+1]) + retVal = make([][]string, 0, upper) + for i, r := 0, 0; r < upper; i += stride { + s := make([]string, 0) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + retVal = append(retVal, s) + r++ + } + return retVal, nil + + } + return +} diff --git a/select_native_batched.go b/select_native_batched.go new file mode 100644 index 0000000..c05bd76 --- /dev/null +++ b/select_native_batched.go @@ -0,0 +1,163 @@ +package tensor + +import ( + "reflect" + "runtime" + "unsafe" +) + +type BatchedNativeSelectF64 struct { + t *Dense + it [][]float64 // FUTURE: this can be made into generic in the future + + // state + + upper int // the outer dimension after being "reshaped" + limit int // limit as to how many rows the `it` can store + stride int // stride + r int // current row +} + +func BatchSelectF64(t *Dense, axis int, limit int) *BatchedNativeSelectF64 { + if err := checkNativeSelectable(t, axis, Float64); err != nil { + panic(err) + } + + if limit <= 0 { + limit = runtime.NumCPU() // default + } + upper := ProdInts(t.Shape()[:axis+1]) + if limit > upper { + limit = upper + // `it` should come from nativeSelectF64 + } + stride := t.Strides()[axis] + data := t.Float64s() + + it := make([][]float64, 0, limit) + var i, r int + for i, r = 0, 0; r < limit; i += stride { + // this block of code is basically + // it = append(it, data[i:i+stride]) + // TODO: benchmark + it = append(it, make([]float64, 0)) + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it[len(it)-1])) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = stride + hdr.Cap = stride + r++ + } + + return &BatchedNativeSelectF64{ + t: t, + it: it, + upper: upper, + limit: limit, + stride: stride, + r: r, + } +} + +func (it *BatchedNativeSelectF64) Start() (curBatch [][]float64, hasRemainingRows bool) { + if it.r != it.limit || it.IsTruncated() { + // then it's been moved, so we reset + it.Reset() + } + curBatch = it.it + hasRemainingRows = it.upper > it.r + return +} + +// Next moves the next batch into the native iterator. +func (it *BatchedNativeSelectF64) Next() (curBatch [][]float64, hasRemaingRows bool) { + var ( + i int // data ptr + r int // relative row / row counter for this batch + s int // absolute row + ) + if it.r == it.upper { + return it.it, false + } + data := it.t.Float64s() + + // this loop statement looks scary. But it isn't. Let me break it down: + // Initialization: + // i := it.r*it.stride // the data pointer is the row number * the stride of the matrix. + // r := 0 // loop counter. We're gonna iterate `it.limit` times. + // s := it.r // the current row number of the matrix. + // Condition (continue if the following are true): + // r < it.limit // we only want to iterate at most `it.limit` times. + // s < it.upper // we want to make sure we don't iterate more rows than there are rows in the matrix. + // Next: + // i = i + it.stride // we're ready to go to the next row. + // r = r+1 // we increment the row counter. + // s = s+1 // we increment the absolute row number. + // + // Could this be written in a less concise way? Sure. But then there'd be a lot more places to keep track of things. + for i, r, s = it.r*it.stride, 0, it.r; r < it.limit && s < it.upper; i, r, s = i+it.stride, r+1, s+1 { + // the block of code below is basically: + // it.it[r] = data[i:i+stride] + // r++ + // For some reason when this is done, Go actually does a lot more allocations. + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&it.it[r])) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + } + it.r = s + + if it.r == it.upper && r < it.limit { + // truncate it.it because iterated rows is less than the limit. + // This implies that there are some extra rows. + it.it = it.it[:r] + } + + return it.it, true +} + +func (it *BatchedNativeSelectF64) Native() [][]float64 { return it.it } + +func (it *BatchedNativeSelectF64) Reset() { + it.it = it.it[:it.limit:it.limit] + + data := it.t.Float64s() + var i, r int + for i, r = 0, 0; r < it.limit; i += it.stride { + sl := it.it[r] + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sl)) + hdr.Data = uintptr(unsafe.Pointer(&data[i])) + hdr.Len = it.stride + hdr.Cap = it.stride + it.it[r] = sl + r++ + } + it.r = r +} + +func (it *BatchedNativeSelectF64) IsTruncated() bool { return len(it.it) != it.limit } + +type IterSelect struct { + r int + upper int + stride int + total int +} + +func NewIterSelect(t *Dense, axis int) *IterSelect { + upper := ProdInts(t.Shape()[:axis+1]) + stride := t.Strides()[axis] + total := t.DataSize() + return &IterSelect{upper: upper, stride: stride, total: total} +} + +func (it *IterSelect) Start() (start, end int, hasRem bool) { + if it.r > it.stride { + it.Reset() + } + return it.r, it.stride, it.r*it.stride+it.stride < it.total +} + +func (it *IterSelect) Next() (start, end int, hasRem bool) { + it.r += it.stride + return it.r, it.r + it.stride, it.r+it.stride <= it.total +} + +func (it *IterSelect) Reset() { it.r = 0 } diff --git a/select_native_test.go b/select_native_test.go new file mode 100644 index 0000000..02291b5 --- /dev/null +++ b/select_native_test.go @@ -0,0 +1,841 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Code generated by genlib2. DO NOT EDIT. + +func TestnativeSelectB(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]bool + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectB(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Bool), WithShape(2, 3)) + if x, err = nativeSelectB(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(false)) + if x, err = nativeSelectB(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectB(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int), WithShape(2, 3)) + if x, err = nativeSelectI(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int(0))) + if x, err = nativeSelectI(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int8 + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int8), WithShape(2, 3)) + if x, err = nativeSelectI8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int8(0))) + if x, err = nativeSelectI8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int16 + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int16), WithShape(2, 3)) + if x, err = nativeSelectI16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int16(0))) + if x, err = nativeSelectI16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int32 + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int32), WithShape(2, 3)) + if x, err = nativeSelectI32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int32(0))) + if x, err = nativeSelectI32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectI64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]int64 + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectI64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Int64), WithShape(2, 3)) + if x, err = nativeSelectI64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(int64(0))) + if x, err = nativeSelectI64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectI64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint), WithShape(2, 3)) + if x, err = nativeSelectU(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint(0))) + if x, err = nativeSelectU(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU8(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint8 + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU8(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint8), WithShape(2, 3)) + if x, err = nativeSelectU8(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint8(0))) + if x, err = nativeSelectU8(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU8(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU16(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint16 + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU16(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint16), WithShape(2, 3)) + if x, err = nativeSelectU16(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint16(0))) + if x, err = nativeSelectU16(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU16(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint32 + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint32), WithShape(2, 3)) + if x, err = nativeSelectU32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint32(0))) + if x, err = nativeSelectU32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectU64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]uint64 + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectU64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Uint64), WithShape(2, 3)) + if x, err = nativeSelectU64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(uint64(0))) + if x, err = nativeSelectU64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectU64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF32(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float32 + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF32(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float32), WithShape(2, 3)) + if x, err = nativeSelectF32(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float32(0))) + if x, err = nativeSelectF32(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF32(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectF64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]float64 + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectF64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Float64), WithShape(2, 3)) + if x, err = nativeSelectF64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(float64(0))) + if x, err = nativeSelectF64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectF64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC64(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex64 + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC64(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex64), WithShape(2, 3)) + if x, err = nativeSelectC64(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex64(0))) + if x, err = nativeSelectC64(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC64(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectC128(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]complex128 + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectC128(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(Complex128), WithShape(2, 3)) + if x, err = nativeSelectC128(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar(complex128(0))) + if x, err = nativeSelectC128(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectC128(T, 10); err == nil { + t.Fatal("Expected errors") + } +} + +func TestnativeSelectStr(t *testing.T) { + assert := assert.New(t) + var T *Dense + var err error + var x [][]string + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(20, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(60, len(x[0])) + + T = New(Of(String), WithShape(2, 3, 4, 5)) + if x, err = nativeSelectStr(T, 3); err != nil { + t.Fatal(err) + } + assert.Equal(120, len(x)) + assert.Equal(1, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(2, len(x)) + assert.Equal(3, len(x[0])) + + T = New(Of(String), WithShape(2, 3)) + if x, err = nativeSelectStr(T, 1); err != nil { + t.Fatal(err) + } + assert.Equal(6, len(x)) + assert.Equal(1, len(x[0])) + + T = New(FromScalar("")) + if x, err = nativeSelectStr(T, 0); err != nil { + t.Fatal(err) + } + assert.Equal(1, len(x)) + assert.Equal(1, len(x[0])) + + if _, err = nativeSelectStr(T, 10); err == nil { + t.Fatal("Expected errors") + } +} diff --git a/shape.go b/shape.go index c1347b4..f8d5d0a 100644 --- a/shape.go +++ b/shape.go @@ -1,9 +1,7 @@ package tensor import ( - "fmt" - - "github.com/pkg/errors" + "gorgonia.org/shapes" ) var scalarShape = Shape{} @@ -11,21 +9,11 @@ var scalarShape = Shape{} // ScalarShape represents a scalar. It has no dimensions, no sizes func ScalarShape() Shape { return scalarShape } -// Shape represents the dimensions of a Tensor. A (2,3) matrix has a shape of (2,3) - 2 rows, 3 columns. -// Likewise, a shape of (2,3,4) means a Tensor has 3 dimensions: 2 layers, 3 rows, 4 columns. -// -// Vectors are of particular note. This package defines a shape of (x, 1) as a column vector and -// a (1, x) as a row vector. Row vectors and column vectors are matrices as well. It is important to note that -// row and column vectors and vanilla vectors are comparable under some circumstances -type Shape []int - -// TotalSize returns the number of elements expected in a Tensor of a certain shape -func (s Shape) TotalSize() int { - return ProdInts([]int(s)) -} +// Shape represents a Shape. See the package shapes +type Shape = shapes.Shape // CalcStrides calculates the default strides for a shape -func (s Shape) CalcStrides() []int { +func CalcStrides(s Shape) []int { if s.IsScalar() { return nil } @@ -51,7 +39,7 @@ func (s Shape) CalcStrides() []int { // CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride -func (s Shape) CalcStridesWithMask(mask []bool) []int { +func CalcStridesWithMask(s Shape, mask []bool) []int { if s.IsScalarEquiv() { return nil } @@ -86,7 +74,7 @@ func (s Shape) CalcStridesWithMask(mask []bool) []int { } // CalcStridesColMajor is like CalcStrides, but assumes a col major layout -func (s Shape) CalcStridesColMajor() []int { +func CalcStridesColMajor(s Shape) []int { if s.IsScalarEquiv() { return nil } @@ -110,280 +98,37 @@ func (s Shape) CalcStridesColMajor() []int { return retVal } -// Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors. +// asMat returns a matrix shape from the given shape and axis. The given axis is which dim it will stop in. // -// If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size; -// if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size -func (s Shape) Eq(other Shape) bool { - if s.IsScalar() && other.IsScalar() { - return true - } - - if s.IsVector() && other.IsVector() { - switch { - case len(s) == 2 && len(other) == 1: - if (s.IsColVec() && s[0] == other[0]) || (s.IsRowVec() && s[1] == other[0]) { - return true - } - return false - case len(s) == 1 && len(other) == 2: - if (other.IsColVec() && other[0] == s[0]) || (other.IsRowVec() && other[1] == s[0]) { - return true - } - return false - } - } - - if len(s) != len(other) { - return false - } - - for i, v := range s { - if other[i] != v { - return false - } - } - return true -} - -// Clone clones a shape. -func (s Shape) Clone() Shape { - retVal := BorrowInts(len(s)) - copy(retVal, s) - return retVal -} - -// IsScalar returns true if the access pattern indicates it's a scalar value -func (s Shape) IsScalar() bool { - return len(s) == 0 -} - -// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value -func (s Shape) IsScalarEquiv() bool { - if len(s) == 0 { - return true - } - isEquiv := true - for i := range s { - if s[i] != 1 { - return false - } - } - return isEquiv -} - -// IsVector returns whether the access pattern falls into one of three possible definitions of vectors: -// vanilla vector (not a row or a col) -// column vector -// row vector -func (s Shape) IsVector() bool { return s.IsColVec() || s.IsRowVec() || (len(s) == 1) } - -// IsColVec returns true when the access pattern has the shape (x, 1) -func (s Shape) IsColVec() bool { return len(s) == 2 && (s[1] == 1 && s[0] > 1) } - -// IsRowVec returns true when the access pattern has the shape (1, x) -func (s Shape) IsRowVec() bool { return len(s) == 2 && (s[0] == 1 && s[1] > 1) } - -// IsVectorLike returns true when the shape looks like a vector -// e.g. a number that is surrounded by 1s: -// (1, 1, ... 1, 10, 1, 1... 1) -func (s Shape) IsVectorLike() bool { - var nonOnes int - for _, i := range s { - if i != 1 { - nonOnes++ - } - } - return nonOnes == 1 || nonOnes == 0 // if there is only one non-one then it's a vector or a scalarlike. -} - -// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices -func (s Shape) IsMatrix() bool { return len(s) == 2 } - -// Dims returns the number of dimensions in the shape -func (s Shape) Dims() int { return len(s) } - -// DimSize returns the size of the dimension wanted. -// -// This method implemnents the DimSizer interface in Gorgonia. -func (s Shape) DimSize(d int) (size int, err error) { - if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { - err = errors.Errorf(dimMismatch, len(s), d) - return - } - - switch { - case s.IsScalar(): - return 0, nil - default: - return s[d], nil - } -} - -// S gives the new shape after a shape has been sliced. It's repeated from the AP S() method mainly because there are other functions in Gorgonia that uses only shape -func (s Shape) S(slices ...Slice) (retVal Shape, err error) { - opDims := len(s) - if len(slices) > opDims { - err = errors.Errorf(dimMismatch, opDims, len(slices)) - return - } - - retVal = s.Clone() - - for d, size := range s { - var sl Slice // default is a nil Slice - if d <= len(slices)-1 { - sl = slices[d] - } - - var start, end, step int - if start, end, step, err = SliceDetails(sl, size); err != nil { - return - } - - if step > 0 { - retVal[d] = (end - start) / step - - //fix - if retVal[d] <= 0 { - retVal[d] = 1 - } - } else { - retVal[d] = (end - start) - } - - } - - // drop any dimension with size 1, except the last dimension - offset := 0 - dims := s.Dims() - for d := 0; d < dims; d++ { - if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { - retVal = append(retVal[:d], retVal[d+1:]...) - d-- - dims-- - offset++ - } - } - - if retVal.IsScalar() { - ReturnInts(retVal) - return ScalarShape(), nil - } - - return -} - -// Repeat returns the expected new shape given the repetition parameters. -func (s Shape) Repeat(axis int, repeats ...int) (newShape Shape, finalRepeats []int, size int, err error) { +// asMat((5), 0, true) = (1, 5) +// asMat((5), 1, true) = (5, 1) +// asMat((3,4,5), 0, true) = (1, 60) +// asMat((3,4,5), 1, true) = (3, 20) +// asMat((3,4,5), 2, true) = (12, 5) +// asMat((3,4,5), 0, false) = (1, 20) +// asMat((3,4,5), 1, false) = (3, 5) +// asMat((3,4,5), 2, false) = (12, 1) +func asMat(a Shape, axis int, inclusive bool) (retVal Shape) { + // no need to do a check because asMat will only ever be used by internal functions. + + retVal = Shape(BorrowInts(2)) switch { - case axis == AllAxes: - size = s.TotalSize() - newShape = Shape{size} - axis = 0 - case s.IsScalar(): - size = 1 - // special case for row vecs - if axis == 1 { - newShape = Shape{1, 0} - } else { - // otherwise it will be repeated into a vanilla vector - newShape = Shape{0} - } - case s.IsVector() && !s.IsRowVec() && !s.IsColVec() && axis == 1: - size = 1 - newShape = s.Clone() - newShape = append(newShape, 1) - default: - if axis >= len(s) { - // error - err = errors.Errorf(invalidAxis, axis, s.Dims()) - return - } - size = s[axis] - newShape = s.Clone() - } - - // special case to allow generic repeats - if len(repeats) == 1 { - rep := repeats[0] - repeats = make([]int, size) - for i := range repeats { - repeats[i] = rep - } - } - reps := len(repeats) - if reps != size { - err = errors.Errorf(broadcastError, size, reps) - return - } - - newSize := SumInts(repeats) - newShape[axis] = newSize - finalRepeats = repeats - return -} - -// Concat returns the expected new shape given the concatenation parameters -func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { - dims := s.Dims() - - // check that all the concatenates have the same dimensions - for _, shp := range ss { - if shp.Dims() != dims { - err = errors.Errorf(dimMismatch, dims, shp.Dims()) - return - } - } - - // special case - if axis == AllAxes { - axis = 0 - } - - // nope... no negative indexing here. - if axis < 0 { - err = errors.Errorf(invalidAxis, axis, len(s)) + case a.Dims() == 1 && axis == 0: + retVal[0] = 1 + retVal[1] = a[0] return - } - - if axis >= dims { - err = errors.Errorf(invalidAxis, axis, len(s)) + case a.Dims() == 1 && axis == 1: + retVal[0] = a[0] + retVal[1] = 1 return } - - newShape = Shape(BorrowInts(dims)) - copy(newShape, s) - - for _, shp := range ss { - for d := 0; d < dims; d++ { - if d == axis { - newShape[d] += shp[d] - } else { - // validate that the rest of the dimensions match up - if newShape[d] != shp[d] { - err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d) - return - } - } - } + // outer + retVal[0] = ProdInts(a[:axis]) + aplus := axis + if !inclusive { + aplus++ } + // inner + retVal[1] = ProdInts(a[aplus:]) return } - -// Format implements fmt.Formatter, and formats a shape nicely -func (s Shape) Format(st fmt.State, r rune) { - switch r { - case 'v', 's': - st.Write([]byte("(")) - for i, v := range s { - fmt.Fprintf(st, "%d", v) - if i < len(s)-1 { - st.Write([]byte(", ")) - } - } - st.Write([]byte(")")) - default: - fmt.Fprintf(st, "%v", []int(s)) - } -} diff --git a/shape_test.go b/shape_test.go index 9cbc370..9433ba9 100644 --- a/shape_test.go +++ b/shape_test.go @@ -1,323 +1,47 @@ package tensor import ( - "fmt" "testing" "github.com/stretchr/testify/assert" ) -func TestShapeBasics(t *testing.T) { - var s Shape - var ds int - var err error - s = Shape{1, 2} - - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - if ds != 1 { - t.Error("Expected DimSize(0) to be 1") - } - - if ds, err = s.DimSize(2); err == nil { - t.Error("Expected a DimensionMismatch error") - } - - s = ScalarShape() - if ds, err = s.DimSize(0); err != nil { - t.Error(err) - } - - if ds != 0 { - t.Error("Expected DimSize(0) of a scalar to be 0") - } - - // format for completeness sake - s = Shape{2, 1} - if fmt.Sprintf("%d", s) != "[2 1]" { - t.Error("Shape.Format() error") - } -} - -func TestShapeIsX(t *testing.T) { - assert := assert.New(t) - var s Shape - - // scalar shape - s = Shape{} - assert.True(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // vectors - - // scalar-equiv vector - s = Shape{1} - assert.False(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // vanila vector - s = Shape{2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // col vec - s = Shape{2, 1} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.True(s.IsColVec()) - assert.False(s.IsRowVec()) - - // row vec - s = Shape{1, 2} - assert.False(s.IsScalar()) - assert.True(s.IsVector()) - assert.True(s.IsVectorLike()) - assert.False(s.IsColVec()) - assert.True(s.IsRowVec()) - - // matrix and up - s = Shape{2, 2} - assert.False(s.IsScalar()) - assert.False(s.IsVector()) - assert.False(s.IsColVec()) - assert.False(s.IsRowVec()) - - // scalar equiv matrix - s = Shape{1, 1} - assert.False(s.IsScalar()) - assert.True(s.IsScalarEquiv()) - assert.True(s.IsVectorLike()) - assert.False(s.IsVector()) -} - func TestShapeCalcStride(t *testing.T) { assert := assert.New(t) var s Shape // scalar shape s = Shape{} - assert.Nil(s.CalcStrides()) + assert.Nil(CalcStrides(s)) // vector shape s = Shape{1} - assert.Equal([]int{1}, s.CalcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) s = Shape{2, 1} - assert.Equal([]int{1, 1}, s.CalcStrides()) + assert.Equal([]int{1, 1}, CalcStrides(s)) s = Shape{1, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{2} - assert.Equal([]int{1}, s.CalcStrides()) + assert.Equal([]int{1}, CalcStrides(s)) // matrix strides s = Shape{2, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) s = Shape{5, 2} - assert.Equal([]int{2, 1}, s.CalcStrides()) + assert.Equal([]int{2, 1}, CalcStrides(s)) // 3D strides s = Shape{2, 3, 4} - assert.Equal([]int{12, 4, 1}, s.CalcStrides()) + assert.Equal([]int{12, 4, 1}, CalcStrides(s)) // stupid shape s = Shape{-2, 1, 2} fail := func() { - s.CalcStrides() + CalcStrides(s) } assert.Panics(fail) } - -func TestShapeEquality(t *testing.T) { - assert := assert.New(t) - var s1, s2 Shape - - // scalar - s1 = Shape{} - s2 = Shape{} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - // scalars and scalar equiv are not the same! - s1 = Shape{1} - s2 = Shape{} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // vector - s1 = Shape{3} - s2 = Shape{5} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2, 1} - s2 = Shape{2, 1} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{1, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - s1 = Shape{2} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{2, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // matrix - s1 = Shape{2, 3} - assert.True(s1.Eq(s2)) - assert.True(s2.Eq(s1)) - - s2 = Shape{3, 2} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) - - // just for that green coloured code - s1 = Shape{2} - s2 = Shape{1, 3} - assert.False(s1.Eq(s2)) - assert.False(s2.Eq(s1)) -} - -var shapeSliceTests = []struct { - name string - s Shape - sli []Slice - - expected Shape - err bool -}{ - {"slicing a scalar shape", ScalarShape(), nil, ScalarShape(), false}, - {"slicing a scalar shape", ScalarShape(), []Slice{rs{0, 0, 0}}, nil, true}, - {"vec[0]", Shape{2}, []Slice{rs{0, 1, 0}}, ScalarShape(), false}, - {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, - {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, - {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, - {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false}, - {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false}, - {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false}, - {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false}, -} - -func TestShape_Slice(t *testing.T) { - for i, ssts := range shapeSliceTests { - newShape, err := ssts.s.S(ssts.sli...) - if checkErr(t, ssts.err, err, "Shape slice", i) { - continue - } - - if !ssts.expected.Eq(newShape) { - t.Errorf("Test %q: Expected shape %v. Got %v instead", ssts.name, ssts.expected, newShape) - } - } -} - -var shapeRepeatTests = []struct { - name string - s Shape - repeats []int - axis int - - expected Shape - expectedRepeats []int - expectedSize int - err bool -}{ - {"scalar repeat on axis 0", ScalarShape(), []int{3}, 0, Shape{3}, []int{3}, 1, false}, - {"scalar repeat on axis 1", ScalarShape(), []int{3}, 1, Shape{1, 3}, []int{3}, 1, false}, - {"vector repeat on axis 0", Shape{2}, []int{3}, 0, Shape{6}, []int{3, 3}, 2, false}, - {"vector repeat on axis 1", Shape{2}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"colvec repeats on axis 0", Shape{2, 1}, []int{3}, 0, Shape{6, 1}, []int{3, 3}, 2, false}, - {"colvec repeats on axis 1", Shape{2, 1}, []int{3}, 1, Shape{2, 3}, []int{3}, 1, false}, - {"rowvec repeats on axis 0", Shape{1, 2}, []int{3}, 0, Shape{3, 2}, []int{3}, 1, false}, - {"rowvec repeats on axis 1", Shape{1, 2}, []int{3}, 1, Shape{1, 6}, []int{3, 3}, 2, false}, - {"3-Tensor repeats", Shape{2, 3, 2}, []int{1, 2, 1}, 1, Shape{2, 4, 2}, []int{1, 2, 1}, 3, false}, - {"3-Tensor generic repeats", Shape{2, 3, 2}, []int{2}, AllAxes, Shape{24}, []int{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 12, false}, - {"3-Tensor generic repeat, axis specified", Shape{2, 3, 2}, []int{2}, 2, Shape{2, 3, 4}, []int{2, 2}, 2, false}, - - // stupids - {"nonexisting axis 2", Shape{2, 1}, []int{3}, 2, nil, nil, 0, true}, - {"mismatching repeats", Shape{2, 3, 2}, []int{3, 1, 2}, 0, nil, nil, 0, true}, -} - -func TestShape_Repeat(t *testing.T) { - assert := assert.New(t) - for _, srts := range shapeRepeatTests { - newShape, reps, size, err := srts.s.Repeat(srts.axis, srts.repeats...) - - switch { - case srts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !srts.err && err != nil: - t.Error(err) - continue - } - - assert.True(srts.expected.Eq(newShape), "Test %q: Want: %v. Got %v", srts.name, srts.expected, newShape) - assert.Equal(srts.expectedRepeats, reps, "Test %q: ", srts.name) - assert.Equal(srts.expectedSize, size, "Test %q: ", srts.name) - } -} - -var shapeConcatTests = []struct { - name string - s Shape - axis int - ss []Shape - - expected Shape - err bool -}{ - {"standard, axis 0 ", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"standard, axis 1 ", Shape{2, 2}, 1, []Shape{{2, 2}, {2, 2}}, Shape{2, 6}, false}, - {"standard, axis AllAxes ", Shape{2, 2}, -1, []Shape{{2, 2}, {2, 2}}, Shape{6, 2}, false}, - {"concat to empty", Shape{2}, 0, nil, Shape{2}, false}, - - {"stupids: different dims", Shape{2, 2}, 0, []Shape{{2, 3, 2}}, nil, true}, - {"stupids: negative axes", Shape{2, 2}, -5, []Shape{{2, 2}}, nil, true}, - {"stupids: toobig axis", Shape{2, 2}, 5, []Shape{{2, 2}}, nil, true}, - {"subtle stupids: dim mismatch", Shape{2, 2}, 0, []Shape{{2, 2}, {2, 3}}, nil, true}, -} - -func TestShape_Concat(t *testing.T) { - assert := assert.New(t) - for _, scts := range shapeConcatTests { - newShape, err := scts.s.Concat(scts.axis, scts.ss...) - switch { - case scts.err: - if err == nil { - t.Error("Expected an error") - } - continue - case !scts.err && err != nil: - t.Error(err) - continue - } - assert.Equal(scts.expected, newShape) - } -} diff --git a/slice.go b/slice.go index 41e1419..7ee3522 100644 --- a/slice.go +++ b/slice.go @@ -1,11 +1,13 @@ package tensor -// A Slice represents a slicing operation for a Tensor. -type Slice interface { - Start() int - End() int - Step() int -} +import ( + "gorgonia.org/shapes" +) + +var xxx Slice = ss(1) +var _ shapes.Slice = xxx + +type Slice = shapes.Slice type rs struct { start, end, step int diff --git a/sparse.go b/sparse.go index 1a9da7c..b500db8 100644 --- a/sparse.go +++ b/sparse.go @@ -6,6 +6,7 @@ import ( "sort" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( @@ -183,7 +184,7 @@ func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { func (t *CS) Shape() Shape { return t.s } func (t *CS) Strides() []int { return nil } -func (t *CS) Dtype() Dtype { return t.t } +func (t *CS) Dtype() dtype.Dtype { return t.t } func (t *CS) Dims() int { return 2 } func (t *CS) Size() int { return t.s.TotalSize() } func (t *CS) DataSize() int { return t.Len() } @@ -233,7 +234,7 @@ func (t *CS) T(axes ...int) error { UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() t.o = MakeDataOrder(t.o, Transposed) - return errors.Errorf(methodNYI, "T", t) + return nyierr(typeNYI, t) } // UT untransposes the CS @@ -242,9 +243,7 @@ func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } // Transpose is a no-op. The data does not move func (t *CS) Transpose() error { return nil } -func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { - return nil, errors.Errorf(methodNYI, "Apply", t) -} +func (t *CS) Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) { return nil, nyierr(typeNYI, t) } func (t *CS) Eq(other interface{}) bool { if ot, ok := other.(*CS); ok { @@ -379,4 +378,4 @@ func (t *CS) IsManuallyManaged() bool { return t.f.manuallyManaged() } func (t *CS) arr() array { return t.array } func (t *CS) arrPtr() *array { return &t.array } -func (t *CS) standardEngine() standardEngine { return nil } +func (t *CS) standardEngine() StandardEngine { return nil } diff --git a/sparse_test.go b/sparse_test.go index 86cdad1..34b22dd 100644 --- a/sparse_test.go +++ b/sparse_test.go @@ -1,105 +1,105 @@ -package tensor - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCS_Basics(t *testing.T) { - assert := assert.New(t) - xs0 := []int{1, 2, 6, 8} - ys0 := []int{1, 2, 1, 6} - xs1 := []int{1, 2, 6, 8} - ys1 := []int{1, 2, 1, 6} - vals0 := []float64{3, 1, 4, 1} - vals1 := []float64{3, 1, 4, 1} - - var T0, T1 *CS - var d0, d1 *Dense - var dp0, dp1 *Dense - var err error - fails := func() { - CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) - } - assert.Panics(fails) - - // Test CSC - T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) - d0 = T0.Dense() - T0.T() - dp0 = T0.Dense() - T0.UT() // untranspose as Materialize() will be called below - - // Test CSR - fails = func() { - CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) - } - T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) - d1 = T1.Dense() - T1.T() - dp1 = T1.Dense() - T1.UT() - - t.Logf("%v %v", T0.indptr, T0.indices) - t.Logf("%v %v", T1.indptr, T1.indices) - - assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) - assert.True(dp0.Eq(dp1)) - assert.True(T1.Eq(T1)) - assert.False(T0.Eq(T1)) - - // At - var got interface{} - correct := float64(3.0) - if got, err = T0.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) - } - if got, err = T1.At(1, 1); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) - } - - correct = 0.0 - if got, err = T0.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) - } - - if got, err = T1.At(3, 3); err != nil { - t.Error(err) - } - if got.(float64) != correct { - t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) - } - - // Test clone - T2 := T0.Clone() - assert.True(T0.Eq(T2)) - - // Scalar representation - assert.False(T0.IsScalar()) - fails = func() { - T0.ScalarValue() - } - assert.Panics(fails) - assert.Equal(len(vals0), T0.NonZeroes()) - - // Sparse Iterator - it := T0.Iterator() - var valids []int - correctValids := []int{0, 2, 1, 3} - for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { - if valid { - valids = append(valids, i) - } - } - assert.Equal(correctValids, valids) -} +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCS_Basics(t *testing.T) { + assert := assert.New(t) + xs0 := []int{1, 2, 6, 8} + ys0 := []int{1, 2, 1, 6} + xs1 := []int{1, 2, 6, 8} + ys1 := []int{1, 2, 1, 6} + vals0 := []float64{3, 1, 4, 1} + vals1 := []float64{3, 1, 4, 1} + + var T0, T1 *CS + var d0, d1 *Dense + var dp0, dp1 *Dense + var err error + fails := func() { + CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) + } + assert.Panics(fails) + + // Test CSC + T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) + d0 = T0.Dense() + T0.T() + dp0 = T0.Dense() + T0.UT() // untranspose as Materialize() will be called below + + // Test CSR + fails = func() { + CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) + } + T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) + d1 = T1.Dense() + T1.T() + dp1 = T1.Dense() + T1.UT() + + t.Logf("%v %v", T0.indptr, T0.indices) + t.Logf("%v %v", T1.indptr, T1.indices) + + assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) + assert.True(dp0.Eq(dp1)) + assert.True(T1.Eq(T1)) + assert.False(T0.Eq(T1)) + + // At + var got interface{} + correct := float64(3.0) + if got, err = T0.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) + } + if got, err = T1.At(1, 1); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) + } + + correct = 0.0 + if got, err = T0.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) + } + + if got, err = T1.At(3, 3); err != nil { + t.Error(err) + } + if got.(float64) != correct { + t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) + } + + // Test clone + T2 := T0.Clone() + assert.True(T0.Eq(T2)) + + // Scalar representation + assert.False(T0.IsScalar()) + fails = func() { + T0.ScalarValue() + } + assert.Panics(fails) + assert.Equal(len(vals0), T0.NonZeroes()) + + // Sparse Iterator + it := T0.Iterator() + var valids []int + correctValids := []int{0, 2, 1, 3} + for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { + if valid { + valids = append(valids, i) + } + } + assert.Equal(correctValids, valids) +} diff --git a/tensor.go b/tensor.go index 071ca67..8445a39 100644 --- a/tensor.go +++ b/tensor.go @@ -4,16 +4,15 @@ package tensor // import "gorgonia.org/tensor" import ( "encoding/gob" - "fmt" - "io" "github.com/pkg/errors" + "gorgonia.org/dtype" ) var ( _ Tensor = &Dense{} _ Tensor = &CS{} - _ View = &Dense{} + _ View = &DenseView{} ) func init() { @@ -21,16 +20,22 @@ func init() { gob.Register(&CS{}) } -// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. -// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. -type Tensor interface { +// Desc is a description of a tensor. It does not actually deal with data. +type Desc interface { // info about the ndarray Shape() Shape Strides() []int - Dtype() Dtype + Dtype() dtype.Dtype + Dims() int Size() int DataSize() int +} + +// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. +// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. +type Tensor interface { + Desc // Data access related RequiresIterator() bool @@ -54,10 +59,6 @@ type Tensor interface { Eq Cloner - // type overloading methods - IsScalar() bool - ScalarValue() interface{} - // engine/memory related stuff // all Tensors should be able to be expressed of as a slab of memory // Note: the size of each element can be acquired by T.Dtype().Size() @@ -67,18 +68,20 @@ type Tensor interface { IsManuallyManaged() bool // Must Go manage the memory // formatters - fmt.Formatter - fmt.Stringer + // fmt.Formatter + // fmt.Stringer // all Tensors are serializable to these formats - WriteNpy(io.Writer) error - ReadNpy(io.Reader) error - gob.GobEncoder - gob.GobDecoder + //WriteNpy(io.Writer) error + //ReadNpy(io.Reader) error + //gob.GobEncoder + //gob.GobDecoder - standardEngine() standardEngine headerer arrayer + + // TO BE DEPRECATED + ScalarRep } // New creates a new Dense Tensor. For sparse arrays use their relevant construction function @@ -95,15 +98,26 @@ func New(opts ...ConsOpt) *Dense { return d } +// MustGetDense gets a *Dense from a given Tensor. Panics otherwise. +func MustGetDense(T Tensor) *Dense { + d, err := assertDense(T) + if err != nil { + panic(err) + } + return d +} + func assertDense(t Tensor) (*Dense, error) { if t == nil { return nil, errors.New("nil is not a *Dense") } - if retVal, ok := t.(*Dense); ok { - return retVal, nil - } - if retVal, ok := t.(Densor); ok { - return retVal.Dense(), nil + switch tt := t.(type) { + case *Dense: + return tt, nil + case DenseView: + return tt.Dense, nil + case Densor: + return tt.Dense(), nil } return nil, errors.Errorf("%T is not *Dense", t) } @@ -124,7 +138,7 @@ func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { if t == nil { return } - if err = typeclassCheck(t.Dtype(), floatTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.Floats); err != nil { err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype()) return } @@ -145,7 +159,7 @@ func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { if t == nil { return } - if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil { + if err = dtype.TypeClassCheck(t.Dtype(), dtype.FloatComplex); err != nil { err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype()) return } @@ -161,10 +175,11 @@ func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { return } +// sliceDense returns a *Dense. func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) { var sliced Tensor if sliced, err = t.Slice(slices...); err != nil { return nil, err } - return sliced.(*Dense), nil + return sliced.(DenseView).Dense, nil } diff --git a/test_test.go b/test_test.go index 5f76d8a..f5a7e0c 100644 --- a/test_test.go +++ b/test_test.go @@ -1,5 +1,3 @@ -// Code generated by genlib2. DO NOT EDIT. - package tensor import ( @@ -9,8 +7,11 @@ import ( "unsafe" "github.com/chewxy/math32" + "gorgonia.org/dtype" ) +// Code generated by genlib2. DO NOT EDIT. + func anyToFloat64s(x interface{}) (retVal []float64) { switch xt := x.(type) { case []int: @@ -120,7 +121,7 @@ func anyToFloat64s(x interface{}) (retVal []float64) { panic("Unreachable") } -func identityVal(x int, dt Dtype) interface{} { +func identityVal(x int, dt dtype.Dtype) interface{} { switch dt { case Int: return int(x) diff --git a/testutils_test.go b/testutils_test.go index 3a0d466..77312fb 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -2,6 +2,7 @@ package tensor import ( "bytes" + "context" "errors" "math" "math/cmplx" @@ -14,6 +15,8 @@ import ( "github.com/chewxy/math32" "gorgonia.org/tensor/internal/storage" + + "gorgonia.org/dtype" ) func randomBool() bool { @@ -330,7 +333,7 @@ func shuffleInts(a []int, r *rand.Rand) { type TensorGenerator struct { ShapeConstraint Shape - DtypeConstraint Dtype + DtypeConstraint dtype.Dtype } func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { @@ -342,8 +345,8 @@ func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { func (t *Dense) Generate(r *rand.Rand, size int) reflect.Value { // generate type - ri := r.Intn(len(specializedTypes.set)) - of := specializedTypes.set[ri] + ri := r.Intn(len(specializedTypes)) + of := specializedTypes[ri] datatyp := reflect.SliceOf(of.Type) gendat, _ := quick.Value(datatyp, r) // generate dims @@ -502,14 +505,26 @@ func (e dummyEngine2) Memcpy(dst, src Memory) error { return e.e.Mem func (e dummyEngine2) Accessible(mem Memory) (Memory, error) { return e.e.Accessible(mem) } func (e dummyEngine2) WorksWith(order DataOrder) bool { return e.e.WorksWith(order) } -func (e dummyEngine2) Argmax(t Tensor, axis int) (Tensor, error) { return e.e.Argmax(t, axis) } -func (e dummyEngine2) Argmin(t Tensor, axis int) (Tensor, error) { return e.e.Argmin(t, axis) } +func (e dummyEngine2) Argmax(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmax(ctx, t, axis) +} +func (e dummyEngine2) Argmin(ctx context.Context, t Tensor, axis int) (Tensor, error) { + return e.e.Argmin(ctx, t, axis) +} -func willerr(a *Dense, tc, eqtc *typeclass) (retVal, willFailEq bool) { - if err := typeclassCheck(a.Dtype(), eqtc); err == nil { +func willerr(a *Dense, tc, eqtc dtype.TypeClass) (retVal, willFailEq bool) { + if eqtc == nilTC { willFailEq = true + } else { + if err := dtype.TypeClassCheck(a.Dtype(), eqtc); err == nil { + willFailEq = true + } + } + if tc == nilTC { + retVal = !a.IsNativelyAccessible() + return } - if err := typeclassCheck(a.Dtype(), tc); err != nil { + if err := dtype.TypeClassCheck(a.Dtype(), tc); err != nil { return true, willFailEq } @@ -539,14 +554,14 @@ func qcErrCheck(t *testing.T, name string, a Dtyper, b interface{}, we bool, err return nil, false } -func qcIsFloat(dt Dtype) bool { - if err := typeclassCheck(dt, floatcmplxTypes); err == nil { +func qcIsFloat(dt dtype.Dtype) bool { + if err := dtype.TypeClassCheck(dt, dtype.FloatComplex); err == nil { return true } return false } -func qcEqCheck(t *testing.T, dt Dtype, willFailEq bool, correct, got interface{}) bool { +func qcEqCheck(t *testing.T, dt dtype.Dtype, willFailEq bool, correct, got interface{}) bool { isFloatTypes := qcIsFloat(dt) if !willFailEq && (isFloatTypes && !allClose(correct, got) || (!isFloatTypes && !reflect.DeepEqual(correct, got))) { t.Errorf("q.Dtype: %v", dt) diff --git a/type_test.go b/type_test.go index d616b8f..7200f66 100644 --- a/type_test.go +++ b/type_test.go @@ -1,66 +1,13 @@ package tensor import ( - "reflect" - "testing" + "gorgonia.org/dtype" ) -type Float16 uint16 - -func TestRegisterType(t *testing.T) { - dt := Dtype{reflect.TypeOf(Float16(0))} - RegisterFloat(dt) - - if err := typeclassCheck(dt, floatTypes); err != nil { - t.Errorf("Expected %v to be in floatTypes: %v", dt, err) - } - if err := typeclassCheck(dt, numberTypes); err != nil { - t.Errorf("Expected %v to be in numberTypes: %v", dt, err) - } - if err := typeclassCheck(dt, ordTypes); err != nil { - t.Errorf("Expected %v to be in ordTypes: %v", dt, err) - } - if err := typeclassCheck(dt, eqTypes); err != nil { - t.Errorf("Expected %v to be in eqTypes: %v", dt, err) - } - +var numberTypes = []dtype.Dtype{ + Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, } -func TestDtypeConversions(t *testing.T) { - for k, v := range reverseNumpyDtypes { - if npdt, err := v.numpyDtype(); npdt != k { - t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - dt := Dtype{reflect.TypeOf(Float16(0))} - if _, err := dt.numpyDtype(); err == nil { - t.Errorf("Expected an error when passing in type unknown to np") - } - - for k, v := range numpyDtypes { - if dt, err := fromNumpyDtype(v); dt != k { - // special cases - if Int.Size() == 4 && v == "i4" && dt == Int { - continue - } - if Int.Size() == 8 && v == "i8" && dt == Int { - continue - } - - if Uint.Size() == 4 && v == "u4" && dt == Uint { - continue - } - if Uint.Size() == 8 && v == "u8" && dt == Uint { - continue - } - t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt) - } else if err != nil { - t.Errorf("Error: %v", err) - } - } - if _, err := fromNumpyDtype("EDIUH"); err == nil { - t.Error("Expected error when nonsense is passed into fromNumpyDtype") - } +var specializedTypes = []dtype.Dtype{ + Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, } diff --git a/types.go b/types.go index 69740cf..0cc7ef1 100644 --- a/types.go +++ b/types.go @@ -4,100 +4,15 @@ import ( "fmt" "math" "reflect" - "unsafe" - "github.com/chewxy/hm" - "github.com/pkg/errors" + "gorgonia.org/dtype" ) -// Dtype represents a data type of a Tensor. Concretely it's implemented as an embedded reflect.Type -// which allows for easy reflection operations. It also implements hm.Type, for type inference in Gorgonia -type Dtype struct { - reflect.Type -} - -// note: the Name() and String() methods are already defined in reflect.Type. Might as well use the composed methods - -func (dt Dtype) Apply(hm.Subs) hm.Substitutable { return dt } -func (dt Dtype) FreeTypeVar() hm.TypeVarSet { return nil } -func (dt Dtype) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { return dt, nil } -func (dt Dtype) Types() hm.Types { return nil } -func (dt Dtype) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", dt.Name()) } -func (dt Dtype) Eq(other hm.Type) bool { return other == dt } - -var numpyDtypes map[Dtype]string -var reverseNumpyDtypes map[string]Dtype - -func init() { - numpyDtypes = map[Dtype]string{ - Bool: "b1", - Int: fmt.Sprintf("i%d", Int.Size()), - Int8: "i1", - Int16: "i2", - Int32: "i4", - Int64: "i8", - Uint: fmt.Sprintf("u%d", Uint.Size()), - Uint8: "u1", - Uint16: "u2", - Uint32: "u4", - Uint64: "u8", - Float32: "f4", - Float64: "f8", - Complex64: "c8", - Complex128: "c16", - } - - reverseNumpyDtypes = map[string]Dtype{ - "b1": Bool, - "i1": Int8, - "i2": Int16, - "i4": Int32, - "i8": Int64, - "u1": Uint8, - "u2": Uint16, - "u4": Uint32, - "u8": Uint64, - "f4": Float32, - "f8": Float64, - "c8": Complex64, - "c16": Complex128, - } -} - -// NumpyDtype returns the Numpy's Dtype equivalent. This is predominantly used in converting a Tensor to a Numpy ndarray, -// however, not all Dtypes are supported -func (dt Dtype) numpyDtype() (string, error) { - retVal, ok := numpyDtypes[dt] - if !ok { - return "v", errors.Errorf("Unsupported Dtype conversion to Numpy Dtype: %v", dt) - } - return retVal, nil -} - -func fromNumpyDtype(t string) (Dtype, error) { - retVal, ok := reverseNumpyDtypes[t] - if !ok { - return Dtype{}, errors.Errorf("Unsupported Dtype conversion from %q to Dtype", t) - } - if t == "i4" && Int.Size() == 4 { - return Int, nil - } - if t == "i8" && Int.Size() == 8 { - return Int, nil - } - if t == "u4" && Uint.Size() == 4 { - return Uint, nil - } - if t == "u8" && Uint.Size() == 8 { - return Uint, nil - } - return retVal, nil -} +// Dtype is an alias for dtype.Dtype. This alias is here for backward compatibility purposes, for when users are transitioning out of the older tensor libraries. +type Dtype = dtype.Dtype -type typeclass struct { - name string - set []Dtype -} +// nil type class for skipping type class checks +var nilTC dtype.TypeClass = -1 var parameterizedKinds = [...]reflect.Kind{ reflect.Array, @@ -119,227 +34,31 @@ func isParameterizedKind(k reflect.Kind) bool { return false } -// oh how nice it'd be if I could make them immutable -var ( - Bool = Dtype{reflect.TypeOf(true)} - Int = Dtype{reflect.TypeOf(int(1))} - Int8 = Dtype{reflect.TypeOf(int8(1))} - Int16 = Dtype{reflect.TypeOf(int16(1))} - Int32 = Dtype{reflect.TypeOf(int32(1))} - Int64 = Dtype{reflect.TypeOf(int64(1))} - Uint = Dtype{reflect.TypeOf(uint(1))} - Uint8 = Dtype{reflect.TypeOf(uint8(1))} - Uint16 = Dtype{reflect.TypeOf(uint16(1))} - Uint32 = Dtype{reflect.TypeOf(uint32(1))} - Uint64 = Dtype{reflect.TypeOf(uint64(1))} - Float32 = Dtype{reflect.TypeOf(float32(1))} - Float64 = Dtype{reflect.TypeOf(float64(1))} - Complex64 = Dtype{reflect.TypeOf(complex64(1))} - Complex128 = Dtype{reflect.TypeOf(complex128(1))} - String = Dtype{reflect.TypeOf("")} - - // aliases - Byte = Uint8 +func isFloat(dt dtype.Dtype) bool { return dt == Float64 || dt == Float32 } - // extras - Uintptr = Dtype{reflect.TypeOf(uintptr(0))} - UnsafePointer = Dtype{reflect.TypeOf(unsafe.Pointer(&Uintptr))} +// type aliases +var ( + Bool = dtype.Bool + Int = dtype.Int + Int8 = dtype.Int8 + Int16 = dtype.Int16 + Int32 = dtype.Int32 + Int64 = dtype.Int64 + Uint = dtype.Uint + Uint8 = dtype.Uint8 + Uint16 = dtype.Uint16 + Uint32 = dtype.Uint32 + Uint64 = dtype.Uint64 + Float32 = dtype.Float32 + Float64 = dtype.Float64 + Complex64 = dtype.Complex64 + Complex128 = dtype.Complex128 + String = dtype.String + Byte = dtype.Byte + Uintptr = dtype.Uintptr + UnsafePointer = dtype.UnsafePointer ) -// allTypes for indexing -var allTypes = &typeclass{ - name: "τ", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -// specialized types indicate that there are specialized code generated for these types -var specializedTypes = &typeclass{ - name: "Specialized", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var addableTypes = &typeclass{ - name: "Addable", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, - }, -} - -var numberTypes = &typeclass{ - name: "Number", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, - }, -} - -var ordTypes = &typeclass{ - name: "Ord", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -var eqTypes = &typeclass{ - name: "Eq", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, Complex64, Complex128, String, Uintptr, UnsafePointer, - }, -} - -var unsignedTypes = &typeclass{ - name: "Unsigned", - set: []Dtype{Uint, Uint8, Uint16, Uint32, Uint64}, -} - -var signedTypes = &typeclass{ - name: "Signed", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, Complex64, Complex128, - }, -} - -// this typeclass is ever only used by Sub tests -var signedNonComplexTypes = &typeclass{ - name: "Signed NonComplex", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Float32, Float64, - }, -} - -var floatTypes = &typeclass{ - name: "Float", - set: []Dtype{ - Float32, Float64, - }, -} - -var complexTypes = &typeclass{ - name: "Complex Numbers", - set: []Dtype{Complex64, Complex128}, -} - -var floatcmplxTypes = &typeclass{ - name: "Real", - set: []Dtype{ - Float32, Float64, Complex64, Complex128, - }, -} - -var nonComplexNumberTypes = &typeclass{ - name: "Non complex numbers", - set: []Dtype{ - Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, - }, -} - -// this typeclass is ever only used by Pow tests -var generatableTypes = &typeclass{ - name: "Generatable types", - set: []Dtype{ - Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Float32, Float64, String, - }, -} - -func isFloat(dt Dtype) bool { - return dt == Float64 || dt == Float32 -} - -func typeclassCheck(a Dtype, tc *typeclass) error { - if tc == nil { - return nil - } - for _, s := range tc.set { - if s == a { - return nil - } - } - return errors.Errorf("Type %v is not a member of %v", a, tc.name) -} - -// RegisterNumber is a function required to register a new numerical Dtype. -// This package provides the following Dtype: -// Int -// Int8 -// Int16 -// Int32 -// Int64 -// Uint -// Uint8 -// Uint16 -// Uint32 -// Uint64 -// Float32 -// Float64 -// Complex64 -// Complex128 -// -// If a Dtype that is registered already exists on the list, it will not be added to the list. -func RegisterNumber(a Dtype) { - for _, dt := range numberTypes.set { - if dt == a { - return - } - } - numberTypes.set = append(numberTypes.set, a) - RegisterEq(a) -} - -func RegisterFloat(a Dtype) { - for _, dt := range floatTypes.set { - if dt == a { - return - } - } - floatTypes.set = append(floatTypes.set, a) - RegisterNumber(a) - RegisterOrd(a) -} - -// RegisterOrd registers a dtype as a type that can be typed -func RegisterOrd(a Dtype) { - for _, dt := range ordTypes.set { - if dt == a { - return - } - } - ordTypes.set = append(ordTypes.set, a) - RegisterEq(a) -} - -// RegisterEq registers a dtype as a type that can be compared for equality -func RegisterEq(a Dtype) { - for _, dt := range eqTypes.set { - if dt == a { - return - } - } - eqTypes.set = append(eqTypes.set, a) - Register(a) -} - -// Register registers a new Dtype -func Register(a Dtype) { - for _, dt := range allTypes.set { - if a == dt { - return - } - } - allTypes.set = append(allTypes.set, a) -} - -func dtypeID(a Dtype) int { - for i, v := range allTypes.set { - if a == v { - return i - } - } - return -1 -} - // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. // But there are norm types that are outside numerical types, such as nuclear norm and fobenius norm. // So it is internally represented by a float. If Go could use NaN and Inf as consts, it would have been best, @@ -410,54 +129,3 @@ func (n NormOrder) String() string { } panic("unreachable") } - -// FuncOpt are optionals for calling Tensor function. -type FuncOpt func(*OpOpt) - -// WithIncr passes in a Tensor to be incremented. -func WithIncr(incr Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.incr = incr - } - return f -} - -// WithReuse passes in a Tensor to be reused. -func WithReuse(reuse Tensor) FuncOpt { - f := func(opt *OpOpt) { - opt.reuse = reuse - } - return f -} - -// UseSafe ensures that the operation is a safe operation (copies data, does not clobber). This is the default option for most methods and functions -func UseSafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = false - } - return f -} - -// UseUnsafe ensures that the operation is an unsafe operation - data will be clobbered, and operations performed inplace -func UseUnsafe() FuncOpt { - f := func(opt *OpOpt) { - opt.unsafe = true - } - return f -} - -// AsSameType makes sure that the return Tensor is the same type as input Tensors. -func AsSameType() FuncOpt { - f := func(opt *OpOpt) { - opt.same = true - } - return f -} - -// As makes sure that the the return Tensor is of the type specified. Currently only works for FromMat64 -func As(t Dtype) FuncOpt { - f := func(opt *OpOpt) { - opt.t = t - } - return f -} diff --git a/utils.go b/utils.go index 2b3aa65..426a1dd 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package tensor import ( + "context" + "github.com/pkg/errors" ) @@ -244,37 +246,35 @@ func SliceDetails(s Slice, size int) (start, end, step int, err error) { return } -// reuseDenseCheck checks a reuse tensor, and reshapes it to be the correct one -func reuseDenseCheck(reuse DenseTensor, as DenseTensor) (err error) { - if reuse.DataSize() != as.Size() { - err = errors.Errorf("Reused Tensor %p does not have expected shape %v. Got %v instead. Reuse Size: %v, as Size %v (real: %d)", reuse, as.Shape(), reuse.Shape(), reuse.DataSize(), as.Size(), as.DataSize()) - return - } - return reuseCheckShape(reuse, as.Shape()) - -} - -// reuseCheckShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. -func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { +// checkFixShape checks the shape and reshapes it to be correct if the size fits but the shape doesn't. +func checkFixShape(reuse Tensor, s Shape) (err error) { throw := BorrowInts(len(s)) copy(throw, s) - if err = reuse.reshape(throw...); err != nil { - err = errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + d, ok := reuse.(DenseTensor) + if !ok { + if err = reuse.Reshape(throw...); err != nil { + return errors.Wrapf(err, reuseReshapeErr, s, reuse.DataSize()) + } + return nil + } + + if err = d.reshape(throw...); err != nil { + err = errors.Wrapf(err, reuseReshapeErr, s, d.DataSize()) return } // clean up any funny things that may be in the reuse - if oldAP := reuse.oldAP(); !oldAP.IsZero() { + if oldAP := d.oldAP(); !oldAP.IsZero() { oldAP.zero() } - if axes := reuse.transposeAxes(); axes != nil { + if axes := d.transposeAxes(); axes != nil { ReturnInts(axes) } - if viewOf := reuse.parentTensor(); viewOf != nil { - reuse.setParentTensor(nil) + if viewOf := d.parentTensor(); viewOf != nil { + d.setParentTensor(nil) } return nil } @@ -291,6 +291,7 @@ func memsetBools(a []bool, v bool) { } } +// allones checks that a slice of ints are all 1. func allones(a []int) bool { for i := range a { if a[i] != 1 { @@ -300,6 +301,14 @@ func allones(a []int) bool { return true } +// ctxFromEngine gets a context from an engine if it's a contexter. Otherwise it returns a context.Background() +func ctxFromEngine(e Engine) context.Context { + if c, ok := e.(contexter); ok { + return c.Context() + } + return context.Background() +} + func getFloat64s(a Tensor) []float64 { if um, ok := a.(unsafeMem); ok { return um.Float64s() @@ -319,6 +328,7 @@ func getInts(a Tensor) []int { return um.Ints() } return a.Data().([]int) + } /* FOR ILLUSTRATIVE PURPOSES */