diff --git a/common/dataset/dataset.go b/common/dataset/dataset.go new file mode 100644 index 000000000..8063bc496 --- /dev/null +++ b/common/dataset/dataset.go @@ -0,0 +1,184 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dataset + +import ( + "archive/zip" + "encoding/csv" + "fmt" + "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/common/util" + "go.uber.org/zap" + "io" + "net/http" + "os" + "os/user" + "path/filepath" + "strings" +) + +var ( + tempDir string + datasetDir string +) + +func init() { + usr, err := user.Current() + if err != nil { + log.Logger().Fatal("failed to get user directory", zap.Error(err)) + } + datasetDir = filepath.Join(usr.HomeDir, ".gorse", "dataset") + tempDir = filepath.Join(usr.HomeDir, ".gorse", "temp") +} + +func LoadIris() ([][]float32, []int, error) { + // Download dataset + path, err := downloadAndUnzip("iris") + if err != nil { + return nil, nil, err + } + dataFile := filepath.Join(path, "iris.data") + // Load data + f, err := os.Open(dataFile) + if err != nil { + return nil, nil, err + } + reader := csv.NewReader(f) + rows, err := reader.ReadAll() + if err != nil { + return nil, nil, err + } + // Parse data + data := make([][]float32, len(rows)) + target := make([]int, len(rows)) + types := make(map[string]int) + for i, row := range rows { + data[i] = make([]float32, 4) + for j, cell := range row[:4] { + data[i][j], err = util.ParseFloat32(cell) + if err != nil { + return nil, nil, err + } + } + if _, exist := types[row[4]]; !exist { + types[row[4]] = len(types) + } + target[i] = types[row[4]] + } + return data, target, nil +} + +func downloadAndUnzip(name string) (string, error) { + url := fmt.Sprintf("https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/%s.zip", name) + path := filepath.Join(datasetDir, name) + if _, err := os.Stat(path); os.IsNotExist(err) { + zipFileName, _ := downloadFromUrl(url, tempDir) + if _, err := unzip(zipFileName, path); err != nil { + return "", err + } + } + return path, nil +} + +// downloadFromUrl downloads file from URL. +func downloadFromUrl(src, dst string) (string, error) { + log.Logger().Info("Download dataset", zap.String("source", src), zap.String("destination", dst)) + // Extract file name + tokens := strings.Split(src, "/") + fileName := filepath.Join(dst, tokens[len(tokens)-1]) + // Create file + if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil { + return fileName, err + } + output, err := os.Create(fileName) + if err != nil { + log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName)) + return fileName, err + } + defer output.Close() + // Download file + response, err := http.Get(src) + if err != nil { + log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) + return fileName, err + } + defer response.Body.Close() + // Save file + _, err = io.Copy(output, response.Body) + if err != nil { + log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src)) + return fileName, err + } + return fileName, nil +} + +// unzip zip file. +func unzip(src, dst string) ([]string, error) { + var fileNames []string + // Open zip file + r, err := zip.OpenReader(src) + if err != nil { + return fileNames, err + } + defer r.Close() + // Extract files + for _, f := range r.File { + // Open file + rc, err := f.Open() + if err != nil { + return fileNames, err + } + // Store filename/path for returning and using later on + filePath := filepath.Join(dst, f.Name) + // Check for ZipSlip. More Info: http://bit.ly/2MsjAWE + if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) { + return fileNames, fmt.Errorf("%s: illegal file path", filePath) + } + // Add filename + fileNames = append(fileNames, filePath) + if f.FileInfo().IsDir() { + // Create folder + if err = os.MkdirAll(filePath, os.ModePerm); err != nil { + return fileNames, err + } + } else { + // Create all folders + if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { + return fileNames, err + } + // Create file + outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return fileNames, err + } + // Save file + _, err = io.Copy(outFile, rc) + if err != nil { + return nil, err + } + // Close the file without defer to close before next iteration of loop + err = outFile.Close() + if err != nil { + return nil, err + } + } + // Close file + err = rc.Close() + if err != nil { + return nil, err + } + } + return fileNames, nil +} diff --git a/common/dataset/dataset_test.go b/common/dataset/dataset_test.go new file mode 100644 index 000000000..78ef60ccd --- /dev/null +++ b/common/dataset/dataset_test.go @@ -0,0 +1,14 @@ +package dataset + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLoadIris(t *testing.T) { + data, target, err := LoadIris() + assert.NoError(t, err) + assert.Len(t, data, 150) + assert.Len(t, data[0], 4) + assert.Len(t, target, 150) +} diff --git a/common/nn/functions.go b/common/nn/functions.go new file mode 100644 index 000000000..3b7fe048d --- /dev/null +++ b/common/nn/functions.go @@ -0,0 +1,207 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +func Neg(x *Tensor) *Tensor { + return apply(&neg{}, x) +} + +// Add returns the element-wise sum of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. +func Add(x0, x1 *Tensor) *Tensor { + if len(x0.shape) < len(x1.shape) { + x0, x1 = x1, x0 + } + for i := 0; i < len(x1.shape); i++ { + if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + } + return apply(&add{}, x0, x1) +} + +// Sub returns the element-wise difference of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. +func Sub(x0, x1 *Tensor) *Tensor { + if len(x0.shape) < len(x1.shape) { + x0, x1 = x1, x0 + } + for i := 0; i < len(x1.shape); i++ { + if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + } + return apply(&sub{}, x0, x1) +} + +// Mul returns the element-wise product of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. +func Mul(x0, x1 *Tensor) *Tensor { + if len(x0.shape) < len(x1.shape) { + x0, x1 = x1, x0 + } + for i := 0; i < len(x1.shape); i++ { + if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + } + return apply(&mul{}, x0, x1) +} + +// Div returns the element-wise division of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. +func Div(x0, x1 *Tensor) *Tensor { + if len(x0.shape) < len(x1.shape) { + x0, x1 = x1, x0 + } + for i := 0; i < len(x1.shape); i++ { + if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + } + return apply(&div{}, x0, x1) +} + +// Square returns the element-wise square of a tensor. +func Square(x *Tensor) *Tensor { + return apply(&square{}, x) +} + +// Pow returns the element-wise power of a tensor. The shape of the second tensor must be a suffix sequence of the shape of the first tensor. +func Pow(x *Tensor, n *Tensor) *Tensor { + if len(x.shape) < len(x.shape) { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + for i := 0; i < len(x.shape); i++ { + if x.shape[len(x.shape)-len(x.shape)+i] != x.shape[i] { + panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor") + } + } + return apply(&pow{}, x, n) +} + +// Exp returns the element-wise exponential of a tensor. +func Exp(x *Tensor) *Tensor { + return apply(&exp{}, x) +} + +// Log returns the element-wise natural logarithm of a tensor. +func Log(x *Tensor) *Tensor { + return apply(&log{}, x) +} + +// Sin returns the element-wise sine of a tensor. +func Sin(x *Tensor) *Tensor { + return apply(&sin{}, x) +} + +func Cos(x *Tensor) *Tensor { + return apply(&cos{}, x) +} + +// Sum returns the sum of all elements in a tensor. +func Sum(x *Tensor, along ...int) *Tensor { + if len(along) > 1 { + panic("only one along is allowed") + } else if len(along) == 1 { + return apply(&partialSum{along: int64(along[0])}, x) + } + return apply(&sum{}, x) +} + +// Mean returns the mean of all elements in a tensor. +func Mean(x *Tensor) *Tensor { + return apply(&mean{}, x) +} + +func MatMul(x, y *Tensor, transpose ...bool) *Tensor { + op := &matMul{} + if len(transpose) > 2 { + panic("only two transpose is allowed") + } + if len(transpose) > 0 { + op.transpose1 = transpose[0] + } + if len(transpose) > 1 { + op.transpose2 = transpose[1] + } + return apply(op, x, y) +} + +func BMM(x, y *Tensor, transpose ...bool) *Tensor { + op := &batchMatMul{} + if len(transpose) > 2 { + panic("only two transpose is allowed") + } + if len(transpose) > 0 { + op.transpose1 = transpose[0] + } + if len(transpose) > 1 { + op.transpose2 = transpose[1] + } + return apply(op, x, y) +} + +func Broadcast(x *Tensor, shape ...int) *Tensor { + return apply(&broadcast{shape: shape}, x) +} + +func Flatten(x *Tensor) *Tensor { + return apply(&flatten{}, x) +} + +func Reshape(x *Tensor, shape ...int) *Tensor { + size1 := 1 + for i := range x.shape { + size1 *= x.shape[i] + } + size2 := 1 + for i := range shape { + size2 *= shape[i] + } + if size1 != size2 { + panic("the size of the tensor must be equal to the size of the new shape") + } + return apply(&reshape{shape: shape}, x) +} + +func Embedding(w, x *Tensor) *Tensor { + return apply(&embedding{}, w, x) +} + +func Sigmoid(x *Tensor) *Tensor { + return apply(&sigmoid{}, x) +} + +func ReLu(x *Tensor) *Tensor { + return apply(&relu{}, x) +} + +func MSE(x, y *Tensor) *Tensor { + return Mean(Square(Sub(x, y))) +} + +// BCEWithLogits is equivalent to: +// +// (1 + target) * math32.Log(1+math32.Exp(-prediction)) / 2 + (1 - target) * math32.Log(1+math32.Exp(prediction)) / 2 +func BCEWithLogits(target, prediction *Tensor) *Tensor { + return Add( + Div( + Mul( + Add(NewScalar(1), target), + Log(Add(NewScalar(1), Exp(Neg(prediction))))), + NewScalar(2)), + Div( + Mul( + Sub(NewScalar(1), target), + Log(Add(NewScalar(1), Exp(prediction)))), + NewScalar(2))) +} diff --git a/common/nn/layers.go b/common/nn/layers.go new file mode 100644 index 000000000..ae6fba718 --- /dev/null +++ b/common/nn/layers.go @@ -0,0 +1,112 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +type Layer interface { + Parameters() []*Tensor + Forward(x *Tensor) *Tensor +} + +type Model Layer + +type linearLayer struct { + w *Tensor + b *Tensor +} + +func NewLinear(in, out int) Layer { + return &linearLayer{ + w: RandN(in, out).RequireGrad(), + b: RandN(out).RequireGrad(), + } +} + +func (l *linearLayer) Forward(x *Tensor) *Tensor { + return Add(MatMul(x, l.w), l.b) +} + +func (l *linearLayer) Parameters() []*Tensor { + return []*Tensor{l.w, l.b} +} + +type flattenLayer struct{} + +func NewFlatten() Layer { + return &flattenLayer{} +} + +func (f *flattenLayer) Parameters() []*Tensor { + return nil +} + +func (f *flattenLayer) Forward(x *Tensor) *Tensor { + return Flatten(x) +} + +type embeddingLayer struct { + w *Tensor +} + +func NewEmbedding(n int, shape ...int) Layer { + wShape := append([]int{n}, shape...) + return &embeddingLayer{ + w: RandN(wShape...), + } +} + +func (e *embeddingLayer) Parameters() []*Tensor { + return []*Tensor{e.w} +} + +func (e *embeddingLayer) Forward(x *Tensor) *Tensor { + return Embedding(e.w, x) +} + +type reluLayer struct{} + +func NewReLU() Layer { + return &reluLayer{} +} + +func (r *reluLayer) Parameters() []*Tensor { + return nil +} + +func (r *reluLayer) Forward(x *Tensor) *Tensor { + return ReLu(x) +} + +type Sequential struct { + layers []Layer +} + +func NewSequential(layers ...Layer) Model { + return &Sequential{layers: layers} +} + +func (s *Sequential) Parameters() []*Tensor { + var params []*Tensor + for _, l := range s.layers { + params = append(params, l.Parameters()...) + } + return params +} + +func (s *Sequential) Forward(x *Tensor) *Tensor { + for _, l := range s.layers { + x = l.Forward(x) + } + return x +} diff --git a/common/nn/op.go b/common/nn/op.go new file mode 100644 index 000000000..44f117384 --- /dev/null +++ b/common/nn/op.go @@ -0,0 +1,698 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +import ( + "github.com/chewxy/math32" +) + +type op interface { + String() string + forward(inputs ...*Tensor) *Tensor + backward(dy *Tensor) []*Tensor + inputsAndOutput() ([]*Tensor, *Tensor) + setInputs(inputs ...*Tensor) + setOutput(y *Tensor) +} + +type base struct { + inputs []*Tensor + output *Tensor +} + +func (b *base) inputsAndOutput() ([]*Tensor, *Tensor) { + return b.inputs, b.output +} + +func (b *base) setInputs(inputs ...*Tensor) { + b.inputs = inputs +} + +func (b *base) setOutput(y *Tensor) { + b.output = y +} + +func apply[T op](f T, inputs ...*Tensor) *Tensor { + y := f.forward(inputs...) + f.setInputs(inputs...) + f.setOutput(y) + y.op = f + return y +} + +type neg struct { + base +} + +func (n *neg) String() string { + return "Neg" +} + +func (n *neg) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.neg() + return y +} + +func (n *neg) backward(dy *Tensor) []*Tensor { + dx := dy.clone() + dx.neg() + return []*Tensor{dx} +} + +type add struct { + base +} + +func (a *add) String() string { + return "Add" +} + +func (a *add) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.add(inputs[1]) + return y +} + +func (a *add) backward(dy *Tensor) []*Tensor { + gx0 := dy.clone() + gx1 := Zeros(a.inputs[1].shape...) + wSize := 1 + for i := range gx1.shape { + wSize *= gx1.shape[i] + } + for i := range dy.data { + gx1.data[i%wSize] += dy.data[i] + } + return []*Tensor{gx0, gx1} +} + +type sub struct { + base +} + +func (s *sub) String() string { + return "Sub" +} + +func (s *sub) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.sub(inputs[1]) + return y +} + +func (s *sub) backward(dy *Tensor) []*Tensor { + gx0 := dy.clone() + gx1 := Zeros(s.inputs[1].shape...) + wSize := 1 + for i := range gx1.shape { + wSize *= gx1.shape[i] + } + for i := range dy.data { + gx1.data[i%wSize] -= dy.data[i] + } + return []*Tensor{gx0, gx1} +} + +type mul struct { + base +} + +func (m *mul) String() string { + return "Mul" +} + +func (m *mul) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.mul(inputs[1]) + return y +} + +func (m *mul) backward(dy *Tensor) []*Tensor { + gx0 := dy.clone() + gx0.mul(m.inputs[1]) + gx1 := Zeros(m.inputs[1].shape...) + wSize := 1 + for i := range gx1.shape { + wSize *= gx1.shape[i] + } + for i := range dy.data { + gx1.data[i%wSize] += dy.data[i] * m.inputs[0].data[i] + } + return []*Tensor{gx0, gx1} +} + +type div struct { + base +} + +func (d *div) String() string { + return "Div" +} + +func (d *div) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.div(inputs[1]) + return y +} + +func (d *div) backward(dy *Tensor) []*Tensor { + wSize := 1 + for i := range d.inputs[1].shape { + wSize *= d.inputs[1].shape[i] + } + gx0 := Zeros(d.inputs[0].shape...) + for i := range dy.data { + gx0.data[i] = dy.data[i] / d.inputs[1].data[i%wSize] + } + gx1 := Zeros(d.inputs[1].shape...) + for i := range dy.data { + gx1.data[i%wSize] -= dy.data[i] * d.inputs[0].data[i] / d.inputs[1].data[i%wSize] / d.inputs[1].data[i%wSize] + } + return []*Tensor{gx0, gx1} +} + +type sin struct { + base +} + +func (s *sin) String() string { + return "Sin" +} + +func (s *sin) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.sin() + return y +} + +func (s *sin) backward(dy *Tensor) []*Tensor { + dx := s.inputs[0].clone() + dx.cos() + dx.mul(dy) + return []*Tensor{dx} +} + +type cos struct { + base +} + +func (c *cos) String() string { + return "Cos" +} + +func (c *cos) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.cos() + return y +} + +func (c *cos) backward(dy *Tensor) []*Tensor { + dx := c.inputs[0].clone() + dx.sin() + dx.neg() + dx.mul(dy) + return []*Tensor{dx} +} + +type square struct { + base +} + +func (s *square) String() string { + return "Square" +} + +func (s *square) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.square() + return y +} + +func (s *square) backward(dy *Tensor) []*Tensor { + dx := s.inputs[0].clone() + dx.mul(dy) + for i := range dx.data { + dx.data[i] *= 2 + } + return []*Tensor{dx} +} + +type pow struct { + base +} + +func (p *pow) String() string { + return "Pow" +} + +func (p *pow) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.pow(inputs[1]) + return y +} + +func (p *pow) backward(dy *Tensor) []*Tensor { + dx0 := p.inputs[0].clone() + dx0.pow(p.inputs[1]) + dx0.mul(p.inputs[1]) + dx0.div(p.inputs[0]) + dx0.mul(dy) + wSize := 1 + for i := range p.inputs[1].shape { + wSize *= p.inputs[1].shape[i] + } + dx1 := Zeros(p.inputs[1].shape...) + for i := range dy.data { + dx1.data[i%wSize] += dy.data[i] * p.output.data[i] * math32.Log(p.inputs[0].data[i]) + } + return []*Tensor{dx0, dx1} +} + +type exp struct { + base +} + +func (e *exp) String() string { + return "Exp" +} + +func (e *exp) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.exp() + return y +} + +func (e *exp) backward(dy *Tensor) []*Tensor { + dx := e.inputs[0].clone() + dx.exp() + dx.mul(dy) + return []*Tensor{dx} +} + +type log struct { + base +} + +func (l *log) String() string { + return "Log" +} + +func (l *log) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.log() + return y +} + +func (l *log) backward(dy *Tensor) []*Tensor { + dx := dy.clone() + dx.div(l.inputs[0]) + return []*Tensor{dx} +} + +type sum struct { + base +} + +func (s *sum) String() string { + return "Sum" +} + +func (s *sum) forward(inputs ...*Tensor) *Tensor { + x := inputs[0] + y := NewTensor([]float32{0}) + for i := range x.data { + y.data[0] += x.data[i] + } + return y +} + +func (s *sum) backward(dy *Tensor) []*Tensor { + dx := Zeros(s.inputs[0].shape...) + for i := range dx.data { + dx.data[i] = dy.data[0] + } + return []*Tensor{dx} +} + +type partialSum struct { + base + along int64 +} + +func (p *partialSum) String() string { + return "Sum" +} + +func (p *partialSum) forward(inputs ...*Tensor) *Tensor { + x := inputs[0] + // Squash the shape. + s1, s2, s3 := 1, 1, 1 + for i := 0; i < len(x.shape); i++ { + if int64(i) == p.along { + s2 = x.shape[i] + } else if int64(i) < p.along { + s1 *= x.shape[i] + } else { + s3 *= x.shape[i] + } + } + // Calculate the output size and shape. + outputSize := s1 * s3 + outputShape := make([]int, 0) + for i := 0; i < len(x.shape); i++ { + if int64(i) != p.along { + outputShape = append(outputShape, x.shape[i]) + } + } + // Calculate the output. + y := NewTensor(make([]float32, outputSize), outputShape...) + for i := 0; i < s1; i++ { + for j := 0; j < s2; j++ { + for k := 0; k < s3; k++ { + y.data[i*s3+k] += x.data[i*s2*s3+j*s3+k] + } + } + } + return y +} + +func (p *partialSum) backward(dy *Tensor) []*Tensor { + x := p.inputs[0] + // Squash the shape. + s1, s2, s3 := 1, 1, 1 + for i := 0; i < len(x.shape); i++ { + if int64(i) == p.along { + s2 = x.shape[i] + } else if int64(i) < p.along { + s1 *= x.shape[i] + } else { + s3 *= x.shape[i] + } + } + // Calculate the output. + dx := Zeros(x.shape...) + for i := 0; i < s1; i++ { + for j := 0; j < s2; j++ { + for k := 0; k < s3; k++ { + dx.data[i*s2*s3+j*s3+k] = dy.data[i*s3+k] + } + } + } + return []*Tensor{dx} +} + +type mean struct { + base +} + +func (m *mean) String() string { + return "Mean" +} + +func (m *mean) forward(inputs ...*Tensor) *Tensor { + x := inputs[0] + y := NewTensor([]float32{0}) + for i := range x.data { + y.data[0] += x.data[i] + } + y.data[0] /= float32(len(x.data)) + return y +} + +func (m *mean) backward(dy *Tensor) []*Tensor { + dx := Zeros(m.inputs[0].shape...) + for i := range dx.data { + dx.data[i] = dy.data[0] / float32(len(dx.data)) + } + return []*Tensor{dx} +} + +type matMul struct { + base + transpose1 bool + transpose2 bool +} + +func (m *matMul) String() string { + return "MatMul" +} + +func (m *matMul) forward(inputs ...*Tensor) *Tensor { + return inputs[0].matMul(inputs[1], m.transpose1, m.transpose2) +} + +func (m *matMul) backward(dy *Tensor) []*Tensor { + var dx0, dx1 *Tensor + if !m.transpose1 && !m.transpose2 { // y = x0 * x1 + // dx0 = dy * x1^T + dx0 = dy.matMul(m.inputs[1], false, true) + // dx1 = x0^T * dy + dx1 = m.inputs[0].matMul(dy, true, false) + } else if m.transpose1 && !m.transpose2 { // y = x0^T * x1 + // dx0 = dy * x1^T + dx0 = m.inputs[1].matMul(dy, false, true) + // dx1 = dy^T * x0 + dx1 = m.inputs[0].matMul(dy, false, false) + } else if !m.transpose1 && m.transpose2 { // y = x0 * x1^T + // dx0 = dy * x1 + dx0 = dy.matMul(m.inputs[1], false, false) + // dx1 = dy^T * x0 + dx1 = dy.matMul(m.inputs[0], true, false) + } else { // y = x0^T * x1^T + // dx0 = x1 * dy^T + dx0 = m.inputs[1].matMul(dy, true, true) + // dx1 = dy * x0^T + dx1 = dy.matMul(m.inputs[0], true, true) + } + return []*Tensor{dx0, dx1} +} + +type batchMatMul struct { + base + transpose1 bool + transpose2 bool +} + +func (b *batchMatMul) String() string { + return "BatchMatMul" +} + +func (b *batchMatMul) forward(inputs ...*Tensor) *Tensor { + return inputs[0].batchMatMul(inputs[1], b.transpose1, b.transpose2) +} + +func (b *batchMatMul) backward(dy *Tensor) []*Tensor { + var dx0, dx1 *Tensor + if !b.transpose1 && !b.transpose2 { // y = x0 * x1 + // dx0 = dy * x1^T + dx0 = dy.batchMatMul(b.inputs[1], false, true) + // dx1 = x0^T * dy + dx1 = b.inputs[0].batchMatMul(dy, true, false) + } else if b.transpose1 && !b.transpose2 { // y = x0^T * x1 + // dx0 = dy * x1^T + dx0 = b.inputs[1].batchMatMul(dy, false, true) + // dx1 = dy^T * x0 + dx1 = b.inputs[0].batchMatMul(dy, false, false) + } else if !b.transpose1 && b.transpose2 { // y = x0 * x1^T + // dx0 = dy * x1 + dx0 = dy.batchMatMul(b.inputs[1], false, false) + // dx1 = dy^T * x0 + dx1 = dy.batchMatMul(b.inputs[0], true, false) + } else { // y = x0^T * x1^T + // dx0 = x1 * dy^T + dx0 = b.inputs[1].batchMatMul(dy, true, true) + // dx1 = dy * x0^T + dx1 = dy.batchMatMul(b.inputs[0], true, true) + } + return []*Tensor{dx0, dx1} +} + +type broadcast struct { + base + shape []int +} + +func (b *broadcast) String() string { + return "Broadcast" +} + +func (b *broadcast) forward(inputs ...*Tensor) *Tensor { + x := inputs[0] + // Concatenate the shape + shape := make([]int, len(x.shape)) + copy(shape, x.shape) + shape = append(shape, b.shape...) + size := 1 + for i := range shape { + size *= shape[i] + } + // Create a new tensor with the new shape + y := NewTensor(make([]float32, size), shape...) + wSize := 1 + for i := range b.shape { + wSize *= b.shape[i] + } + for i := range x.data { + for j := i * wSize; j < (i+1)*wSize; j++ { + y.data[j] = x.data[i] + } + } + return y +} + +func (b *broadcast) backward(dy *Tensor) []*Tensor { + gx := Zeros(b.inputs[0].shape...) + wSize := 1 + for i := range b.shape { + wSize *= b.shape[i] + } + for i := range gx.data { + for j := i * wSize; j < (i+1)*wSize; j++ { + gx.data[i] += dy.data[j] + } + } + return []*Tensor{gx} +} + +type flatten struct { + base +} + +func (f *flatten) String() string { + return "Flatten" +} + +func (f *flatten) forward(inputs ...*Tensor) *Tensor { + return NewTensor(inputs[0].data, len(inputs[0].data)) +} + +func (f *flatten) backward(dy *Tensor) []*Tensor { + return []*Tensor{NewTensor(dy.data, f.inputs[0].shape...)} +} + +type reshape struct { + base + shape []int +} + +func (r *reshape) String() string { + return "Reshape" +} + +func (r *reshape) forward(inputs ...*Tensor) *Tensor { + return NewTensor(inputs[0].data, r.shape...) +} + +func (r *reshape) backward(dy *Tensor) []*Tensor { + return []*Tensor{NewTensor(dy.data, r.inputs[0].shape...)} +} + +type embedding struct { + base +} + +func (e *embedding) String() string { + return "Embedding" +} + +func (e *embedding) forward(inputs ...*Tensor) *Tensor { + w, x := inputs[0], inputs[1] + // Calculate embedding size + dim := 1 + for i := 1; i < len(w.shape); i++ { + dim *= w.shape[i] + } + // Calculate shape + shape := make([]int, len(x.shape), len(x.shape)+1) + copy(shape, x.shape) + shape = append(shape, w.shape[1:]...) + // Calculate data size + size := 1 + for _, s := range shape { + size *= s + } + // Create output tensor + data := make([]float32, size) + for i := 0; i < len(x.data); i++ { + index := int(x.data[i]) + copy(data[i*dim:(i+1)*dim], w.data[index*dim:(index+1)*dim]) + } + return NewTensor(data, shape...) +} + +func (e *embedding) backward(dy *Tensor) []*Tensor { + w, x := e.inputs[0], e.inputs[1] + dim := 1 + for i := 1; i < len(w.shape); i++ { + dim *= w.shape[i] + } + dw := Zeros(w.shape...) + for i := 0; i < len(x.data); i++ { + index := int(x.data[i]) + for j := 0; j < dim; j++ { + dw.data[index*dim+j] += dy.data[i*dim+j] + } + } + return []*Tensor{dw} +} + +type sigmoid struct { + base +} + +func (s *sigmoid) String() string { + return "Sigmoid" +} + +func (s *sigmoid) forward(inputs ...*Tensor) *Tensor { + // y = tanh(x * 0.5) * 0.5 + 0.5 + y := inputs[0].clone() + y.mul(NewScalar(0.5)) + y.tanh() + y.mul(NewScalar(0.5)) + y.add(NewScalar(0.5)) + return y +} + +func (s *sigmoid) backward(dy *Tensor) []*Tensor { + // dx = dy * y * (1 - y) + dx := s.output.clone() + dx.neg() + dx.add(NewScalar(1)) + dx.mul(s.output) + dx.mul(dy) + return []*Tensor{dx} +} + +type relu struct { + base +} + +func (r *relu) String() string { + return "ReLU" +} + +func (r *relu) forward(inputs ...*Tensor) *Tensor { + y := inputs[0].clone() + y.maximum(NewScalar(0)) + return y +} + +func (r *relu) backward(dy *Tensor) []*Tensor { + dx := dy.clone() + dx.maximum(NewScalar(0)) + return []*Tensor{dx} +} diff --git a/common/nn/op_test.go b/common/nn/op_test.go new file mode 100644 index 000000000..5fb034abd --- /dev/null +++ b/common/nn/op_test.go @@ -0,0 +1,546 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +import ( + "github.com/chewxy/math32" + "github.com/stretchr/testify/assert" + "testing" +) + +const ( + eps = 1e-4 + rtol = 1e-2 + atol = 1e-4 +) + +func numericalDiff(f func(*Tensor) *Tensor, x *Tensor) *Tensor { + x0 := Sub(x, NewVariable([]float32{eps})) + x1 := Add(x, NewVariable([]float32{eps})) + y0 := f(x0) + y1 := f(x1) + dx := Div(Sub(y1, y0), NewVariable([]float32{2 * eps})) + return dx +} + +func allClose(t *testing.T, a, b *Tensor) { + if !assert.Equal(t, a.shape, b.shape) { + return + } + for i := range a.data { + if math32.Abs(a.data[i]-b.data[i]) > atol+rtol*math32.Abs(b.data[i]) { + t.Fatalf("a.data[%d] = %f, b.data[%d] = %f\n", i, a.data[i], i, b.data[i]) + return + } + } +} + +func TestAdd(t *testing.T) { + // (2,3) + (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{2, 3, 4, 5, 6, 7}, 2, 3) + z := Add(x, y) + assert.Equal(t, []float32{3, 5, 7, 9, 11, 13}, z.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = RandN(2, 3).RequireGrad() + z = Add(x, y) + z.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, y) }, x) + allClose(t, x.grad, dx) + dy := numericalDiff(func(y *Tensor) *Tensor { return Add(x, y) }, y) + allClose(t, y.grad, dy) + + // (2,3) + () -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2}) + z = Add(x, y) + assert.Equal(t, []float32{3, 4, 5, 6, 7, 8}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) + assert.Equal(t, []float32{6}, y.grad.data) + + // (2,3) + (3) -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2, 3, 4}, 3) + z = Add(x, y) + assert.Equal(t, []float32{3, 5, 7, 6, 8, 10}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) + assert.Equal(t, []float32{2, 2, 2}, y.grad.data) +} + +func TestSub(t *testing.T) { + // (2,3) - (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{2, 3, 4, 5, 6, 7}, 2, 3) + z := Sub(x, y) + assert.Equal(t, []float32{-1, -1, -1, -1, -1, -1}, z.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = RandN(2, 3).RequireGrad() + z = Sub(x, y) + z.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Sub(x, y) }, x) + allClose(t, x.grad, dx) + dy := numericalDiff(func(y *Tensor) *Tensor { return Sub(x, y) }, y) + allClose(t, y.grad, dy) + + // (2,3) - () -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2}) + z = Sub(x, y) + assert.Equal(t, []float32{-1, 0, 1, 2, 3, 4}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) + assert.Equal(t, []float32{-6}, y.grad.data) + + // (2,3) - (3) -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2, 3, 4}, 3) + z = Sub(x, y) + assert.Equal(t, []float32{-1, -1, -1, 2, 2, 2}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) + assert.Equal(t, []float32{-2, -2, -2}, y.grad.data) +} + +func TestMul(t *testing.T) { + // (2,3) * (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{2, 3, 4, 5, 6, 7}, 2, 3) + z := Mul(x, y) + assert.Equal(t, []float32{2, 6, 12, 20, 30, 42}, z.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = RandN(2, 3).RequireGrad() + z = Mul(x, y) + z.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Mul(x, y) }, x) + allClose(t, x.grad, dx) + dy := numericalDiff(func(y *Tensor) *Tensor { return Mul(x, y) }, y) + allClose(t, y.grad, dy) + + // (2,3) * () -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2}) + z = Mul(x, y) + assert.Equal(t, []float32{2, 4, 6, 8, 10, 12}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{2, 2, 2, 2, 2, 2}, x.grad.data) + assert.Equal(t, []float32{21}, y.grad.data) + + // (2,3) * (3) -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2, 3, 4}, 3) + z = Mul(x, y) + assert.Equal(t, []float32{2, 6, 12, 8, 15, 24}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []float32{2, 3, 4, 2, 3, 4}, x.grad.data) + assert.Equal(t, []float32{5, 7, 9}, y.grad.data) +} + +func TestDiv(t *testing.T) { + // (2,3) / (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{2, 3, 4, 5, 6, 7}, 2, 3) + z := Div(x, y) + assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 0.75, 4.0 / 5.0, 5.0 / 6.0, 6.0 / 7.0}, z.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = RandN(2, 3).RequireGrad() + z = Div(x, y) + z.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Div(x, y) }, x) + allClose(t, x.grad, dx) + dy := numericalDiff(func(y *Tensor) *Tensor { return Div(x, y) }, y) + allClose(t, y.grad, dy) + + // (2,3) / () -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2}) + z = Div(x, y) + assert.InDeltaSlice(t, []float32{0.5, 1, 1.5, 2, 2.5, 3}, z.data, 1e-6) + + // Test gradient + z.Backward() + assert.InDeltaSlice(t, []float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, x.grad.data, 1e-6) + assert.InDeltaSlice(t, []float32{-21.0 / 4.0}, y.grad.data, 1e-6) + + // (2,3) / (3) -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2, 3, 4}, 3) + z = Div(x, y) + assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 3.0 / 4.0, 2, 5.0 / 3.0, 1.5}, z.data, 1e-6) + + // Test gradient + z.Backward() + assert.InDeltaSlice(t, []float32{1.0 / 2, 1.0 / 3, 1.0 / 4, 1.0 / 2, 1.0 / 3, 1.0 / 4}, x.grad.data, 1e-6) + assert.InDeltaSlice(t, []float32{-5.0 / 4.0, -7.0 / 9.0, -9.0 / 16.0}, y.grad.data, 1e-6) +} + +func TestSquare(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Square(x) + assert.Equal(t, []float32{1, 4, 9, 16, 25, 36}, y.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Square(x) + y.Backward() + dx := numericalDiff(Square, x) + allClose(t, x.grad, dx) +} + +func TestPow(t *testing.T) { + // (2,3) ** (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{2, 3, 4, 5, 6, 7}, 2, 3) + z := Pow(x, y) + assert.InDeltaSlice(t, []float32{1, 8, 81, 1024, 15625, 279936}, z.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = RandN(2, 3).RequireGrad() + z = Pow(x, y) + z.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Pow(x, y) }, x) + allClose(t, x.grad, dx) + dy := numericalDiff(func(y *Tensor) *Tensor { return Pow(x, y) }, y) + allClose(t, y.grad, dy) + + // (2,3) ** () -> (2,3) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y = NewVariable([]float32{2}) + z = Pow(x, y) + assert.InDeltaSlice(t, []float32{1, 4, 9, 16, 25, 36}, z.data, 1e-6) + + // Test gradient + z.Backward() + assert.InDeltaSlice(t, []float32{2, 4, 6, 8, 10, 12}, x.grad.data, 1e-6) + assert.InDeltaSlice(t, []float32{ + math32.Pow(1, 2)*math32.Log(1) + + math32.Pow(2, 2)*math32.Log(2) + + math32.Pow(3, 2)*math32.Log(3) + + math32.Pow(4, 2)*math32.Log(4) + + math32.Pow(5, 2)*math32.Log(5) + + math32.Pow(6, 2)*math32.Log(6), + }, y.grad.data, 1e-6) +} + +func TestExp(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{0, 1, 2, 3, 4, 5}, 2, 3) + y := Exp(x) + assert.InDeltaSlice(t, []float32{1, math32.Exp(1), math32.Exp(2), math32.Exp(3), math32.Exp(4), math32.Exp(5)}, y.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Exp(x) + y.Backward() + dx := numericalDiff(Exp, x) + allClose(t, x.grad, dx) +} + +func TestLog(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Log(x) + assert.InDeltaSlice(t, []float32{0, math32.Log(2), math32.Log(3), math32.Log(4), math32.Log(5), math32.Log(6)}, y.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Log(x) + y.Backward() + dx := numericalDiff(Log, x) + allClose(t, x.grad, dx) +} + +func TestSum(t *testing.T) { + // (2,3) -> () + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Sum(x) + assert.Equal(t, []float32{21}, y.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Sum(x) + y.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) + + // (2,3,2) -> (2,2) + x = NewVariable([]float32{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, 2, 3, 2) + y = Sum(x, 1) + assert.Equal(t, []int{2, 2}, y.shape) + assert.Equal(t, []float32{9, 12, 9, 12}, y.data) + + // Test gradient + x = RandN(2, 3, 2).RequireGrad() + y = Sum(x, 1) + y.Backward() + assert.Equal(t, []int{2, 3, 2}, x.grad.shape) + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, x.grad.data) +} + +func TestMean(t *testing.T) { + // (2,3) -> () + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Mean(x) + assert.Equal(t, []float32{3.5}, y.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Mean(x) + y.Backward() + assert.Equal(t, []float32{1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6}, x.grad.data) +} + +func TestCos(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{0, 0.1, 0.2, 0.3, 0.4, 0.5}, 2, 3) + y := Cos(x) + assert.InDeltaSlice(t, []float32{1, 0.9950041652780258, 0.9800665778412416, 0.955336489125606, 0.9210609940028851, 0.8775825618903728}, y.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Cos(x) + y.Backward() + dx := numericalDiff(Cos, x) + allClose(t, x.grad, dx) +} + +func TestSin(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{0, 1, 2, 3, 4, 5}, 2, 3) + y := Sin(x) + assert.InDeltaSlice(t, []float32{0, 0.8414709848078965, 0.9092974268256817, 0.1411200080598672, -0.7568024953079282, -0.9589242746631385}, y.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Sin(x) + y.Backward() + dx := numericalDiff(Sin, x) + allClose(t, x.grad, dx) +} + +func TestMatMul(t *testing.T) { + // (2,3) * (3,4) -> (2,4) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := NewVariable([]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 3, 4) + z := MatMul(x, y) + assert.Equal(t, []int{2, 4}, z.shape) + assert.Equal(t, []float32{38, 44, 50, 56, 83, 98, 113, 128}, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []int{2, 3}, x.grad.shape) + assert.Equal(t, []float32{10, 26, 42, 10, 26, 42}, x.grad.data) + assert.Equal(t, []int{3, 4}, y.grad.shape) + assert.Equal(t, []float32{5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9}, y.grad.data) + + // (3,2).T * (3,4) -> (2,4) + x = RandN(3, 2).RequireGrad() + y = RandN(3, 4).RequireGrad() + z = MatMul(x, y, true, false) + assert.Equal(t, []int{2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{3, 2}, x.grad.shape) + assert.Equal(t, []int{3, 4}, y.grad.shape) + + // (2,3) * (4,3).T -> (2,4) + x = RandN(2, 3).RequireGrad() + y = RandN(4, 3).RequireGrad() + z = MatMul(x, y, false, true) + assert.Equal(t, []int{2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{2, 3}, x.grad.shape) + assert.Equal(t, []int{4, 3}, y.grad.shape) + + // (3,2).T * (4,3).T -> (2,4) + x = RandN(3, 2).RequireGrad() + y = RandN(4, 3).RequireGrad() + z = MatMul(x, y, true, true) + assert.Equal(t, []int{2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{3, 2}, x.grad.shape) +} + +func TestBMM(t *testing.T) { + // (2,2,3) * (2,3,4) -> (2,2,4) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, 2, 2, 3) + y := NewVariable([]float32{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + }, 2, 3, 4) + z := BMM(x, y) + assert.Equal(t, []int{2, 2, 4}, z.shape) + assert.Equal(t, []float32{ + 38, 44, 50, 56, 83, 98, 113, 128, + 38, 44, 50, 56, 83, 98, 113, 128, + }, z.data) + + // Test gradient + z.Backward() + assert.Equal(t, []int{2, 2, 3}, x.grad.shape) + assert.Equal(t, []float32{ + 10, 26, 42, 10, 26, 42, + 10, 26, 42, 10, 26, 42, + }, x.grad.data) + assert.Equal(t, []int{2, 3, 4}, y.grad.shape) + assert.Equal(t, []float32{ + 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, + 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, + }, y.grad.data) + + // (2,3,2).T * (2,3,4) -> (2,2,4) + x = RandN(2, 3, 2).RequireGrad() + y = RandN(2, 3, 4).RequireGrad() + z = BMM(x, y, true, false) + assert.Equal(t, []int{2, 2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{2, 3, 2}, x.grad.shape) + + // (2,2,3) * (2,4,3).T -> (2,2,4) + x = RandN(2, 2, 3).RequireGrad() + y = RandN(2, 4, 3).RequireGrad() + z = BMM(x, y, false, true) + assert.Equal(t, []int{2, 2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{2, 2, 3}, x.grad.shape) + + // (2,3,2).T * (2,43).T -> (2,2,4) + x = RandN(2, 3, 2).RequireGrad() + y = RandN(2, 4, 3).RequireGrad() + z = BMM(x, y, true, true) + assert.Equal(t, []int{2, 2, 4}, z.shape) + z.Backward() + assert.Equal(t, []int{2, 3, 2}, x.grad.shape) +} + +func TestBroadcast(t *testing.T) { + // (2) -> (2,3) + x := NewVariable([]float32{1, 2}, 2) + y := Broadcast(x, 3) + assert.Equal(t, []float32{1, 1, 1, 2, 2, 2}, y.data) + + // Test gradient + y.Backward() + assert.Equal(t, []float32{3, 3}, x.grad.data) +} + +func TestEmbedding(t *testing.T) { + // (2,3) -> (2,3,2) + x := NewVariable([]float32{0, 1, 0, 3, 0, 5}, 2, 3) + w := NewVariable([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 2) + y := Embedding(w, x) + assert.Equal(t, []int{2, 3, 2}, y.shape) + assert.Equal(t, []float32{0, 1, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11}, y.data) + + // Test gradient + y.Backward() + assert.Nil(t, x.grad) + assert.Equal(t, []float32{3, 3, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1}, w.grad.data) + + // (2,3) -> (2,3,1,2) + x = NewVariable([]float32{0, 1, 0, 3, 0, 5}, 2, 3) + w = NewVariable([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 1, 2) + y = Embedding(w, x) + assert.Equal(t, []int{2, 3, 1, 2}, y.shape) + assert.Equal(t, []float32{0, 1, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11}, y.data) + + // Test gradient + y.Backward() + assert.Nil(t, x.grad) + assert.Equal(t, []float32{3, 3, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1}, w.grad.data) +} + +func TestSigmoid(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{0, 1, 2, 3, 4, 5}, 2, 3) + y := Sigmoid(x) + assert.InDeltaSlice(t, []float32{0.5, 0.7310585786300049, 0.8807970779778823, 0.9525741268224334, 0.9820137900379085, 0.9933071490757153}, y.data, 1e-6) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = Sigmoid(x) + y.Backward() + dx := numericalDiff(Sigmoid, x) + allClose(t, x.grad, dx) +} + +func TestReLu(t *testing.T) { + // (2,3) -> (2,3) + x := NewVariable([]float32{-1, 0, 1, 2, 3, 4}, 2, 3) + y := ReLu(x) + assert.Equal(t, []float32{0, 0, 1, 2, 3, 4}, y.data) + + // Test gradient + x = RandN(2, 3).RequireGrad() + y = ReLu(x) + y.Backward() + dx := numericalDiff(ReLu, x) + allClose(t, x.grad, dx) +} + +func TestFlatten(t *testing.T) { + // (2,3) -> (6) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Flatten(x) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, y.data) + + // Test gradient + y.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) +} + +func TestReshape(t *testing.T) { + // (2,3) -> (3,2) + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Reshape(x, 3, 2) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, y.data) + + // Test gradient + y.Backward() + assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data) +} + +func TestReuse(t *testing.T) { + // x + x + x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + y := Add(x, x) + assert.Equal(t, []float32{2, 4, 6, 8, 10, 12}, y.data) + + // Test gradient + y.Backward() + dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, x) }, x) + allClose(t, x.grad, dx) +} diff --git a/common/nn/optimizers.go b/common/nn/optimizers.go new file mode 100644 index 000000000..314980ade --- /dev/null +++ b/common/nn/optimizers.go @@ -0,0 +1,98 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +import ( + "github.com/chewxy/math32" + "github.com/google/uuid" +) + +type Optimizer interface { + ZeroGrad() + Step() +} + +type baseOptimizer struct { + params []*Tensor +} + +func (o *baseOptimizer) ZeroGrad() { + for _, p := range o.params { + p.grad = nil + } +} + +type SGD struct { + baseOptimizer + lr float32 +} + +func NewSGD(params []*Tensor, lr float32) Optimizer { + return &SGD{ + baseOptimizer: baseOptimizer{params: params}, + lr: lr, + } +} + +func (s *SGD) Step() { + for _, p := range s.params { + for i := range p.data { + p.data[i] -= s.lr * p.grad.data[i] + } + } +} + +type Adam struct { + baseOptimizer + alpha float32 + beta1 float32 + beta2 float32 + eps float32 + ms map[uuid.UUID]*Tensor + vs map[uuid.UUID]*Tensor +} + +func NewAdam(params []*Tensor, alpha float32) Optimizer { + return &Adam{ + baseOptimizer: baseOptimizer{params: params}, + alpha: alpha, + beta1: 0.9, + beta2: 0.999, + eps: 1e-8, + ms: make(map[uuid.UUID]*Tensor), + vs: make(map[uuid.UUID]*Tensor), + } +} + +func (a *Adam) Step() { + for _, p := range a.params { + if _, ok := a.ms[p.id]; !ok { + a.ms[p.id] = Zeros(p.shape...) + a.vs[p.id] = Zeros(p.shape...) + } + + m, v := a.ms[p.id], a.vs[p.id] + grad := p.grad.data + + for i := range m.data { + // m += (1 - beta1) * (grad - m) + m.data[i] += (1 - a.beta1) * (grad[i] - m.data[i]) + // v += (1 - beta2) * (grad * grad - v) + v.data[i] += (1 - a.beta2) * (grad[i]*grad[i] - v.data[i]) + // param.data -= self.lr * m / (xp.sqrt(v) + eps) + p.data[i] -= a.alpha * m.data[i] / (math32.Sqrt(v.data[i]) + a.eps) + } + } +} diff --git a/common/nn/tensor.go b/common/nn/tensor.go new file mode 100644 index 000000000..48f0e800b --- /dev/null +++ b/common/nn/tensor.go @@ -0,0 +1,569 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +import ( + "fmt" + "github.com/chewxy/math32" + "github.com/google/uuid" + "github.com/zhenghaoz/gorse/base/floats" + "golang.org/x/exp/slices" + "math/rand" + "strings" +) + +type Tensor struct { + data []float32 + shape []int + grad *Tensor + op op + + requireGrad bool + id uuid.UUID // Only assigned if requireGrad is true +} + +func NewTensor(data []float32, shape ...int) *Tensor { + size := 1 + for i := range shape { + size *= shape[i] + } + if len(data) != size { + panic(fmt.Sprintf("shape %v does not match data size %v", shape, len(data))) + } + return &Tensor{ + data: data, + shape: shape, + } +} + +func NewVariable(data []float32, shape ...int) *Tensor { + return NewTensor(data, shape...).RequireGrad() +} + +func NewScalar(data float32) *Tensor { + return &Tensor{ + data: []float32{data}, + shape: []int{}, + } +} + +func LinSpace(start, end float32, shape ...int) *Tensor { + n := 1 + for _, s := range shape { + n *= s + } + data := make([]float32, n) + delta := (end - start) / float32(n-1) + for i := range data { + data[i] = start + delta*float32(i) + } + return &Tensor{ + data: data, + shape: shape, + } +} + +func RandN(shape ...int) *Tensor { + n := 1 + for _, s := range shape { + n *= s + } + data := make([]float32, n) + for i := range data { + data[i] = rand.Float32() + } + return &Tensor{ + data: data, + shape: shape, + } +} + +// Ones creates a tensor filled with ones. +func Ones(shape ...int) *Tensor { + n := 1 + for _, s := range shape { + n *= s + } + data := make([]float32, n) + for i := range data { + data[i] = 1 + } + return &Tensor{ + data: data, + shape: shape, + } +} + +// Zeros creates a tensor filled with zeros. +func Zeros(shape ...int) *Tensor { + n := 1 + for _, s := range shape { + n *= s + } + data := make([]float32, n) + return &Tensor{ + data: data, + shape: shape, + } +} + +func (t *Tensor) IsScalar() bool { + return len(t.shape) == 0 +} + +// NoGrad creates a tensor does not require gradient. +func (t *Tensor) NoGrad() *Tensor { + if t.op != nil { + t.op = nil + } + return t +} + +func (t *Tensor) RequireGrad() *Tensor { + t.requireGrad = true + t.id = uuid.New() + return t +} + +func (t *Tensor) Shape() []int { + return t.shape +} + +// Slice returns a slice of the tensor. +func (t *Tensor) Slice(start, end int) *Tensor { + if len(t.shape) < 1 { + panic("slice requires at least 1-D tensor") + } + if start < 0 || end > t.shape[0] { + panic("slice out of range") + } + subSize := 1 + for i := 1; i < len(t.shape); i++ { + subSize *= t.shape[i] + } + return &Tensor{ + data: t.data[start*subSize : end*subSize], + shape: append([]int{end - start}, t.shape[1:]...), + } +} + +// Get returns the value of the tensor at the given indices. +func (t *Tensor) Get(indices ...int) float32 { + if len(indices) != len(t.shape) { + panic("the number of indices does not match the shape of the tensor") + } + index := 0 + for i := range indices { + if indices[i] < 0 || indices[i] >= t.shape[i] { + panic("index out of range") + } + index = index*t.shape[i] + indices[i] + } + return t.data[index] +} + +func (t *Tensor) String() string { + // Print scalar value + if len(t.shape) == 0 { + return fmt.Sprint(t.data[0]) + } + + builder := strings.Builder{} + builder.WriteString("[") + if len(t.data) <= 10 { + for i := 0; i < len(t.data); i++ { + builder.WriteString(fmt.Sprint(t.data[i])) + if i != len(t.data)-1 { + builder.WriteString(", ") + } + } + } else { + for i := 0; i < 5; i++ { + builder.WriteString(fmt.Sprint(t.data[i])) + builder.WriteString(", ") + } + builder.WriteString("..., ") + for i := len(t.data) - 5; i < len(t.data); i++ { + builder.WriteString(fmt.Sprint(t.data[i])) + if i != len(t.data)-1 { + builder.WriteString(", ") + } + } + } + builder.WriteString("]") + return builder.String() +} + +func (t *Tensor) Backward() { + t.grad = Ones(t.shape...) + ops := []op{t.op} + for len(ops) > 0 { + op := ops[0] + ops = ops[1:] + inputs, output := op.inputsAndOutput() + grads := op.backward(output.grad) + // Clear gradient of non-leaf tensor + //output.grad = nil + for i := range grads { + if !slices.Equal(inputs[i].shape, grads[i].shape) { + panic(fmt.Sprintf("%s: shape %v does not match shape %v", op.String(), inputs[i].shape, grads[i].shape)) + } + if inputs[i].grad == nil { + inputs[i].grad = grads[i] + } else { + inputs[i].grad.add(grads[i]) + } + if inputs[i].op != nil { + ops = append(ops, inputs[i].op) + } else if !inputs[i].requireGrad { + // Clear gradient if the leaf tensor does not require gradient + //inputs[i].grad = nil + } + } + } +} + +func (t *Tensor) Grad() *Tensor { + return t.grad +} + +func (t *Tensor) Data() []float32 { + return t.data +} + +func (t *Tensor) clone() *Tensor { + newData := make([]float32, len(t.data)) + copy(newData, t.data) + return &Tensor{ + data: newData, + shape: t.shape, + } +} + +func (t *Tensor) add(other *Tensor) *Tensor { + wSize := 1 + for i := range other.shape { + wSize *= other.shape[i] + } + if wSize == 1 { + floats.AddConst(t.data, other.data[0]) + } else { + for i := 0; i < len(t.data); i += wSize { + floats.Add(t.data[i:i+wSize], other.data) + } + } + return t +} + +func (t *Tensor) sub(other *Tensor) *Tensor { + wSize := 1 + for i := range other.shape { + wSize *= other.shape[i] + } + for i := range t.data { + t.data[i] -= other.data[i%wSize] + } + return t +} + +func (t *Tensor) mul(other *Tensor) *Tensor { + wSize := 1 + for i := range other.shape { + wSize *= other.shape[i] + } + for i := range t.data { + t.data[i] *= other.data[i%wSize] + } + return t +} + +func (t *Tensor) div(other *Tensor) *Tensor { + wSize := 1 + for i := range other.shape { + wSize *= other.shape[i] + } + for i := range t.data { + t.data[i] /= other.data[i%wSize] + } + return t +} + +func (t *Tensor) square() *Tensor { + for i := range t.data { + t.data[i] = t.data[i] * t.data[i] + } + return t +} + +func (t *Tensor) pow(other *Tensor) *Tensor { + wSize := 1 + for i := range other.shape { + wSize *= other.shape[i] + } + for i := range t.data { + t.data[i] = math32.Pow(t.data[i], other.data[i%wSize]) + } + return t +} + +func (t *Tensor) exp() *Tensor { + for i := range t.data { + t.data[i] = math32.Exp(t.data[i]) + } + return t +} + +func (t *Tensor) log() *Tensor { + for i := range t.data { + t.data[i] = math32.Log(t.data[i]) + } + return t +} + +func (t *Tensor) sin() *Tensor { + for i := range t.data { + t.data[i] = math32.Sin(t.data[i]) + } + return t +} + +func (t *Tensor) cos() *Tensor { + for i := range t.data { + t.data[i] = math32.Cos(t.data[i]) + } + return t +} + +func (t *Tensor) tanh() *Tensor { + for i := range t.data { + t.data[i] = math32.Tanh(t.data[i]) + } + return t +} + +func (t *Tensor) neg() *Tensor { + for i := range t.data { + t.data[i] = -t.data[i] + } + return t +} + +func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool) *Tensor { + if !transpose1 && !transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[1] != other.shape[0] { + panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) + } + m, n, p := t.shape[0], t.shape[1], other.shape[1] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j, aij := range t.data[i*n : (i+1)*n] { + // C_j += A_{ij} * B_i + floats.MulConstAddTo(other.data[j*p:(j+1)*p], aij, result[i*p:(i+1)*p]) + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else if transpose1 && !transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[0] != other.shape[0] { + panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) + } + m, n, p := t.shape[1], t.shape[0], other.shape[1] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + // C_j += A_{ji} * B_i + floats.MulConstAddTo(other.data[j*p:(j+1)*p], t.data[j*m+i], result[i*p:(i+1)*p]) + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else if !transpose1 && transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[1] != other.shape[1] { + panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) + } + m, n, p := t.shape[0], t.shape[1], other.shape[0] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + result[i*p+j] = floats.Dot(t.data[i*n:(i+1)*n], other.data[j*n:(j+1)*n]) + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else { + // (n,m).T @ (p,n).T = (m,p) + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[0] != other.shape[1] { + panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape)) + } + m, n, p := t.shape[1], t.shape[0], other.shape[0] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + for k := 0; k < n; k++ { + result[i*p+j] += t.data[k*m+i] * other.data[j*n+k] + } + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } +} + +func (t *Tensor) batchMatMul(other *Tensor, transpose1, transpose2 bool) *Tensor { + if !transpose1 && !transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("BatchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[1] { + panic("BatchMatMul requires the shapes of tensors are compatible") + } + batches, m, n, p := t.shape[0], t.shape[1], t.shape[2], other.shape[2] + result := make([]float32, batches*m*p) + for b := 0; b < batches; b++ { + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + // C_{bj} += A_{bij} * B_{bi} + floats.MulConstAddTo(other.data[b*n*p+j*p:b*n*p+(j+1)*p], t.data[b*m*n+i*n+j], result[b*m*p+i*p:b*m*p+(i+1)*p]) + } + } + } + return &Tensor{ + data: result, + shape: []int{batches, m, p}, + } + } else if transpose1 && !transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[1] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + batches, m, n, p := t.shape[0], t.shape[2], t.shape[1], other.shape[2] + result := make([]float32, batches*m*p) + for b := 0; b < batches; b++ { + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + floats.MulConstAddTo(other.data[b*n*p+j*p:b*n*p+(j+1)*p], t.data[b*n*m+j*m+i], result[b*m*p+i*p:b*m*p+(i+1)*p]) + } + } + } + return &Tensor{ + data: result, + shape: []int{batches, m, p}, + } + } else if !transpose1 && transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + batches, m, n, p := t.shape[0], t.shape[1], t.shape[2], other.shape[1] + result := make([]float32, batches*m*p) + for b := 0; b < batches; b++ { + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + result[b*m*p+i*p+j] = floats.Dot(t.data[b*m*n+i*n:b*m*n+(i+1)*n], + other.data[b*p*n+j*n:b*p*n+(j+1)*n]) + } + } + } + return &Tensor{ + data: result, + shape: []int{batches, m, p}, + } + } else { + // (b,n,m).T @ (b,p,n).T = (b,m,p) + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[2] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + batches, m, n, p := t.shape[0], t.shape[2], t.shape[1], other.shape[1] + result := make([]float32, m*n*p) + for b := 0; b < batches; b++ { + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + for k := 0; k < p; k++ { + result[i*n*p+j*p+k] += t.data[b*m*n+j*m+i] * other.data[b*p*n+k*n+j] + } + } + } + } + return &Tensor{ + data: result, + shape: []int{batches, m, p}, + } + } +} + +func (t *Tensor) maximum(other *Tensor) { + if other.IsScalar() { + for i := range t.data { + t.data[i] = max(t.data[i], other.data[0]) + } + } else { + for i := range t.data { + t.data[i] = max(t.data[i], other.data[i]) + } + } +} + +func (t *Tensor) transpose() *Tensor { + if len(t.shape) < 2 { + panic("transpose requires at least 2-D tensor") + } + shape := make([]int, 0, len(t.shape)) + batchSize := 0 + for i := 0; i < len(t.shape)-2; i++ { + batchSize += t.shape[i] + shape = append(shape, t.shape[i]) + } + m, n := t.shape[len(t.shape)-2], t.shape[len(t.shape)-1] + shape = append(shape, n, m) + data := make([]float32, batchSize*m*n) + for b := 0; b < batchSize; b++ { + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + data[b*m*n+j*m+i] = t.data[b*m*n+i*n+j] + } + } + } + return &Tensor{ + data: data, + shape: shape, + } +} diff --git a/common/nn/tensor_test.go b/common/nn/tensor_test.go new file mode 100644 index 000000000..acb02a6ac --- /dev/null +++ b/common/nn/tensor_test.go @@ -0,0 +1,266 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nn + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTensor_Slice(t *testing.T) { + x := RandN(3, 4, 5) + y := x.Slice(1, 3) + assert.Equal(t, []int{2, 4, 5}, y.Shape()) + for i := 0; i < 2; i++ { + for j := 0; j < 4; j++ { + for k := 0; k < 5; k++ { + assert.Equal(t, x.Get(i+1, j, k), y.Get(i, j, k)) + } + } + } +} + +func (t *Tensor) matMulLegacy(other *Tensor, transpose1, transpose2 bool) *Tensor { + if !transpose1 && !transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[1] != other.shape[0] { + panic("matMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[0], t.shape[1], other.shape[1] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + for k := 0; k < n; k++ { + result[i*p+j] += t.data[i*n+k] * other.data[k*p+j] + } + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else if transpose1 && !transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[0] != other.shape[0] { + panic("matMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[1], t.shape[0], other.shape[1] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + for k := 0; k < n; k++ { + result[i*p+j] += t.data[k*m+i] * other.data[k*p+j] + } + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else if !transpose1 && transpose2 { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[1] != other.shape[1] { + panic("matMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[0], t.shape[1], other.shape[0] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + for k := 0; k < n; k++ { + result[i*p+j] += t.data[i*n+k] * other.data[j*n+k] + } + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } else { + if len(t.shape) != 2 || len(other.shape) != 2 { + panic("matMul requires 2-D tensors") + } + if t.shape[0] != other.shape[0] { + panic("matMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[1], t.shape[0], other.shape[1] + result := make([]float32, m*p) + for i := 0; i < m; i++ { + for j := 0; j < p; j++ { + for k := 0; k < n; k++ { + result[i*p+j] += t.data[k*m+i] * other.data[j*n+k] + } + } + } + return &Tensor{ + data: result, + shape: []int{m, p}, + } + } +} + +func (t *Tensor) batchMatMulLegacy(other *Tensor, transpose1, transpose2 bool) *Tensor { + if !transpose1 && !transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("BatchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[1] { + panic("BatchMatMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[0], t.shape[1], other.shape[2] + result := make([]float32, m*n*p) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + for k := 0; k < p; k++ { + for l := 0; l < t.shape[2]; l++ { + result[i*n*p+j*p+k] += t.data[i*n*t.shape[2]+j*t.shape[2]+l] * other.data[i*other.shape[1]*other.shape[2]+l*other.shape[2]+k] + } + } + } + } + return &Tensor{ + data: result, + shape: []int{m, n, p}, + } + } else if transpose1 && !transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[1] != other.shape[1] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[0], t.shape[2], other.shape[2] + result := make([]float32, m*n*p) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + for k := 0; k < p; k++ { + for l := 0; l < t.shape[1]; l++ { + result[i*n*p+j*p+k] += t.data[i*t.shape[1]*t.shape[2]+l*t.shape[2]+j] * other.data[i*other.shape[1]*other.shape[2]+l*other.shape[2]+k] + } + } + } + } + return &Tensor{ + data: result, + shape: []int{m, n, p}, + } + } else if !transpose1 && transpose2 { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[0], t.shape[1], other.shape[1] + result := make([]float32, m*n*p) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + for k := 0; k < p; k++ { + for l := 0; l < t.shape[2]; l++ { + result[i*n*p+j*p+k] += t.data[i*n*t.shape[2]+j*t.shape[2]+l] * other.data[i*other.shape[1]*other.shape[2]+k*other.shape[2]+l] + } + } + } + } + return &Tensor{ + data: result, + shape: []int{m, n, p}, + } + } else { + if len(t.shape) != 3 || len(other.shape) != 3 { + panic("batchMatMul requires 3-D tensors") + } + if t.shape[0] != other.shape[0] || t.shape[2] != other.shape[2] { + panic("batchMatMul requires the shapes of tensors are compatible") + } + m, n, p := t.shape[1], t.shape[2], other.shape[2] + result := make([]float32, m*n*p) + for i := 0; i < m; i++ { + for j := 0; j < n; j++ { + for k := 0; k < p; k++ { + for l := 0; l < t.shape[0]; l++ { + result[i*n*p+j*p+k] += t.data[l*t.shape[1]*t.shape[2]+i*t.shape[2]+j] * other.data[l*other.shape[1]*other.shape[2]+j*other.shape[2]+k] + } + } + } + } + return &Tensor{ + data: result, + shape: []int{m, n, p}, + } + } +} + +func BenchmarkMatMulLegacy64(b *testing.B) { + x := RandN(64, 64) + y := RandN(64, 64) + for t1 := 0; t1 < 2; t1++ { + for t2 := 0; t2 < 2; t2++ { + b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.matMulLegacy(y, t1 == 1, t2 == 1) + } + }) + } + } +} + +func BenchmarkMatMul64(b *testing.B) { + x := RandN(64, 64) + y := RandN(64, 64) + for t1 := 0; t1 < 2; t1++ { + for t2 := 0; t2 < 2; t2++ { + b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.matMul(y, t1 == 1, t2 == 1) + } + }) + } + } +} + +func BenchmarkBatchMatMulLegacy64(b *testing.B) { + x := RandN(64, 64, 64) + y := RandN(64, 64, 64) + for t1 := 0; t1 < 2; t1++ { + for t2 := 0; t2 < 2; t2++ { + b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.batchMatMulLegacy(y, t1 == 1, t2 == 1) + } + }) + } + } +} + +func BenchmarkBatchMatMul64(b *testing.B) { + x := RandN(64, 64, 64) + y := RandN(64, 64, 64) + for t1 := 0; t1 < 2; t1++ { + for t2 := 0; t2 < 2; t2++ { + b.Run(fmt.Sprintf("(%d,%d)", t1, t2), func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.batchMatMul(y, t1 == 1, t2 == 1) + } + }) + } + } +} diff --git a/common/util/strconv.go b/common/util/strconv.go new file mode 100644 index 000000000..7d60af99f --- /dev/null +++ b/common/util/strconv.go @@ -0,0 +1,8 @@ +package util + +import "strconv" + +func ParseFloat32(s string) (float32, error) { + v, err := strconv.ParseFloat(s, 32) + return float32(v), err +} diff --git a/model/click/deepfm_v2.go b/model/click/deepfm_v2.go new file mode 100644 index 000000000..d2f039c34 --- /dev/null +++ b/model/click/deepfm_v2.go @@ -0,0 +1,389 @@ +// Copyright 2023 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package click + +import ( + "bytes" + "context" + "fmt" + "github.com/chewxy/math32" + "github.com/juju/errors" + "github.com/samber/lo" + "github.com/zhenghaoz/gorse/base" + "github.com/zhenghaoz/gorse/base/encoding" + "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/common/nn" + "github.com/zhenghaoz/gorse/model" + "go.uber.org/zap" + "io" + "modernc.org/mathutil" + "runtime" + "sync" + "time" +) + +type DeepFMV2 struct { + BaseFactorizationMachine + + // runtime + numCPU int + mu sync.RWMutex + + // dataset stats + minTarget float32 + maxTarget float32 + numFeatures int + numDimension int + + // tuned parameters + v [][]float32 + w []float32 + w0 [][]float32 + bData []float32 + b0Data []float32 + w1Data [][]float32 + b1Data [][]float32 + marshables []any + + // params and layers + bias *nn.Tensor + embeddingW nn.Layer + embeddingV nn.Layer + linear []nn.Layer + + // Adam optimizer variables + m_v [][]float32 + m_w []float32 + m_w0 [][]float32 + v_v [][]float32 + v_w []float32 + v_w0 [][]float32 + t int + + // preallocated arrays + dataV []float32 + dataW []float32 + dataW0 []float32 + + // Hyper parameters + batchSize int + nFactors int + nEpochs int + lr float32 + reg float32 + initMean float32 + initStdDev float32 + hiddenLayers []int +} + +func NewDeepFMV2(params model.Params) *DeepFMV2 { + fm := new(DeepFMV2) + fm.SetParams(params) + fm.numCPU = runtime.NumCPU() + fm.marshables = []any{&fm.v, &fm.w, &fm.w0, &fm.bData, &fm.b0Data, &fm.w1Data, &fm.b1Data} + return fm +} + +func (fm *DeepFMV2) Clear() { + fm.Index = nil +} + +func (fm *DeepFMV2) Invalid() bool { + return fm == nil || + fm.Index == nil +} + +func (fm *DeepFMV2) SetParams(params model.Params) { + fm.BaseFactorizationMachine.SetParams(params) + fm.batchSize = fm.Params.GetInt(model.BatchSize, 1024) + fm.nFactors = fm.Params.GetInt(model.NFactors, 16) + fm.nEpochs = fm.Params.GetInt(model.NEpochs, 50) + fm.lr = fm.Params.GetFloat32(model.Lr, 0.001) + fm.reg = fm.Params.GetFloat32(model.Reg, 0.0) + fm.initMean = fm.Params.GetFloat32(model.InitMean, 0) + fm.initStdDev = fm.Params.GetFloat32(model.InitStdDev, 0.01) + fm.hiddenLayers = fm.Params.GetIntSlice(model.HiddenLayers, []int{200, 200}) +} + +func (fm *DeepFMV2) GetParamsGrid(withSize bool) model.ParamsGrid { + return model.ParamsGrid{ + model.NFactors: lo.If(withSize, []interface{}{8, 16, 32, 64}).Else([]interface{}{16}), + model.Lr: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1}, + model.Reg: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1}, + model.InitMean: []interface{}{0}, + model.InitStdDev: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1}, + } +} + +func (fm *DeepFMV2) Predict(userId, itemId string, userFeatures, itemFeatures []Feature) float32 { + panic("Predict is unsupported for deep learning models") +} + +func (fm *DeepFMV2) InternalPredict(indices []int32, values []float32) float32 { + panic("InternalPredict is unsupported for deep learning models") +} + +func (fm *DeepFMV2) BatchInternalPredict(x []lo.Tuple2[[]int32, []float32]) []float32 { + fm.mu.RLock() + defer fm.mu.RUnlock() + indicesTensor, valuesTensor, _ := fm.convertToTensors(x, nil) + predictions := make([]float32, 0, len(x)) + for i := 0; i < len(x); i += fm.batchSize { + output := fm.Forward( + indicesTensor.Slice(i, i+fm.batchSize), + valuesTensor.Slice(i, i+fm.batchSize)) + predictions = append(predictions, output.Data()...) + } + return predictions[:len(x)] +} + +func (fm *DeepFMV2) BatchPredict(inputs []lo.Tuple4[string, string, []Feature, []Feature]) []float32 { + x := make([]lo.Tuple2[[]int32, []float32], len(inputs)) + for i, input := range inputs { + // encode user + if userIndex := fm.Index.EncodeUser(input.A); userIndex != base.NotId { + x[i].A = append(x[i].A, userIndex) + x[i].B = append(x[i].B, 1) + } + // encode item + if itemIndex := fm.Index.EncodeItem(input.B); itemIndex != base.NotId { + x[i].A = append(x[i].A, itemIndex) + x[i].B = append(x[i].B, 1) + } + // encode user labels + for _, userFeature := range input.C { + if userFeatureIndex := fm.Index.EncodeUserLabel(userFeature.Name); userFeatureIndex != base.NotId { + x[i].A = append(x[i].A, userFeatureIndex) + x[i].B = append(x[i].B, userFeature.Value) + } + } + // encode item labels + for _, itemFeature := range input.D { + if itemFeatureIndex := fm.Index.EncodeItemLabel(itemFeature.Name); itemFeatureIndex != base.NotId { + x[i].A = append(x[i].A, itemFeatureIndex) + x[i].B = append(x[i].B, itemFeature.Value) + } + } + } + return fm.BatchInternalPredict(x) +} + +func (fm *DeepFMV2) Fit(ctx context.Context, trainSet *Dataset, testSet *Dataset, config *FitConfig) Score { + fm.Init(trainSet) + evalStart := time.Now() + score := EvaluateClassification(fm, testSet) + evalTime := time.Since(evalStart) + fields := append([]zap.Field{zap.String("eval_time", evalTime.String())}, score.ZapFields()...) + log.Logger().Info(fmt.Sprintf("fit DeepFM %v/%v", 0, fm.nEpochs), fields...) + + var x []lo.Tuple2[[]int32, []float32] + var y []float32 + for i := 0; i < trainSet.Target.Len(); i++ { + fm.minTarget = math32.Min(fm.minTarget, trainSet.Target.Get(i)) + fm.maxTarget = math32.Max(fm.maxTarget, trainSet.Target.Get(i)) + indices, values, target := trainSet.Get(i) + x = append(x, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) + y = append(y, target) + } + indices, values, target := fm.convertToTensors(x, y) + + optimizer := nn.NewAdam(fm.Parameters(), fm.lr) + for epoch := 1; epoch <= fm.nEpochs; epoch++ { + fitStart := time.Now() + cost := float32(0) + for i := 0; i < trainSet.Count(); i += fm.batchSize { + batchIndices := indices.Slice(i, i+fm.batchSize) + batchValues := values.Slice(i, i+fm.batchSize) + batchTarget := target.Slice(i, i+fm.batchSize) + batchOutput := fm.Forward(batchIndices, batchValues) + batchLoss := nn.BCEWithLogits(batchTarget, batchOutput) + cost += batchLoss.Data()[0] + optimizer.ZeroGrad() + batchLoss.Backward() + optimizer.Step() + } + + fitTime := time.Since(fitStart) + // Cross validation + if epoch%config.Verbose == 0 || epoch == fm.nEpochs { + evalStart = time.Now() + score = EvaluateClassification(fm, testSet) + evalTime = time.Since(evalStart) + fields = append([]zap.Field{ + zap.String("fit_time", fitTime.String()), + zap.String("eval_time", evalTime.String()), + zap.Float32("loss", cost), + }, score.ZapFields()...) + log.Logger().Info(fmt.Sprintf("fit DeepFM %v/%v", epoch, fm.nEpochs), fields...) + // check NaN + if math32.IsNaN(cost) || math32.IsNaN(score.GetValue()) { + log.Logger().Warn("model diverged", zap.Float32("lr", fm.lr)) + break + } + } + } + return score +} + +// Init parameters for DeepFM. +func (fm *DeepFMV2) Init(trainSet *Dataset) { + fm.numFeatures = trainSet.ItemCount() + trainSet.UserCount() + len(trainSet.UserFeatures) + len(trainSet.ItemFeatures) + len(trainSet.ContextFeatures) + fm.numDimension = 0 + for i := 0; i < trainSet.Count(); i++ { + _, x, _ := trainSet.Get(i) + fm.numDimension = mathutil.MaxVal(fm.numDimension, len(x)) + } + fm.bias = nn.RandN() + fm.embeddingW = nn.NewEmbedding(fm.numFeatures, 1) + fm.embeddingV = nn.NewEmbedding(fm.numFeatures, fm.nFactors) + fm.linear = []nn.Layer{nn.NewLinear(fm.numDimension*fm.nFactors, fm.hiddenLayers[0])} + for i := 0; i < len(fm.hiddenLayers); i++ { + if i < len(fm.hiddenLayers)-1 { + fm.linear = append(fm.linear, nn.NewLinear(fm.hiddenLayers[i], fm.hiddenLayers[i+1])) + } else { + fm.linear = append(fm.linear, nn.NewLinear(fm.hiddenLayers[i], 1)) + } + } + fm.BaseFactorizationMachine.Init(trainSet) +} + +func (fm *DeepFMV2) Marshal(w io.Writer) error { + // write params + if err := encoding.WriteGob(w, fm.Params); err != nil { + return errors.Trace(err) + } + // write index + if err := MarshalIndex(w, fm.Index); err != nil { + return errors.Trace(err) + } + // write dataset stats + if err := encoding.WriteGob(w, fm.minTarget); err != nil { + return errors.Trace(err) + } + if err := encoding.WriteGob(w, fm.maxTarget); err != nil { + return errors.Trace(err) + } + if err := encoding.WriteGob(w, fm.numFeatures); err != nil { + return errors.Trace(err) + } + if err := encoding.WriteGob(w, fm.numDimension); err != nil { + return errors.Trace(err) + } + // write weights + for _, data := range fm.marshables { + if err := encoding.WriteGob(w, data); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (fm *DeepFMV2) Unmarshal(r io.Reader) error { + return nil +} + +func (fm *DeepFMV2) Forward(indices, values *nn.Tensor) *nn.Tensor { + // embedding + e := fm.embeddingV.Forward(indices) + + // factorization machine + x := nn.Reshape(values, fm.batchSize, fm.numDimension, 1) + vx := nn.BMM(e, x, true) + sumSquare := nn.Square(vx) + e2 := nn.Square(e) + x2 := nn.Square(x) + squareSum := nn.BMM(e2, x2, true) + sum := nn.Sub(sumSquare, squareSum) + sum = nn.Sum(sum, 1) + sum = nn.Mul(sum, nn.NewScalar(0.5)) + w := fm.embeddingW.Forward(indices) + linear := nn.BMM(w, x, true) + fmOutput := nn.Add(linear, fm.bias) + fmOutput = nn.Flatten(fmOutput) + + // deep network + a := nn.Reshape(e, fm.batchSize, fm.numDimension*fm.nFactors) + for i := 0; i < len(fm.linear); i++ { + a = fm.linear[i].Forward(a) + if i < len(fm.linear)-1 { + a = nn.ReLu(a) + } else { + a = nn.Sigmoid(a) + } + } + dnnOutput := nn.Flatten(a) + + // output + return nn.Add(fmOutput, dnnOutput) +} + +func (fm *DeepFMV2) Parameters() []*nn.Tensor { + var params []*nn.Tensor + params = append(params, fm.bias) + params = append(params, fm.embeddingV.Parameters()...) + params = append(params, fm.embeddingW.Parameters()...) + for _, layer := range fm.linear { + params = append(params, layer.Parameters()...) + } + return params +} + +func (fm *DeepFMV2) convertToTensors(x []lo.Tuple2[[]int32, []float32], y []float32) (indicesTensor, valuesTensor, targetTensor *nn.Tensor) { + if y != nil && len(x) != len(y) { + panic("length of x and y must be equal") + } + + numBatch := (len(x) + fm.batchSize - 1) / fm.batchSize + alignedSize := numBatch * fm.batchSize + alignedIndices := make([]float32, alignedSize*fm.numDimension) + alignedValues := make([]float32, alignedSize*fm.numDimension) + alignedTarget := make([]float32, alignedSize) + for i := range x { + if len(x[i].A) != len(x[i].B) { + panic("length of indices and values must be equal") + } + for j := range x[i].A { + alignedIndices[i*fm.numDimension+j] = float32(x[i].A[j]) + alignedValues[i*fm.numDimension+j] = x[i].B[j] + } + if y != nil { + alignedTarget[i] = y[i] + } + } + + indicesTensor = nn.NewTensor(alignedIndices, alignedSize, fm.numDimension) + valuesTensor = nn.NewTensor(alignedValues, alignedSize, fm.numDimension) + if y != nil { + targetTensor = nn.NewTensor(alignedTarget, alignedSize) + } + return +} + +func (fm *DeepFMV2) Clone() FactorizationMachine { + buf := bytes.NewBuffer(nil) + if err := MarshalModel(buf, fm); err != nil { + panic(err) + } + if copied, err := UnmarshalModel(buf); err != nil { + panic(err) + } else { + copied.SetParams(copied.GetParams()) + return copied + } +} + +func (fm *DeepFMV2) Spawn() FactorizationMachine { + return fm.Clone() +} diff --git a/model/click/deepfm_v2_test.go b/model/click/deepfm_v2_test.go new file mode 100644 index 000000000..bafcc093c --- /dev/null +++ b/model/click/deepfm_v2_test.go @@ -0,0 +1,88 @@ +// Copyright 2023 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package click + +import ( + "bytes" + "context" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/model" +) + +func TestDeepFMV2_Classification_Frappe(t *testing.T) { + t.Skip() + train, test, err := LoadDataFromBuiltIn("frappe") + assert.NoError(t, err) + m := NewDeepFMV2(model.Params{ + model.InitStdDev: 0.01, + model.NFactors: 8, + model.NEpochs: 10, + model.Lr: 0.01, + model.Reg: 0.0001, + model.BatchSize: 1024, + }) + fitConfig := newFitConfigWithTestTracker(20) + score := m.Fit(context.Background(), train, test, fitConfig) + //assert.InDelta(t, 0.9439709, score.Accuracy, classificationDelta) + _ = score +} + +func TestDeepFMV2_Classification_Criteo(t *testing.T) { + t.Skip() + train, test, err := LoadDataFromBuiltIn("criteo") + assert.NoError(t, err) + m := NewDeepFM(model.Params{ + model.InitStdDev: 0.01, + model.NFactors: 8, + model.NEpochs: 10, + model.Lr: 0.01, + model.Reg: 0.0001, + model.BatchSize: 1024, + }) + fitConfig := newFitConfigWithTestTracker(10) + score := m.Fit(context.Background(), train, test, fitConfig) + assert.InDelta(t, 0.77, score.Accuracy, classificationDelta) + + // test prediction + assert.Equal(t, m.BatchInternalPredict([]lo.Tuple2[[]int32, []float32]{{A: []int32{1, 2, 3, 4, 5, 6}, B: []float32{1, 1, 0.3, 0.4, 0.5, 0.6}}}), + m.BatchPredict([]lo.Tuple4[string, string, []Feature, []Feature]{{ + A: "1", + B: "2", + C: []Feature{ + {Name: "3", Value: 0.3}, + {Name: "4", Value: 0.4}, + }, + D: []Feature{ + {Name: "5", Value: 0.5}, + {Name: "6", Value: 0.6}, + }}})) + + // test marshal and unmarshal + buf := bytes.NewBuffer(nil) + err = MarshalModel(buf, m) + assert.NoError(t, err) + tmp, err := UnmarshalModel(buf) + assert.NoError(t, err) + scoreClone := EvaluateClassification(tmp, test) + assert.InDelta(t, 0.77, scoreClone.Accuracy, regressionDelta) + + // test clear + assert.False(t, m.Invalid()) + m.Clear() + assert.True(t, m.Invalid()) +}