Skip to content

Commit 8cff360

Browse files
authored
Fix for max_chunk_size; restructure databases (server); update mxbai to 512 (#332)
1 parent 0590309 commit 8cff360

File tree

10 files changed

+630
-541
lines changed

10 files changed

+630
-541
lines changed

src/client/content/config/tabs/models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,16 +200,19 @@ def _render_model_specific_config(model: dict, model_type: str, provider_models:
200200
value=max_tokens,
201201
)
202202
else:
203-
output_vector_size = next(
204-
(m.get("output_vector_size", 8191) for m in provider_models if m.get("key") == model["id"]),
205-
model.get("output_vector_size", 8191),
206-
)
203+
# First try to get max_chunk_size from the model, then fall back to output_vector_size from provider
204+
max_chunk_size = model.get("max_chunk_size")
205+
if max_chunk_size is None:
206+
max_chunk_size = next(
207+
(m.get("max_chunk_size", 8192) for m in provider_models if m.get("key") == model["id"]),
208+
8192,
209+
)
207210
model["max_chunk_size"] = st.number_input(
208211
"Max Chunk Size:",
209212
help=help_text.help_dict["chunk_size"],
210213
min_value=0,
211214
key="add_model_max_chunk_size",
212-
value=output_vector_size,
215+
value=max_chunk_size,
213216
)
214217

215218
return model

src/server/api/core/databases.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

src/server/api/utils/databases.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import oracledb
1010
from langchain_community.vectorstores import oraclevs as LangchainVS
1111

12-
import server.api.core.databases as core_databases
1312
import server.api.core.settings as core_settings
13+
from server.bootstrap.bootstrap import DATABASE_OBJECTS
1414

1515
from common.schema import (
1616
Database,
@@ -38,6 +38,56 @@ def __init__(self, status_code: int, detail: str):
3838
super().__init__(detail)
3939

4040

41+
class ExistsDatabaseError(ValueError):
42+
"""Raised when the database already exist."""
43+
44+
45+
class UnknownDatabaseError(ValueError):
46+
"""Raised when the database doesn't exist."""
47+
48+
49+
#####################################################
50+
# CRUD Functions
51+
#####################################################
52+
def create(database: Database) -> Database:
53+
"""Create a new Database definition"""
54+
55+
try:
56+
_ = get(name=database.name)
57+
raise ExistsDatabaseError(f"Database: {database.name} already exists")
58+
except UnknownDatabaseError:
59+
pass
60+
61+
if any(not getattr(database, key) for key in ("user", "password", "dsn")):
62+
raise ValueError("'user', 'password', and 'dsn' are required")
63+
64+
DATABASE_OBJECTS.append(database)
65+
return get(name=database.name)
66+
67+
68+
def get(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]:
69+
"""
70+
Return all Database objects if `name` is not provided,
71+
or the single Database if `name` is provided.
72+
If a `name` is provided and not found, raise exception
73+
"""
74+
database_objects = DATABASE_OBJECTS
75+
76+
logger.debug("%i databases are defined", len(database_objects))
77+
database_filtered = [db for db in database_objects if (name is None or db.name == name)]
78+
logger.debug("%i databases after filtering", len(database_filtered))
79+
80+
if name and not database_filtered:
81+
raise UnknownDatabaseError(f"{name} not found")
82+
83+
return database_filtered
84+
85+
86+
def delete(name: DatabaseNameType) -> None:
87+
"""Remove database from database objects"""
88+
DATABASE_OBJECTS[:] = [d for d in DATABASE_OBJECTS if d.name != name]
89+
90+
4191
#####################################################
4292
# Protected Functions
4393
#####################################################
@@ -231,7 +281,7 @@ def get_databases(
231281
db_name: Optional[DatabaseNameType] = None, validate: bool = False
232282
) -> Union[list[Database], Database, None]:
233283
"""Return list of Database Objects"""
234-
databases = core_databases.get_database(db_name)
284+
databases = get(db_name)
235285
if validate:
236286
for db in databases:
237287
try:

src/server/api/utils/models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,18 @@ def get(
8989
def update(payload: schema.Model) -> schema.Model:
9090
"""Update an existing Model definition"""
9191

92-
(model_update,) = get(model_provider=payload.provider, model_id=payload.id)
93-
if payload.enabled and model_update.api_base and not is_url_accessible(model_update.api_base)[0]:
94-
model_update.enabled = False
95-
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")
92+
# Get the existing model from MODEL_OBJECTS (this is a reference to the object in the list)
93+
(model_existing,) = get(model_provider=payload.provider, model_id=payload.id)
9694

97-
for key, value in payload:
98-
if hasattr(model_update, key):
99-
setattr(model_update, key, value)
100-
else:
101-
raise InvalidModelError(f"Model: Invalid setting - {key}.")
95+
# Check URL accessibility if enabling the model
96+
if payload.enabled and payload.api_base and not is_url_accessible(payload.api_base)[0]:
97+
model_existing.enabled = False
98+
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")
10299

103-
return model_update
100+
# Update all fields from payload in place
101+
for key, value in payload.model_dump().items():
102+
setattr(model_existing, key, value)
103+
return model_existing
104104

105105

106106
def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdType) -> None:

src/server/api/v1/databases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from fastapi import APIRouter, HTTPException
7+
import oracledb
78

89
import server.api.utils.databases as utils_databases
910

@@ -15,7 +16,7 @@
1516
# Validate the DEFAULT Databases
1617
try:
1718
_ = utils_databases.get_databases(db_name="DEFAULT", validate=True)
18-
except Exception:
19+
except (ValueError, PermissionError, ConnectionError, LookupError, oracledb.DatabaseError):
1920
pass
2021

2122
auth = APIRouter()

src/server/bootstrap/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _get_base_models_list() -> list[dict]:
144144
"provider": "ollama",
145145
"api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"),
146146
"api_key": "",
147-
"max_chunk_size": 8192,
147+
"max_chunk_size": 512,
148148
},
149149
]
150150

tests/server/integration/test_endpoints_models.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,52 @@ def test_models_update_edge_cases(self, client, auth_headers):
321321
)
322322
assert response.status_code == 404
323323

324+
def test_models_update_max_chunk_size(self, client, auth_headers):
325+
"""Test updating max_chunk_size for embedding models (regression test)"""
326+
# Create an embedding model with default max_chunk_size
327+
payload = {
328+
"id": "test-embed-chunk-size",
329+
"enabled": False,
330+
"type": "embed",
331+
"provider": "test_provider",
332+
"api_base": "http://127.0.0.1:11434",
333+
"max_chunk_size": 8192,
334+
}
335+
336+
# Create the model
337+
response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload)
338+
assert response.status_code == 201
339+
assert response.json()["max_chunk_size"] == 8192
340+
341+
# Update the max_chunk_size to 512
342+
payload["max_chunk_size"] = 512
343+
response = client.patch(
344+
f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload
345+
)
346+
assert response.status_code == 200
347+
assert response.json()["max_chunk_size"] == 512
348+
349+
# Verify the update persists by fetching the model again
350+
response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])
351+
assert response.status_code == 200
352+
assert response.json()["max_chunk_size"] == 512
353+
354+
# Update to a different value to ensure it's not cached
355+
payload["max_chunk_size"] = 1024
356+
response = client.patch(
357+
f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload
358+
)
359+
assert response.status_code == 200
360+
assert response.json()["max_chunk_size"] == 1024
361+
362+
# Verify again
363+
response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])
364+
assert response.status_code == 200
365+
assert response.json()["max_chunk_size"] == 1024
366+
367+
# Clean up
368+
client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])
369+
324370
def test_models_response_schema_validation(self, client, auth_headers):
325371
"""Test response schema validation for all endpoints"""
326372
# Test /v1/models response schema

0 commit comments

Comments
 (0)