-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
♻ REFACTOR: Uses db proxy to dynamically define connection
- Loading branch information
Showing
26 changed files
with
178 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,134 @@ | ||
from contextvars import ContextVar | ||
import logging | ||
import os | ||
from dataclasses import dataclass | ||
from contextvars import ContextVar | ||
from pathlib import Path | ||
from unicodedata import name | ||
|
||
import peewee as pw | ||
|
||
from conf.local import get_database_info | ||
|
||
# ! Importing settings will create circular import | ||
# from conf.setup import settings | ||
from conf.local import ( | ||
read_database_info, | ||
write_config, | ||
read_config, | ||
append_database_list, | ||
) | ||
from models.base import database_proxy | ||
from samudra.conf.database.options import DatabaseEngine | ||
|
||
# TODO: Enforce requirements per database engine | ||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} | ||
db_state = ContextVar("db_state", default=db_state_default.copy()) | ||
|
||
|
||
# As settings | ||
# ENGINE = settings.get("database").get("engine", None) | ||
# DATABASE_NAME = settings.get("database").get("name", "samudra") | ||
class SQLiteConnectionState(pw._ConnectionState): | ||
"""Defaults to make SQLite DB async-compatible (according to FastAPI/Pydantic)""" | ||
|
||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} | ||
db_state = ContextVar("db_state", default=db_state_default.copy()) | ||
def __init__(self, **kwargs): | ||
super().__setattr__("_state", db_state) | ||
super().__init__(**kwargs) | ||
|
||
def __setattr__(self, name, value): | ||
self._state.get()[name] = value | ||
|
||
def get_database(db_name: str, engine: DatabaseEngine, **kwargs) -> pw.Database: | ||
""" | ||
Returns the connection class based on the engine. | ||
""" | ||
if engine is None or engine not in DatabaseEngine.__members__.values(): | ||
raise ValueError( | ||
"Please specify database engine in conf.toml. You entered {}. Valid values are: \n - {}".format( | ||
engine, "\n - ".join(DatabaseEngine.__members__.values()) | ||
) | ||
) | ||
if engine == DatabaseEngine.SQLite: | ||
return get_sqlite( | ||
folder=db_name, path=kwargs.pop("path"), new=kwargs.pop("new"), **kwargs | ||
def __getattr__(self, name): | ||
return self._state.get()[name] | ||
|
||
|
||
def set_active_database(name: str) -> None: | ||
"""Sets the database as the currently active database""" | ||
# Check if the name is already registered in .samudra/databases.toml | ||
db_obj = read_database_info(name=name) | ||
if db_obj is None: | ||
raise FileNotFoundError( | ||
f"The database name `{name}` is not found. Perhaps it is not created yet." | ||
) | ||
# Write the info in .samudra/config.toml | ||
write_config({"active": name}) | ||
# TODO ? Write relevant variables into .env for server? | ||
|
||
|
||
def new_database(name: str, engine: DatabaseEngine, path: str) -> pw.Database: | ||
"""Create and register a SQLite database or just register a database if not SQLite""" | ||
# ? Should this be a function parameters? | ||
DATABASE_HOST = os.getenv("DATABASE_HOST") | ||
DATABASE_PORT = int(os.getenv("DATABASE_PORT")) | ||
DATABASE_OPTIONS = os.getenv("DATABASE_OPTIONS") | ||
USERNAME = os.getenv("DATABASE_USERNAME") | ||
PASSWORD = os.getenv("DATABASE_PASSWORD") | ||
SSL_MODE = os.getenv("SSL_MODE") | ||
|
||
if engine is None or engine not in DatabaseEngine.__members__.values(): | ||
raise ValueError( | ||
"Invalid engine. You entered {}. Valid values are: \n - {}".format( | ||
engine, "\n - ".join(DatabaseEngine.__members__.values()) | ||
) | ||
) | ||
if engine == DatabaseEngine.SQLite: | ||
return create_sqlite(name=name, path=Path(path)) | ||
if engine == DatabaseEngine.MySQL: | ||
conn_str = f"mysql://{USERNAME}:{PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{db_name}?ssl-mode=REQUIRED" | ||
conn_str = f"mysql://{USERNAME}:{PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{name}?ssl-mode=REQUIRED" | ||
return_db = pw.MySQLDatabase(conn_str) | ||
logging.info(f"Connecting to {return_db.database} as {USERNAME}") | ||
if engine == DatabaseEngine.CockroachDB: | ||
from playhouse.cockroachdb import CockroachDatabase | ||
|
||
conn_str = f"postgresql://{USERNAME}:{PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{db_name}?sslmode=verify-full&options={DATABASE_OPTIONS}" | ||
return_db = CockroachDatabase(conn_str) | ||
append_database_list(name=name, path=conn_str, engine=engine) | ||
logging.info(f"Connecting to {return_db.database} as {USERNAME}") | ||
else: | ||
raise NotImplementedError("Invalid engine") | ||
return return_db | ||
|
||
|
||
def get_sqlite(folder: str, path: str, db_file: str = "samudra.db", new: bool = False): | ||
# Defaults to make it async-compatible (according to FastAPI/Pydantic) | ||
class PeeweeConnectionState(pw._ConnectionState): | ||
def __init__(self, **kwargs): | ||
super().__setattr__("_state", db_state) | ||
super().__init__(**kwargs) | ||
|
||
def __setattr__(self, name, value): | ||
self._state.get()[name] = value | ||
|
||
def __getattr__(self, name): | ||
return self._state.get()[name] | ||
|
||
# The DB connection object | ||
# ? Perlu ke test? | ||
# TODO Add Test | ||
base_path: Path = Path(path, folder) | ||
full_path: Path = Path(base_path, db_file) | ||
if new: | ||
# If a new database is to be created, check if the given path is occupied. | ||
# If not occupied, create the database. | ||
# If occupied, raise FileExistsError | ||
try: | ||
base_path.mkdir(parents=True) | ||
except FileExistsError: | ||
if full_path in [*base_path.iterdir()]: | ||
raise FileExistsError( | ||
f"A samudra database already exists in {full_path.resolve()}" | ||
) | ||
elif [*base_path.iterdir()] is [None]: | ||
print(f"Populating empty folder `{base_path.resolve()}` with {db_file}") | ||
else: | ||
raise FileExistsError( | ||
f"The path `{base_path.resolve()}` is already occupied with something else. Try passing `new=False` to access the database or consider creating new database in another folder." | ||
) | ||
def get_active_database() -> pw.Database: | ||
active_database_name = read_config(key="active") | ||
if not active_database_name: | ||
raise KeyError("No active database is defined") | ||
return get_database(name=active_database_name) | ||
|
||
|
||
def get_database(name: str) -> pw.Database: | ||
"""Returns the connection class based on the name.""" | ||
info = read_database_info(name) | ||
if info.get("engine") == DatabaseEngine.SQLite: | ||
return_db = pw.SqliteDatabase(info.get("path")) | ||
return_db._state = SQLiteConnectionState() | ||
return return_db | ||
if info.get("engine") == DatabaseEngine.MySQL: | ||
return pw.MySQLDatabase(info.get("path")) | ||
|
||
|
||
def create_sqlite( | ||
name: str, path: Path, filename: str = "samudra.db", description: str = "" | ||
) -> pw.SqliteDatabase: | ||
base_path: Path = Path(path, name) | ||
full_path: Path = Path(base_path, filename) | ||
# Check if the given path is occupied. | ||
# If not occupied, create the database. | ||
# If occupied, raise FileExistsError. | ||
try: | ||
base_path.mkdir(parents=True) | ||
except FileExistsError: | ||
if full_path in [*base_path.iterdir()]: | ||
raise FileExistsError( | ||
f"A samudra database already exists in {full_path.resolve()}" | ||
) | ||
elif [*base_path.iterdir()] is [None]: | ||
print(f"Populating empty folder `{base_path.resolve()}` with {filename}") | ||
else: | ||
raise FileExistsError( | ||
f"The path `{base_path.resolve()}` is already occupied with something else. Consider creating new database in another folder." | ||
) | ||
# Set up readme | ||
README = Path(base_path, "README.md") | ||
README.touch() | ||
with README.open(mode="w") as f: | ||
f.writelines( | ||
[ | ||
f"# {folder.title()}\n", | ||
f"# {name.title()}\n", | ||
"Created using [samudra](https://github.com/samudradev/samudra)", | ||
"", | ||
description, | ||
] | ||
) | ||
# else: | ||
# # Originally, this part was intended to get the path to database via its name in the local `~/.samudra/databases.toml` file when `new=False`. | ||
# # However, that would mean that the `path` parameter is rendered meaningless unless `new=True`. | ||
# # Perhaps this functionality should be outside the function with paths and folder parameters as its result | ||
# # which will be passed to `get_database/get_sqlite` with explicit `new=False` | ||
# db_obj = get_database_info(name=db_file) | ||
# if db_obj is None: | ||
# return FileNotFoundError( | ||
# f"The database name {db_file} is not found. Perhaps it is not created yet. Pass the key `new=True` if that's the case" | ||
# ) | ||
# base_path: Path = Path(db_obj["path"], folder=db_file) | ||
# full_path: Path = Path(base_path, db_file) | ||
return_db = pw.SqliteDatabase( | ||
full_path.resolve(), | ||
check_same_thread=False, | ||
) | ||
return_db._state = PeeweeConnectionState() | ||
return_db._state = SQLiteConnectionState() | ||
database_proxy.init(return_db) | ||
database_proxy.create_tables() | ||
logging.info(f"Connecting to {return_db.database}") | ||
append_database_list(name=name, path=full_path, engine="sqlite") | ||
return return_db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,3 @@ | |
class DatabaseEngine(str, enum.Enum): | ||
SQLite = "sqlite" | ||
MySQL = "mysql" | ||
CockroachDB = "cockroachdb" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,46 @@ | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import List, Dict, Optional | ||
from typing import Dict, Optional, Union, Any | ||
|
||
import pytomlpp as toml | ||
|
||
HOME: Path = Path("~") | ||
|
||
dotconfig = Path(HOME, ".samudra").expanduser() | ||
db_dotconfig = Path(dotconfig, "databases.toml") | ||
database_list_file = Path(dotconfig, "databases.toml") | ||
config_file = Path(dotconfig, "config.toml") | ||
|
||
database_list_file.touch() | ||
config_file.touch() | ||
|
||
if not dotconfig.exists(): | ||
dotconfig.mkdir() | ||
|
||
|
||
def save_database(db_name: str, path: Path): | ||
def append_database_list(name: str, path: Union[Path, str], engine: str): | ||
databases: Dict[Dict] = defaultdict(dict) | ||
databases[db_name] = path.resolve().__str__() | ||
with open(db_dotconfig, mode="a") as f: | ||
databases[name] = {"path": path.resolve().__str__(), "engine": engine} | ||
with open(database_list_file, mode="a") as f: | ||
f.write(toml.dumps({"databases": databases})) | ||
|
||
|
||
def get_databases_config() -> dict: | ||
with open(db_dotconfig, mode="r") as f: | ||
return toml.load(f) | ||
def read_databases_list() -> dict: | ||
return toml.load(database_list_file) | ||
|
||
|
||
def read_database_info(name: str) -> Optional[dict]: | ||
return read_databases_list().get("databases").get(name, None) | ||
|
||
|
||
def read_config(key: Optional[str] = None) -> Any: | ||
configs = toml.load(config_file, mode="r") | ||
if key: | ||
return configs.get(key) | ||
return configs | ||
|
||
|
||
def get_database_info(name: str) -> Optional[dict]: | ||
return get_databases_config().get("databases").get(name, None) | ||
def write_config(content: Dict) -> None: | ||
configs = read_config(key=None) | ||
for key, val in zip(content.keys(), content.values()): | ||
configs[key] = val | ||
toml.dump(configs, config_file, mode="w") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.