Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: implement DeepFM from scratch #871

Merged
merged 30 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3edb8e8
implement deep learning framework
zhenghaoz Oct 18, 2024
c887fe9
implement forward
zhenghaoz Oct 18, 2024
c92536b
implement backward
zhenghaoz Oct 19, 2024
5f1b38a
implement layers
zhenghaoz Oct 19, 2024
7a62d83
implement activate functions
zhenghaoz Oct 19, 2024
03dab3d
remove example
zhenghaoz Oct 19, 2024
c8c9d02
implement embedding
zhenghaoz Oct 19, 2024
691f9bc
implement DeepFM from scratch
zhenghaoz Oct 19, 2024
8d85fba
implement batch matmul
zhenghaoz Oct 20, 2024
2ccb75f
Fix derivative of ln(x)
zhenghaoz Oct 22, 2024
cb2371f
Fix derivative of sigmoid(x)
zhenghaoz Oct 22, 2024
89e2e7f
Fix derivative of reuse
zhenghaoz Oct 22, 2024
9406583
Stash
zhenghaoz Oct 25, 2024
39174d3
implement partial sum
zhenghaoz Oct 26, 2024
94882e7
implement zero_grad()
zhenghaoz Oct 26, 2024
5d6f107
Refactor
zhenghaoz Oct 26, 2024
917d1c6
implement adam
zhenghaoz Oct 26, 2024
7866394
implement adam
zhenghaoz Oct 27, 2024
85c43ff
implement BCEWithLogits
zhenghaoz Oct 27, 2024
2e68b71
implement Slice
zhenghaoz Oct 27, 2024
cf500e4
implement Slice
zhenghaoz Oct 27, 2024
6e793cf
implement MatMul with SIMD
zhenghaoz Oct 27, 2024
dbfab9f
save
zhenghaoz Oct 30, 2024
e0c3290
Fix DeepFM
zhenghaoz Nov 2, 2024
42cd63c
Fix DeepFM
zhenghaoz Nov 6, 2024
5b8cfe6
Merge branch 'refs/heads/master' into zhenghaoz/nn
zhenghaoz Nov 17, 2024
7f3fcdf
Merge branch 'master' into zhenghaoz/nn
zhenghaoz Dec 6, 2024
e7fe64a
add dataset
zhenghaoz Dec 7, 2024
9ff4807
Merge remote-tracking branch 'origin/zhenghaoz/nn' into zhenghaoz/nn
zhenghaoz Dec 8, 2024
7a59927
Fix build
zhenghaoz Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions common/dataset/dataset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Copyright 2024 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataset

import (
"archive/zip"
"encoding/csv"
"fmt"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/common/util"
"go.uber.org/zap"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
)

var (
tempDir string
datasetDir string
)

func init() {
usr, err := user.Current()
if err != nil {
log.Logger().Fatal("failed to get user directory", zap.Error(err))
}
datasetDir = filepath.Join(usr.HomeDir, ".gorse", "dataset")
tempDir = filepath.Join(usr.HomeDir, ".gorse", "temp")
}

func LoadIris() ([][]float32, []int, error) {
// Download dataset
path, err := downloadAndUnzip("iris")
if err != nil {
return nil, nil, err
}
dataFile := filepath.Join(path, "iris.data")
// Load data
f, err := os.Open(dataFile)
if err != nil {
return nil, nil, err
}
reader := csv.NewReader(f)
rows, err := reader.ReadAll()
if err != nil {
return nil, nil, err
}
// Parse data
data := make([][]float32, len(rows))
target := make([]int, len(rows))
types := make(map[string]int)
for i, row := range rows {
data[i] = make([]float32, 4)
for j, cell := range row[:4] {
data[i][j], err = util.ParseFloat32(cell)
if err != nil {
return nil, nil, err
}
}
if _, exist := types[row[4]]; !exist {
types[row[4]] = len(types)
}
target[i] = types[row[4]]
}
return data, target, nil
}

func downloadAndUnzip(name string) (string, error) {
url := fmt.Sprintf("https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/%s.zip", name)
path := filepath.Join(datasetDir, name)
if _, err := os.Stat(path); os.IsNotExist(err) {
zipFileName, _ := downloadFromUrl(url, tempDir)
if _, err := unzip(zipFileName, path); err != nil {
return "", err
}
}
return path, nil
}

// downloadFromUrl downloads file from URL.
func downloadFromUrl(src, dst string) (string, error) {
log.Logger().Info("Download dataset", zap.String("source", src), zap.String("destination", dst))
// Extract file name
tokens := strings.Split(src, "/")
fileName := filepath.Join(dst, tokens[len(tokens)-1])
// Create file
if err := os.MkdirAll(filepath.Dir(fileName), os.ModePerm); err != nil {
return fileName, err
}
output, err := os.Create(fileName)
if err != nil {
log.Logger().Error("failed to create file", zap.Error(err), zap.String("filename", fileName))
return fileName, err
}
defer output.Close()
// Download file
response, err := http.Get(src)
if err != nil {
log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
return fileName, err
}
defer response.Body.Close()
// Save file
_, err = io.Copy(output, response.Body)
if err != nil {
log.Logger().Error("failed to download", zap.Error(err), zap.String("source", src))
return fileName, err
}
return fileName, nil
}

// unzip zip file.
func unzip(src, dst string) ([]string, error) {
var fileNames []string
// Open zip file
r, err := zip.OpenReader(src)
if err != nil {
return fileNames, err
}
defer r.Close()
// Extract files
for _, f := range r.File {
// Open file
rc, err := f.Open()
if err != nil {
return fileNames, err
}
// Store filename/path for returning and using later on
filePath := filepath.Join(dst, f.Name)
// Check for ZipSlip. More Info: http://bit.ly/2MsjAWE
if !strings.HasPrefix(filePath, filepath.Clean(dst)+string(os.PathSeparator)) {
return fileNames, fmt.Errorf("%s: illegal file path", filePath)
}
// Add filename
fileNames = append(fileNames, filePath)
if f.FileInfo().IsDir() {
// Create folder
if err = os.MkdirAll(filePath, os.ModePerm); err != nil {
return fileNames, err
}
} else {
// Create all folders
if err = os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return fileNames, err
}
// Create file
outFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return fileNames, err
}
// Save file
_, err = io.Copy(outFile, rc)
if err != nil {
return nil, err
}
// Close the file without defer to close before next iteration of loop
err = outFile.Close()
if err != nil {
return nil, err
}
}
// Close file
err = rc.Close()
if err != nil {
return nil, err
}
}
return fileNames, nil
}
14 changes: 14 additions & 0 deletions common/dataset/dataset_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dataset

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestLoadIris(t *testing.T) {
data, target, err := LoadIris()
assert.NoError(t, err)
assert.Len(t, data, 150)
assert.Len(t, data[0], 4)
assert.Len(t, target, 150)
}
Loading
Loading