Skip to content

Commit

Permalink
remove embeddings from search for now cause its broken
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-brennan2005 committed Jun 20, 2024
1 parent fa1ba59 commit 76e50c8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion backend/search/base/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewSearchService(serviceParams types.ServiceParams) SearchServiceInterface
}

func (s *SearchService) SearchClubs(query search_types.ClubSearchRequest) (*search_types.SearchResult[models.Club], error) {
return SearchClubs(s.DB, &query)
return SearchClubs(s.DB, s.Search, &query)
}

func (s *SearchService) SearchEvents(query search_types.EventSearchRequest) (*search_types.SearchResult[models.Event], error) {
Expand Down
39 changes: 26 additions & 13 deletions backend/search/base/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"strings"
"time"

"github.com/GenerateNU/sac/backend/config"
"github.com/GenerateNU/sac/backend/entities/models"
"github.com/GenerateNU/sac/backend/search"
"github.com/GenerateNU/sac/backend/search/types"
"gorm.io/gorm"
)
Expand All @@ -21,18 +23,34 @@ func joinQuoted(elems []string, seperator string) string {
return strings.Join(elemsQuoted, seperator)
}

func SearchClubs(db *gorm.DB, query *types.ClubSearchRequest) (*types.SearchResult[models.Club], error) {
func SearchClubs(db *gorm.DB, s *config.SearchSettings, query *types.ClubSearchRequest) (*types.SearchResult[models.Club], error) {
var clubs []models.Club

dbQuery := db.Model(&clubs)

finalQuery := clubKeywordSearchSQL(query)

if err := dbQuery.Raw(finalQuery).Scan(&clubs).Error; err != nil {
return nil, err
}

// return results
return &types.SearchResult[models.Club]{
Results: clubs,
}, nil
}

func clubKeywordSearchSQL(query *types.ClubSearchRequest) string {
var whereClauses []string
rankClause := ""
innerJoin := ""

if query.Search != "" {
whereClauses =
append(whereClauses,
fmt.Sprintf("clubsearch_index_col @@ plainto_tsquery('%s')", query.Search))

rankClause = fmt.Sprintf(", RANK () OVER (ORDER BY ts_rank_cd(clubsearch_index_col, plainto_tsquery('%s')) DESC)", query.Search)
}

if query.MaxMembers != 0 {
Expand All @@ -53,19 +71,14 @@ func SearchClubs(db *gorm.DB, query *types.ClubSearchRequest) (*types.SearchResu
fmt.Sprintf("club_tags.tag_id IN (%s)", joinQuoted(query.Tags, ",")))
}

finalQuery := fmt.Sprintf(
"SELECT * FROM clubs %s WHERE %s",
innerJoin,
strings.Join(whereClauses, " AND "))
return fmt.Sprintf("SELECT *%s FROM clubs %s WHERE %s", rankClause, innerJoin, strings.Join(whereClauses, " AND "))
}

if err := dbQuery.Raw(finalQuery).Scan(&clubs).Error; err != nil {
return nil, err
}
//lint:ignore U1000 Ignore unused function temporarily for debugging
func clubSemanticSearchSQL(embedding []float32) string {
embeddingStr := search.FloatArrayToSql(embedding)

// return results
return &types.SearchResult[models.Club]{
Results: clubs,
}, nil
return fmt.Sprintf("SELECT id, RANK () OVER (ORDER BY embedding <=> '[%s]') AS rank FROM clubs ORDER BY embedding <=> '[%s]' LIMIT 20", embeddingStr, embeddingStr)
}

func SearchEvents(db *gorm.DB, query *types.EventSearchRequest) (*types.SearchResult[models.Event], error) {
Expand Down Expand Up @@ -107,7 +120,7 @@ func SearchEvents(db *gorm.DB, query *types.EventSearchRequest) (*types.SearchRe
}

finalQuery := fmt.Sprintf(
"SELECT * FROM events %s WHERE %s",
"SELECT * FROM events %s WHERE %s LIMIT 20",
strings.Join(innerJoins, " "),
strings.Join(whereClauses, " AND "))

Expand Down
12 changes: 6 additions & 6 deletions backend/search/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ type CreateEmbeddingResponseBody struct {

func UpsertClubEmbedding(db *gorm.DB, s *config.SearchSettings, club *models.Club) error {
embedding, err :=
createEmbedding(s, fmt.Sprintf("%s %s %s", club.Name, club.Preview, club.Description))
CreateEmbedding(s, fmt.Sprintf("%s %s %s", club.Name, club.Preview, club.Description))
if err != nil {
return err
}

embeddingStr := floatArrayToSql(embedding)
embeddingStr := FloatArrayToSql(embedding)

queryString := fmt.Sprintf(
"UPDATE clubs SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, club.ID.String())
Expand All @@ -54,12 +54,12 @@ func UpsertClubEmbedding(db *gorm.DB, s *config.SearchSettings, club *models.Clu

func UpsertEventEmbedding(db *gorm.DB, s *config.SearchSettings, event *models.Event) error {
embedding, err :=
createEmbedding(s, fmt.Sprintf("%s %s %s", event.Name, event.Preview, event.Description))
CreateEmbedding(s, fmt.Sprintf("%s %s %s", event.Name, event.Preview, event.Description))
if err != nil {
return err
}

embeddingStr := floatArrayToSql(embedding)
embeddingStr := FloatArrayToSql(embedding)

queryString := fmt.Sprintf(
"UPDATE events SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, event.ID.String())
Expand All @@ -74,7 +74,7 @@ func UpsertEventEmbedding(db *gorm.DB, s *config.SearchSettings, event *models.E
return nil
}

func floatArrayToSql(embedding []float32) string {
func FloatArrayToSql(embedding []float32) string {
embeddingArr := make([]string, len(embedding))
for i, f := range embedding {
embeddingArr[i] = strconv.FormatFloat(float64(f), 'f', -1, 32)
Expand All @@ -83,7 +83,7 @@ func floatArrayToSql(embedding []float32) string {
return embeddingStr
}

func createEmbedding(s *config.SearchSettings, item string) ([]float32, error) {
func CreateEmbedding(s *config.SearchSettings, item string) ([]float32, error) {
embeddingBody, err := json.Marshal(
CreateEmbeddingRequestBody{
Input: []string{item},
Expand Down

0 comments on commit 76e50c8

Please sign in to comment.