|
9 | 9 | import oracledb |
10 | 10 | from langchain_community.vectorstores import oraclevs as LangchainVS |
11 | 11 |
|
12 | | -import server.api.core.databases as core_databases |
13 | 12 | import server.api.core.settings as core_settings |
| 13 | +from server.bootstrap.bootstrap import DATABASE_OBJECTS |
14 | 14 |
|
15 | 15 | from common.schema import ( |
16 | 16 | Database, |
@@ -38,6 +38,56 @@ def __init__(self, status_code: int, detail: str): |
38 | 38 | super().__init__(detail) |
39 | 39 |
|
40 | 40 |
|
| 41 | +class ExistsDatabaseError(ValueError): |
| 42 | + """Raised when the database already exist.""" |
| 43 | + |
| 44 | + |
| 45 | +class UnknownDatabaseError(ValueError): |
| 46 | + """Raised when the database doesn't exist.""" |
| 47 | + |
| 48 | + |
| 49 | +##################################################### |
| 50 | +# CRUD Functions |
| 51 | +##################################################### |
| 52 | +def create(database: Database) -> Database: |
| 53 | + """Create a new Database definition""" |
| 54 | + |
| 55 | + try: |
| 56 | + _ = get(name=database.name) |
| 57 | + raise ExistsDatabaseError(f"Database: {database.name} already exists") |
| 58 | + except UnknownDatabaseError: |
| 59 | + pass |
| 60 | + |
| 61 | + if any(not getattr(database, key) for key in ("user", "password", "dsn")): |
| 62 | + raise ValueError("'user', 'password', and 'dsn' are required") |
| 63 | + |
| 64 | + DATABASE_OBJECTS.append(database) |
| 65 | + return get(name=database.name) |
| 66 | + |
| 67 | + |
| 68 | +def get(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]: |
| 69 | + """ |
| 70 | + Return all Database objects if `name` is not provided, |
| 71 | + or the single Database if `name` is provided. |
| 72 | + If a `name` is provided and not found, raise exception |
| 73 | + """ |
| 74 | + database_objects = DATABASE_OBJECTS |
| 75 | + |
| 76 | + logger.debug("%i databases are defined", len(database_objects)) |
| 77 | + database_filtered = [db for db in database_objects if (name is None or db.name == name)] |
| 78 | + logger.debug("%i databases after filtering", len(database_filtered)) |
| 79 | + |
| 80 | + if name and not database_filtered: |
| 81 | + raise UnknownDatabaseError(f"{name} not found") |
| 82 | + |
| 83 | + return database_filtered |
| 84 | + |
| 85 | + |
| 86 | +def delete(name: DatabaseNameType) -> None: |
| 87 | + """Remove database from database objects""" |
| 88 | + DATABASE_OBJECTS[:] = [d for d in DATABASE_OBJECTS if d.name != name] |
| 89 | + |
| 90 | + |
41 | 91 | ##################################################### |
42 | 92 | # Protected Functions |
43 | 93 | ##################################################### |
@@ -231,7 +281,7 @@ def get_databases( |
231 | 281 | db_name: Optional[DatabaseNameType] = None, validate: bool = False |
232 | 282 | ) -> Union[list[Database], Database, None]: |
233 | 283 | """Return list of Database Objects""" |
234 | | - databases = core_databases.get_database(db_name) |
| 284 | + databases = get(db_name) |
235 | 285 | if validate: |
236 | 286 | for db in databases: |
237 | 287 | try: |
|
0 commit comments