diff --git a/Dockerfile b/Dockerfile index e4fef14..e0dd44a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,4 +29,4 @@ RUN prisma generate EXPOSE 8000 # Command to run the application -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/Dockerfile-airgapped b/Dockerfile-airgapped new file mode 100644 index 0000000..fbb0330 --- /dev/null +++ b/Dockerfile-airgapped @@ -0,0 +1,55 @@ +############################## +# Stage 1: builder (ONLINE) +############################## +FROM python:3.11-slim-bookworm AS builder +ENV PYTHONDONTWRITEBYTECODE=1 PIP_NO_CACHE_DIR=1 +WORKDIR /app + +# minimal system deps +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# isolated venv we’ll copy into the runtime +RUN python -m venv /venv +ENV PATH="/venv/bin:$PATH" + +# install deps +COPY requirements.txt . +RUN pip install -r requirements.txt + +# bring in app code +COPY . . + +# Bake Prisma engines into a deterministic cache dir, and pre-generate client +# NOTE: we DO NOT set PRISMA_QUERY_ENGINE_BINARY here; we only cache binaries. +ENV PRISMA_BINARY_CACHE_DIR=/opt/prisma-engines +RUN mkdir -p "$PRISMA_BINARY_CACHE_DIR" \ + && python -m prisma py fetch \ + && python -m prisma generate --schema=prisma/schema.prisma + +############################## +# Stage 2: runtime (AIR-GAPPED) +############################## +FROM python:3.11-slim-bookworm AS runtime +ENV PYTHONDONTWRITEBYTECODE=1 PIP_DISABLE_PIP_VERSION_CHECK=1 +WORKDIR /app + +# copy venv with installed deps + generated client +COPY --from=builder /venv /venv +ENV PATH="/venv/bin:$PATH" + +# copy app code +COPY . . + +# copy pre-fetched Prisma engines +COPY --from=builder /opt/prisma-engines /opt/prisma-engines + +# Tell Prisma Python where the baked-in cache lives (no network at runtime) +# Important: don't set PRISMA_QUERY_ENGINE_BINARY here; let Prisma pick from the cache. +ENV PRISMA_BINARY_CACHE_DIR=/opt/prisma-engines \ + PRISMA_HIDE_UPDATE_MESSAGE=true + +EXPOSE 8000 +ENV HOST=0.0.0.0 PORT=8000 +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] + diff --git a/k8s/deployment.yaml b/k8s/deployment.yaml new file mode 100644 index 0000000..03380fd --- /dev/null +++ b/k8s/deployment.yaml @@ -0,0 +1,53 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: litellm-pgvector + namespace: {{ YOUR_NAMESPACE }} +spec: + replicas: 1 + selector: + matchLabels: + app: litellm-pgvector + template: + metadata: + labels: + app: litellm-pgvector + spec: + containers: + - name: api + image: {{ YOUR_IMAGE }} + imagePullPolicy: IfNotPresent + envFrom: + - secretRef: + name: litellm-pgvector-env + # do NOT run prisma generate at runtime + env: + - name: RUN_DB_PUSH + value: "false" + ports: + - name: http + containerPort: 8000 + readinessProbe: + httpGet: + path: /health + port: http + initialDelaySeconds: 10 + periodSeconds: 5 + timeoutSeconds: 3 + failureThreshold: 6 + livenessProbe: + httpGet: + path: /health + port: http + initialDelaySeconds: 20 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 6 + resources: + requests: + cpu: "250m" + memory: "512Mi" + limits: + cpu: "1" + memory: "1Gi" + diff --git a/k8s/secret.yaml b/k8s/secret.yaml new file mode 100644 index 0000000..f8f5da9 --- /dev/null +++ b/k8s/secret.yaml @@ -0,0 +1,33 @@ +apiVersion: v1 +kind: Secret +metadata: + name: litellm-pgvector-env + namespace: {{ YOUR_NAMESPACE }} +type: Opaque +stringData: + # Server + SERVER_API_KEY: "your-api-key-here" + # Database Configuration + DATABASE_URL: "postgresql://username:password@localhost:5432/vectordb?schema=public" + + # API Configuration + OPENAI_API_KEY: "your-api-key-here" + + # Server Configuration + HOST: "0.0.0.0" + PORT: 8000 + + # LiteLLM Proxy Configuration + EMBEDDING__MODEL: "text-embedding-ada-002" + EMBEDDING__BASE_URL: "http://localhost:4000" + EMBEDDING__API_KEY: "sk-1234" + EMBEDDING__DIMENSIONS: 1536 + + # Database Field Configuration (optional) + DB_FIELDS__ID_FIELD: "id" + DB_FIELDS__CONTENT_FIELD: "content" + DB_FIELDS__METADATA_FIELD: "metadata" + DB_FIELDS__EMBEDDING_FIELD: "embedding" + DB_FIELDS__VECTOR_STORE_ID_FIELD: "vector_store_id" + DB_FIELDS__CREATED_AT_FIELD: "created_at" + diff --git a/k8s/service.yaml b/k8s/service.yaml new file mode 100644 index 0000000..40e351b --- /dev/null +++ b/k8s/service.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Service +metadata: + name: litellm-pgvector-svc + namespace: {{ YOUR_NAMESPACE }} +spec: + type: LoadBalancer + selector: + app: litellm-pgvector + ports: + - name: http + port: 8000 + targetPort: http + diff --git a/main.py b/main.py index eee3bbc..eee7311 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,7 @@ import os -import asyncio import time from typing import List, Optional -from fastapi import FastAPI, HTTPException, Depends, Header +from fastapi import FastAPI, HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.cors import CORSMiddleware from prisma import Prisma @@ -19,7 +18,7 @@ EmbeddingBatchCreateRequest, EmbeddingBatchCreateResponse, VectorStoreListResponse, - ContentChunk + ContentChunk, ) from config import settings from embedding_service import embedding_service @@ -29,10 +28,10 @@ app = FastAPI( title="OpenAI Vector Stores API", description="OpenAI-compatible Vector Stores API using PGVector", - version="1.0.0" + version="1.0.0", ) -# CORS middleware +# CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -41,14 +40,11 @@ allow_headers=["*"], ) -# Global Prisma client db = Prisma() - security = HTTPBearer() async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): - """Validate API key from Authorization header""" expected_key = settings.server_api_key if credentials.credentials != expected_key: raise HTTPException(status_code=401, detail="Invalid API key") @@ -57,73 +53,73 @@ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(securi @app.on_event("startup") async def startup(): - """Connect to database on startup""" await db.connect() @app.on_event("shutdown") async def shutdown(): - """Disconnect from database on shutdown""" await db.disconnect() async def generate_query_embedding(query: str) -> List[float]: - """ - Generate an embedding for the query using LiteLLM - """ return await embedding_service.generate_embedding(query) +# ----------------------------- +# Vector Stores +# ----------------------------- @app.post("/v1/vector_stores", response_model=VectorStoreResponse) async def create_vector_store( - request: VectorStoreCreateRequest, - api_key: str = Depends(get_api_key) + request: VectorStoreCreateRequest, api_key: str = Depends(get_api_key) ): - """ - Create a new vector store. - """ try: - # Use raw SQL to insert the vector store with configurable table/field names - vector_store_table = settings.table_names["vector_stores"] - - result = await db.query_raw( + table = settings.table_names["vector_stores"] + + res = await db.query_raw( f""" - INSERT INTO {vector_store_table} (id, name, file_counts, status, usage_bytes, expires_after, metadata, created_at) - VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, NOW()) - RETURNING id, name, file_counts, status, usage_bytes, expires_after, expires_at, last_active_at, metadata, - EXTRACT(EPOCH FROM created_at)::bigint as created_at_timestamp + INSERT INTO {table} + (id, name, file_counts, status, usage_bytes, expires_after, metadata, created_at) + VALUES + (gen_random_uuid(), $1, + $2, $3, $4, $5, $6, NOW()) + RETURNING + id AS id, + name AS name, + file_counts AS file_counts, + status AS status, + usage_bytes AS usage_bytes, + expires_after AS expires_after, + EXTRACT(EPOCH FROM expires_at)::bigint AS expires_at_ts, + EXTRACT(EPOCH FROM last_active_at)::bigint AS last_active_at_ts, + metadata AS metadata, + EXTRACT(EPOCH FROM created_at)::bigint AS created_at_ts """, request.name, {"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}, "completed", 0, request.expires_after, - request.metadata or {} + request.metadata or {}, ) - - if not result: + if not res: raise HTTPException(status_code=500, detail="Failed to create vector store") - - vector_store = result[0] - - # Convert to response format - created_at = int(vector_store["created_at_timestamp"]) - expires_at = int(vector_store["expires_at"].timestamp()) if vector_store.get("expires_at") else None - last_active_at = int(vector_store["last_active_at"].timestamp()) if vector_store.get("last_active_at") else None - + + row = res[0] return VectorStoreResponse( - id=vector_store["id"], - created_at=created_at, - name=vector_store["name"], - usage_bytes=vector_store["usage_bytes"] or 0, - file_counts=vector_store["file_counts"] or {"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}, - status=vector_store["status"], - expires_after=vector_store["expires_after"], - expires_at=expires_at, - last_active_at=last_active_at, - metadata=vector_store["metadata"] + id=row["id"], + created_at=int(row["created_at_ts"]) if row.get("created_at_ts") is not None else int(time.time()), + name=row["name"], + usage_bytes=row["usage_bytes"] or 0, + file_counts=row["file_counts"] + or {"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}, + status=row["status"], + expires_after=row["expires_after"], + expires_at=int(row["expires_at_ts"]) if row.get("expires_at_ts") is not None else None, + last_active_at=int(row["last_active_at_ts"]) if row.get("last_active_at_ts") is not None else None, + metadata=row["metadata"], ) - + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to create vector store: {str(e)}") @@ -133,390 +129,324 @@ async def list_vector_stores( limit: Optional[int] = 20, after: Optional[str] = None, before: Optional[str] = None, - api_key: str = Depends(get_api_key) + api_key: str = Depends(get_api_key), ): - """ - List vector stores with optional pagination. - """ try: - limit = min(limit or 20, 100) # Cap at 100 results - - vector_store_table = settings.table_names["vector_stores"] - - # Build base query - base_query = f""" - SELECT id, name, file_counts, status, usage_bytes, expires_after, expires_at, last_active_at, metadata, - EXTRACT(EPOCH FROM created_at)::bigint as created_at_timestamp - FROM {vector_store_table} + limit = min(limit or 20, 100) + table = settings.table_names["vector_stores"] + + base = f""" + SELECT + id AS id, + name AS name, + file_counts AS file_counts, + status AS status, + usage_bytes AS usage_bytes, + expires_after AS expires_after, + EXTRACT(EPOCH FROM expires_at)::bigint AS expires_at_ts, + EXTRACT(EPOCH FROM last_active_at)::bigint AS last_active_at_ts, + metadata AS metadata, + EXTRACT(EPOCH FROM created_at)::bigint AS created_at_ts + FROM {table} """ - - # Add pagination conditions - conditions = [] + + clauses = [] params = [] - param_count = 1 - + i = 1 if after: - conditions.append(f"id > ${param_count}") + clauses.append(f"id > ${i}") params.append(after) - param_count += 1 - + i += 1 if before: - conditions.append(f"id < ${param_count}") + clauses.append(f"id < ${i}") params.append(before) - param_count += 1 - - if conditions: - base_query += " WHERE " + " AND ".join(conditions) - - # Add ordering and limit - final_query = base_query + f" ORDER BY created_at DESC LIMIT {limit + 1}" - - # Execute query - results = await db.query_raw(final_query, *params) - - # Check if there are more results - has_more = len(results) > limit + i += 1 + if clauses: + base += " WHERE " + " AND ".join(clauses) + + sql = base + f" ORDER BY created_at DESC LIMIT {limit + 1}" + rows = await db.query_raw(sql, *params) + + has_more = len(rows) > limit if has_more: - results = results[:limit] # Remove extra result - - # Convert to response format - vector_stores = [] - for row in results: - created_at = int(row["created_at_timestamp"]) - expires_at = int(row["expires_at"].timestamp()) if row.get("expires_at") else None - last_active_at = int(row["last_active_at"].timestamp()) if row.get("last_active_at") else None - - vector_store = VectorStoreResponse( - id=row["id"], - created_at=created_at, - name=row["name"], - usage_bytes=row["usage_bytes"] or 0, - file_counts=row["file_counts"] or {"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}, - status=row["status"], - expires_after=row["expires_after"], - expires_at=expires_at, - last_active_at=last_active_at, - metadata=row["metadata"] + rows = rows[:limit] + + data: List[VectorStoreResponse] = [] + for r in rows: + data.append( + VectorStoreResponse( + id=r["id"], + created_at=int(r["created_at_ts"]) if r.get("created_at_ts") is not None else None, + name=r["name"], + usage_bytes=r["usage_bytes"] or 0, + file_counts=r["file_counts"] + or {"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}, + status=r["status"], + expires_after=r["expires_after"], + expires_at=int(r["expires_at_ts"]) if r.get("expires_at_ts") is not None else None, + last_active_at=int(r["last_active_at_ts"]) if r.get("last_active_at_ts") is not None else None, + metadata=r["metadata"], + ) ) - vector_stores.append(vector_store) - - # Determine first_id and last_id - first_id = vector_stores[0].id if vector_stores else None - last_id = vector_stores[-1].id if vector_stores else None - - return VectorStoreListResponse( - data=vector_stores, - first_id=first_id, - last_id=last_id, - has_more=has_more - ) - + + first_id = data[0].id if data else None + last_id = data[-1].id if data else None + + return VectorStoreListResponse(data=data, first_id=first_id, last_id=last_id, has_more=has_more) + except HTTPException: + raise except Exception as e: import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed to list vector stores: {str(e)}") +# ----------------------------- +# Search +# ----------------------------- @app.post("/v1/vector_stores/{vector_store_id}/search", response_model=VectorStoreSearchResponse) @app.post("/vector_stores/{vector_store_id}/search", response_model=VectorStoreSearchResponse) async def search_vector_store( - vector_store_id: str, - request: VectorStoreSearchRequest, - api_key: str = Depends(get_api_key) + vector_store_id: str, request: VectorStoreSearchRequest, api_key: str = Depends(get_api_key) ): - """ - Search a vector store for similar content. - """ try: - # Check if vector store exists - vector_store_table = settings.table_names["vector_stores"] - vector_store_result = await db.query_raw( - f"SELECT id FROM {vector_store_table} WHERE id = $1", - vector_store_id - ) - if not vector_store_result: + vs_table = settings.table_names["vector_stores"] + found = await db.query_raw(f"SELECT id AS id FROM {vs_table} WHERE id = $1", vector_store_id) + if not found: raise HTTPException(status_code=404, detail="Vector store not found") - - # Generate embedding for query + query_embedding = await generate_query_embedding(request.query) - query_vector_str = "[" + ",".join(map(str, query_embedding)) + "]" - - # Build the raw SQL query for vector similarity search - limit = min(request.limit or 20, 100) # Cap at 100 results - - # Base query with vector similarity using cosine distance - # Use configurable field names + query_vec = "[" + ",".join(map(str, query_embedding)) + "]" + + limit = min(request.limit or 20, 100) fields = settings.db_fields - table_name = settings.table_names["embeddings"] - - # Build query with proper parameter placeholders for Prisma - param_count = 1 - query_params = [query_vector_str, vector_store_id] - - base_query = f""" - SELECT - {fields.id_field}, - {fields.content_field}, - {fields.metadata_field}, - ({fields.embedding_field} <=> ${param_count}::vector) as distance - FROM {table_name} - WHERE {fields.vector_store_id_field} = ${param_count + 1} + emb_table = settings.table_names["embeddings"] + + i = 1 + params = [query_vec, vector_store_id] + base = f""" + SELECT + {fields.id_field} AS id, + {fields.content_field} AS content, + {fields.metadata_field} AS metadata, + ({fields.embedding_field} <=> ${i}::vector) AS distance + FROM {emb_table} + WHERE {fields.vector_store_id_field} = ${i + 1} """ - param_count += 2 - - # Add metadata filters if provided - filter_conditions = [] - + i += 2 + + filters = [] if request.filters: - for key, value in request.filters.items(): - filter_conditions.append(f"{fields.metadata_field}->>${param_count} = ${param_count + 1}") - query_params.extend([key, str(value)]) - param_count += 2 - - if filter_conditions: - base_query += " AND " + " AND ".join(filter_conditions) - - # Add ordering and limit - final_query = base_query + f" ORDER BY distance ASC LIMIT {limit}" - - # Execute the query - results = await db.query_raw(final_query, *query_params) - - # Convert results to SearchResult objects - search_results = [] - for row in results: - # Convert distance to similarity score (1 - normalized_distance) - # Cosine distance ranges from 0 (identical) to 2 (opposite) - similarity_score = max(0, 1 - (row['distance'] / 2)) - - # Extract filename from metadata or use a default - metadata = row[fields.metadata_field] or {} - filename = metadata.get('filename', 'document.txt') - - content_chunks = [ContentChunk(type="text", text=row[fields.content_field])] - - result = SearchResult( - file_id=row[fields.id_field], - filename=filename, - score=similarity_score, - attributes=metadata if request.return_metadata else None, - content=content_chunks + for k, v in request.filters.items(): + filters.append(f"{fields.metadata_field}->>${i} = ${i + 1}") + params.extend([k, str(v)]) + i += 2 + if filters: + base += " AND " + " AND ".join(filters) + + sql = base + f" ORDER BY distance ASC LIMIT {limit}" + rows = await db.query_raw(sql, *params) + + results: List[SearchResult] = [] + for r in rows: + similarity = max(0, 1 - (r["distance"] / 2)) + metadata = r["metadata"] or {} + filename = metadata.get("filename", "document.txt") + content_chunks = [ContentChunk(type="text", text=r["content"])] + results.append( + SearchResult( + file_id=r["id"], + filename=filename, + score=similarity, + attributes=metadata if request.return_metadata else None, + content=content_chunks, + ) ) - search_results.append(result) - - return VectorStoreSearchResponse( - search_query=request.query, - data=search_results, - has_more=False, # TODO: Implement pagination - next_page=None - ) - + + return VectorStoreSearchResponse(search_query=request.query, data=results, has_more=False, next_page=None) except HTTPException: raise except Exception as e: import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") +# ----------------------------- +# Embeddings +# ----------------------------- @app.post("/v1/vector_stores/{vector_store_id}/embeddings", response_model=EmbeddingResponse) async def create_embedding( - vector_store_id: str, - request: EmbeddingCreateRequest, - api_key: str = Depends(get_api_key) + vector_store_id: str, request: EmbeddingCreateRequest, api_key: str = Depends(get_api_key) ): - """ - Add a single embedding to a vector store. - """ try: - # Check if vector store exists - vector_store_table = settings.table_names["vector_stores"] - vector_store_result = await db.query_raw( - f"SELECT id FROM {vector_store_table} WHERE id = $1", - vector_store_id - ) - if not vector_store_result: + vs_table = settings.table_names["vector_stores"] + ok = await db.query_raw(f"SELECT id AS id FROM {vs_table} WHERE id = $1", vector_store_id) + if not ok: raise HTTPException(status_code=404, detail="Vector store not found") - - # Convert embedding to vector string format - embedding_vector_str = "[" + ",".join(map(str, request.embedding)) + "]" - - # Insert embedding using configurable field names + + emb_vec = "[" + ",".join(map(str, request.embedding)) + "]" fields = settings.db_fields - table_name = settings.table_names["embeddings"] - - result = await db.query_raw( + emb_table = settings.table_names["embeddings"] + + res = await db.query_raw( f""" - INSERT INTO {table_name} ({fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, - {fields.embedding_field}, {fields.metadata_field}, {fields.created_at_field}) - VALUES (gen_random_uuid(), $1, $2, $3::vector, $4, NOW()) - RETURNING {fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, - {fields.metadata_field}, EXTRACT(EPOCH FROM {fields.created_at_field})::bigint as created_at_timestamp + INSERT INTO {emb_table} + ({fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, + {fields.embedding_field}, {fields.metadata_field}, {fields.created_at_field}) + VALUES + (gen_random_uuid(), $1, $2, $3::vector, $4, NOW()) + RETURNING + {fields.id_field} AS id, + {fields.vector_store_id_field} AS vector_store_id, + {fields.content_field} AS content, + {fields.metadata_field} AS metadata, + EXTRACT(EPOCH FROM {fields.created_at_field})::bigint AS created_at_ts """, vector_store_id, request.content, - embedding_vector_str, - request.metadata or {} + emb_vec, + request.metadata or {}, ) - - if not result: + if not res: raise HTTPException(status_code=500, detail="Failed to create embedding") - - embedding = result[0] - - # Update vector store statistics + + row = res[0] + + # single assignment to file_counts: +1 to completed & total await db.query_raw( f""" - UPDATE {vector_store_table} + UPDATE {vs_table} SET file_counts = jsonb_set( - COALESCE(file_counts, '{{"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}}'::jsonb), - '{{completed}}', - (COALESCE(file_counts->>'completed', '0')::int + 1)::text::jsonb - ), - file_counts = jsonb_set( - file_counts, - '{{total}}', - (COALESCE(file_counts->>'total', '0')::int + 1)::text::jsonb - ), + jsonb_set( + COALESCE(file_counts, '{{"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}}'::jsonb), + '{{completed}}', + to_jsonb(COALESCE(file_counts->>'completed', '0')::int + 1) + ), + '{{total}}', + to_jsonb(COALESCE(file_counts->>'total', '0')::int + 1) + ), usage_bytes = COALESCE(usage_bytes, 0) + LENGTH($2), last_active_at = NOW() WHERE id = $1 """, vector_store_id, - request.content + request.content, ) - + return EmbeddingResponse( - id=embedding[fields.id_field], - vector_store_id=embedding[fields.vector_store_id_field], - content=embedding[fields.content_field], - metadata=embedding[fields.metadata_field], - created_at=int(embedding["created_at_timestamp"]) + id=row["id"], + vector_store_id=row["vector_store_id"], + content=row["content"], + metadata=row["metadata"], + created_at=int(row["created_at_ts"]) if row.get("created_at_ts") is not None else int(time.time()), ) - except HTTPException: raise except Exception as e: import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed to create embedding: {str(e)}") @app.post("/v1/vector_stores/{vector_store_id}/embeddings/batch", response_model=EmbeddingBatchCreateResponse) async def create_embeddings_batch( - vector_store_id: str, - request: EmbeddingBatchCreateRequest, - api_key: str = Depends(get_api_key) + vector_store_id: str, request: EmbeddingBatchCreateRequest, api_key: str = Depends(get_api_key) ): - """ - Add multiple embeddings to a vector store in batch. - """ try: - # Check if vector store exists - vector_store_table = settings.table_names["vector_stores"] - vector_store_result = await db.query_raw( - f"SELECT id FROM {vector_store_table} WHERE id = $1", - vector_store_id - ) - if not vector_store_result: + vs_table = settings.table_names["vector_stores"] + ok = await db.query_raw(f"SELECT id AS id FROM {vs_table} WHERE id = $1", vector_store_id) + if not ok: raise HTTPException(status_code=404, detail="Vector store not found") - + if not request.embeddings: raise HTTPException(status_code=400, detail="No embeddings provided") - - # Prepare batch insert + fields = settings.db_fields - table_name = settings.table_names["embeddings"] - - # Build VALUES clause for batch insert - values_clauses = [] + emb_table = settings.table_names["embeddings"] + + vals = [] params = [] - param_count = 1 - - for embedding_req in request.embeddings: - embedding_vector_str = "[" + ",".join(map(str, embedding_req.embedding)) + "]" - values_clauses.append(f"(gen_random_uuid(), ${param_count}, ${param_count + 1}, ${param_count + 2}::vector, ${param_count + 3}, NOW())") - params.extend([ - vector_store_id, - embedding_req.content, - embedding_vector_str, - embedding_req.metadata or {} - ]) - param_count += 4 - - values_clause = ", ".join(values_clauses) - - # Execute batch insert - result = await db.query_raw( + i = 1 + for e in request.embeddings: + vec = "[" + ",".join(map(str, e.embedding)) + "]" + vals.append(f"(gen_random_uuid(), ${i}, ${i+1}, ${i+2}::vector, ${i+3}, NOW())") + params.extend([vector_store_id, e.content, vec, e.metadata or {}]) + i += 4 + + res = await db.query_raw( f""" - INSERT INTO {table_name} ({fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, - {fields.embedding_field}, {fields.metadata_field}, {fields.created_at_field}) - VALUES {values_clause} - RETURNING {fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, - {fields.metadata_field}, EXTRACT(EPOCH FROM {fields.created_at_field})::bigint as created_at_timestamp + INSERT INTO {emb_table} + ({fields.id_field}, {fields.vector_store_id_field}, {fields.content_field}, + {fields.embedding_field}, {fields.metadata_field}, {fields.created_at_field}) + VALUES + {", ".join(vals)} + RETURNING + {fields.id_field} AS id, + {fields.vector_store_id_field} AS vector_store_id, + {fields.content_field} AS content, + {fields.metadata_field} AS metadata, + EXTRACT(EPOCH FROM {fields.created_at_field})::bigint AS created_at_ts """, - *params + *params, ) - - if not result: + if not res: raise HTTPException(status_code=500, detail="Failed to create embeddings") - - # Calculate total content length for usage bytes update - total_content_length = sum(len(emb.content) for emb in request.embeddings) - - # Update vector store statistics + + total_len = sum(len(e.content) for e in request.embeddings) + batch_size = len(request.embeddings) + await db.query_raw( f""" - UPDATE {vector_store_table} + UPDATE {vs_table} SET file_counts = jsonb_set( - COALESCE(file_counts, '{{"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}}'::jsonb), - '{{completed}}', - (COALESCE(file_counts->>'completed', '0')::int + $2)::text::jsonb - ), - file_counts = jsonb_set( - file_counts, - '{{total}}', - (COALESCE(file_counts->>'total', '0')::int + $2)::text::jsonb - ), + jsonb_set( + COALESCE(file_counts, '{{"in_progress": 0, "completed": 0, "failed": 0, "cancelled": 0, "total": 0}}'::jsonb), + '{{completed}}', + to_jsonb(COALESCE(file_counts->>'completed', '0')::int + $2) + ), + '{{total}}', + to_jsonb(COALESCE(file_counts->>'total', '0')::int + $2) + ), usage_bytes = COALESCE(usage_bytes, 0) + $3, last_active_at = NOW() WHERE id = $1 """, vector_store_id, - len(request.embeddings), - total_content_length - ) - - # Convert results to response format - embeddings = [] - for row in result: - embeddings.append(EmbeddingResponse( - id=row[fields.id_field], - vector_store_id=row[fields.vector_store_id_field], - content=row[fields.content_field], - metadata=row[fields.metadata_field], - created_at=int(row["created_at_timestamp"]) - )) - - return EmbeddingBatchCreateResponse( - data=embeddings, - created=int(time.time()) + batch_size, + total_len, ) - + + out = [ + EmbeddingResponse( + id=r["id"], + vector_store_id=r["vector_store_id"], + content=r["content"], + metadata=r["metadata"], + created_at=int(r["created_at_ts"]) if r.get("created_at_ts") is not None else int(time.time()), + ) + for r in res + ] + return EmbeddingBatchCreateResponse(data=out, created=int(time.time())) except HTTPException: raise except Exception as e: import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed to create embeddings batch: {str(e)}") @app.get("/health") async def health_check(): - """Health check endpoint""" return {"status": "healthy", "timestamp": int(time.time())} if __name__ == "__main__": import uvicorn - uvicorn.run("main:app", host=settings.host, port=settings.port, reload=True) \ No newline at end of file + + uvicorn.run("main:app", host=settings.host, port=settings.port, reload=True) +