From ed0c95a97b909b82d44f9cd659c37b10eb78aa31 Mon Sep 17 00:00:00 2001 From: carl-andersson Date: Tue, 4 Feb 2025 09:44:56 +0100 Subject: [PATCH] Feature/SK-1367 | Refactor the creation of the databaseconnection (#807) --- fedn/network/api/shared.py | 75 +++----- fedn/network/combiner/shared.py | 63 +++--- fedn/network/storage/dbconnection.py | 180 ++++++++++++++++++ .../storage/statestore/stores/client_store.py | 19 +- .../statestore/stores/combiner_store.py | 15 +- .../storage/statestore/stores/model_store.py | 23 ++- .../statestore/stores/package_store.py | 23 ++- .../statestore/stores/prediction_store.py | 13 +- .../storage/statestore/stores/round_store.py | 15 +- .../statestore/stores/session_store.py | 14 +- .../storage/statestore/stores/status_store.py | 13 +- .../storage/statestore/stores/store.py | 28 +-- .../statestore/stores/validation_store.py | 13 +- 13 files changed, 315 insertions(+), 179 deletions(-) create mode 100644 fedn/network/storage/dbconnection.py diff --git a/fedn/network/api/shared.py b/fedn/network/api/shared.py index 5cd397566..b1969bd6f 100644 --- a/fedn/network/api/shared.py +++ b/fedn/network/api/shared.py @@ -1,70 +1,39 @@ import os -import pymongo -from pymongo.database import Database from werkzeug.security import safe_join -from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config +from fedn.common.config import get_modelstorage_config, get_network_config from fedn.network.controller.control import Control +from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.base import RepositoryBase from fedn.network.storage.s3.miniorepository import MINIORepository from fedn.network.storage.s3.repository import Repository -from fedn.network.storage.statestore.stores.client_store import ClientStore, MongoDBClientStore, SQLClientStore -from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore -from fedn.network.storage.statestore.stores.model_store import MongoDBModelStore, SQLModelStore -from fedn.network.storage.statestore.stores.package_store import MongoDBPackageStore, PackageStore, SQLPackageStore -from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore -from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore -from fedn.network.storage.statestore.stores.session_store import MongoDBSessionStore, SQLSessionStore +from fedn.network.storage.statestore.stores.client_store import ClientStore +from fedn.network.storage.statestore.stores.combiner_store import CombinerStore +from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.storage.statestore.stores.package_store import PackageStore +from fedn.network.storage.statestore.stores.prediction_store import PredictionStore +from fedn.network.storage.statestore.stores.round_store import RoundStore +from fedn.network.storage.statestore.stores.session_store import SessionStore from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore -from fedn.network.storage.statestore.stores.store import MyAbstractBase, engine -from fedn.network.storage.statestore.stores.validation_store import MongoDBValidationStore, SQLValidationStore, ValidationStore +from fedn.network.storage.statestore.stores.status_store import StatusStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore from fedn.utils.checksum import sha -statestore_config = get_statestore_config() modelstorage_config = get_modelstorage_config() network_id = get_network_config() -client_store: ClientStore = None -validation_store: ValidationStore = None -combiner_store: CombinerStore = None -status_store: StatusStore = None -prediction_store: PredictionStore = None -round_store: RoundStore = None -package_store: PackageStore = None -model_store: SQLModelStore = None -session_store: SQLSessionStore = None - -if statestore_config["type"] == "MongoDB": - mc = pymongo.MongoClient(**statestore_config["mongo_config"]) - mc.server_info() - mdb: Database = mc[network_id] - - client_store = MongoDBClientStore(mdb, "network.clients") - validation_store = MongoDBValidationStore(mdb, "control.validations") - combiner_store = MongoDBCombinerStore(mdb, "network.combiners") - status_store = MongoDBStatusStore(mdb, "control.status") - prediction_store = MongoDBPredictionStore(mdb, "control.predictions") - round_store = MongoDBRoundStore(mdb, "control.rounds") - package_store = MongoDBPackageStore(mdb, "control.packages") - model_store = MongoDBModelStore(mdb, "control.models") - session_store = MongoDBSessionStore(mdb, "control.sessions") - -elif statestore_config["type"] in ["SQLite", "PostgreSQL"]: - MyAbstractBase.metadata.create_all(engine, checkfirst=True) - - client_store = SQLClientStore() - validation_store = SQLValidationStore() - combiner_store = SQLCombinerStore() - status_store = SQLStatusStore() - prediction_store = SQLPredictionStore() - round_store = SQLRoundStore() - package_store = SQLPackageStore() - model_store = SQLModelStore() - session_store = SQLSessionStore() -else: - raise ValueError("Unknown statestore type") +# TODO: Refactor all access to the stores to use the DatabaseConnection +stores = DatabaseConnection().get_stores() +session_store: SessionStore = stores.session_store +model_store: ModelStore = stores.model_store +round_store: RoundStore = stores.round_store +package_store: PackageStore = stores.package_store +combiner_store: CombinerStore = stores.combiner_store +client_store: ClientStore = stores.client_store +status_store: StatusStore = stores.status_store +validation_store: ValidationStore = stores.validation_store +prediction_store: PredictionStore = stores.prediction_store repository = Repository(modelstorage_config["storage_config"]) diff --git a/fedn/network/combiner/shared.py b/fedn/network/combiner/shared.py index a0aa66441..d770cc070 100644 --- a/fedn/network/combiner/shared.py +++ b/fedn/network/combiner/shared.py @@ -1,50 +1,31 @@ -import pymongo -from pymongo.database import Database - -from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config +from fedn.common.config import get_modelstorage_config from fedn.network.combiner.modelservice import ModelService +from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.repository import Repository -from fedn.network.storage.statestore.stores.client_store import ClientStore, MongoDBClientStore, SQLClientStore -from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore -from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore -from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore -from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore -from fedn.network.storage.statestore.stores.store import MyAbstractBase, engine -from fedn.network.storage.statestore.stores.validation_store import MongoDBValidationStore, SQLValidationStore, ValidationStore +from fedn.network.storage.statestore.stores.client_store import ClientStore +from fedn.network.storage.statestore.stores.combiner_store import CombinerStore +from fedn.network.storage.statestore.stores.model_store import ModelStore +from fedn.network.storage.statestore.stores.package_store import PackageStore +from fedn.network.storage.statestore.stores.prediction_store import PredictionStore +from fedn.network.storage.statestore.stores.round_store import RoundStore +from fedn.network.storage.statestore.stores.session_store import SessionStore +from fedn.network.storage.statestore.stores.status_store import StatusStore +from fedn.network.storage.statestore.stores.validation_store import ValidationStore -statestore_config = get_statestore_config() modelstorage_config = get_modelstorage_config() -network_id = get_network_config() - -client_store: ClientStore = None -validation_store: ValidationStore = None -combiner_store: CombinerStore = None -status_store: StatusStore = None -prediction_store: PredictionStore = None -round_store: RoundStore = None - -if statestore_config["type"] == "MongoDB": - mc = pymongo.MongoClient(**statestore_config["mongo_config"]) - mc.server_info() - mdb: Database = mc[network_id] - client_store = MongoDBClientStore(mdb, "network.clients") - validation_store = MongoDBValidationStore(mdb, "control.validations") - combiner_store = MongoDBCombinerStore(mdb, "network.combiners") - status_store = MongoDBStatusStore(mdb, "control.status") - prediction_store = MongoDBPredictionStore(mdb, "control.predictions") - round_store = MongoDBRoundStore(mdb, "control.rounds") -elif statestore_config["type"] in ["SQLite", "PostgreSQL"]: - MyAbstractBase.metadata.create_all(engine, checkfirst=True) +# TODO: Refactor all access to the stores to use the DatabaseConnection +stores = DatabaseConnection().get_stores() +session_store: SessionStore = stores.session_store +model_store: ModelStore = stores.model_store +round_store: RoundStore = stores.round_store +package_store: PackageStore = stores.package_store +combiner_store: CombinerStore = stores.combiner_store +client_store: ClientStore = stores.client_store +status_store: StatusStore = stores.status_store +validation_store: ValidationStore = stores.validation_store +prediction_store: PredictionStore = stores.prediction_store - client_store = SQLClientStore() - validation_store = SQLValidationStore() - combiner_store = SQLCombinerStore() - status_store = SQLStatusStore() - prediction_store = SQLPredictionStore() - round_store = SQLRoundStore() -else: - raise ValueError("Unknown statestore type") repository = Repository(modelstorage_config["storage_config"], init_buckets=False) diff --git a/fedn/network/storage/dbconnection.py b/fedn/network/storage/dbconnection.py new file mode 100644 index 000000000..fb6e05cce --- /dev/null +++ b/fedn/network/storage/dbconnection.py @@ -0,0 +1,180 @@ +"""This module provides classes for managing database connections and stores in a federated network environment. + +Classes: + StoreContainer: A container for various store instances. + DatabaseConnection: A singleton class for managing database connections and stores. +""" + +import pymongo +from pymongo.database import Database +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from fedn.common.config import get_network_config, get_statestore_config +from fedn.network.storage.statestore.stores.client_store import ClientStore, MongoDBClientStore, SQLClientStore +from fedn.network.storage.statestore.stores.combiner_store import CombinerStore, MongoDBCombinerStore, SQLCombinerStore +from fedn.network.storage.statestore.stores.model_store import ModelStore, MongoDBModelStore, SQLModelStore +from fedn.network.storage.statestore.stores.package_store import MongoDBPackageStore, PackageStore, SQLPackageStore +from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore +from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore +from fedn.network.storage.statestore.stores.session_store import MongoDBSessionStore, SessionStore, SQLSessionStore +from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore +from fedn.network.storage.statestore.stores.store import MyAbstractBase +from fedn.network.storage.statestore.stores.validation_store import MongoDBValidationStore, SQLValidationStore, ValidationStore + + +class StoreContainer: + """A container for various store instances.""" + + def __init__( # noqa: PLR0913 + self, + client_store: ClientStore, + validation_store: ValidationStore, + combiner_store: CombinerStore, + status_store: StatusStore, + prediction_store: PredictionStore, + round_store: RoundStore, + package_store: PackageStore, + model_store: ModelStore, + session_store: SessionStore, + ) -> None: + """Initialize the StoreContainer with various store instances.""" + self.client_store = client_store + self.validation_store = validation_store + self.combiner_store = combiner_store + self.status_store = status_store + self.prediction_store = prediction_store + self.round_store = round_store + self.package_store = package_store + self.model_store = model_store + self.session_store = session_store + + +class DatabaseConnection: + """Singleton class for managing database connections and stores.""" + + _instance = None + + def __new__(cls, *, force_create_new: bool = False) -> "DatabaseConnection": + """Create a new instance of DatabaseConnection or return the existing singleton instance. + + Args: + force_create_new (bool): If True, a new instance will be created regardless of the singleton pattern. + + Returns: + DatabaseConnection: A new instance if force_create_new is True, otherwise the existing singleton instance. + + """ + if cls._instance is None or force_create_new: + obj = super(DatabaseConnection, cls).__new__(cls) + obj._init_connection() + cls._instance = obj + + return cls._instance + + def _init_connection(self, statestore_config: dict = None, network_id: dict = None) -> None: + if statestore_config is None: + statestore_config = get_statestore_config() + if network_id is None: + network_id = get_network_config() + + if statestore_config["type"] == "MongoDB": + mdb: Database = self._setup_mongo(statestore_config, network_id) + + client_store = MongoDBClientStore(mdb, "network.clients") + validation_store = MongoDBValidationStore(mdb, "control.validations") + combiner_store = MongoDBCombinerStore(mdb, "network.combiners") + status_store = MongoDBStatusStore(mdb, "control.status") + prediction_store = MongoDBPredictionStore(mdb, "control.predictions") + round_store = MongoDBRoundStore(mdb, "control.rounds") + package_store = MongoDBPackageStore(mdb, "control.packages") + model_store = MongoDBModelStore(mdb, "control.models") + session_store = MongoDBSessionStore(mdb, "control.sessions") + + elif statestore_config["type"] in ["SQLite", "PostgreSQL"]: + Session = self._setup_sql(statestore_config) # noqa: N806 + + client_store = SQLClientStore(Session) + validation_store = SQLValidationStore(Session) + combiner_store = SQLCombinerStore(Session) + status_store = SQLStatusStore(Session) + prediction_store = SQLPredictionStore(Session) + round_store = SQLRoundStore(Session) + package_store = SQLPackageStore(Session) + model_store = SQLModelStore(Session) + session_store = SQLSessionStore(Session) + else: + raise ValueError("Unknown statestore type") + + self.sc = StoreContainer( + client_store, validation_store, combiner_store, status_store, prediction_store, round_store, package_store, model_store, session_store + ) + + def close(self) -> None: + """Close the database connection.""" + pass + + def _setup_mongo(self, statestore_config: dict, network_id: str) -> "DatabaseConnection": + mc = pymongo.MongoClient(**statestore_config["mongo_config"]) + mc.server_info() + mdb: Database = mc[network_id] + + return mdb + + def _setup_sql(self, statestore_config: dict) -> "DatabaseConnection": + if statestore_config["type"] == "SQLite": + engine = create_engine("sqlite:///my_database.db", echo=False) + elif statestore_config["type"] == "PostgreSQL": + postgres_config = statestore_config["postgres_config"] + username = postgres_config["username"] + password = postgres_config["password"] + host = postgres_config["host"] + port = postgres_config["port"] + + engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/fedn_db", echo=False) + + Session = sessionmaker(engine) # noqa: N806 + + MyAbstractBase.metadata.create_all(engine, checkfirst=True) + + return Session + + def get_stores(self) -> StoreContainer: + """Get the StoreContainer instance.""" + return self.sc + + @property + def client_store(self) -> ClientStore: + return self.sc.client_store + + @property + def validation_store(self) -> ValidationStore: + return self.sc.validation_store + + @property + def combiner_store(self) -> CombinerStore: + return self.sc.combiner_store + + @property + def status_store(self) -> StatusStore: + return self.sc.status_store + + @property + def prediction_store(self) -> PredictionStore: + return self.sc.prediction_store + + @property + def round_store(self) -> RoundStore: + return self.sc.round_store + + @property + def package_store(self) -> PackageStore: + return self.sc.package_store + + @property + def model_store(self) -> ModelStore: + return self.sc.model_store + + @property + def session_store(self) -> SessionStore: + return self.sc.session_store diff --git a/fedn/network/storage/statestore/stores/client_store.py b/fedn/network/storage/statestore/stores/client_store.py index 9753238fd..3265ed332 100644 --- a/fedn/network/storage/statestore/stores/client_store.py +++ b/fedn/network/storage/statestore/stores/client_store.py @@ -8,7 +8,7 @@ from sqlalchemy import String, func, or_, select from sqlalchemy.orm import Mapped, mapped_column -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store from .shared import EntityNotFound, from_document @@ -171,8 +171,11 @@ def from_row(row: ClientModel) -> Client: class SQLClientStore(ClientStore, SQLStore[Client]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Client: - with Session() as session: + with self.Session() as session: stmt = select(ClientModel).where(or_(ClientModel.id == id, ClientModel.client_id == id)) item = session.scalars(stmt).first() @@ -182,7 +185,7 @@ def get(self, id: str) -> Client: return from_row(item) def update(self, id: str, item: Client) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: stmt = select(ClientModel).where(or_(ClientModel.id == id, ClientModel.client_id == id)) existing_item = session.scalars(stmt).first() @@ -201,7 +204,7 @@ def update(self, id: str, item: Client) -> Tuple[bool, Any]: return True, from_row(existing_item) def add(self, item: Client) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: entity = ClientModel( client_id=item.get("client_id"), combiner=item.get("combiner"), @@ -221,7 +224,7 @@ def delete(self, id): raise NotImplementedError def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(ClientModel) for key, value in kwargs.items(): @@ -249,7 +252,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": count, "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(ClientModel) for key, value in kwargs.items(): @@ -260,7 +263,7 @@ def count(self, **kwargs): return count def upsert(self, item: Client) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: id = item.get("id") client_id = item.get("client_id") @@ -295,7 +298,7 @@ def upsert(self, item: Client) -> Tuple[bool, Any]: return True, from_row(existing_item) def connected_client_count(self, combiners): - with Session() as session: + with self.Session() as session: stmt = select(ClientModel.combiner, func.count(ClientModel.combiner)).group_by(ClientModel.combiner) if combiners: stmt = stmt.where(ClientModel.combiner.in_(combiners)) diff --git a/fedn/network/storage/statestore/stores/combiner_store.py b/fedn/network/storage/statestore/stores/combiner_store.py index 448985a84..dcc46d87e 100644 --- a/fedn/network/storage/statestore/stores/combiner_store.py +++ b/fedn/network/storage/statestore/stores/combiner_store.py @@ -7,7 +7,7 @@ from sqlalchemy import String, func, or_, select from sqlalchemy.orm import Mapped, mapped_column -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store from .shared import EntityNotFound, from_document @@ -137,8 +137,11 @@ def from_row(row: CombinerModel) -> Combiner: class SQLCombinerStore(CombinerStore, SQLStore[Combiner]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Combiner: - with Session() as session: + with self.Session() as session: stmt = select(CombinerModel).where(or_(CombinerModel.id == id, CombinerModel.name == id)) item = session.scalars(stmt).first() if item is None: @@ -149,7 +152,7 @@ def update(self, id, item): raise NotImplementedError def add(self, item): - with Session() as session: + with self.Session() as session: entity = CombinerModel( address=item["address"], fqdn=item["fqdn"], @@ -163,7 +166,7 @@ def add(self, item): return True, from_row(entity) def delete(self, id: str) -> bool: - with Session() as session: + with self.Session() as session: stmt = select(CombinerModel).where(CombinerModel.id == id) item = session.scalars(stmt).first() if item is None: @@ -172,7 +175,7 @@ def delete(self, id: str) -> bool: return True def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(CombinerModel) for key, value in kwargs.items(): @@ -200,7 +203,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": count, "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(CombinerModel) for key, value in kwargs.items(): diff --git a/fedn/network/storage/statestore/stores/model_store.py b/fedn/network/storage/statestore/stores/model_store.py index 98334cd53..8e4aa0326 100644 --- a/fedn/network/storage/statestore/stores/model_store.py +++ b/fedn/network/storage/statestore/stores/model_store.py @@ -10,7 +10,7 @@ from fedn.network.storage.statestore.stores.shared import EntityNotFound, from_document from fedn.network.storage.statestore.stores.sql.shared import ModelModel -from fedn.network.storage.statestore.stores.store import MongoDBStore, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store class Model: @@ -263,8 +263,11 @@ def from_row(row: ModelModel) -> Model: class SQLModelStore(ModelStore, SQLStore[Model]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Model: - with Session() as session: + with self.Session() as session: stmt = select(ModelModel).where(ModelModel.id == id) item = session.scalars(stmt).first() if item is None: @@ -275,7 +278,7 @@ def update(self, id: str, item: Model) -> Tuple[bool, Any]: valid, message = validate(item) if not valid: return False, message - with Session() as session: + with self.Session() as session: stmt = select(ModelModel).where(ModelModel.id == id) existing_item = session.execute(stmt).first() if existing_item is None: @@ -294,7 +297,7 @@ def add(self, item: Model) -> Tuple[bool, Any]: if not valid: return False, message - with Session() as session: + with self.Session() as session: id: str = None if "model" in item: id = item["model"] @@ -317,7 +320,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(ModelModel) for key, value in kwargs.items(): @@ -345,7 +348,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": count, "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(ModelModel) for key, value in kwargs.items(): @@ -356,7 +359,7 @@ def count(self, **kwargs): return count def list_descendants(self, id: str, limit: int): - with Session() as session: + with self.Session() as session: # Define the recursive CTE descendant = aliased(ModelModel) # Alias for recursion cte = select(ModelModel).where(ModelModel.parent_model == id).cte(name="descendant_cte", recursive=True) @@ -378,7 +381,7 @@ def list_descendants(self, id: str, limit: int): return result def list_ancestors(self, id: str, limit: int, include_self=False, reverse=False): - with Session() as session: + with self.Session() as session: # Define the recursive CTE ancestor = aliased(ModelModel) # Alias for recursion cte = select(ModelModel).where(ModelModel.id == id).cte(name="ancestor_cte", recursive=True) @@ -400,7 +403,7 @@ def list_ancestors(self, id: str, limit: int, include_self=False, reverse=False) return result def get_active(self) -> str: - with Session() as session: + with self.Session() as session: active_stmt = select(ModelModel).where(ModelModel.active) active_item = session.scalars(active_stmt).first() if active_item: @@ -408,7 +411,7 @@ def get_active(self) -> str: raise EntityNotFound("Entity not found") def set_active(self, id: str) -> bool: - with Session() as session: + with self.Session() as session: active_stmt = select(ModelModel).where(ModelModel.active) active_item = session.scalars(active_stmt).first() if active_item: diff --git a/fedn/network/storage/statestore/stores/package_store.py b/fedn/network/storage/statestore/stores/package_store.py index 55d74d5e2..96f51be7c 100644 --- a/fedn/network/storage/statestore/stores/package_store.py +++ b/fedn/network/storage/statestore/stores/package_store.py @@ -11,7 +11,7 @@ from werkzeug.utils import secure_filename from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store def from_document(data: dict, active_package: dict): @@ -306,6 +306,9 @@ def from_row(row: PackageModel) -> Package: class SQLPackageStore(PackageStore, SQLStore[Package]): + def __init__(self, Session): + super().__init__(Session) + def _complement(self, item: Package): if "committed_at" not in item or item.committed_at is None: item["committed_at"] = datetime.now() @@ -322,7 +325,7 @@ def add(self, item: Package) -> Tuple[bool, Any]: return False, message self._complement(item) - with Session() as session: + with self.Session() as session: item = PackageModel( committed_at=item["committed_at"], description=item["description"] if "description" in item else "", @@ -336,7 +339,7 @@ def add(self, item: Package) -> Tuple[bool, Any]: return True, from_row(item) def get(self, id: str) -> Package: - with Session() as session: + with self.Session() as session: stmt = select(PackageModel).where(PackageModel.id == id) item = session.scalars(stmt).first() if item is None: @@ -347,7 +350,7 @@ def update(self, id: str, item: Package) -> bool: raise NotImplementedError def delete(self, id: str) -> bool: - with Session() as session: + with self.Session() as session: stmt = select(PackageModel).where(PackageModel.id == id) item = session.scalars(stmt).first() if item is None: @@ -357,7 +360,7 @@ def delete(self, id: str) -> bool: return True def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(PackageModel) for key, value in kwargs.items(): @@ -385,7 +388,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": count, "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(PackageModel) for key, value in kwargs.items(): @@ -396,7 +399,7 @@ def count(self, **kwargs): return count def set_active(self, id: str): - with Session() as session: + with self.Session() as session: active_stmt = select(PackageModel).where(PackageModel.active) active_item = session.scalars(active_stmt).first() if active_item: @@ -413,7 +416,7 @@ def set_active(self, id: str): return True def get_active(self) -> Package: - with Session() as session: + with self.Session() as session: active_stmt = select(PackageModel).where(PackageModel.active) active_item = session.scalars(active_stmt).first() if active_item: @@ -424,7 +427,7 @@ def set_active_helper(self, helper: str) -> bool: if not helper or helper == "" or helper not in ["numpyhelper", "binaryhelper", "androidhelper"]: raise ValueError() - with Session() as session: + with self.Session() as session: active_stmt = select(PackageModel).where(PackageModel.active) active_item = session.scalars(active_stmt).first() if active_item: @@ -445,7 +448,7 @@ def set_active_helper(self, helper: str) -> bool: session.commit() def delete_active(self) -> bool: - with Session() as session: + with self.Session() as session: active_stmt = select(PackageModel).where(PackageModel.active) active_item = session.scalars(active_stmt).first() if active_item: diff --git a/fedn/network/storage/statestore/stores/prediction_store.py b/fedn/network/storage/statestore/stores/prediction_store.py index c019b72c1..2d7294865 100644 --- a/fedn/network/storage/statestore/stores/prediction_store.py +++ b/fedn/network/storage/statestore/stores/prediction_store.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store class Prediction: @@ -98,8 +98,11 @@ def from_row(row: PredictionModel) -> Prediction: class SQLPredictionStore(PredictionStore, SQLStore[Prediction]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Prediction: - with Session() as session: + with self.Session() as session: stmt = select(Prediction).where(Prediction.id == id) item = session.scalars(stmt).first() @@ -112,7 +115,7 @@ def update(self, id: str, item: Prediction) -> bool: raise NotImplementedError("Update not implemented for PredictionStore") def add(self, item: Prediction) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: sender = item["sender"] if "sender" in item else None receiver = item["receiver"] if "receiver" in item else None @@ -137,7 +140,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for PredictionStore") def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(PredictionModel) for key, value in kwargs.items(): @@ -199,7 +202,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": len(result), "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(PredictionModel) for key, value in kwargs.items(): diff --git a/fedn/network/storage/statestore/stores/round_store.py b/fedn/network/storage/statestore/stores/round_store.py index c74b8d599..d433ca191 100644 --- a/fedn/network/storage/statestore/stores/round_store.py +++ b/fedn/network/storage/statestore/stores/round_store.py @@ -7,7 +7,7 @@ from sqlalchemy import Integer, func, or_, select from fedn.network.storage.statestore.stores.sql.shared import RoundCombinerModel, RoundConfigModel, RoundDataModel, RoundModel -from fedn.network.storage.statestore.stores.store import MongoDBStore, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store from .shared import EntityNotFound, from_document @@ -173,8 +173,11 @@ def from_row(row: RoundModel) -> Round: class SQLRoundStore(RoundStore, SQLStore[Round]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Round: - with Session() as session: + with self.Session() as session: stmt = select(RoundModel).where(or_(RoundModel.id == id, RoundModel.round_id == id)) item = session.scalars(stmt).first() @@ -184,7 +187,7 @@ def get(self, id: str) -> Round: return from_row(item) def update(self, id, item: Round) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: stmt = select(RoundModel).where(or_(RoundModel.id == id, RoundModel.round_id == id)) existing_item = session.scalars(stmt).first() @@ -295,7 +298,7 @@ def update(self, id, item: Round) -> Tuple[bool, Any]: return True, from_row(existing_item) def add(self, item: Round) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: round_id = item["round_id"] stmt = select(RoundModel).where(RoundModel.round_id == round_id) existing_item = session.scalars(stmt).first() @@ -391,7 +394,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(RoundModel) for key, value in kwargs.items(): @@ -426,7 +429,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": len(result), "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(RoundModel) for key, value in kwargs.items(): diff --git a/fedn/network/storage/statestore/stores/session_store.py b/fedn/network/storage/statestore/stores/session_store.py index efc59806f..d86f4709e 100644 --- a/fedn/network/storage/statestore/stores/session_store.py +++ b/fedn/network/storage/statestore/stores/session_store.py @@ -10,7 +10,6 @@ from fedn.network.storage.statestore.stores.shared import EntityNotFound, from_document from fedn.network.storage.statestore.stores.sql.shared import SessionConfigModel, SessionModel from fedn.network.storage.statestore.stores.store import MongoDBStore, SQLStore, Store -from fedn.network.storage.statestore.stores.store import Session as SQLSession class SessionConfig: @@ -213,8 +212,11 @@ def from_row(row: dict) -> Session: class SQLSessionStore(SessionStore, SQLStore[Session]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Session: - with SQLSession() as session: + with self.Session() as session: stmt = select(SessionModel, SessionConfigModel).join(SessionModel.session_config).where(SessionModel.id == id) item = session.execute(stmt).first() if item is None: @@ -242,7 +244,7 @@ def update(self, id: str, item: Session) -> Tuple[bool, Any]: valid, message = validate(item) if not valid: return False, message - with SQLSession() as session: + with self.Session() as session: stmt = select(SessionModel, SessionConfigModel).join(SessionModel.session_config).where(SessionModel.id == id) existing_item = session.execute(stmt).first() if existing_item is None: @@ -292,7 +294,7 @@ def add(self, item: Session) -> Tuple[bool, Any]: complement(item) - with SQLSession() as session: + with self.Session() as session: parent_item = SessionModel( id=item["session_id"], status=item["status"], name=item["name"] if "name" in item else None, committed_at=item["committed_at"] or None ) @@ -337,7 +339,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with SQLSession() as session: + with self.Session() as session: stmt = select(SessionModel, SessionConfigModel).join(SessionModel.session_config) for key, value in kwargs.items(): if "session_config" in key: @@ -384,7 +386,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": len(result), "result": result} def count(self, **kwargs): - with SQLSession() as session: + with self.Session() as session: stmt = select(func.count()).select_from(SessionModel) for key, value in kwargs.items(): diff --git a/fedn/network/storage/statestore/stores/status_store.py b/fedn/network/storage/statestore/stores/status_store.py index 78c063782..7dcfb42e1 100644 --- a/fedn/network/storage/statestore/stores/status_store.py +++ b/fedn/network/storage/statestore/stores/status_store.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store class Status: @@ -100,8 +100,11 @@ def from_row(row: StatusModel) -> Status: class SQLStatusStore(StatusStore, SQLStore[Status]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Status: - with Session() as session: + with self.Session() as session: stmt = select(StatusModel).where(StatusModel.id == id) item = session.scalars(stmt).first() @@ -114,7 +117,7 @@ def update(self, id, item): raise NotImplementedError def add(self, item: Status) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: sender = item["sender"] if "sender" in item else None status = StatusModel( @@ -137,7 +140,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(StatusModel) for key, value in kwargs.items(): @@ -189,7 +192,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": len(result), "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(StatusModel) for key, value in kwargs.items(): diff --git a/fedn/network/storage/statestore/stores/store.py b/fedn/network/storage/statestore/stores/store.py index a628a40e8..8563ec309 100644 --- a/fedn/network/storage/statestore/stores/store.py +++ b/fedn/network/storage/statestore/stores/store.py @@ -6,10 +6,9 @@ import pymongo from bson import ObjectId from pymongo.database import Database -from sqlalchemy import MetaData, create_engine -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker +from sqlalchemy import MetaData +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from fedn.common.config import get_statestore_config from fedn.network.storage.statestore.stores.shared import EntityNotFound, from_document T = TypeVar("T") @@ -133,7 +132,8 @@ def count(self, **kwargs) -> int: class SQLStore(Store[T]): - pass + def __init__(self, Session): + self.Session = Session constraint_naming_conventions = { @@ -154,23 +154,3 @@ class MyAbstractBase(Base): id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4())) committed_at: Mapped[datetime] = mapped_column(default=datetime.now()) - - -statestore_config = get_statestore_config() - -engine = None -Session = None - -if statestore_config["type"] in ["SQLite", "PostgreSQL"]: - if statestore_config["type"] == "SQLite": - engine = create_engine("sqlite:///my_database.db", echo=True) - elif statestore_config["type"] == "PostgreSQL": - postgres_config = statestore_config["postgres_config"] - username = postgres_config["username"] - password = postgres_config["password"] - host = postgres_config["host"] - port = postgres_config["port"] - - engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/fedn_db", echo=True) - - Session = sessionmaker(engine) diff --git a/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py index 694195d00..27c951ada 100644 --- a/fedn/network/storage/statestore/stores/validation_store.py +++ b/fedn/network/storage/statestore/stores/validation_store.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column from fedn.network.storage.statestore.stores.shared import EntityNotFound -from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, Session, SQLStore, Store +from fedn.network.storage.statestore.stores.store import MongoDBStore, MyAbstractBase, SQLStore, Store class Validation: @@ -97,8 +97,11 @@ def from_row(row: ValidationModel) -> Validation: class SQLValidationStore(ValidationStore, SQLStore[Validation]): + def __init__(self, Session): + super().__init__(Session) + def get(self, id: str) -> Validation: - with Session() as session: + with self.Session() as session: stmt = select(ValidationModel).where(ValidationModel.id == id) item = session.scalars(stmt).first() @@ -111,7 +114,7 @@ def update(self, id: str, item: Validation) -> bool: raise NotImplementedError("Update not implemented for ValidationStore") def add(self, item: Validation) -> Tuple[bool, Any]: - with Session() as session: + with self.Session() as session: sender = item["sender"] if "sender" in item else None receiver = item["receiver"] if "receiver" in item else None @@ -136,7 +139,7 @@ def delete(self, id: str) -> bool: raise NotImplementedError("Delete not implemented for ValidationStore") def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDING, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(ValidationModel) for key, value in kwargs.items(): @@ -202,7 +205,7 @@ def list(self, limit: int, skip: int, sort_key: str, sort_order=pymongo.DESCENDI return {"count": len(result), "result": result} def count(self, **kwargs): - with Session() as session: + with self.Session() as session: stmt = select(func.count()).select_from(ValidationModel) for key, value in kwargs.items():