Skip to content

Commit

Permalink
nn: support save and load weights (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 5, 2025
1 parent b5d6890 commit 515c3c9
Show file tree
Hide file tree
Showing 14 changed files with 588 additions and 288 deletions.
6 changes: 3 additions & 3 deletions cmd/gorse-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/spf13/cobra"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/cmd/version"
"github.com/zhenghaoz/gorse/protocol"
"github.com/zhenghaoz/gorse/common/util"
"github.com/zhenghaoz/gorse/server"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -50,9 +50,9 @@ var serverCommand = &cobra.Command{
caFile, _ := cmd.PersistentFlags().GetString("ssl-ca")
certFile, _ := cmd.PersistentFlags().GetString("ssl-cert")
keyFile, _ := cmd.PersistentFlags().GetString("ssl-key")
var tlsConfig *protocol.TLSConfig
var tlsConfig *util.TLSConfig
if caFile != "" && certFile != "" && keyFile != "" {
tlsConfig = &protocol.TLSConfig{
tlsConfig = &util.TLSConfig{
SSLCA: caFile,
SSLCert: certFile,
SSLKey: keyFile,
Expand Down
6 changes: 3 additions & 3 deletions cmd/gorse-worker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"github.com/spf13/cobra"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/cmd/version"
"github.com/zhenghaoz/gorse/protocol"
"github.com/zhenghaoz/gorse/common/util"
"github.com/zhenghaoz/gorse/worker"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -49,9 +49,9 @@ var workerCommand = &cobra.Command{
caFile, _ := cmd.PersistentFlags().GetString("ssl-ca")
certFile, _ := cmd.PersistentFlags().GetString("ssl-cert")
keyFile, _ := cmd.PersistentFlags().GetString("ssl-key")
var tlsConfig *protocol.TLSConfig
var tlsConfig *util.TLSConfig
if caFile != "" && certFile != "" && keyFile != "" {
tlsConfig = &protocol.TLSConfig{
tlsConfig = &util.TLSConfig{
SSLCA: caFile,
SSLCert: certFile,
SSLKey: keyFile,
Expand Down
7 changes: 4 additions & 3 deletions protocol/decoder.go → common/encoding/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package protocol
package encoding

import (
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/model/click"
"github.com/zhenghaoz/gorse/model/ranking"
"github.com/zhenghaoz/gorse/protocol"
"go.uber.org/zap"
"io"
)

// UnmarshalClickModel unmarshal click model from gRPC.
func UnmarshalClickModel(receiver Master_GetClickModelClient) (click.FactorizationMachine, error) {
func UnmarshalClickModel(receiver protocol.Master_GetClickModelClient) (click.FactorizationMachine, error) {
// receive model
reader, writer := io.Pipe()
var finalError error
Expand Down Expand Up @@ -66,7 +67,7 @@ func UnmarshalClickModel(receiver Master_GetClickModelClient) (click.Factorizati
}

// UnmarshalRankingModel unmarshal ranking model from gRPC.
func UnmarshalRankingModel(receiver Master_GetRankingModelClient) (ranking.MatrixFactorization, error) {
func UnmarshalRankingModel(receiver protocol.Master_GetRankingModelClient) (ranking.MatrixFactorization, error) {
// receive model
reader, writer := io.Pipe()
var receiverError error
Expand Down
164 changes: 141 additions & 23 deletions common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

package nn

import "github.com/chewxy/math32"
import (
"github.com/chewxy/math32"
"github.com/juju/errors"
"github.com/matttproud/golang_protobuf_extensions/pbutil"
"github.com/zhenghaoz/gorse/protocol"
"io"
"os"
"reflect"
"strconv"
)

type Layer interface {
Parameters() []*Tensor
Expand All @@ -23,24 +32,24 @@ type Layer interface {

type Model Layer

type linearLayer struct {
w *Tensor
b *Tensor
type LinearLayer struct {
W *Tensor
B *Tensor
}

func NewLinear(in, out int) Layer {
return &linearLayer{
w: Normal(0, 1.0/math32.Sqrt(float32(in)), in, out).RequireGrad(),
b: Zeros(out).RequireGrad(),
return &LinearLayer{
W: Normal(0, 1.0/math32.Sqrt(float32(in)), in, out).RequireGrad(),
B: Zeros(out).RequireGrad(),
}
}

func (l *linearLayer) Forward(x *Tensor) *Tensor {
return Add(MatMul(x, l.w), l.b)
func (l *LinearLayer) Forward(x *Tensor) *Tensor {
return Add(MatMul(x, l.W), l.B)
}

func (l *linearLayer) Parameters() []*Tensor {
return []*Tensor{l.w, l.b}
func (l *LinearLayer) Parameters() []*Tensor {
return []*Tensor{l.W, l.B}
}

type flattenLayer struct{}
Expand All @@ -57,23 +66,23 @@ func (f *flattenLayer) Forward(x *Tensor) *Tensor {
return Flatten(x)
}

type embeddingLayer struct {
w *Tensor
type EmbeddingLayer struct {
W *Tensor
}

func NewEmbedding(n int, shape ...int) Layer {
wShape := append([]int{n}, shape...)
return &embeddingLayer{
w: Rand(wShape...),
return &EmbeddingLayer{
W: Rand(wShape...),
}
}

func (e *embeddingLayer) Parameters() []*Tensor {
return []*Tensor{e.w}
func (e *EmbeddingLayer) Parameters() []*Tensor {
return []*Tensor{e.W}
}

func (e *embeddingLayer) Forward(x *Tensor) *Tensor {
return Embedding(e.w, x)
func (e *EmbeddingLayer) Forward(x *Tensor) *Tensor {
return Embedding(e.W, x)
}

type sigmoidLayer struct{}
Expand Down Expand Up @@ -105,24 +114,133 @@ func (r *reluLayer) Forward(x *Tensor) *Tensor {
}

type Sequential struct {
layers []Layer
Layers []Layer
}

func NewSequential(layers ...Layer) Model {
return &Sequential{layers: layers}
return &Sequential{Layers: layers}
}

func (s *Sequential) Parameters() []*Tensor {
var params []*Tensor
for _, l := range s.layers {
for _, l := range s.Layers {
params = append(params, l.Parameters()...)
}
return params
}

func (s *Sequential) Forward(x *Tensor) *Tensor {
for _, l := range s.layers {
for _, l := range s.Layers {
x = l.Forward(x)
}
return x
}

func Save[T Model](o T, path string) error {
// Open file
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()

// Save function
var save func(o any, key []string) error
save = func(o any, key []string) error {
switch typed := o.(type) {
case *Tensor:
pb := typed.toPB()
pb.Key = key
_, err = pbutil.WriteDelimited(file, pb)
if err != nil {
return err
}
default:
tp := reflect.TypeOf(o)
if tp.Kind() == reflect.Ptr {
return save(reflect.ValueOf(o).Elem().Interface(), key)
} else if tp.Kind() == reflect.Struct {
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
newKey := make([]string, len(key))
copy(newKey, key)
newKey = append(newKey, field.Name)
if err = save(reflect.ValueOf(o).Field(i).Interface(), append(key, field.Name)); err != nil {
return err
}
}
} else if tp.Kind() == reflect.Slice {
for i := 0; i < reflect.ValueOf(o).Len(); i++ {
newKey := make([]string, len(key))
copy(newKey, key)
newKey = append(newKey, strconv.Itoa(i))
if err = save(reflect.ValueOf(o).Index(i).Interface(), newKey); err != nil {
return err
}
}
} else {
return errors.New("unexpected type")
}
}
return nil
}
return save(o, nil)
}

func Load[T Model](o T, path string) error {
// Open file
file, err := os.Open(path)
if err != nil {
return err
}

// Place function
var place func(o any, key []string, pb *protocol.Tensor) error
place = func(o any, key []string, pb *protocol.Tensor) error {
switch typed := o.(type) {
case *Tensor:
typed.fromPB(pb)
default:
tp := reflect.TypeOf(o)
if tp.Kind() == reflect.Ptr {
return place(reflect.ValueOf(o).Elem().Interface(), key, pb)
} else if tp.Kind() == reflect.Struct {
field := reflect.ValueOf(o).FieldByName(key[0])
if field.IsValid() {
if err := place(field.Interface(), key[1:], pb); err != nil {
return err
}
}
} else if tp.Kind() == reflect.Slice {
index, err := strconv.Atoi(key[0])
if err != nil {
return err
}
elem := reflect.ValueOf(o).Index(index)
if elem.IsValid() {
if err := place(elem.Interface(), key[1:], pb); err != nil {
return err
}
}
} else {
return errors.New("unexpected type")
}
}
return nil
}

// Read data
for {
pb := new(protocol.Tensor)
if _, err = pbutil.ReadDelimited(file, pb); err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if err = place(o, pb.Key, pb); err != nil {
return err
}
}
return nil
}
67 changes: 65 additions & 2 deletions common/nn/nn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bufio"
"encoding/csv"
"fmt"
"math/rand"
"os"
"path/filepath"
"strconv"
Expand Down Expand Up @@ -68,8 +69,8 @@ func TestNeuralNetwork(t *testing.T) {
NewSigmoid(),
NewLinear(10, 1),
)
NormalInit(model.(*Sequential).layers[0].(*linearLayer).w, 0, 0.01)
NormalInit(model.(*Sequential).layers[2].(*linearLayer).w, 0, 0.01)
NormalInit(model.(*Sequential).Layers[0].(*LinearLayer).W, 0, 0.01)
NormalInit(model.(*Sequential).Layers[2].(*LinearLayer).W, 0, 0.01)
optimizer := NewSGD(model.Parameters(), 0.2)

var l float32
Expand Down Expand Up @@ -254,3 +255,65 @@ func TestMNIST(t *testing.T) {
precision /= float32(len(test.B.data))
assert.Greater(t, float64(precision), 0.92)
}

func spiral() (*Tensor, *Tensor, error) {
numData, numClass, inputDim := 100, 3, 2
dataSize := numClass * numData
x := Zeros(dataSize, inputDim)
t := Zeros(dataSize)

for j := 0; j < numClass; j++ {
for i := 0; i < numData; i++ {
rate := float32(i) / float32(numData)
radius := 1.0 * rate
theta := float32(j)*4.0 + 4.0*rate + float32(rand.NormFloat64())*0.2
ix := numData*j + i
x.data[ix*inputDim] = radius * math32.Sin(theta)
x.data[ix*inputDim+1] = radius * math32.Cos(theta)
t.data[ix] = float32(j)
}
}

indices := rand.Perm(dataSize)
x = x.SliceIndices(indices...)
t = t.SliceIndices(indices...)
return x, t, nil
}

func TestSaveAndLoad(t *testing.T) {
x, y, err := spiral()
assert.NoError(t, err)

model := NewSequential(
NewLinear(2, 10),
NewSigmoid(),
NewLinear(10, 3),
)
optimizer := NewAdam(model.Parameters(), 0.01)

var expected float32
for i := 0; i < 300; i++ {
yPred := model.Forward(x)
loss := SoftmaxCrossEntropy(yPred, y)

optimizer.ZeroGrad()
loss.Backward()

optimizer.Step()
expected = loss.data[0]
}

modelPath := filepath.Join(os.TempDir(), "spiral.nn")
err = Save(model, modelPath)
assert.NoError(t, err)
modelLoaded := NewSequential(
NewLinear(2, 10),
NewSigmoid(),
NewLinear(10, 3),
)
err = Load(modelLoaded, modelPath)
assert.NoError(t, err)
yPred := modelLoaded.Forward(x)
loss := SoftmaxCrossEntropy(yPred, y)
assert.InDelta(t, float64(expected), float64(loss.data[0]), 0.01)
}
Loading

0 comments on commit 515c3c9

Please sign in to comment.