Skip to content

Commit

Permalink
implement multi-categories filter for all APIs (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Dec 21, 2024
1 parent 5c8f681 commit c9be1ea
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 87 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
[![build](https://github.com/zhenghaoz/gorse/workflows/build/badge.svg)](https://github.com/zhenghaoz/gorse/actions?query=workflow%3Abuild)
[![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse)
[![Go Report Card](https://goreportcard.com/badge/github.com/zhenghaoz/gorse)](https://goreportcard.com/report/github.com/zhenghaoz/gorse)
[![GoDoc](https://godoc.org/github.com/zhenghaoz/gorse?status.svg)](https://godoc.org/github.com/zhenghaoz/gorse)
[![Discord](https://img.shields.io/discord/830635934210588743)](https://discord.gg/x6gAtNNkAE)
[![Twitter Follow](https://img.shields.io/twitter/follow/gorse_io?label=Follow&style=social)](https://twitter.com/gorse_io)
[![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20Gorse%20Guru-006BFF)](https://gurubase.io/g/gorse)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ require (
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/google/uuid v1.6.0
github.com/gorilla/securecookie v1.1.1
github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77
github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606
github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946
github.com/jaswdr/faker v1.16.0
github.com/jellydator/ttlcache/v3 v3.3.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI=
github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4 h1:FOUvD2HvTY/8j1/I4j/FlX3LEqKGLWPWQLl6jPtUqQ0=
github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4/go.mod h1:LBLzsMv3XVLmpaM/1q8/sGvv2Avj1YxmHBZfXcdqRjU=
github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77 h1:WA5kRl4LNduJuM59vvMoAyBPU+7KZL2ROjE2fPUy6sE=
github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc=
github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606 h1:5Vh8xik8c905IYFg66ujt7FuuuPtzSW6e2DRzBUYc58=
github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0=
github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg=
Expand Down
85 changes: 28 additions & 57 deletions master/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"io"
"net/http"
"os"
"reflect"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -138,7 +137,7 @@ func (m *Master) CreateWebService() {
Returns(http.StatusOK, "OK", UserIterator{}).
Writes(UserIterator{}))
// Get non-personalized recommendation
ws.Route(ws.GET("/non-personalized/{name}").To(m.getNonPersonalized).
ws.Route(ws.GET("/dashboard/non-personalized/{name}").To(m.getNonPersonalized).
Doc("Get non-personalized recommendations.").
Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")).
Expand All @@ -151,6 +150,7 @@ func (m *Master) CreateWebService() {
Doc("Get recommendation for user.").
Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
Param(ws.QueryParameter("category", "category of items").DataType("string")).
Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
Returns(http.StatusOK, "OK", []data.Item{}).
Writes([]data.Item{}))
Expand All @@ -159,6 +159,7 @@ func (m *Master) CreateWebService() {
Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}).
Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")).
Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")).
Param(ws.QueryParameter("category", "category of items").DataType("string")).
Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
Returns(http.StatusOK, "OK", []data.Item{}).
Writes([]data.Item{}))
Expand All @@ -175,6 +176,7 @@ func (m *Master) CreateWebService() {
Doc("get neighbors of a item").
Metadata(restfulspec.KeyOpenAPITags, []string{"recommendation"}).
Param(ws.PathParameter("item-id", "identifier of the item").DataType("string")).
Param(ws.QueryParameter("category", "category of items").DataType("string")).
Param(ws.QueryParameter("n", "number of returned items").DataType("int")).
Param(ws.QueryParameter("offset", "offset of the list").DataType("int")).
Returns(http.StatusOK, "OK", []ScoredItem{}).
Expand Down Expand Up @@ -743,7 +745,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
// parse arguments
recommender := request.PathParameter("recommender")
userId := request.PathParameter("user-id")
categories := []string{request.PathParameter("category")}
categories := server.ReadCategories(request)
n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN)
if err != nil {
server.BadRequest(response, err)
Expand Down Expand Up @@ -844,80 +846,49 @@ type ScoreUser struct {
Score float64
}

func (m *Master) searchDocuments(collection, subset, category string, request *restful.Request, response *restful.Response, retType interface{}) {
ctx := context.Background()
if request != nil && request.Request != nil {
ctx = request.Request.Context()
}
var n, offset int
func (m *Master) GetItem(score cache.Score) (any, error) {
var item ScoredItem
var err error
// read arguments
if offset, err = server.ParseInt(request, "offset", 0); err != nil {
server.BadRequest(response, err)
return
}
if n, err = server.ParseInt(request, "n", m.Config.Server.DefaultN); err != nil {
server.BadRequest(response, err)
return
}
// Get the popular list
scores, err := m.CacheClient.SearchScores(ctx, collection, subset, []string{category}, offset, m.Config.Recommend.CacheSize)
item.Score = score.Score
item.Item, err = m.DataClient.GetItem(context.Background(), score.Id)
if err != nil {
server.InternalServerError(response, err)
return
}
if n > 0 && len(scores) > n {
scores = scores[:n]
return nil, err
}
// Send result
switch retType.(type) {
case data.Item:
details := make([]ScoredItem, len(scores))
for i := range scores {
details[i].Score = scores[i].Score
details[i].Item, err = m.DataClient.GetItem(ctx, scores[i].Id)
if err != nil {
server.InternalServerError(response, err)
return
}
}
server.Ok(response, details)
case data.User:
details := make([]ScoreUser, len(scores))
for i := range scores {
details[i].Score = scores[i].Score
details[i].User, err = m.DataClient.GetUser(ctx, scores[i].Id)
if err != nil {
server.InternalServerError(response, err)
return
}
}
server.Ok(response, details)
default:
log.ResponseLogger(response).Fatal("unknown return type", zap.Any("ret_type", reflect.TypeOf(retType)))
return item, nil
}

func (m *Master) GetUser(score cache.Score) (any, error) {
var user ScoreUser
var err error
user.Score = score.Score
user.User, err = m.DataClient.GetUser(context.Background(), score.Id)
if err != nil {
return nil, err
}
return user, nil
}

func (m *Master) getNonPersonalized(request *restful.Request, response *restful.Response) {
name := request.PathParameter("name")
category := request.QueryParameter("category")
m.searchDocuments(cache.NonPersonalized, name, category, request, response, data.Item{})
categories := server.ReadCategories(request)
m.SearchDocuments(cache.NonPersonalized, name, categories, m.GetItem, request, response)
}

func (m *Master) getItemNeighbors(request *restful.Request, response *restful.Response) {
itemId := request.PathParameter("item-id")
m.searchDocuments(cache.ItemNeighbors, itemId, "", request, response, data.Item{})
categories := server.ReadCategories(request)
m.SearchDocuments(cache.ItemNeighbors, itemId, categories, m.GetItem, request, response)
}

func (m *Master) getItemCategorizedNeighbors(request *restful.Request, response *restful.Response) {
itemId := request.PathParameter("item-id")
category := request.PathParameter("category")
m.searchDocuments(cache.ItemNeighbors, itemId, category, request, response, data.Item{})
categories := server.ReadCategories(request)
m.SearchDocuments(cache.ItemNeighbors, itemId, categories, m.GetItem, request, response)
}

func (m *Master) getUserNeighbors(request *restful.Request, response *restful.Response) {
userId := request.PathParameter("user-id")
m.searchDocuments(cache.UserNeighbors, userId, "", request, response, data.User{})
m.SearchDocuments(cache.UserNeighbors, userId, []string{""}, m.GetUser, request, response)
}

func (m *Master) importExportUsers(response http.ResponseWriter, request *http.Request) {
Expand Down
8 changes: 4 additions & 4 deletions master/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,10 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) {
operators := []ListOperator{
{"Item Neighbors", cache.ItemNeighbors, "0", "", "/api/dashboard/item/0/neighbors"},
{"Item Neighbors in Category", cache.ItemNeighbors, "0", "*", "/api/dashboard/item/0/neighbors/*"},
{"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/non-personalized/latest/"},
{"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/non-personalized/popular/"},
{"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/non-personalized/latest/"},
{"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/non-personalized/popular/"},
{"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/dashboard/non-personalized/latest/"},
{"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/dashboard/non-personalized/popular/"},
{"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/dashboard/non-personalized/latest/"},
{"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/dashboard/non-personalized/popular/"},
}
for i, operator := range operators {
t.Run(operator.Name, func(t *testing.T) {
Expand Down
65 changes: 45 additions & 20 deletions server/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ func (s *RestServer) CreateWebService() {
Doc("Get popular items.").
Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}).
Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")).
Param(ws.QueryParameter("category", "Category of returned items").DataType("string")).
Param(ws.QueryParameter("n", "Number of returned recommendations").DataType("integer")).
Param(ws.QueryParameter("offset", "Offset of returned recommendations").DataType("integer")).
Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")).
Expand All @@ -456,6 +457,7 @@ func (s *RestServer) CreateWebService() {
Doc("Get the latest items.").
Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}).
Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")).
Param(ws.QueryParameter("category", "Category of returned items").DataType("string")).
Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")).
Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")).
Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")).
Expand Down Expand Up @@ -585,7 +587,10 @@ func ParseDuration(request *restful.Request, name string) (time.Duration, error)
return time.ParseDuration(valueString)
}

func (s *RestServer) searchDocuments(collection, subset, category string, isItem bool, request *restful.Request, response *restful.Response) {
func (s *RestServer) SearchDocuments(collection, subset string, categories []string,
iteratee func(item cache.Score) (any, error),
request *restful.Request, response *restful.Response,
) {
var (
ctx = request.Request.Context()
n int
Expand Down Expand Up @@ -623,7 +628,7 @@ func (s *RestServer) searchDocuments(collection, subset, category string, isItem
}

// Get the sorted list
items, err := s.CacheClient.SearchScores(ctx, collection, subset, []string{category}, offset, end)
items, err := s.CacheClient.SearchScores(ctx, collection, subset, categories, offset, end)
if err != nil {
InternalServerError(response, err)
return
Expand All @@ -644,26 +649,39 @@ func (s *RestServer) searchDocuments(collection, subset, category string, isItem
if n > 0 && len(items) > n {
items = items[:n]
}
Ok(response, items)
if iteratee != nil {
var results []any
for _, item := range items {
result, err := iteratee(item)
if err != nil {
InternalServerError(response, err)
return
}
results = append(results, result)
}
Ok(response, results)
} else {
Ok(response, items)
}
}

func (s *RestServer) getPopular(request *restful.Request, response *restful.Response) {
category := request.PathParameter("category")
log.ResponseLogger(response).Debug("get category popular items in category", zap.String("category", category))
s.searchDocuments(cache.NonPersonalized, cache.Popular, category, true, request, response)
categories := ReadCategories(request)
log.ResponseLogger(response).Debug("get category popular items in category", zap.Strings("categories", categories))
s.SearchDocuments(cache.NonPersonalized, cache.Popular, categories, nil, request, response)
}

func (s *RestServer) getLatest(request *restful.Request, response *restful.Response) {
category := request.PathParameter("category")
log.ResponseLogger(response).Debug("get category latest items in category", zap.String("category", category))
s.searchDocuments(cache.NonPersonalized, cache.Latest, category, true, request, response)
categories := ReadCategories(request)
log.ResponseLogger(response).Debug("get category latest items in category", zap.Strings("categories", categories))
s.SearchDocuments(cache.NonPersonalized, cache.Latest, categories, nil, request, response)
}

func (s *RestServer) getNonPersonalized(request *restful.Request, response *restful.Response) {
name := request.PathParameter("name")
category := request.QueryParameter("category")
categories := ReadCategories(request)
log.ResponseLogger(response).Debug("get leaderboard", zap.String("name", name))
s.searchDocuments(cache.NonPersonalized, name, category, false, request, response)
s.SearchDocuments(cache.NonPersonalized, name, categories, nil, request, response)
}

// get feedback by item-id with feedback type
Expand Down Expand Up @@ -701,23 +719,23 @@ func (s *RestServer) getFeedbackByItem(request *restful.Request, response *restf
func (s *RestServer) getItemNeighbors(request *restful.Request, response *restful.Response) {
// Get item id
itemId := request.PathParameter("item-id")
category := request.PathParameter("category")
s.searchDocuments(cache.ItemNeighbors, itemId, category, true, request, response)
categories := ReadCategories(request)
s.SearchDocuments(cache.ItemNeighbors, itemId, categories, nil, request, response)
}

// getUserNeighbors gets neighbors of a user from database.
func (s *RestServer) getUserNeighbors(request *restful.Request, response *restful.Response) {
// Get item id
userId := request.PathParameter("user-id")
s.searchDocuments(cache.UserNeighbors, userId, "", false, request, response)
s.SearchDocuments(cache.UserNeighbors, userId, []string{""}, nil, request, response)
}

// getCollaborative gets cached recommended items from database.
func (s *RestServer) getCollaborative(request *restful.Request, response *restful.Response) {
// Get user id
userId := request.PathParameter("user-id")
category := request.PathParameter("category")
s.searchDocuments(cache.OfflineRecommend, userId, category, true, request, response)
categories := ReadCategories(request)
s.SearchDocuments(cache.OfflineRecommend, userId, categories, nil, request, response)
}

// Recommend items to users.
Expand Down Expand Up @@ -995,10 +1013,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re
BadRequest(response, err)
return
}
categories := request.QueryParameters("category")
if len(categories) == 0 {
categories = []string{request.PathParameter("category")}
}
categories := ReadCategories(request)
offset, err := ParseInt(request, "offset", 0)
if err != nil {
BadRequest(response, err)
Expand Down Expand Up @@ -1963,3 +1978,13 @@ func withWildCard(categories []string) []string {
result = append(result, "")
return result
}

func ReadCategories(request *restful.Request) []string {
if pathValue := request.PathParameter("category"); pathValue != "" {
return []string{pathValue}
} else if queryValues := request.QueryParameters("category"); len(queryValues) > 0 {
return queryValues
} else {
return []string{""}
}
}

0 comments on commit c9be1ea

Please sign in to comment.