-
Notifications
You must be signed in to change notification settings - Fork 8
/
ent_test.go
77 lines (67 loc) · 2.12 KB
/
ent_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package pgvector_test
import (
"context"
"reflect"
"testing"
"entgo.io/ent/dialect/sql"
_ "github.com/lib/pq"
"github.com/pgvector/pgvector-go"
"github.com/pgvector/pgvector-go/ent"
)
func TestEnt(t *testing.T) {
ctx := context.Background()
client, err := ent.Open("postgres", "postgres://localhost/pgvector_go_test?sslmode=disable")
if err != nil {
panic(err)
}
defer client.Close()
_, err = client.ExecContext(ctx, "CREATE EXTENSION IF NOT EXISTS vector")
if err != nil {
panic(err)
}
_, err = client.ExecContext(ctx, "DROP TABLE IF EXISTS items")
if err != nil {
panic(err)
}
err = client.Schema.Create(ctx)
if err != nil {
panic(err)
}
embedding := pgvector.NewVector([]float32{1, 1, 1})
halfEmbedding := pgvector.NewHalfVector([]float32{1, 1, 1})
binaryEmbedding := "000"
sparseEmbedding := pgvector.NewSparseVector([]float32{1, 1, 1})
_, err = client.Item.Create().SetEmbedding(embedding).SetHalfEmbedding(halfEmbedding).SetBinaryEmbedding(binaryEmbedding).SetSparseEmbedding(sparseEmbedding).Save(ctx)
if err != nil {
panic(err)
}
_, err = client.Item.CreateBulk(
client.Item.Create().SetEmbedding(pgvector.NewVector([]float32{2, 2, 2})).SetHalfEmbedding(pgvector.NewHalfVector([]float32{2, 2, 2})).SetBinaryEmbedding("101").SetSparseEmbedding(pgvector.NewSparseVector([]float32{2, 2, 2})),
client.Item.Create().SetEmbedding(pgvector.NewVector([]float32{1, 1, 2})).SetHalfEmbedding(pgvector.NewHalfVector([]float32{1, 1, 2})).SetBinaryEmbedding("111").SetSparseEmbedding(pgvector.NewSparseVector([]float32{1, 1, 2})),
).Save(ctx)
if err != nil {
panic(err)
}
items, err := client.Item.
Query().
Order(func(s *sql.Selector) {
s.OrderExpr(sql.ExprP("embedding <-> $1", embedding))
}).
Limit(5).
All(ctx)
if err != nil {
panic(err)
}
if items[0].ID != 1 || items[1].ID != 3 || items[2].ID != 2 {
t.Error()
}
if !reflect.DeepEqual(items[1].Embedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
if !reflect.DeepEqual(items[1].HalfEmbedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
if !reflect.DeepEqual(items[1].SparseEmbedding.Slice(), []float32{1, 1, 2}) {
t.Error()
}
}