From 5c8f681d409893beace30c03ca57fe7f1b2c5041 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Thu, 19 Dec 2024 23:49:44 +0800 Subject: [PATCH] dashboard: support non-personalized recommender (#901) --- config/config.go | 6 ++-- go.mod | 7 ++-- go.sum | 6 ++-- master/rest.go | 55 ++++++++---------------------- master/rest_test.go | 11 +++--- master/tasks.go | 14 +++++++- master/tasks_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 124 insertions(+), 55 deletions(-) diff --git a/config/config.go b/config/config.go index 9b23d52ff..2145c6d29 100644 --- a/config/config.go +++ b/config/config.go @@ -132,9 +132,9 @@ type DataSourceConfig struct { } type NonPersonalizedConfig struct { - Name string `mapstructure:"name"` - Score string `mapstructure:"score" validate:"required,item_expr"` - Filter string `mapstructure:"filter" validate:"item_expr"` + Name string `mapstructure:"name" json:"name"` + Score string `mapstructure:"score" json:"score" validate:"required,item_expr"` + Filter string `mapstructure:"filter" json:"filter" validate:"item_expr"` } type PopularConfig struct { diff --git a/go.mod b/go.mod index 0ddfc4e49..365988233 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/zhenghaoz/gorse -go 1.23.3 +go 1.23.4 require ( github.com/XSAM/otelsql v0.35.0 @@ -19,9 +19,10 @@ require ( github.com/go-playground/validator/v10 v10.22.1 github.com/go-resty/resty/v2 v2.7.0 github.com/go-sql-driver/mysql v1.6.0 + github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4 + github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 github.com/jellydator/ttlcache/v3 v3.3.0 @@ -33,7 +34,6 @@ require ( github.com/lib/pq v1.10.6 github.com/madflojo/testcerts v1.3.0 github.com/mailru/go-clickhouse/v2 v2.0.1-0.20221121001540-b259988ad8e5 - github.com/mitchellh/mapstructure v1.5.0 github.com/orcaman/concurrent-map v1.0.0 github.com/prometheus/client_golang v1.13.0 github.com/rakyll/statik v0.1.7 @@ -133,6 +133,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.7.1 // indirect diff --git a/go.sum b/go.sum index c4aa36b0f..8f535388a 100644 --- a/go.sum +++ b/go.sum @@ -211,6 +211,8 @@ github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSM github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= +github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -307,10 +309,10 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20241115145254-4def1c814899 h1:1BQ8+NLDKMYp7BcBhjJgEska+Gt8t2JTj6Rj0afYwG8= -github.com/gorse-io/dashboard v0.0.0-20241115145254-4def1c814899/go.mod h1:LBLzsMv3XVLmpaM/1q8/sGvv2Avj1YxmHBZfXcdqRjU= github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4 h1:FOUvD2HvTY/8j1/I4j/FlX3LEqKGLWPWQLl6jPtUqQ0= github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4/go.mod h1:LBLzsMv3XVLmpaM/1q8/sGvv2Avj1YxmHBZfXcdqRjU= +github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77 h1:WA5kRl4LNduJuM59vvMoAyBPU+7KZL2ROjE2fPUy6sE= +github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg= diff --git a/master/rest.go b/master/rest.go index 411aef29a..4e6d6fb6d 100644 --- a/master/rest.go +++ b/master/rest.go @@ -32,10 +32,10 @@ import ( mapset "github.com/deckarep/golang-set/v2" restfulspec "github.com/emicklei/go-restful-openapi/v2" "github.com/emicklei/go-restful/v3" + "github.com/go-viper/mapstructure/v2" "github.com/gorilla/securecookie" _ "github.com/gorse-io/dashboard" "github.com/juju/errors" - "github.com/mitchellh/mapstructure" "github.com/rakyll/statik/fs" "github.com/samber/lo" "github.com/zhenghaoz/gorse/base" @@ -137,38 +137,16 @@ func (m *Master) CreateWebService() { Param(ws.QueryParameter("cursor", "cursor for next page").DataType("string")). Returns(http.StatusOK, "OK", UserIterator{}). Writes(UserIterator{})) - // Get popular items - ws.Route(ws.GET("/dashboard/popular/").To(m.getPopular). - Doc("get popular items"). + // Get non-personalized recommendation + ws.Route(ws.GET("/non-personalized/{name}").To(m.getNonPersonalized). + Doc("Get non-personalized recommendations."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). - Param(ws.QueryParameter("n", "number of returned items").DataType("int")). - Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). - Returns(http.StatusOK, "OK", []ScoredItem{}). - Writes([]ScoredItem{})) - ws.Route(ws.GET("/dashboard/popular/{category}").To(m.getPopular). - Doc("get popular items"). - Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). - Param(ws.PathParameter("category", "category of items").DataType("string")). - Param(ws.QueryParameter("n", "number of returned items").DataType("int")). - Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). - Returns(http.StatusOK, "OK", []ScoredItem{}). - Writes([]ScoredItem{})) - // Get latest items - ws.Route(ws.GET("/dashboard/latest/").To(m.getLatest). - Doc("get latest items"). - Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). - Param(ws.QueryParameter("n", "number of returned items").DataType("int")). - Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). - Returns(http.StatusOK, "OK", []ScoredItem{}). - Writes([]ScoredItem{})) - ws.Route(ws.GET("/dashboard/latest/{category}").To(m.getLatest). - Doc("get latest items"). - Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). - Param(ws.PathParameter("category", "category of items").DataType("string")). - Param(ws.QueryParameter("n", "number of returned items").DataType("int")). - Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). - Returns(http.StatusOK, "OK", []ScoredItem{}). - Writes([]ScoredItem{})) + Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). + Param(ws.QueryParameter("n", "Number of returned users").DataType("integer")). + Param(ws.QueryParameter("offset", "Offset of returned users").DataType("integer")). + Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). + Returns(http.StatusOK, "OK", []cache.Score{}). + Writes([]cache.Score{})) ws.Route(ws.GET("/dashboard/recommend/{user-id}").To(m.getRecommend). Doc("Get recommendation for user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). @@ -920,15 +898,10 @@ func (m *Master) searchDocuments(collection, subset, category string, request *r } } -// getPopular gets popular items from database. -func (m *Master) getPopular(request *restful.Request, response *restful.Response) { - category := request.PathParameter("category") - m.searchDocuments(cache.NonPersonalized, cache.Popular, category, request, response, data.Item{}) -} - -func (m *Master) getLatest(request *restful.Request, response *restful.Response) { - category := request.PathParameter("category") - m.searchDocuments(cache.NonPersonalized, cache.Latest, category, request, response, data.Item{}) +func (m *Master) getNonPersonalized(request *restful.Request, response *restful.Response) { + name := request.PathParameter("name") + category := request.QueryParameter("category") + m.searchDocuments(cache.NonPersonalized, name, category, request, response, data.Item{}) } func (m *Master) getItemNeighbors(request *restful.Request, response *restful.Response) { diff --git a/master/rest_test.go b/master/rest_test.go index 4346bf1ec..2ff94b44b 100644 --- a/master/rest_test.go +++ b/master/rest_test.go @@ -28,8 +28,8 @@ import ( "time" "github.com/emicklei/go-restful/v3" + "github.com/go-viper/mapstructure/v2" "github.com/juju/errors" - "github.com/mitchellh/mapstructure" "github.com/samber/lo" "github.com/steinfletcher/apitest" "github.com/stretchr/testify/assert" @@ -515,10 +515,10 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { operators := []ListOperator{ {"Item Neighbors", cache.ItemNeighbors, "0", "", "/api/dashboard/item/0/neighbors"}, {"Item Neighbors in Category", cache.ItemNeighbors, "0", "*", "/api/dashboard/item/0/neighbors/*"}, - {"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/dashboard/latest/"}, - {"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/dashboard/popular/"}, - {"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/dashboard/latest/*"}, - {"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/dashboard/popular/*"}, + {"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/non-personalized/latest/"}, + {"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/non-personalized/popular/"}, + {"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/non-personalized/latest/"}, + {"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/non-personalized/popular/"}, } for i, operator := range operators { t.Run(operator.Name, func(t *testing.T) { @@ -551,6 +551,7 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { Handler(s.handler). Get(operator.Get). Header("Cookie", cookie). + Query("category", operator.Category). Expect(t). Status(http.StatusOK). Body(marshal(t, []ScoredItem{items[0], items[1], items[2], items[4]})). diff --git a/master/tasks.go b/master/tasks.go index b7f133d58..a15be24c0 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -1560,6 +1560,7 @@ func (m *Master) LoadDataFromDatabase( err = parallel.Parallel(len(itemGroups), m.Config.Master.NumJobs, func(_, i int) error { var itemFeedback []data.Feedback var itemGroupIndex int + itemHasFeedback := make([]bool, len(itemGroups[i])) feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize, data.WithBeginItemId(itemGroups[i][0].ItemId), data.WithEndItemId(itemGroups[i][len(itemGroups[i])-1].ItemId), @@ -1598,6 +1599,7 @@ func (m *Master) LoadDataFromDatabase( itemFeedback = append(itemFeedback, f) } else { // add item to non-personalized recommenders + itemHasFeedback[itemGroupIndex] = true for _, recommender := range nonPersonalizedRecommenders { recommender.Push(itemGroups[i][itemGroupIndex], itemFeedback) } @@ -1611,12 +1613,22 @@ func (m *Master) LoadDataFromDatabase( } } } + } - // add item to non-personalized recommenders + // add item to non-personalized recommenders + if len(itemFeedback) > 0 { + itemHasFeedback[itemGroupIndex] = true for _, recommender := range nonPersonalizedRecommenders { recommender.Push(itemGroups[i][itemGroupIndex], itemFeedback) } } + for index, hasFeedback := range itemHasFeedback { + if !hasFeedback { + for _, recommender := range nonPersonalizedRecommenders { + recommender.Push(itemGroups[i][index], nil) + } + } + } if err = <-errChan; err != nil { return errors.Trace(err) } diff --git a/master/tasks_test.go b/master/tasks_test.go index d7a168fbe..51e831d73 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -619,6 +619,86 @@ func (s *MasterTestSuite) TestLoadDataFromDatabase() { s.Equal([]string{"0", "1", "2"}, categories) } +func (s *MasterTestSuite) TestNonPersonalizedRecommend() { + ctx := context.Background() + // create config + s.Config = &config.Config{} + s.Config.Recommend.CacheSize = 3 + s.Config.Recommend.DataSource.PositiveFeedbackTypes = []string{"positive"} + s.Config.Recommend.DataSource.ReadFeedbackTypes = []string{"negative"} + s.Config.Master.NumJobs = runtime.NumCPU() + + // insert items + var items []data.Item + for i := 0; i < 10; i++ { + items = append(items, data.Item{ + ItemId: strconv.Itoa(i), + Timestamp: time.Date(2000+i%2, 1, 1, i, 1, 0, 0, time.UTC), + }) + } + err := s.DataClient.BatchInsertItems(ctx, items) + s.NoError(err) + + // insert users + var users []data.User + for i := 0; i < 10; i++ { + users = append(users, data.User{ + UserId: strconv.Itoa(i), + }) + } + err = s.DataClient.BatchInsertUsers(ctx, users) + s.NoError(err) + + // insert feedback + feedbacks := make([]data.Feedback, 0) + for i := 0; i < 10; i++ { + // positive feedback + // item 0: user 0 + // ... + // item 8: user 0 ... user 8 + if i%2 == 0 { + for j := 0; j <= i; j++ { + feedbacks = append(feedbacks, data.Feedback{ + FeedbackKey: data.FeedbackKey{ + ItemId: strconv.Itoa(i), + UserId: strconv.Itoa(j), + FeedbackType: "positive", + }, + Timestamp: time.Now(), + }) + } + } + } + err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, false, false, true) + s.NoError(err) + + // load dataset + err = s.runLoadDatasetTask() + s.NoError(err) + + // check latest items + latest, err := s.CacheClient.SearchScores(ctx, cache.NonPersonalized, cache.Latest, []string{""}, 0, 3) + s.NoError(err) + s.Equal([]cache.Score{ + {Id: items[9].ItemId, Score: float64(items[9].Timestamp.Unix())}, + {Id: items[7].ItemId, Score: float64(items[7].Timestamp.Unix())}, + {Id: items[5].ItemId, Score: float64(items[5].Timestamp.Unix())}, + }, lo.Map(latest, func(document cache.Score, _ int) cache.Score { + return cache.Score{Id: document.Id, Score: document.Score} + })) + + // check popular items + popular, err := s.CacheClient.SearchScores(ctx, cache.NonPersonalized, cache.Popular, []string{""}, 0, 3) + s.NoError(err) + s.Equal([]cache.Score{ + {Id: items[8].ItemId, Score: 9}, + {Id: items[6].ItemId, Score: 7}, + {Id: items[4].ItemId, Score: 5}, + }, lo.Map(popular, func(document cache.Score, _ int) cache.Score { + return cache.Score{Id: document.Id, Score: document.Score} + })) +} + func (s *MasterTestSuite) TestCheckItemNeighborCacheTimeout() { s.Config = config.GetDefaultConfig() ctx := context.Background()