-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
Copy pathsupabase_datastore.py
95 lines (81 loc) · 3.33 KB
/
supabase_datastore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from typing import Any, List
from datetime import datetime
from supabase import Client
from datastore.providers.pgvector_datastore import PGClient, PgVectorDataStore
from models.models import (
DocumentMetadataFilter,
)
SUPABASE_URL = os.environ.get("SUPABASE_URL")
assert SUPABASE_URL is not None, "SUPABASE_URL is not set"
SUPABASE_ANON_KEY = os.environ.get("SUPABASE_ANON_KEY")
# use service role key if you want this app to be able to bypass your Row Level Security policies
SUPABASE_SERVICE_ROLE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
assert (
SUPABASE_ANON_KEY is not None or SUPABASE_SERVICE_ROLE_KEY is not None
), "SUPABASE_ANON_KEY or SUPABASE_SERVICE_ROLE_KEY must be set"
# class that implements the DataStore interface for Supabase Datastore provider
class SupabaseDataStore(PgVectorDataStore):
def create_db_client(self):
return SupabaseClient()
class SupabaseClient(PGClient):
def __init__(self) -> None:
super().__init__()
if not SUPABASE_SERVICE_ROLE_KEY:
self.client = Client(SUPABASE_URL, SUPABASE_ANON_KEY)
else:
self.client = Client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)
async def upsert(self, table: str, json: dict[str, Any]):
"""
Takes in a list of documents and inserts them into the table.
"""
if "created_at" in json:
json["created_at"] = json["created_at"][0].isoformat()
self.client.table(table).upsert(json).execute()
async def rpc(self, function_name: str, params: dict[str, Any]):
"""
Calls a stored procedure in the database with the given parameters.
"""
if "in_start_date" in params:
params["in_start_date"] = params["in_start_date"].isoformat()
if "in_end_date" in params:
params["in_end_date"] = params["in_end_date"].isoformat()
response = self.client.rpc(function_name, params=params).execute()
return response.data
async def delete_like(self, table: str, column: str, pattern: str):
"""
Deletes rows in the table that match the pattern.
"""
self.client.table(table).delete().like(column, pattern).execute()
async def delete_in(self, table: str, column: str, ids: List[str]):
"""
Deletes rows in the table that match the ids.
"""
self.client.table(table).delete().in_(column, ids).execute()
async def delete_by_filters(self, table: str, filter: DocumentMetadataFilter):
"""
Deletes rows in the table that match the filter.
"""
builder = self.client.table(table).delete()
if filter.document_id:
builder = builder.eq(
"document_id",
filter.document_id,
)
if filter.source:
builder = builder.eq("source", filter.source)
if filter.source_id:
builder = builder.eq("source_id", filter.source_id)
if filter.author:
builder = builder.eq("author", filter.author)
if filter.start_date:
builder = builder.gte(
"created_at",
filter.start_date[0].isoformat(),
)
if filter.end_date:
builder = builder.lte(
"created_at",
filter.end_date[0].isoformat(),
)
builder.execute()