From a0cc4fb57475c9c43492ca0228e5eaf50758ad42 Mon Sep 17 00:00:00 2001 From: Michael Brennan Date: Tue, 18 Jun 2024 17:01:48 -0400 Subject: [PATCH] stash --- backend/config/search.go | 19 +++++++- backend/config/settings.go | 9 +++- backend/search/embeddings.go | 91 ++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 backend/search/embeddings.go diff --git a/backend/config/search.go b/backend/config/search.go index 27f40eaf..5d236482 100644 --- a/backend/config/search.go +++ b/backend/config/search.go @@ -1,5 +1,22 @@ package config +import m "github.com/garrettladley/mattress" + type SearchSettings struct { - URI string `env:"URI"` + OpenAIApiKey *m.Secret[string] +} + +type intermediateSearchSettings struct { + OpenAIApiKey string `env:"OPENAI_API_KEY` +} + +func (i *intermediateSearchSettings) into() (*SearchSettings, error) { + openAiApiKey, err := m.NewSecret(i.OpenAIApiKey) + if err != nil { + return nil, err + } + + return &SearchSettings{ + OpenAIApiKey: openAiApiKey, + }, nil } diff --git a/backend/config/settings.go b/backend/config/settings.go index b257ba4e..1b3b7ffe 100644 --- a/backend/config/settings.go +++ b/backend/config/settings.go @@ -32,7 +32,7 @@ type intermediateSettings struct { Calendar intermediateCalendarSettings `envPrefix:"SAC_CALENDAR_"` Google intermediateGoogleOAuthSettings `envPrefix:"SAC_GOOGLE_OAUTH_"` Microsft intermediateMicrosoftOAuthSetting `envPrefix:"SAC_MICROSOFT_OAUTH_"` - Search SearchSettings `envPrefix:"SAC_SEARCH_"` + Search intermediateSearchSettings `envPrefix:"SAC_SEARCH_"` } func (i *intermediateSettings) into() (*Settings, error) { @@ -91,6 +91,11 @@ func (i *intermediateSettings) into() (*Settings, error) { return nil, err } + search, err := i.Search.into() + if err != nil { + return nil, err + } + return &Settings{ Application: i.Application, DBCache: *dbCache, @@ -104,7 +109,7 @@ func (i *intermediateSettings) into() (*Settings, error) { Microsft: *microsoft, AWS: *aws, Resend: *resend, - Search: i.Search, + Search: *search, }, }, nil } diff --git a/backend/search/embeddings.go b/backend/search/embeddings.go new file mode 100644 index 00000000..f9d828a1 --- /dev/null +++ b/backend/search/embeddings.go @@ -0,0 +1,91 @@ +package search + +import ( + "bytes" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/GenerateNU/sac/backend/config" + "github.com/GenerateNU/sac/backend/entities/models" + "github.com/GenerateNU/sac/backend/utilities" + "github.com/goccy/go-json" + "github.com/gofiber/fiber/v2" + "gorm.io/gorm" +) + +type CreateEmbeddingRequestBody struct { + Input []string `json:"input"` + Model string `json:"model"` +} + +type CreateEmbeddingResponseBody struct { + Data [][]float32 `json:"data"` +} + +func UpsertClubEmbedding(db *gorm.DB, s *config.Settings, club *models.Club) error { + embedding, err := + createEmbedding(s, fmt.Sprintf("%s %s %s", club.Name, club.Preview, club.Description)) + if err != nil { + return err + } + + embeddingArr := make([]string, len(embedding)) + for i, f := range embedding { + embeddingArr[i] = strconv.FormatFloat(float64(f), 'f', -1, 32) + } + embeddingStr := strings.Join(embeddingArr, ", ") + + queryString := fmt.Sprintf( + "UPDATE clubs SET embedding = '[%s]' WHERE id = '%s'", embeddingStr, club.ID.String()) + + if err := db.Exec(queryString).Error; err != nil { + return err + } + + return nil +} + +func createEmbedding(s *config.Settings, item string) ([]float32, error) { + embeddingBody, err := json.Marshal( + CreateEmbeddingRequestBody{ + Input: []string{item}, + Model: "text-embedding-ada-002", + }) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(fiber.MethodPost, + "https://api.openai.com/v1/embeddings", + bytes.NewBuffer(embeddingBody)) + if err != nil { + return nil, err + } + + req = utilities.ApplyModifiers(req, + utilities.Authorization(s.Integrations.Search.OpenAIApiKey.Expose()), + utilities.JSON(), + ) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + var embeddingResultBody CreateEmbeddingResponseBody + + err = json.NewDecoder(resp.Body).Decode(&embeddingResultBody) + if err != nil { + return nil, err + } + + if len(embeddingResultBody.Data) < 1 { + return nil, err + } + + return embeddingResultBody.Data[0], nil +}