Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Club Vector Embeddings #180

Merged
merged 15 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions backend/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/spf13/viper v1.18.2
github.com/swaggo/swag v1.16.3
golang.org/x/crypto v0.19.0
golang.org/x/text v0.14.0
gorm.io/driver/postgres v1.5.6
gorm.io/gorm v1.25.7
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.6
)

require (
github.com/awnumar/memcall v0.2.0 // indirect
github.com/awnumar/memguard v0.22.4 // indirect
github.com/h2non/gock v1.2.0 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
)

require (
Expand Down
5 changes: 5 additions & 0 deletions backend/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/h2non/gock v1.2.0 h1:K6ol8rfrRkUOefooBC8elXoaNGYkpp7y2qcxGG6BzUE=
github.com/h2non/gock v1.2.0/go.mod h1:tNhoxHYW2W42cYkYb1WqzdbYIieALC99kpYr7rH/BQk=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/huandu/go-assert v1.1.6 h1:oaAfYxq9KNDi9qswn/6aE0EydfxSa+tWZC1KabNitYs=
Expand Down Expand Up @@ -97,6 +101,7 @@ github.com/mcnijman/go-emailaddress v1.1.1 h1:AGhgVDG3tCDaL0/Vc6erlPQjDuDN3dAT7r
github.com/mcnijman/go-emailaddress v1.1.1/go.mod h1:5whZrhS8Xp5LxO8zOD35BC+b76kROtsh+dPomeRt/II=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
Expand Down
4 changes: 4 additions & 0 deletions backend/src/errors/club.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ var (
StatusCode: fiber.StatusInternalServerError,
Message: "failed to get admin ids",
}
FailedToVectorizeClub = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to vectorize club",
}
FailedToGetClubFollowers = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to get club followers",
Expand Down
22 changes: 22 additions & 0 deletions backend/src/errors/search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package errors

import "github.com/gofiber/fiber/v2"

var (
FailedToCreateEmbedding = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to create embedding from string",
}
FailedToUpsertToPinecone = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to upsert to pinecone",
}
FailedToDeleteToPinecone = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to delete from pinecone",
}
FailedToSearchToPinecone = Error{
StatusCode: fiber.StatusInternalServerError,
Message: "failed to search on pinecone",
}
)
2 changes: 1 addition & 1 deletion backend/src/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (m *MiddlewareService) Authorize(requiredPermissions ...types.Permission) f
return errors.FailedToParseAccessToken.FiberError(c)
}

userPermissions := types.GetPermissions(models.UserRole(*role))
userPermissions := models.GetPermissions(models.UserRole(*role))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetPermissions/permissions in general is not a model, possibly not a type--can be put in the auth folder instead


for _, requiredPermission := range requiredPermissions {
if !slices.Contains(userPermissions, requiredPermission) {
Expand Down
12 changes: 12 additions & 0 deletions backend/src/models/club.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,15 @@ func (c *Club) AfterDelete(tx *gorm.DB) (err error) {
tx.Model(&c).Update("num_members", c.NumMembers-1)
return
}

func (c *Club) SearchId() string {
return c.ID.String()
}

func (c *Club) Namespace() string {
return "clubs"
}

func (c *Club) EmbeddingString() string {
return c.Name + " " + c.Name + " " + c.Name + " " + c.Name + " " + c.Description
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a formatted string; fmt.Sprintf

}
37 changes: 36 additions & 1 deletion backend/src/models/user.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package models

import "github.com/google/uuid"
import (
"github.com/GenerateNU/sac/backend/src/types"
"github.com/google/uuid"
)

type UserRole string

Expand Down Expand Up @@ -89,3 +92,35 @@ type UpdatePasswordRequestBody struct {
type CreateUserTagsBody struct {
Tags []uuid.UUID `json:"tags" validate:"required"`
}

var rolePermissions = map[UserRole][]types.Permission{
Super: {
types.UserRead, types.UserReadAll, types.UserWrite, types.UserDelete,
types.TagRead, types.TagCreate, types.TagWrite, types.TagDelete,
types.ClubRead, types.ClubCreate, types.ClubWrite, types.ClubDelete,
types.PointOfContactRead, types.PointOfContactCreate, types.PointOfContactWrite, types.PointOfContactDelete,
types.CommentRead, types.CommentCreate, types.CommentWrite, types.CommentDelete,
types.EventRead, types.EventCreate, types.EventWrite, types.EventDelete,
types.ContactRead, types.ContactCreate, types.ContactWrite, types.ContactDelete,
types.CategoryRead, types.CategoryCreate, types.CategoryWrite, types.CategoryDelete,
types.NotificationRead, types.NotificationCreate, types.NotificationWrite, types.NotificationDelete,
types.UserReadAll, types.TagReadAll, types.ClubReadAll, types.PointOfContactReadAll, types.CommentReadAll,
types.EventReadAll, types.ContactReadAll, types.CategoryReadAll, types.NotificationReadAll,
},
Student: {
types.UserRead,
types.TagRead,
types.ClubRead,
types.PointOfContactRead,
types.CommentRead,
types.EventRead,
types.ContactRead,
types.CategoryRead,
types.NotificationRead,
},
}

// Returns the permissions for a given role
func GetPermissions(role UserRole) []types.Permission {
return rolePermissions[role]
}
76 changes: 76 additions & 0 deletions backend/src/search/README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but solution should be on a new line

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Jargon

### Embeddings
**Problem**: We have arbitrary-dimension data, such as descriptions for clubs, or searches for
events. Given a piece of this arbitrary-dimension data (search, club desc.) we want to find
other arbitrary-dimension data that is similar to it; think 2 club descriptions where both clubs
are acapella groups, 2 search queries that are both effectively looking for professional
fraternities, etc. **Solution**: Transform the arbitrary-dimension data to fixed-dimension data,
say, a vector of floating-point numbers that is *n*-elements large. Make the transformation in
such a way that similar arbitrary-dimension pieces of data will also have similar
fixed-dimension data, i.e vectors that are close together (think Euclidean distance). **How do
we do this transformation**: Train a machine learning model on large amounts of text, and then
use the model to make vectors. **So what's an embedding?** Formally, when we
refer to the embedding of a particular object, we refer to the vector created by feeding that
object through the machine-learning model.

This is arguably the most complex/unintuitive part of understanding search, so here are some extra
resources:
- [What are embeddings?](https://www.cloudflare.com/learning/ai/what-are-embeddings/)
- [fastai book - Chapters 10 and 12 are both about natural language processing](https://github.com/fastai/fastbook)
- [Vector Embeddings for Developers: The Basics](https://www.pinecone.io/learn/vector-embeddings-for-developers/)

### OpenAI API
**Problem:**: We need a machine learning model to create the embeddings. **Solution:** Use
OpenAI's api to create the embeddings for us; we send text over a REST api and we get a back a
vector that represents that text's embedding.

### PineconeDB
**Problem**: We've created a bunch of embeddings for our club descriptions (or event
descriptions, etc.), we now need a place to store them and a way to search through them (with an
embedding for a search query) **Solution**: PineconeDB is a vector database that allows us to
upload our embeddings and then query them by giving a vector to find similar ones to.

# How to create searchable objects for fun and fame and profit

```golang
package search

// in backend/search/searchable.go
type Searchable interface {
SearchId() string
Namespace() string
EmbeddingString() string
}

// in backend/search/pinecone.go
type PineconeClientInterface interface {
Upsert(item Searchable) *errors.Error
Delete(item Searchable) *errors.Error
Search(item Searchable, topK int) ([]string, *errors.Error)
}
```

1. Implement the `Searchable` interface on whatever model you want to make searchable.
`Searchable` requires 3 methods:
- `SearchId()`: This should return a unique id that can be used to store a model entry's
embedding (if you want to store it at all) in PineconeDB. In practice, this should be the
entry's UUID.
- `Namespace()`: Namespaces are to PineconeDB what tables are to PostgreSQL. Searching in
one namespace will only retrieve vectors in that namespace. In practice, this should be
unique to the model type (i.e `Club`, `Event`, etc.)
- `EmbeddingString()`: This should return the string you want to feed into the OpenAI API
and create an embedding for. In practice, create a string with the fields you think will
affect the embedding all appended together, and/or try repeating a field multiple times in
the string to see if that gives a better search experience.
2. Use a `PineconeClientInterface` and call `Upsert` with your searchable object to send it to the
database, and `Delete` with your searchable object to delete it from the database. Upserts
should be done on creation and updating of a model entry, and deletes should be done on
deleting of a model entry. In practice, a `PineconeClientInterface` should be passed in to
the various services in `backend/server.go`, similar to how `*gorm.DB` and `*validator.
Validator` instances are passed in.

# How to search for fun and fame and profit

TODO: (probably create a searchable object that just uses namespace and embeddingstring, pass to
pineconeclient search)
73 changes: 73 additions & 0 deletions backend/src/search/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package search

import (
"bytes"
"encoding/json"
"fmt"
"github.com/GenerateNU/sac/backend/src/errors"
"github.com/garrettladley/mattress"
"net/http"
"os"
)

type OpenAiClientInterface interface {
CreateEmbedding(payload string) ([]float32, *errors.Error)
}

type OpenAiClient struct {
apiKey *mattress.Secret[string]
}

func NewOpenAiClient() *OpenAiClient {
apiKey, _ := mattress.NewSecret(os.Getenv("SAC_OPENAI_API_KEY"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be in config as well, i assume @garrettladley


return &OpenAiClient{apiKey: apiKey}
}

func (c *OpenAiClient) CreateEmbedding(payload string) ([]float32, *errors.Error) {
apiKey := c.apiKey.Expose()

embeddingBody, _ := json.Marshal(map[string]interface{}{
"input": payload,
"model": "text-embedding-ada-002",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer this to be a config option,

})
requestBody := bytes.NewBuffer(embeddingBody)

req, err := http.NewRequest("POST", fmt.Sprintf("https://api.openai.com/v1/embeddings"), requestBody)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines below are repeated over in your guys code, would prefer this to be generalized and exported into a utility function

if err != nil {
return nil, &errors.FailedToCreateEmbedding
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
req.Header.Set("content-type", "application/json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, &errors.FailedToCreateEmbedding
}

defer resp.Body.Close()

if err != nil {
return nil, &errors.FailedToCreateEmbedding
}

type ResponseBody struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}

embeddingResultBody := ResponseBody{}
err = json.NewDecoder(resp.Body).Decode(&embeddingResultBody)
if err != nil {
return nil, &errors.FailedToCreateEmbedding
}

if len(embeddingResultBody.Data) < 1 {
return nil, &errors.FailedToCreateEmbedding
}

return embeddingResultBody.Data[0].Embedding, nil

}
Loading
Loading