-
Notifications
You must be signed in to change notification settings - Fork 120
/
lancedb_functions.py
55 lines (48 loc) · 1.67 KB
/
lancedb_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Union
from indexify.functions_sdk.indexify_functions import IndexifyFunction
from common_objects import ImageWithEmbedding, TextChunk
import lancedb
from lancedb.pydantic import LanceModel, Vector
from images import lance_image
class ImageEmbeddingTable(LanceModel):
vector: Vector(512)
image_bytes: bytes
page_number: int
class TextEmbeddingTable(LanceModel):
vector: Vector(384)
text: str
page_number: int
class LanceDBWriter(IndexifyFunction):
name = "lancedb_writer"
image = lance_image
def __init__(self):
super().__init__()
self._client = lancedb.connect("vectordb.lance")
self._text_table = self._client.create_table(
"text_embeddings", schema=TextEmbeddingTable, exist_ok=True
)
self._clip_table = self._client.create_table(
"image_embeddings", schema=ImageEmbeddingTable, exist_ok=True
)
def run(self, input: Union[ImageWithEmbedding, TextChunk]) -> bool:
if type(input) == ImageWithEmbedding:
self._clip_table.add(
[
ImageEmbeddingTable(
vector=input.embedding,
image_bytes=input.image_bytes,
page_number=input.page_number,
)
]
)
elif type(input) == TextChunk:
self._text_table.add(
[
TextEmbeddingTable(
vector=input.embeddings,
text=input.chunk,
page_number=input.page_number,
)
]
)
return True