diff --git a/backend/search/base/service.go b/backend/search/base/service.go index aac32b1e..00576e82 100644 --- a/backend/search/base/service.go +++ b/backend/search/base/service.go @@ -21,7 +21,7 @@ func NewSearchService(serviceParams types.ServiceParams) SearchServiceInterface } func (s *SearchService) SearchClubs(query search_types.ClubSearchRequest) (*search_types.SearchResult[models.Club], error) { - return SearchClubs(s.DB, &query) + return SearchClubs(s.DB, s.Search, &query) } func (s *SearchService) SearchEvents(query search_types.EventSearchRequest) (*search_types.SearchResult[models.Event], error) { diff --git a/backend/search/base/transactions.go b/backend/search/base/transactions.go index 657d7222..2532d08f 100644 --- a/backend/search/base/transactions.go +++ b/backend/search/base/transactions.go @@ -5,7 +5,9 @@ import ( "strings" "time" + "github.com/GenerateNU/sac/backend/config" "github.com/GenerateNU/sac/backend/entities/models" + "github.com/GenerateNU/sac/backend/search" "github.com/GenerateNU/sac/backend/search/types" "gorm.io/gorm" ) @@ -21,18 +23,34 @@ func joinQuoted(elems []string, seperator string) string { return strings.Join(elemsQuoted, seperator) } -func SearchClubs(db *gorm.DB, query *types.ClubSearchRequest) (*types.SearchResult[models.Club], error) { +func SearchClubs(db *gorm.DB, s *config.SearchSettings, query *types.ClubSearchRequest) (*types.SearchResult[models.Club], error) { var clubs []models.Club dbQuery := db.Model(&clubs) + finalQuery := clubKeywordSearchSQL(query) + + if err := dbQuery.Raw(finalQuery).Scan(&clubs).Error; err != nil { + return nil, err + } + + // return results + return &types.SearchResult[models.Club]{ + Results: clubs, + }, nil +} + +func clubKeywordSearchSQL(query *types.ClubSearchRequest) string { var whereClauses []string + rankClause := "" innerJoin := "" if query.Search != "" { whereClauses = append(whereClauses, fmt.Sprintf("clubsearch_index_col @@ plainto_tsquery('%s')", query.Search)) + + rankClause = fmt.Sprintf(", RANK () OVER (ORDER BY ts_rank_cd(clubsearch_index_col, plainto_tsquery('%s')) DESC)", query.Search) } if query.MaxMembers != 0 { @@ -53,19 +71,14 @@ func SearchClubs(db *gorm.DB, query *types.ClubSearchRequest) (*types.SearchResu fmt.Sprintf("club_tags.tag_id IN (%s)", joinQuoted(query.Tags, ","))) } - finalQuery := fmt.Sprintf( - "SELECT * FROM clubs %s WHERE %s", - innerJoin, - strings.Join(whereClauses, " AND ")) + return fmt.Sprintf("SELECT *%s FROM clubs %s WHERE %s", rankClause, innerJoin, strings.Join(whereClauses, " AND ")) +} - if err := dbQuery.Raw(finalQuery).Scan(&clubs).Error; err != nil { - return nil, err - } +//lint:ignore U1000 Ignore unused function temporarily for debugging +func clubSemanticSearchSQL(embedding []float32) string { + embeddingStr := search.FloatArrayToSql(embedding) - // return results - return &types.SearchResult[models.Club]{ - Results: clubs, - }, nil + return fmt.Sprintf("SELECT id, RANK () OVER (ORDER BY embedding <=> '[%s]') AS rank FROM clubs ORDER BY embedding <=> '[%s]' LIMIT 20", embeddingStr, embeddingStr) } func SearchEvents(db *gorm.DB, query *types.EventSearchRequest) (*types.SearchResult[models.Event], error) { @@ -107,7 +120,7 @@ func SearchEvents(db *gorm.DB, query *types.EventSearchRequest) (*types.SearchRe } finalQuery := fmt.Sprintf( - "SELECT * FROM events %s WHERE %s", + "SELECT * FROM events %s WHERE %s LIMIT 20", strings.Join(innerJoins, " "), strings.Join(whereClauses, " AND ")) diff --git a/backend/search/embeddings.go b/backend/search/embeddings.go index e37a1615..e0484598 100644 --- a/backend/search/embeddings.go +++ b/backend/search/embeddings.go @@ -32,12 +32,12 @@ type CreateEmbeddingResponseBody struct { func UpsertClubEmbedding(db *gorm.DB, s *config.SearchSettings, club *models.Club) error { embedding, err := - createEmbedding(s, fmt.Sprintf("%s %s %s", club.Name, club.Preview, club.Description)) + CreateEmbedding(s, fmt.Sprintf("%s %s %s", club.Name, club.Preview, club.Description)) if err != nil { return err } - embeddingStr := floatArrayToSql(embedding) + embeddingStr := FloatArrayToSql(embedding) queryString := fmt.Sprintf( "UPDATE clubs SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, club.ID.String()) @@ -54,12 +54,12 @@ func UpsertClubEmbedding(db *gorm.DB, s *config.SearchSettings, club *models.Clu func UpsertEventEmbedding(db *gorm.DB, s *config.SearchSettings, event *models.Event) error { embedding, err := - createEmbedding(s, fmt.Sprintf("%s %s %s", event.Name, event.Preview, event.Description)) + CreateEmbedding(s, fmt.Sprintf("%s %s %s", event.Name, event.Preview, event.Description)) if err != nil { return err } - embeddingStr := floatArrayToSql(embedding) + embeddingStr := FloatArrayToSql(embedding) queryString := fmt.Sprintf( "UPDATE events SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, event.ID.String()) @@ -74,7 +74,7 @@ func UpsertEventEmbedding(db *gorm.DB, s *config.SearchSettings, event *models.E return nil } -func floatArrayToSql(embedding []float32) string { +func FloatArrayToSql(embedding []float32) string { embeddingArr := make([]string, len(embedding)) for i, f := range embedding { embeddingArr[i] = strconv.FormatFloat(float64(f), 'f', -1, 32) @@ -83,7 +83,7 @@ func floatArrayToSql(embedding []float32) string { return embeddingStr } -func createEmbedding(s *config.SearchSettings, item string) ([]float32, error) { +func CreateEmbedding(s *config.SearchSettings, item string) ([]float32, error) { embeddingBody, err := json.Marshal( CreateEmbeddingRequestBody{ Input: []string{item},