diff --git a/storage/cache/database_test.go b/storage/cache/database_test.go index 00f214b6d..87449931f 100644 --- a/storage/cache/database_test.go +++ b/storage/cache/database_test.go @@ -300,6 +300,17 @@ func (suite *baseTestSuite) TestDocument() { {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}, }, documents) + // search documents with nil category + documents, err = suite.SearchScores(ctx, "a", "", nil, 0, -1) + suite.NoError(err) + suite.Equal([]Score{ + {Id: "5", Score: 5, Categories: []string{"b"}, Timestamp: ts}, + {Id: "4", Score: 4, Categories: []string{""}, Timestamp: ts}, + {Id: "3", Score: 3, Categories: []string{"b"}, Timestamp: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)}, + {Id: "2", Score: 2, Categories: []string{"b", "c"}, Timestamp: ts}, + {Id: "1", Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}, + }, documents) + // search documents with empty category documents, err = suite.SearchScores(ctx, "a", "", []string{""}, 0, -1) suite.NoError(err) diff --git a/storage/cache/mongodb.go b/storage/cache/mongodb.go index 85b321df8..34e516d45 100644 --- a/storage/cache/mongodb.go +++ b/storage/cache/mongodb.go @@ -342,19 +342,19 @@ func (m MongoDB) AddScores(ctx context.Context, collection, subset string, docum } func (m MongoDB) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { - if len(query) == 0 { - return nil, nil - } opt := options.Find().SetSkip(int64(begin)).SetSort(bson.M{"score": -1}) if end != -1 { opt.SetLimit(int64(end - begin)) } - cur, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).Find(ctx, bson.M{ + filter := bson.M{ "collection": collection, "subset": subset, "is_hidden": false, - "categories": bson.M{"$all": query}, - }, opt) + } + if len(query) > 0 { + filter["categories"] = bson.M{"$all": query} + } + cur, err := m.client.Database(m.dbName).Collection(m.DocumentTable()).Find(ctx, filter, opt) if err != nil { return nil, errors.Trace(err) } diff --git a/storage/cache/redis.go b/storage/cache/redis.go index b2a0a7386..84f4cfb74 100644 --- a/storage/cache/redis.go +++ b/storage/cache/redis.go @@ -254,9 +254,6 @@ func (r *Redis) AddScores(ctx context.Context, collection, subset string, docume } func (r *Redis) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { - if len(query) == 0 { - return nil, nil - } var builder strings.Builder builder.WriteString(fmt.Sprintf("@collection:{ %s } @is_hidden:[0 0]", escape(collection))) if subset != "" { diff --git a/storage/cache/sql.go b/storage/cache/sql.go index d4b8ad76b..95293341a 100644 --- a/storage/cache/sql.go +++ b/storage/cache/sql.go @@ -427,25 +427,27 @@ func (db *SQLDatabase) AddScores(ctx context.Context, collection, subset string, } func (db *SQLDatabase) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { - if len(query) == 0 { - return nil, nil - } - tx := db.gormDB.WithContext(ctx).Model(&PostgresDocument{}).Select("id, score, categories, timestamp") - switch db.driver { - case Postgres: - tx = tx.Where("collection = ? and subset = ? and is_hidden = false and categories @> ?", collection, subset, pq.StringArray(query)) - case SQLite, MySQL: - q, err := json.Marshal(query) - if err != nil { - return nil, errors.Trace(err) + tx := db.gormDB.WithContext(ctx). + Model(&PostgresDocument{}). + Select("id, score, categories, timestamp"). + Where("collection = ? and subset = ? and is_hidden = false", collection, subset) + if len(query) > 0 { + switch db.driver { + case Postgres: + tx.Where("categories @> ?", pq.StringArray(query)) + case SQLite, MySQL: + q, err := json.Marshal(query) + if err != nil { + return nil, errors.Trace(err) + } + tx.Where("JSON_CONTAINS(categories,?)", string(q)) } - tx = tx.Where("collection = ? and subset = ? and is_hidden = false and JSON_CONTAINS(categories,?)", collection, subset, string(q)) } - tx = tx.Order("score desc").Offset(begin) + tx.Order("score desc").Offset(begin) if end != -1 { - tx = tx.Limit(end - begin) + tx.Limit(end - begin) } else { - tx = tx.Limit(math.MaxInt64) + tx.Limit(math.MaxInt64) } rows, err := tx.Rows() if err != nil { @@ -484,28 +486,28 @@ func (db *SQLDatabase) UpdateScores(ctx context.Context, collections []string, s if patch.Score == nil && patch.IsHidden == nil && patch.Categories == nil { return nil } - tx := db.gormDB.WithContext(ctx).Model(&PostgresDocument{}) + tx := db.gormDB.WithContext(ctx). + Model(&PostgresDocument{}). + Where("collection in (?) and id = ?", collections, id) if subset != nil { - tx = tx.Where("collection in (?) and id = ? and subset = ?", collections, id, subset) - } else { - tx = tx.Where("collection in (?) and id = ?", collections, id) + tx.Where("subset = ?", subset) } if patch.Score != nil { - tx = tx.Update("score", *patch.Score) + tx.Update("score", *patch.Score) } if patch.IsHidden != nil { - tx = tx.Update("is_hidden", *patch.IsHidden) + tx.Update("is_hidden", *patch.IsHidden) } if patch.Categories != nil { switch db.driver { case Postgres: - tx = tx.Update("categories", pq.StringArray(patch.Categories)) + tx.Update("categories", pq.StringArray(patch.Categories)) case SQLite, MySQL: q, err := json.Marshal(patch.Categories) if err != nil { return errors.Trace(err) } - tx = tx.Update("categories", string(q)) + tx.Update("categories", string(q)) } } return tx.Error diff --git a/storage/data/sql.go b/storage/data/sql.go index c4eb6ba02..5d5d8ddca 100644 --- a/storage/data/sql.go +++ b/storage/data/sql.go @@ -353,7 +353,8 @@ func (d *SQLDatabase) BatchGetItems(ctx context.Context, itemIds []string) ([]It if len(itemIds) == 0 { return nil, nil } - result, err := d.gormDB.WithContext(ctx).Table(d.ItemsTable()). + result, err := d.gormDB.WithContext(ctx). + Table(d.ItemsTable()). Select("item_id, is_hidden, categories, time_stamp, labels, comment"). Where("item_id IN ?", itemIds).Rows() if err != nil { @@ -394,7 +395,10 @@ func (d *SQLDatabase) DeleteItem(ctx context.Context, itemId string) error { func (d *SQLDatabase) GetItem(ctx context.Context, itemId string) (Item, error) { var result *sql.Rows var err error - result, err = d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment").Where("item_id = ?", itemId).Rows() + result, err = d.gormDB.WithContext(ctx). + Table(d.ItemsTable()). + Select("item_id, is_hidden, categories, time_stamp, labels, comment"). + Where("item_id = ?", itemId).Rows() if err != nil { return Item{}, errors.Trace(err) } @@ -454,7 +458,9 @@ func (d *SQLDatabase) GetItems(ctx context.Context, cursor string, n int, timeLi return "", nil, errors.Trace(err) } cursorItem := string(buf) - tx := d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment") + tx := d.gormDB.WithContext(ctx). + Table(d.ItemsTable()). + Select("item_id, is_hidden, categories, time_stamp, labels, comment") if cursorItem != "" { tx.Where("item_id >= ?", cursorItem) } @@ -488,7 +494,9 @@ func (d *SQLDatabase) GetItemStream(ctx context.Context, batchSize int, timeLimi defer close(itemChan) defer close(errChan) // send query - tx := d.gormDB.WithContext(ctx).Table(d.ItemsTable()).Select("item_id, is_hidden, categories, time_stamp, labels, comment") + tx := d.gormDB.WithContext(ctx). + Table(d.ItemsTable()). + Select("item_id, is_hidden, categories, time_stamp, labels, comment") if timeLimit != nil { tx.Where("time_stamp >= ?", *timeLimit) } @@ -528,15 +536,16 @@ func (d *SQLDatabase) GetItemFeedback(ctx context.Context, itemId string, feedba } else { tx = tx.Table(d.FeedbackTable()) } - tx = tx.Select("user_id, item_id, feedback_type, time_stamp") + tx.Select("user_id, item_id, feedback_type, time_stamp") switch d.driver { case SQLite: - tx.Where("time_stamp <= DATETIME() AND item_id = ?", itemId) + tx.Where("time_stamp <= DATETIME()") case ClickHouse: - tx.Where("time_stamp <= NOW('UTC') AND item_id = ?", itemId) + tx.Where("time_stamp <= NOW('UTC')") default: - tx.Where("time_stamp <= NOW() AND item_id = ?", itemId) + tx.Where("time_stamp <= NOW()") } + tx.Where("item_id = ?", itemId) if len(feedbackTypes) > 0 { tx.Where("feedback_type IN ?", feedbackTypes) } @@ -659,7 +668,9 @@ func (d *SQLDatabase) GetUsers(ctx context.Context, cursor string, n int) (strin return "", nil, errors.Trace(err) } cursorUser := string(buf) - tx := d.gormDB.WithContext(ctx).Table(d.UsersTable()).Select("user_id, labels, subscribe, comment") + tx := d.gormDB.WithContext(ctx). + Table(d.UsersTable()). + Select("user_id, labels, subscribe, comment") if cursorUser != "" { tx.Where("user_id >= ?", cursorUser) } @@ -726,7 +737,7 @@ func (d *SQLDatabase) GetUserFeedback(ctx context.Context, userId string, endTim } else { tx = tx.Table(d.FeedbackTable()) } - tx = tx.Select("feedback_type, user_id, item_id, time_stamp, comment"). + tx.Select("feedback_type, user_id, item_id, time_stamp, comment"). Where("user_id = ?", userId) if endTime != nil { tx.Where("time_stamp <= ?", d.convertTimeZone(endTime)) @@ -1023,7 +1034,7 @@ func (d *SQLDatabase) GetUserItemFeedback(ctx context.Context, userId, itemId st } else { tx = tx.Table(d.FeedbackTable()) } - tx = tx.Select("feedback_type, user_id, item_id, time_stamp, comment"). + tx.Select("feedback_type, user_id, item_id, time_stamp, comment"). Where("user_id = ? AND item_id = ?", userId, itemId) if len(feedbackTypes) > 0 { tx.Where("feedback_type IN ?", feedbackTypes)