From f820f2a21c7ce45cb8ce0255cf6029401c5589a9 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Fri, 26 Aug 2022 13:20:14 -0400 Subject: [PATCH 01/25] Alchemy --- src/orion/storage/sql.py | 417 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 src/orion/storage/sql.py diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py new file mode 100644 index 000000000..4031b0c95 --- /dev/null +++ b/src/orion/storage/sql.py @@ -0,0 +1,417 @@ +import contextlib +import logging +import datetime +import pickle + +import sqlalchemy +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer, JSON +from sqlalchemy import String, DateTime, select, delete, update +from sqlalchemy.orm import declarative_base, Session + +import orion.core +from orion.core.worker.trial import validate_status +from orion.storage.base import BaseStorageProtocol, LockedAlgorithmState, get_trial_uid_and_exp, get_uid + +log = logging.getLogger(__name__) + +Base = declarative_base() + +# fmt: off +class User(Base): + """Defines the User table""" + __tablename__ = "users" + + uid = Column(Integer, primary_key=True) + name = Column(String(30)) + token = Column(String(30)) + created_at = Column(DateTime) + last_seen = Column(DateTime) + + +class Experiment(Base): + """Defines the Experiment table""" + __tablename__ = "experiments" + + uid = Column(Integer, primary_key=True) + name = Column(String(30)) + config = Column(JSON) + version = Column(Integer) + owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + datetime = Column(DateTime) + + +class Trial: + """Defines the Trial table""" + __tablename__ = "trial" + + uid = Column(Integer, primary_key=True) + experiment_id = Column(Integer, ForeignKey("experiment.uid"), nullable=False) + owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + status = Column(String(30)) + results = Column(JSON) + start_time = Column(DateTime) + end_time = Column(DateTime) + heartbeat = Column(DateTime) + + +class Algo: + """Defines the Algo table""" + __tablename__ = "algo" + + uid = Column(Integer, primary_key=True) + experiment_id = Column(Integer, ForeignKey("experiment.uid"), nullable=False) + owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + configuration = Column(JSON) + locked = Column(Integer) + state = Column(JSON) + heartbeat = Column(DateTime) +# fmt: on + + + +class SQLAlchemy(BaseStorageProtocol): # noqa: F811 + """Implement a generic protocol to allow Orion to communicate using + different storage backend + + Parameters + ---------- + uri: str + PostgreSQL backend to use for storage; the format is as follow + `protocol://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]]` + + """ + + def __init__(self, uri): + self.engine = sqlalchemy.create_engine("", echo=True, future=True) + + # Create the schema + Base.metadata.create_all(self.engine) + + with Session(self.engine) as session: + stmt = select(User).where(User.token == self.token) + self.user = session.scalars(stmt).one() + + + # Experiment Operations + # ===================== + + def create_experiment(self, config): + """Insert a new experiment inside the database""" + + with Session(self.engine) as session: + experiment = Experiment( + name=config['name'], + config=config, + onwer_id=self.user.uid, + version=0 + ) + session.add(experiment) + session.commit() + + def delete_experiment(self, experiment, uid): + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = delete(Experiment).where(Experiment.uid == uid) + session.execute(stmt) + session.commit() + + def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): + uid = get_uid(experiment, uid) + + query = True + if uid is not None: + query = Experiment.uid == uid + + query = query and self._to_query(Experiment, where) + + with Session(self.engine) as session: + stmt = select(Experiment).where(query) + experiment = session.scalars(stmt).one() + experiment.config = kwargs + session.commit() + + def fetch_experiments(self, query, selection=None): + query = self._to_query(query) + + with Session(self.engine) as session: + stmt = select(Experiment).where(query) + experiments = session.scalars(stmt).all() + + if selection is not None: + assert False, 'Not Implemented' + + return experiments + + # Benchmarks + # ========== + + + # Trials + # ====== + def fetch_trials(self, experiment=None, uid=None, where=None): + uid = get_uid(experiment, uid) + + query = True + if uid is not None: + query = Trial.experiment_id == uid + + query = query and self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = select(Trial).where(query) + return session.scalars(stmt).all() + + def register_trial(self, trial): + config = trial.to_dict() + + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.uid == trial._id) + trial = session.scalars(stmt).one() + self._set_from_dict(trial, config) + session.commit() + + def delete_trials(self, experiment=None, uid=None, where=None): + uid = get_uid(experiment, uid) + + query = True + if uid is not None: + query = Trial.experiment_id == uid + + query = query and self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = delete(Trial).where(query) + session.execute(stmt) + session.commit() + + def retrieve_result(self, trial, **kwargs): + return trial + + def get_trial(self, trial=None, uid=None, experiment_uid=None): + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) + + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.experiment_id == experiment_uid and Trial.uid == trial_uid) + return session.scalars(stmt).one() + + def update_trials(self, experiment=None, uid=None, where=None, **kwargs): + uid = get_uid(experiment, uid) + query = Trial.uid == trial._id and self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = select(Trial).where(query) + trials = session.scalars(stmt).all() + for trial in trials: + self._set_from_dict(trial, kwargs) + session.commit() + + return trial + + def update_trial( + self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs + ): + + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) + query = Trial.uid == trial_uid and Trial.experiment_id == experiment_uid and self._to_query(where) + + with Session(self.engine) as session: + stmt = select(Trial).where(query) + trial = session.scalars(stmt).one() + self._set_from_dict(trial, kwargs) + session.commit() + + return trial + + def fetch_lost_trials(self, experiment): + heartbeat = orion.core.config.worker.heartbeat + threshold = datetime.datetime.utcnow() - datetime.timedelta( + seconds=heartbeat * 5 + ) + + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.experiment_id == experiment._id and Trial.status == 'reserved' and Trial.heartbeat < threshold) + return session.scalars(stmt).all() + + def push_trial_results(self, trial): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.experiment_id == trial.experiment and Trial.uid == trial.id and Trial.status == 'reserved') + trial = session.scalars(stmt).one() + self._set_from_dict(trial, trial.to_dict()) + session.commit() + + return trial + + def set_trial_status(self, trial, status, heartbeat=None, was=None): + validate_status(status) + validate_status(was) + + query = Trial.uid == trial.id # and Trial.experiment_id == trial.experiment + if was: + query = query and Trial.status == was + + values = dict(status=status, experiment=trial.experiment) + if heartbeat: + values['heartbeat'] = heartbeat + + with Session(self.engine) as session: + update(Trial).where(query).values(**values) + session.commit() + + def fetch_pending_trials(self, experiment): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.status.in_("interrupted", "new", "suspended") and Trial.experiment_id == experiment._id) + return session.scalars(stmt).all() + + def reserve_trial(self, experiment): + with Session(self.engine) as session: + + with session.begin(): + # not sure it prevents other worker from reserving the same trial + stmt = select(Trial).where(Trial.status.in_("interrupted", "new", "suspended") and Trial.experiment_id == experiment._id) + trial = session.scalars(stmt).one() + + now = datetime.datetime.utcnow() + trial.status = 'reserved' + trial.start_time = now + trial.heartbeat= now + + return trial + + def fetch_trials_by_status(self, experiment, status): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.status == status and Trial.experiment_id == experiment._id) + return session.scalars(stmt).all() + + def fetch_noncompleted_trials(self, experiment): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.status != 'completed' and Trial.experiment_id == experiment._id) + return session.scalars(stmt).all() + + def count_completed(self, experiment): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.status == 'completed' and Trial.experiment_id == experiment._id) + return session.query(stmt).count() + + def count_broken_trials(self, experiment): + with Session(self.engine) as session: + stmt = select(Trial).where(Trial.status == 'broken' and Trial.experiment_id == experiment._id) + return session.query(stmt).count() + + def update_heartbeat(self, trial): + """Update trial's heartbeat""" + + with Session(self.engine) as session: + update(Trial).where(Trial.uid == trial.id, Trial.status == 'reserved').values(heartbeat=datetime.datetime.utcnow()) + session.commit() + + # Algorithm + # ========= + def initialize_algorithm_lock(self, experiment_id, algorithm_config): + with Session(self.engine) as session: + algo = Algo( + experiment_id=experiment_id, + onwer_id=self.user.uid, + configuration=algorithm_config, + locked=0, + heartbeat=datetime.datetime.utcnow() + ) + session.add(algo) + session.commit() + + def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): + uid = get_uid(experiment, uid) + + values = dict( + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + if new_state is not None: + values["state"] = pickle.dumps(new_state) + + with Session(self.engine) as session: + update(Algo).where( + Algo.experiment_id == uid and + Algo.locked == 1 + ).values(**values) + + def get_algorithm_lock_info(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = select(Algo).where(Algo.experiment_id==uid) + algo = session.scalar(stmt).one() + + return LockedAlgorithmState( + state=pickle.loads(algo.state) if algo.state is not None else None, + configuration=algo.configuration, + locked=algo.locked, + ) + + def delete_algorithm_lock(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_algorithm_lock`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = delete(Algo).where(Algo.experiment_id==uid) + session.execute(stmt) + session.commit() + + @contextlib.contextmanager + def acquire_algorithm_lock(self, experiment=None, uid=None, timeout=60, retry_interval=1): + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = update(Algo).where(Algo.experiment_id==uid, Algo.locked==0).values( + locked=1, + heartbeat=datetime.datetime.utcnow() + ) + algo = session.scalar(stmt).one() + session.commit() + + if algo is None: + return + + algo_state = LockedAlgorithmState( + state=pickle.loads(algo.state) if algo.state is not None else None, + configuration=algo.configuration, + locked=True, + ) + + yield algo_state + + self.release_algorithm_lock(uid, new_state=algo_state.state) + + + # Utilities + # ========= + def _set_from_dict(self, obj, data, rest=None): + meta = dict() + while data: + k, v = data.popitem() + + if hasattr(obj, k): + setattr(obj, k, v) + else: + meta[k] = v + + if meta and rest: + setattr(obj, rest, meta) + return + + if meta: + log.warning("Data was discarded %s", meta) + + def _to_query(self, table, where): + query = True + for k, v in where.items(): + + if hash(table, k): + query = query and getattr(k) == v + else: + log.warning("constrained ignored %s = %s", k, v) + + return query From 1c533b89ddbe3180294fac4c1fe163b4b3d8a0fe Mon Sep 17 00:00:00 2001 From: Setepenre Date: Fri, 26 Aug 2022 13:21:19 -0400 Subject: [PATCH 02/25] - --- setup.py | 1 + src/orion/storage/sql.py | 141 ++++++++++++++++++++++++++------------- 2 files changed, 97 insertions(+), 45 deletions(-) diff --git a/setup.py b/setup.py index 12136245e..d0769e21e 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,7 @@ "BaseStorageProtocol": [ "track = orion.storage.track:Track", "legacy = orion.storage.legacy:Legacy", + "sqlalchemy = orion.storage.sql:SQLAlchemy", ], "BaseExecutor": [ "singleexecutor = orion.executor.single_backend:SingleExecutor", diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 4031b0c95..82b662344 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -1,18 +1,30 @@ import contextlib -import logging import datetime +import logging import pickle import sqlalchemy -from sqlalchemy import Column -from sqlalchemy import ForeignKey -from sqlalchemy import Integer, JSON -from sqlalchemy import String, DateTime, select, delete, update -from sqlalchemy.orm import declarative_base, Session +from sqlalchemy import ( + JSON, + Column, + DateTime, + ForeignKey, + Integer, + String, + delete, + select, + update, +) +from sqlalchemy.orm import Session, declarative_base import orion.core from orion.core.worker.trial import validate_status -from orion.storage.base import BaseStorageProtocol, LockedAlgorithmState, get_trial_uid_and_exp, get_uid +from orion.storage.base import ( + BaseStorageProtocol, + LockedAlgorithmState, + get_trial_uid_and_exp, + get_uid, +) log = logging.getLogger(__name__) @@ -70,7 +82,6 @@ class Algo: # fmt: on - class SQLAlchemy(BaseStorageProtocol): # noqa: F811 """Implement a generic protocol to allow Orion to communicate using different storage backend @@ -84,6 +95,20 @@ class SQLAlchemy(BaseStorageProtocol): # noqa: F811 """ def __init__(self, uri): + # dialect+driver://username:password@host:port/database + # + # postgresql://scott:tiger@localhost/mydatabase + # postgresql+psycopg2://scott:tiger@localhost/mydatabase + # postgresql+pg8000://scott:tiger@localhost/mydatabase + # + # mysql://scott:tiger@localhost/foo + # mysql+mysqldb://scott:tiger@localhost/foo + # mysql+pymysql://scott:tiger@localhost/foo + # + # sqlite:///foo.db + # sqlite:// # in memory + + # engine_from_config self.engine = sqlalchemy.create_engine("", echo=True, future=True) # Create the schema @@ -93,7 +118,6 @@ def __init__(self, uri): stmt = select(User).where(User.token == self.token) self.user = session.scalars(stmt).one() - # Experiment Operations # ===================== @@ -102,10 +126,7 @@ def create_experiment(self, config): with Session(self.engine) as session: experiment = Experiment( - name=config['name'], - config=config, - onwer_id=self.user.uid, - version=0 + name=config["name"], config=config, onwer_id=self.user.uid, version=0 ) session.add(experiment) session.commit() @@ -141,14 +162,13 @@ def fetch_experiments(self, query, selection=None): experiments = session.scalars(stmt).all() if selection is not None: - assert False, 'Not Implemented' + assert False, "Not Implemented" return experiments # Benchmarks # ========== - # Trials # ====== def fetch_trials(self, experiment=None, uid=None, where=None): @@ -194,12 +214,14 @@ def get_trial(self, trial=None, uid=None, experiment_uid=None): trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) with Session(self.engine) as session: - stmt = select(Trial).where(Trial.experiment_id == experiment_uid and Trial.uid == trial_uid) + stmt = select(Trial).where( + Trial.experiment_id == experiment_uid and Trial.uid == trial_uid + ) return session.scalars(stmt).one() def update_trials(self, experiment=None, uid=None, where=None, **kwargs): uid = get_uid(experiment, uid) - query = Trial.uid == trial._id and self._to_query(Trial, where) + query = Experiment.uid == uid and self._to_query(Trial, where) with Session(self.engine) as session: stmt = select(Trial).where(query) @@ -211,11 +233,15 @@ def update_trials(self, experiment=None, uid=None, where=None, **kwargs): return trial def update_trial( - self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs - ): + self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs + ): trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - query = Trial.uid == trial_uid and Trial.experiment_id == experiment_uid and self._to_query(where) + query = ( + Trial.uid == trial_uid + and Trial.experiment_id == experiment_uid + and self._to_query(where) + ) with Session(self.engine) as session: stmt = select(Trial).where(query) @@ -232,12 +258,20 @@ def fetch_lost_trials(self, experiment): ) with Session(self.engine) as session: - stmt = select(Trial).where(Trial.experiment_id == experiment._id and Trial.status == 'reserved' and Trial.heartbeat < threshold) + stmt = select(Trial).where( + Trial.experiment_id == experiment._id + and Trial.status == "reserved" + and Trial.heartbeat < threshold + ) return session.scalars(stmt).all() def push_trial_results(self, trial): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.experiment_id == trial.experiment and Trial.uid == trial.id and Trial.status == 'reserved') + stmt = select(Trial).where( + Trial.experiment_id == trial.experiment + and Trial.uid == trial.id + and Trial.status == "reserved" + ) trial = session.scalars(stmt).one() self._set_from_dict(trial, trial.to_dict()) session.commit() @@ -248,13 +282,13 @@ def set_trial_status(self, trial, status, heartbeat=None, was=None): validate_status(status) validate_status(was) - query = Trial.uid == trial.id # and Trial.experiment_id == trial.experiment + query = Trial.uid == trial.id # and Trial.experiment_id == trial.experiment if was: query = query and Trial.status == was values = dict(status=status, experiment=trial.experiment) if heartbeat: - values['heartbeat'] = heartbeat + values["heartbeat"] = heartbeat with Session(self.engine) as session: update(Trial).where(query).values(**values) @@ -262,7 +296,10 @@ def set_trial_status(self, trial, status, heartbeat=None, was=None): def fetch_pending_trials(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.status.in_("interrupted", "new", "suspended") and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status.in_("interrupted", "new", "suspended") + and Trial.experiment_id == experiment._id + ) return session.scalars(stmt).all() def reserve_trial(self, experiment): @@ -270,41 +307,54 @@ def reserve_trial(self, experiment): with session.begin(): # not sure it prevents other worker from reserving the same trial - stmt = select(Trial).where(Trial.status.in_("interrupted", "new", "suspended") and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status.in_("interrupted", "new", "suspended") + and Trial.experiment_id == experiment._id + ) trial = session.scalars(stmt).one() now = datetime.datetime.utcnow() - trial.status = 'reserved' + trial.status = "reserved" trial.start_time = now - trial.heartbeat= now + trial.heartbeat = now return trial def fetch_trials_by_status(self, experiment, status): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.status == status and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status == status and Trial.experiment_id == experiment._id + ) return session.scalars(stmt).all() def fetch_noncompleted_trials(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.status != 'completed' and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status != "completed" and Trial.experiment_id == experiment._id + ) return session.scalars(stmt).all() def count_completed(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.status == 'completed' and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status == "completed" and Trial.experiment_id == experiment._id + ) return session.query(stmt).count() def count_broken_trials(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where(Trial.status == 'broken' and Trial.experiment_id == experiment._id) + stmt = select(Trial).where( + Trial.status == "broken" and Trial.experiment_id == experiment._id + ) return session.query(stmt).count() def update_heartbeat(self, trial): """Update trial's heartbeat""" with Session(self.engine) as session: - update(Trial).where(Trial.uid == trial.id, Trial.status == 'reserved').values(heartbeat=datetime.datetime.utcnow()) + update(Trial).where( + Trial.uid == trial.id, Trial.status == "reserved" + ).values(heartbeat=datetime.datetime.utcnow()) session.commit() # Algorithm @@ -316,7 +366,7 @@ def initialize_algorithm_lock(self, experiment_id, algorithm_config): onwer_id=self.user.uid, configuration=algorithm_config, locked=0, - heartbeat=datetime.datetime.utcnow() + heartbeat=datetime.datetime.utcnow(), ) session.add(algo) session.commit() @@ -332,17 +382,16 @@ def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): values["state"] = pickle.dumps(new_state) with Session(self.engine) as session: - update(Algo).where( - Algo.experiment_id == uid and - Algo.locked == 1 - ).values(**values) + update(Algo).where(Algo.experiment_id == uid and Algo.locked == 1).values( + **values + ) def get_algorithm_lock_info(self, experiment=None, uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" uid = get_uid(experiment, uid) with Session(self.engine) as session: - stmt = select(Algo).where(Algo.experiment_id==uid) + stmt = select(Algo).where(Algo.experiment_id == uid) algo = session.scalar(stmt).one() return LockedAlgorithmState( @@ -356,18 +405,21 @@ def delete_algorithm_lock(self, experiment=None, uid=None): uid = get_uid(experiment, uid) with Session(self.engine) as session: - stmt = delete(Algo).where(Algo.experiment_id==uid) + stmt = delete(Algo).where(Algo.experiment_id == uid) session.execute(stmt) session.commit() @contextlib.contextmanager - def acquire_algorithm_lock(self, experiment=None, uid=None, timeout=60, retry_interval=1): + def acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): uid = get_uid(experiment, uid) with Session(self.engine) as session: - stmt = update(Algo).where(Algo.experiment_id==uid, Algo.locked==0).values( - locked=1, - heartbeat=datetime.datetime.utcnow() + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=datetime.datetime.utcnow()) ) algo = session.scalar(stmt).one() session.commit() @@ -385,7 +437,6 @@ def acquire_algorithm_lock(self, experiment=None, uid=None, timeout=60, retry_in self.release_algorithm_lock(uid, new_state=algo_state.state) - # Utilities # ========= def _set_from_dict(self, obj, data, rest=None): From a7b6a6d5e7c7cdc02f837ac49543d8875d659027 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Mon, 29 Aug 2022 11:15:59 -0400 Subject: [PATCH 03/25] Fix some update queries to make sure no race conditions occur --- src/orion/storage/sql.py | 92 ++++++++++++++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 82b662344..a4672ec95 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -302,23 +302,67 @@ def fetch_pending_trials(self, experiment): ) return session.scalars(stmt).all() + def _reserve_trial_postgre(self, experiment): + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + # In PostgrerSQL we can do single query + stmt = ( + update(Trial) + .where( + True + and Trial.status.in_("interrupted", "new", "suspended") + and Trial.experiment_id == experiment._id + ) + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + .limit(1) + .returning() + ) + trial = session.scalar(stmt) + return trial + def reserve_trial(self, experiment): + if False: + return self._reserve_trial_postgre(experiment) + + now = datetime.datetime.utcnow() + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status.in_("interrupted", "new", "suspended") + and Trial.experiment_id == experiment._id + ) + trial = session.scalars(stmt).one() - with session.begin(): - # not sure it prevents other worker from reserving the same trial - stmt = select(Trial).where( - Trial.status.in_("interrupted", "new", "suspended") + # Update the trial iff the status has not been changed yet + stmt = ( + update(Trial) + .where( + True + and Trial.status == trial.status and Trial.experiment_id == experiment._id ) - trial = session.scalars(stmt).one() + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + ) - now = datetime.datetime.utcnow() - trial.status = "reserved" - trial.start_time = now - trial.heartbeat = now + session.execute(stmt) - return trial + stmt = select(Trial).where(Trial.experiment_id == experiment._id) + trial = session.scalars(stmt).one() + + # time needs to match, could have been reserved by another worker + if trial.status == "reserved" and trial.heartbeat == now: + return trial + + return None def fetch_trials_by_status(self, experiment, status): with Session(self.engine) as session: @@ -409,6 +453,23 @@ def delete_algorithm_lock(self, experiment=None, uid=None): session.execute(stmt) session.commit() + def _acquire_algorithm_lock_postgre( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + with Session(self.engine) as session: + now = datetime.datetime.utcnow() + + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=now) + .returning() + ) + + algo = session.scalar(stmt).one() + session.commit() + return algo + @contextlib.contextmanager def acquire_algorithm_lock( self, experiment=None, uid=None, timeout=60, retry_interval=1 @@ -416,15 +477,20 @@ def acquire_algorithm_lock( uid = get_uid(experiment, uid) with Session(self.engine) as session: + now = datetime.datetime.utcnow() + stmt = ( update(Algo) .where(Algo.experiment_id == uid, Algo.locked == 0) - .values(locked=1, heartbeat=datetime.datetime.utcnow()) + .values(locked=1, heartbeat=now) ) - algo = session.scalar(stmt).one() + + session.execute(stmt) session.commit() - if algo is None: + algo = select(Algo).where(Algo.experiment_id == uid, Algo.locked == 1) + + if algo is None or algo.heartbead != now: return algo_state = LockedAlgorithmState( From 0b8ee89cfaca7c22bc43737e46e32fe9c75ca9fd Mon Sep 17 00:00:00 2001 From: Setepenre Date: Mon, 29 Aug 2022 16:40:17 -0400 Subject: [PATCH 04/25] Fix some queries --- setup.py | 1 + src/orion/core/__init__.py | 8 + src/orion/core/io/config.py | 3 + src/orion/storage/legacy.py | 2 +- src/orion/storage/sql.py | 410 +++++++++++++++++------- src/orion/testing/state.py | 13 +- tests/unittests/storage/test_storage.py | 67 ++-- 7 files changed, 374 insertions(+), 130 deletions(-) diff --git a/setup.py b/setup.py index d0769e21e..03bb05bfb 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ ], "dask": ["dask[complete]"], "track": ["track @ git+https://github.com/Delaunay/track@master#egg=track"], + "sqlalchemy": ["sqlalchemy"], "profet": ["emukit", "GPy", "torch", "pybnn"], "configspace": ["ConfigSpace"], "ax": [ diff --git a/src/orion/core/__init__.py b/src/orion/core/__init__.py index b1da52534..4d9434fd1 100644 --- a/src/orion/core/__init__.py +++ b/src/orion/core/__init__.py @@ -106,6 +106,14 @@ def define_storage_config(config): "type", option_type=str, default="legacy", env_var="ORION_STORAGE_TYPE" ) + storage_config.add_option( + "uri", option_type=str, default="", env_var="ORION_STORAGE_URI" + ) + + storage_config.add_option( + "token", option_type=str, default="", env_var="ORION_STORAGE_TOKEN" + ) + config.storage = storage_config define_database_config(config.storage) diff --git a/src/orion/core/io/config.py b/src/orion/core/io/config.py index b67ed4230..d8a59a1be 100644 --- a/src/orion/core/io/config.py +++ b/src/orion/core/io/config.py @@ -456,6 +456,9 @@ def to_dict(self): def from_dict(self, config): """Set the configuration from a dictionary""" + if config is None: + return + logger.debug("Setting config to %s", config) logger.debug("Config was %s", repr(self)) diff --git a/src/orion/storage/legacy.py b/src/orion/storage/legacy.py index 16484eb83..54f35a2bb 100644 --- a/src/orion/storage/legacy.py +++ b/src/orion/storage/legacy.py @@ -63,7 +63,7 @@ class Legacy(BaseStorageProtocol): """ - def __init__(self, database=None, setup=True): + def __init__(self, database=None, setup=True, **kwargs): self._db = setup_database(database) if setup: diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index a4672ec95..0ad4b3cb6 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -1,7 +1,10 @@ import contextlib import datetime +import getpass import logging import pickle +import uuid +from copy import deepcopy import sqlalchemy from sqlalchemy import ( @@ -11,16 +14,21 @@ ForeignKey, Integer, String, + UniqueConstraint, delete, select, update, ) +from sqlalchemy.exc import DBAPIError, NoResultFound from sqlalchemy.orm import Session, declarative_base import orion.core +from orion.core.io.database import DuplicateKeyError +from orion.core.worker.trial import Trial as OrionTrial from orion.core.worker.trial import validate_status from orion.storage.base import ( BaseStorageProtocol, + FailedUpdate, LockedAlgorithmState, get_trial_uid_and_exp, get_uid, @@ -35,9 +43,9 @@ class User(Base): """Defines the User table""" __tablename__ = "users" - uid = Column(Integer, primary_key=True) + _id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(30)) - token = Column(String(30)) + token = Column(String(32)) created_at = Column(DateTime) last_seen = Column(DateTime) @@ -46,35 +54,52 @@ class Experiment(Base): """Defines the Experiment table""" __tablename__ = "experiments" - uid = Column(Integer, primary_key=True) + _id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(30)) - config = Column(JSON) + meta = Column(JSON) # metadata field is reserved version = Column(Integer) - owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) datetime = Column(DateTime) + algorithms = Column(JSON) + remaining = Column(JSON) + space = Column(JSON) + __table_args__ = ( + UniqueConstraint('name', 'owner_id', name='_one_name_per_owner'), + ) -class Trial: + +class Trial(Base): """Defines the Trial table""" - __tablename__ = "trial" + __tablename__ = "trials" - uid = Column(Integer, primary_key=True) - experiment_id = Column(Integer, ForeignKey("experiment.uid"), nullable=False) - owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) status = Column(String(30)) results = Column(JSON) start_time = Column(DateTime) end_time = Column(DateTime) heartbeat = Column(DateTime) + parent = Column(Integer, ForeignKey("experiments._id")) + params = Column(JSON) + worker = Column(JSON) + submit_time = Column(String(30)) + exp_working_dir = Column(String(30)) + id = Column(String(30)) + + __table_args__ = ( + UniqueConstraint('experiment_id', 'id', name='_one_trial_hash_per_experiment'), + ) -class Algo: +class Algo(Base): """Defines the Algo table""" __tablename__ = "algo" - uid = Column(Integer, primary_key=True) - experiment_id = Column(Integer, ForeignKey("experiment.uid"), nullable=False) - owner_id = Column(Integer, ForeignKey("user.uid"), nullable=False) + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) configuration = Column(JSON) locked = Column(Integer) state = Column(JSON) @@ -94,7 +119,7 @@ class SQLAlchemy(BaseStorageProtocol): # noqa: F811 """ - def __init__(self, uri): + def __init__(self, uri, token=None, **kwargs): # dialect+driver://username:password@host:port/database # # postgresql://scott:tiger@localhost/mydatabase @@ -108,34 +133,91 @@ def __init__(self, uri): # sqlite:///foo.db # sqlite:// # in memory + self.uri = uri + if uri == "": + uri = "sqlite://" + # engine_from_config - self.engine = sqlalchemy.create_engine("", echo=True, future=True) + self.engine = sqlalchemy.create_engine(uri, echo=True, future=True) # Create the schema Base.metadata.create_all(self.engine) - with Session(self.engine) as session: - stmt = select(User).where(User.token == self.token) - self.user = session.scalars(stmt).one() + self.token = token + self.user_id = None + self.user = None + self._connect(token) + + def _connect(self, token): + if token is not None and token != "": + with Session(self.engine) as session: + stmt = select(User).where(User.token == self.token) + self.user = session.scalars(stmt).one() + + self.user_id = self.user._id + else: + # Local database, create a default user + user = getpass.getuser() + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + self.user = User( + name=user, + token=uuid.uuid5(uuid.NAMESPACE_OID, user).hex, + created_at=now, + last_seen=now, + ) + session.add(self.user) + session.commit() + + assert self.user._id > 0 + self.user_id = self.user._id + + def __getstate__(self): + return dict( + uri=self.uri, + token=self.token, + ) + + def __setstate__(self, state): + self.uri = state["uri"] + self.token = state["token"] + self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) + self._connect(self.token) # Experiment Operations # ===================== def create_experiment(self, config): """Insert a new experiment inside the database""" + config = deepcopy(config) - with Session(self.engine) as session: - experiment = Experiment( - name=config["name"], config=config, onwer_id=self.user.uid, version=0 - ) - session.add(experiment) - session.commit() + try: + with Session(self.engine) as session: + experiment = Experiment( + owner_id=self.user_id, + version=0, + ) + + config["meta"] = config.pop("metadata") - def delete_experiment(self, experiment, uid): + # old way + # if 'space' not in config: + # config['space'] = config['meta'].pop('priors', dict()) + + self._set_from_dict(experiment, config, "remaining") + + session.add(experiment) + session.commit() + + except DBAPIError: + raise DuplicateKeyError() + + def delete_experiment(self, experiment=None, uid=None): uid = get_uid(experiment, uid) with Session(self.engine) as session: - stmt = delete(Experiment).where(Experiment.uid == uid) + stmt = delete(Experiment).where(Experiment._id == uid) session.execute(stmt) session.commit() @@ -143,28 +225,35 @@ def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): uid = get_uid(experiment, uid) query = True + if where is None: + where = dict() + if uid is not None: - query = Experiment.uid == uid + where["_id"] = uid - query = query and self._to_query(Experiment, where) + query = self._to_query(Experiment, where) with Session(self.engine) as session: - stmt = select(Experiment).where(query) + stmt = select(Experiment).where(*query) experiment = session.scalars(stmt).one() - experiment.config = kwargs + + metadata = kwargs.pop("metadata", dict()) + self._set_from_dict(experiment, kwargs, "remaining") + experiment.meta.update(metadata) + session.commit() def fetch_experiments(self, query, selection=None): - query = self._to_query(query) + query = self._to_query(Experiment, query) with Session(self.engine) as session: - stmt = select(Experiment).where(query) + stmt = select(Experiment).where(*query) experiments = session.scalars(stmt).all() if selection is not None: assert False, "Not Implemented" - return experiments + return [self._to_experiment(exp) for exp in experiments] # Benchmarks # ========== @@ -174,39 +263,56 @@ def fetch_experiments(self, query, selection=None): def fetch_trials(self, experiment=None, uid=None, where=None): uid = get_uid(experiment, uid) - query = True + if where is None: + where = dict() + if uid is not None: - query = Trial.experiment_id == uid + where["experiment_id"] = uid - query = query and self._to_query(Trial, where) + query = self._to_query(Trial, where) with Session(self.engine) as session: - stmt = select(Trial).where(query) + stmt = select(Trial).where(*query) return session.scalars(stmt).all() def register_trial(self, trial): config = trial.to_dict() - with Session(self.engine) as session: - stmt = select(Trial).where(Trial.uid == trial._id) - trial = session.scalars(stmt).one() - self._set_from_dict(trial, config) - session.commit() + try: + with Session(self.engine) as session: + experiment_id = config.pop("experiment", None) + + db_trial = Trial(experiment_id=experiment_id, owner_id=self.user_id) + + self._set_from_dict(db_trial, config) + + session.add(db_trial) + session.commit() + + session.refresh(db_trial) + + return OrionTrial(**self._to_trial(db_trial)) + except DBAPIError: + raise DuplicateKeyError() def delete_trials(self, experiment=None, uid=None, where=None): uid = get_uid(experiment, uid) - query = True + if where is None: + where = dict() + if uid is not None: - query = Trial.experiment_id == uid + where["experiment_id"] = uid - query = query and self._to_query(Trial, where) + query = self._to_query(Trial, where) with Session(self.engine) as session: - stmt = delete(Trial).where(query) - session.execute(stmt) + stmt = delete(Trial).where(*query) + count = session.execute(stmt) session.commit() + return count.rowcount + def retrieve_result(self, trial, **kwargs): return trial @@ -215,37 +321,55 @@ def get_trial(self, trial=None, uid=None, experiment_uid=None): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.experiment_id == experiment_uid and Trial.uid == trial_uid + Trial.experiment_id == experiment_uid, + Trial.id == trial_uid, ) - return session.scalars(stmt).one() + trial = session.scalars(stmt).one() + + return OrionTrial(**self._to_trial(trial)) def update_trials(self, experiment=None, uid=None, where=None, **kwargs): uid = get_uid(experiment, uid) - query = Experiment.uid == uid and self._to_query(Trial, where) + + if where is None: + where = dict() + + where["experiment_id"] = uid + query = self._to_query(Trial, where) with Session(self.engine) as session: - stmt = select(Trial).where(query) + stmt = select(Trial).where(*query) trials = session.scalars(stmt).all() + for trial in trials: self._set_from_dict(trial, kwargs) + session.commit() - return trial + return len(trials) def update_trial( self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs ): trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - query = ( - Trial.uid == trial_uid - and Trial.experiment_id == experiment_uid - and self._to_query(where) - ) + + if where is None: + where = dict() + + # THIS IS NOT THE UNIQUE ID OF THE TRIAL + where["id"] = trial_uid + where["experiment_id"] = experiment_uid + query = self._to_query(Trial, where) + + print() + print(query) + print() with Session(self.engine) as session: - stmt = select(Trial).where(query) + stmt = select(Trial).where(*query) trial = session.scalars(stmt).one() + self._set_from_dict(trial, kwargs) session.commit() @@ -259,18 +383,18 @@ def fetch_lost_trials(self, experiment): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.experiment_id == experiment._id - and Trial.status == "reserved" - and Trial.heartbeat < threshold + Trial.experiment_id == experiment._id, + Trial.status == "reserved", + Trial.heartbeat < threshold, ) return session.scalars(stmt).all() def push_trial_results(self, trial): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.experiment_id == trial.experiment - and Trial.uid == trial.id - and Trial.status == "reserved" + Trial.experiment_id == trial.experiment, + Trial._id == trial.id, + Trial.status == "reserved", ) trial = session.scalars(stmt).one() self._set_from_dict(trial, trial.to_dict()) @@ -279,26 +403,37 @@ def push_trial_results(self, trial): return trial def set_trial_status(self, trial, status, heartbeat=None, was=None): + heartbeat = heartbeat or datetime.datetime.utcnow() + was = was or trial.status + validate_status(status) validate_status(was) - query = Trial.uid == trial.id # and Trial.experiment_id == trial.experiment - if was: - query = query and Trial.status == was + query = [ + Trial.id == trial.id, + Trial.experiment_id == trial.experiment, + Trial.status == was, + ] - values = dict(status=status, experiment=trial.experiment) + values = dict(status=status) if heartbeat: values["heartbeat"] = heartbeat with Session(self.engine) as session: - update(Trial).where(query).values(**values) + stmt = update(Trial).where(*query).values(**values) + result = session.execute(stmt) session.commit() + if result.rowcount == 1: + trial.status = status + else: + raise FailedUpdate() + def fetch_pending_trials(self, experiment): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.status.in_("interrupted", "new", "suspended") - and Trial.experiment_id == experiment._id + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, ) return session.scalars(stmt).all() @@ -310,9 +445,8 @@ def _reserve_trial_postgre(self, experiment): stmt = ( update(Trial) .where( - True - and Trial.status.in_("interrupted", "new", "suspended") - and Trial.experiment_id == experiment._id + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, ) .values( status="reserved", @@ -333,18 +467,20 @@ def reserve_trial(self, experiment): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.status.in_("interrupted", "new", "suspended") - and Trial.experiment_id == experiment._id + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, ) - trial = session.scalars(stmt).one() + try: + trial = session.scalars(stmt).one() + except NoResultFound: + return None # Update the trial iff the status has not been changed yet stmt = ( update(Trial) .where( - True - and Trial.status == trial.status - and Trial.experiment_id == experiment._id + Trial.status == trial.status, + Trial.experiment_id == experiment._id, ) .values( status="reserved", @@ -369,45 +505,67 @@ def fetch_trials_by_status(self, experiment, status): stmt = select(Trial).where( Trial.status == status and Trial.experiment_id == experiment._id ) - return session.scalars(stmt).all() + return [ + OrionTrial(**self._to_trial(trial)) + for trial in session.scalars(stmt).all() + ] def fetch_noncompleted_trials(self, experiment): with Session(self.engine) as session: stmt = select(Trial).where( - Trial.status != "completed" and Trial.experiment_id == experiment._id + Trial.status != "completed", + Trial.experiment_id == experiment._id, ) return session.scalars(stmt).all() - def count_completed(self, experiment): + def count_completed_trials(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status == "completed" and Trial.experiment_id == experiment._id + return ( + session.query(Trial) + .filter( + Trial.status == "completed", + Trial.experiment_id == experiment._id, + ) + .count() ) - return session.query(stmt).count() def count_broken_trials(self, experiment): with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status == "broken" and Trial.experiment_id == experiment._id + return ( + session.query(Trial) + .filter( + Trial.status == "broken", + Trial.experiment_id == experiment._id, + ) + .count() ) - return session.query(stmt).count() def update_heartbeat(self, trial): """Update trial's heartbeat""" with Session(self.engine) as session: - update(Trial).where( - Trial.uid == trial.id, Trial.status == "reserved" - ).values(heartbeat=datetime.datetime.utcnow()) + stmt = ( + update(Trial) + .where( + Trial._id == trial.id_override, + Trial.status == "reserved", + ) + .values(heartbeat=datetime.datetime.utcnow()) + ) + + cursor = session.execute(stmt) session.commit() + if cursor.rowcount <= 0: + raise FailedUpdate() + # Algorithm # ========= def initialize_algorithm_lock(self, experiment_id, algorithm_config): with Session(self.engine) as session: algo = Algo( experiment_id=experiment_id, - onwer_id=self.user.uid, + owner_id=self.user._id, configuration=algorithm_config, locked=0, heartbeat=datetime.datetime.utcnow(), @@ -415,8 +573,8 @@ def initialize_algorithm_lock(self, experiment_id, algorithm_config): session.add(algo) session.commit() - def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): - uid = get_uid(experiment, uid) + def release_algorithm_lock(self, experiment=None, _id=None, new_state=None): + _id = get_uid(experiment, _id) values = dict( locked=0, @@ -426,13 +584,13 @@ def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): values["state"] = pickle.dumps(new_state) with Session(self.engine) as session: - update(Algo).where(Algo.experiment_id == uid and Algo.locked == 1).values( + update(Algo).where(Algo.experiment_id == _id and Algo.locked == 1).values( **values ) def get_algorithm_lock_info(self, experiment=None, uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" - uid = get_uid(experiment, uid) + _id = get_uid(experiment, uid) with Session(self.engine) as session: stmt = select(Algo).where(Algo.experiment_id == uid) @@ -472,25 +630,27 @@ def _acquire_algorithm_lock_postgre( @contextlib.contextmanager def acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=60, retry_interval=1 + self, experiment=None, _id=None, timeout=60, retry_interval=1 ): - uid = get_uid(experiment, uid) + _id = get_uid(experiment, _id) with Session(self.engine) as session: now = datetime.datetime.utcnow() stmt = ( update(Algo) - .where(Algo.experiment_id == uid, Algo.locked == 0) + .where(Algo.experiment_id == _id, Algo.locked == 0) .values(locked=1, heartbeat=now) ) session.execute(stmt) session.commit() - algo = select(Algo).where(Algo.experiment_id == uid, Algo.locked == 1) + stmt = select(Algo).where(Algo.experiment_id == _id, Algo.locked == 1) + algo = session.scalar(stmt) - if algo is None or algo.heartbead != now: + if algo is None or algo.heartbeat != now: + yield None return algo_state = LockedAlgorithmState( @@ -501,15 +661,19 @@ def acquire_algorithm_lock( yield algo_state - self.release_algorithm_lock(uid, new_state=algo_state.state) + self.release_algorithm_lock(_id, new_state=algo_state.state) # Utilities # ========= def _set_from_dict(self, obj, data, rest=None): + data = deepcopy(data) meta = dict() while data: k, v = data.popitem() + if v is None: + continue + if hasattr(obj, k): setattr(obj, k, v) else: @@ -523,12 +687,42 @@ def _set_from_dict(self, obj, data, rest=None): log.warning("Data was discarded %s", meta) def _to_query(self, table, where): - query = True - for k, v in where.items(): + query = [] - if hash(table, k): - query = query and getattr(k) == v + for k, v in where.items(): + if hasattr(table, k): + query.append(getattr(table, k) == v) else: log.warning("constrained ignored %s = %s", k, v) return query + + def _to_experiment(self, experiment): + exp = deepcopy(experiment.__dict__) + exp["metadata"] = exp.pop("meta", {}) + exp.pop("_sa_instance_state") + exp.pop("owner_id") + exp.pop("datetime") + + none_keys = [] + for k, v in exp.items(): + if v is None: + none_keys.append(k) + + for k in none_keys: + exp.pop(k) + + rest = exp.pop("remaining", {}) + if rest is None: + rest = {} + + exp.update(rest) + + return exp + + def _to_trial(self, trial): + trial = deepcopy(trial.__dict__) + trial.pop("_sa_instance_state") + trial["experiment"] = trial.pop("experiment_id") + trial.pop("owner_id") + return trial diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index 33cae686c..f98686984 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -92,6 +92,7 @@ def __init__( self._workers = _select(workers, []) self._resources = _select(resources, []) self._lies = _select(lies, []) + self.expname_to_uid = dict() # In case of track we also store the inserted object # so the user can compare in tests the different values @@ -107,7 +108,7 @@ def init(self, config): def get_experiment(self, name, version=None): """Make experiment id deterministic""" - exp = experiment_builder.build(name=name, version=version) + exp = experiment_builder.build(name=name, version=version, storage=self.storage) return exp def get_trial(self, index): @@ -139,7 +140,15 @@ def _set_tables(self): for exp in self._experiments: self.storage.create_experiment(exp) + exp = self.storage.fetch_experiments(dict(name=exp["name"]))[0] + self.expname_to_uid[exp["name"]] = exp["_id"] + for trial in self._trials: + exp_id = self.expname_to_uid.get(trial["experiment"], None) + + if exp_id is not None: + trial["experiment"] = exp_id + nt = self.storage.register_trial(Trial(**trial)) self.trials.append(nt.to_dict()) @@ -248,7 +257,7 @@ def init(self, config): def get_experiment(self, name, version=None): """Make experiment id deterministic""" - exp = experiment_builder.build(name, version=version) + exp = experiment_builder.build(name, version=version, storage=self.storage) exp._id = exp.name return exp diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index fb70cc0a8..fbffadfc8 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -28,7 +28,10 @@ log = logging.getLogger(__name__) log.setLevel(logging.WARNING) -storage_backends = [None] # defaults to legacy with PickleDB +storage_backends = [ + None, + dict(type="sqlalchemy", uri="sqlite://"), +] # defaults to legacy with PickleDB if not HAS_TRACK: log.warning("Track is not tested because: %s!", REASON) @@ -282,9 +285,16 @@ def test_delete_experiment(self, storage): def test_register_trial(self, storage): """Test register trial""" + global base_trial + + new_trial = base_trial + if storage and storage["type"] == "sqlalchemy": + new_trial = copy.deepcopy(base_trial) + new_trial["experiment"] = 1 + with OrionState(experiments=[base_experiment], storage=storage) as cfg: storage = cfg.storage - trial1 = storage.register_trial(Trial(**base_trial)) + trial1 = storage.register_trial(Trial(**new_trial)) trial2 = storage.get_trial(trial1) assert ( @@ -298,8 +308,11 @@ def test_register_duplicate_trial(self, storage): ) as cfg: storage = cfg.storage + # Get the trial with its experiment_id populated + trial = cfg.trials[0] + with pytest.raises(DuplicateKeyError): - storage.register_trial(Trial(**base_trial)) + storage.register_trial(Trial(**trial)) def test_update_trials(self, storage): """Test update many trials""" @@ -414,21 +427,29 @@ def test_delete_all_trials(self, storage): trial_from_other_exp = copy.deepcopy(trials[0]) trial_from_other_exp["experiment"] = "other" trials.append(trial_from_other_exp) + + other_experiment = copy.deepcopy(base_experiment) + other_experiment["name"] = "other" + with OrionState( - experiments=[base_experiment], trials=trials, storage=storage + experiments=[base_experiment, other_experiment], + trials=trials, + storage=storage, ) as cfg: storage = cfg.storage + experiment_uid = cfg.expname_to_uid.get("default_name", "default_name") # Make sure we have sufficient trials to test deletion - trials = storage.fetch_trials(uid="default_name") + trials = storage.fetch_trials(uid=experiment_uid) assert len(trials) > 2 - count = storage.delete_trials(uid="default_name") + count = storage.delete_trials(uid=experiment_uid) assert count == len(trials) - assert storage.fetch_trials(uid="default_name") == [] + assert storage.fetch_trials(uid=experiment_uid) == [] # Make sure trials from other experiments were not deleted - assert len(storage.fetch_trials(uid="other")) == 1 + other_uid = cfg.expname_to_uid.get("other", "other") + assert len(storage.fetch_trials(uid=other_uid)) == 1 def test_delete_trials_with_query(self, storage): """Test delete experiment trials matching a query""" @@ -439,8 +460,14 @@ def test_delete_trials_with_query(self, storage): trial_from_other_exp = copy.deepcopy(trials[0]) trial_from_other_exp["experiment"] = "other" trials.append(trial_from_other_exp) + + other_experiment = copy.deepcopy(base_experiment) + other_experiment["name"] = "other" + with OrionState( - experiments=[base_experiment], trials=trials, storage=storage + experiments=[base_experiment, other_experiment], + trials=trials, + storage=storage, ) as cfg: storage = cfg.storage experiment = cfg.get_experiment("default_name") @@ -453,7 +480,8 @@ def test_delete_trials_with_query(self, storage): assert len(trials) > len(trials_with_status) # Test deletion - count = storage.delete_trials(uid="default_name", where={"status": status}) + experiment_uid = cfg.expname_to_uid.get("default_name", "default_name") + count = storage.delete_trials(uid=experiment_uid, where={"status": status}) assert count == len(trials_with_status) assert storage.fetch_trials_by_status(experiment, status) == [] assert len(storage.fetch_trials(experiment)) == len(trials) - len( @@ -461,7 +489,8 @@ def test_delete_trials_with_query(self, storage): ) # Make sure trials from other experiments were not deleted - assert len(storage.fetch_trials(uid="other")) == 1 + other_uid = cfg.expname_to_uid.get("other", "other") + assert len(storage.fetch_trials(uid=other_uid)) == 1 def test_get_trial(self, storage): """Test get trial""" @@ -512,15 +541,15 @@ def check_status_change(new_status): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = setup_storage().get_trial(cfg.get_trial(0)) + trial = cfg.storage.get_trial(cfg.get_trial(0)) assert trial is not None, "was not able to retrieve trial for test" - setup_storage().set_trial_status(trial, status=new_status) + cfg.storage.set_trial_status(trial, status=new_status) assert ( trial.status == new_status ), "Trial status should have been updated locally" - trial = setup_storage().get_trial(trial) + trial = cfg.storage.get_trial(trial) assert ( trial.status == new_status ), "Trial status should have been updated in the storage" @@ -537,11 +566,11 @@ def test_change_status_invalid(self, storage): with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: - trial = setup_storage().get_trial(cfg.get_trial(0)) + trial = cfg.storage.get_trial(cfg.get_trial(0)) assert trial is not None, "Was not able to retrieve trial for test" with pytest.raises(ValueError) as exc: - setup_storage().set_trial_status(trial, status="moo") + cfg.storage.set_trial_status(trial, status="moo") assert exc.match("Given status `moo` not one of") @@ -561,7 +590,7 @@ def check_status_change(new_status): with pytest.raises(FailedUpdate): trial.status = new_status - setup_storage().set_trial_status(trial, status=new_status) + cfg.storage.set_trial_status(trial, status=new_status) check_status_change("completed") check_status_change("broken") @@ -589,9 +618,9 @@ def check_status_change(new_status): trial.status = "broken" assert correct_status != "broken" with pytest.raises(FailedUpdate): - setup_storage().set_trial_status(trial, status=new_status) + cfg.storage.set_trial_status(trial, status=new_status) - setup_storage().set_trial_status( + cfg.storage.set_trial_status( trial, status=new_status, was=correct_status ) From 30ad72f75be095c4f1eb86b0215eb30af8052edf Mon Sep 17 00:00:00 2001 From: Setepenre Date: Tue, 30 Aug 2022 14:43:07 -0400 Subject: [PATCH 05/25] Misc --- .github/workflows/build.yml | 2 +- src/orion/core/evc/experiment.py | 8 +- src/orion/core/io/experiment_builder.py | 4 +- .../interactive_commands/branching_prompt.py | 2 +- src/orion/core/io/resolve_config.py | 6 +- src/orion/core/utils/compat.py | 30 ++ src/orion/storage/sql.py | 267 +++++++++++++----- src/orion/testing/state.py | 2 + tests/conftest.py | 2 +- .../commands/test_insert_command.py | 15 +- tests/stress/client/stress_experiment.py | 152 +++++++--- tests/stress/requirements.txt | 3 + tests/unittests/core/conftest.py | 2 +- tests/unittests/storage/test_storage.py | 15 +- 14 files changed, 388 insertions(+), 122 deletions(-) create mode 100644 src/orion/core/utils/compat.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 91d04d00f..51c8400cb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -145,7 +145,7 @@ jobs: test_no_extras: needs: [pre-commit, pretest] - runs-on: ubuntu-latest + runs-on: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v1 - name: Set up Python 3.9 diff --git a/src/orion/core/evc/experiment.py b/src/orion/core/evc/experiment.py index a1e54255f..40b940106 100644 --- a/src/orion/core/evc/experiment.py +++ b/src/orion/core/evc/experiment.py @@ -100,8 +100,13 @@ def parent(self): """ if self._parent is None and self._no_parent_lookup: + parent_id = self.item.refers.get("parent_id") + + if parent_id is None: + return self._parent + self._no_parent_lookup = False - query = {"_id": self.item.refers.get("parent_id")} + query = {"_id": parent_id} selection = {"name": 1, "version": 1} experiments = self.storage.fetch_experiments(query, selection) @@ -113,6 +118,7 @@ def parent(self): storage=self.storage, ) self.set_parent(exp_node) + return self._parent @property diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 59912cb58..c19b4e0e6 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -77,7 +77,6 @@ import copy import datetime -import getpass import logging import pprint import sys @@ -94,6 +93,7 @@ from orion.core.io.experiment_branch_builder import ExperimentBranchBuilder from orion.core.io.interactive_commands.branching_prompt import BranchingPrompt from orion.core.io.space_builder import SpaceBuilder +from orion.core.utils.compat import getuser from orion.core.utils.exceptions import ( BranchingEvent, NoConfigurationError, @@ -864,7 +864,7 @@ def _default(v: T | None, default: V) -> T | V: max_broken = _default(max_broken, orion.core.config.experiment.max_broken) working_dir = _default(working_dir, orion.core.config.experiment.working_dir) - metadata = _default(metadata, {"user": _default(user, getpass.getuser())}) + metadata = _default(metadata, {"user": _default(user, getuser())}) refers = _default(refers, dict(parent_id=None, root_id=None, adapter=[])) refers["adapter"] = _instantiate_adapters(refers.get("adapter", [])) # type: ignore diff --git a/src/orion/core/io/interactive_commands/branching_prompt.py b/src/orion/core/io/interactive_commands/branching_prompt.py index 5c1ab8b58..e674536e2 100644 --- a/src/orion/core/io/interactive_commands/branching_prompt.py +++ b/src/orion/core/io/interactive_commands/branching_prompt.py @@ -12,12 +12,12 @@ import functools import io import os -import readline import shlex import traceback from orion.algo.space import Dimension from orion.core.evc import adapters, conflicts +from orion.core.utils.compat import readline from orion.core.utils.diff import green, red readline.set_completer_delims(" ") diff --git a/src/orion/core/io/resolve_config.py b/src/orion/core/io/resolve_config.py index d8bc2f0a9..c9c87ea1c 100644 --- a/src/orion/core/io/resolve_config.py +++ b/src/orion/core/io/resolve_config.py @@ -4,7 +4,6 @@ """ import copy -import getpass import hashlib import logging import os @@ -16,6 +15,7 @@ import orion import orion.core from orion.core.io.orion_cmdline_parser import OrionCmdlineParser +from orion.core.utils.compat import getuser from orion.core.utils.flatten import unflatten @@ -267,7 +267,7 @@ def fetch_env_vars(): def fetch_metadata(user=None, user_args=None, user_script_config=None): """Infer rest information about the process + versioning""" - metadata = {"user": user if user else getpass.getuser()} + metadata = {"user": user if user else getuser()} metadata["orion_version"] = orion.core.__version__ @@ -300,7 +300,7 @@ def fetch_metadata(user=None, user_args=None, user_script_config=None): def update_metadata(metadata): """Update information about the process + versioning""" - metadata.setdefault("user", getpass.getuser()) + metadata.setdefault("user", getuser()) metadata["orion_version"] = orion.core.__version__ if not metadata.get("user_args"): diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py new file mode 100644 index 000000000..61cf7ba23 --- /dev/null +++ b/src/orion/core/utils/compat.py @@ -0,0 +1,30 @@ +"""Windows compatibility utilities""" +import os + + +def getuser(): + """getpass use pwd which is UNIX only""" + + if os.name == 'nt': + return os.getlogin() + + import getpass + return getpass.getuser() + + +class _readline: + def set_completer_delims(*args, **kwargs): + """Fake method for windows""" + pass + + +def get_readline(): + """Fake readline interface, readline is UNIX only""" + if os.name == 'nt': + return _readline + + import readline + return readline + + +readline = get_readline() diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 0ad4b3cb6..0ab682803 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -1,17 +1,23 @@ import contextlib import datetime -import getpass import logging import pickle +import time import uuid from copy import deepcopy import sqlalchemy + +# Use MongoDB json serializer +from bson.json_util import dumps as to_json +from bson.json_util import loads as from_json from sqlalchemy import ( + BINARY, JSON, Column, DateTime, ForeignKey, + Index, Integer, String, UniqueConstraint, @@ -20,15 +26,18 @@ update, ) from sqlalchemy.exc import DBAPIError, NoResultFound +from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session, declarative_base import orion.core from orion.core.io.database import DuplicateKeyError +from orion.core.utils.compat import getuser from orion.core.worker.trial import Trial as OrionTrial from orion.core.worker.trial import validate_status from orion.storage.base import ( BaseStorageProtocol, FailedUpdate, + LockAcquisitionTimeout, LockedAlgorithmState, get_trial_uid_and_exp, get_uid, @@ -38,6 +47,13 @@ Base = declarative_base() + +@compiles(BINARY, "postgresql") +def compile_binary_postgresql(type_, compiler, **kw): + """Postgresql does not know about Binary type we should byte array instead""" + return "BYTEA" + + # fmt: off class User(Base): """Defines the User table""" @@ -66,6 +82,7 @@ class Experiment(Base): __table_args__ = ( UniqueConstraint('name', 'owner_id', name='_one_name_per_owner'), + Index('idx_experiment_name_version', 'name', 'version'), ) @@ -81,7 +98,7 @@ class Trial(Base): start_time = Column(DateTime) end_time = Column(DateTime) heartbeat = Column(DateTime) - parent = Column(Integer, ForeignKey("experiments._id")) + parent = Column(Integer, ForeignKey("trials._id"), nullable=True) params = Column(JSON) worker = Column(JSON) submit_time = Column(String(30)) @@ -90,6 +107,12 @@ class Trial(Base): __table_args__ = ( UniqueConstraint('experiment_id', 'id', name='_one_trial_hash_per_experiment'), + Index('idx_trial_experiment_id', 'experiment_id'), + Index('idx_trial_status', 'status'), + # Can't put an index on json + # Index('idx_trial_results', 'results'), + Index('idx_trial_start_time', 'start_time'), + Index('idx_trial_end_time', 'end_time'), ) @@ -97,16 +120,26 @@ class Algo(Base): """Defines the Algo table""" __tablename__ = "algo" + # it is one algo per experiment so we could set experiment_id as the primary key + # and make it a 1-1 relation _id = Column(Integer, primary_key=True, autoincrement=True) experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) configuration = Column(JSON) locked = Column(Integer) - state = Column(JSON) + state = Column(BINARY) heartbeat = Column(DateTime) + + __table_args__ = ( + Index('idx_algo_experiment_id', 'experiment_id'), + ) # fmt: on +def get_tables(): + return [User, Experiment, Trial, Algo, User] + + class SQLAlchemy(BaseStorageProtocol): # noqa: F811 """Implement a generic protocol to allow Orion to communicate using different storage backend @@ -130,7 +163,8 @@ def __init__(self, uri, token=None, **kwargs): # mysql+mysqldb://scott:tiger@localhost/foo # mysql+pymysql://scott:tiger@localhost/foo # - # sqlite:///foo.db + # sqlite:///foo.db # relative + # sqlite:////foo.db # absolute # sqlite:// # in memory self.uri = uri @@ -138,7 +172,13 @@ def __init__(self, uri, token=None, **kwargs): uri = "sqlite://" # engine_from_config - self.engine = sqlalchemy.create_engine(uri, echo=True, future=True) + self.engine = sqlalchemy.create_engine( + uri, + echo=True, + future=True, + json_serializer=to_json, + json_deserializer=from_json, + ) # Create the schema Base.metadata.create_all(self.engine) @@ -157,7 +197,7 @@ def _connect(self, token): self.user_id = self.user._id else: # Local database, create a default user - user = getpass.getuser() + user = getuser() now = datetime.datetime.utcnow() with Session(self.engine) as session: @@ -183,6 +223,11 @@ def __setstate__(self, state): self.uri = state["uri"] self.token = state["token"] self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) + + if self.uri == "sqlite://" or self.uri == "": + log.warning("You are serializing an in-memory database, data will be lost") + Base.metadata.create_all(self.engine) + self._connect(self.token) # Experiment Operations @@ -190,7 +235,7 @@ def __setstate__(self, state): def create_experiment(self, config): """Insert a new experiment inside the database""" - config = deepcopy(config) + cpy = deepcopy(config) try: with Session(self.engine) as session: @@ -199,21 +244,19 @@ def create_experiment(self, config): version=0, ) - config["meta"] = config.pop("metadata") - - # old way - # if 'space' not in config: - # config['space'] = config['meta'].pop('priors', dict()) - - self._set_from_dict(experiment, config, "remaining") + cpy["meta"] = cpy.pop("metadata") + self._set_from_dict(experiment, cpy, "remaining") session.add(experiment) session.commit() + session.refresh(experiment) + config.update(self._to_experiment(experiment)) except DBAPIError: raise DuplicateKeyError() def delete_experiment(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" uid = get_uid(experiment, uid) with Session(self.engine) as session: @@ -222,11 +265,10 @@ def delete_experiment(self, experiment=None, uid=None): session.commit() def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" uid = get_uid(experiment, uid) - query = True - if where is None: - where = dict() + where = self._get_query(where) if uid is not None: where["_id"] = uid @@ -243,17 +285,43 @@ def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): session.commit() + def _fetch_experiments_with_select(self, query, selection=None): + query = self._get_query(query) + + where = self._to_query(Experiment, query) + + with Session(self.engine) as session: + columns = self._selection(Experiment, selection) + stmt = select(columns).where(*where) + + rows = session.execute(stmt).all() + + results = [] + + for row in rows: + obj = dict() + for value, k in zip(row, columns): + obj[str(k).split(".")[-1]] = value + results.append(obj) + + return results + def fetch_experiments(self, query, selection=None): - query = self._to_query(Experiment, query) + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" + if selection: + return self._fetch_experiments_with_select(query, selection) + + query = self._get_query(query) + where = self._to_query(Experiment, query) with Session(self.engine) as session: - stmt = select(Experiment).where(*query) - experiments = session.scalars(stmt).all() + stmt = select(Experiment).where(*where) - if selection is not None: - assert False, "Not Implemented" + experiments = session.scalars(stmt).all() - return [self._to_experiment(exp) for exp in experiments] + r = [self._to_experiment(exp) for exp in experiments] + print("RESULT", r) + return r # Benchmarks # ========== @@ -261,10 +329,10 @@ def fetch_experiments(self, query, selection=None): # Trials # ====== def fetch_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" uid = get_uid(experiment, uid) - if where is None: - where = dict() + where = self._get_query(where) if uid is not None: where["experiment_id"] = uid @@ -276,6 +344,7 @@ def fetch_trials(self, experiment=None, uid=None, where=None): return session.scalars(stmt).all() def register_trial(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" config = trial.to_dict() try: @@ -296,10 +365,10 @@ def register_trial(self, trial): raise DuplicateKeyError() def delete_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" uid = get_uid(experiment, uid) - if where is None: - where = dict() + where = self._get_query(where) if uid is not None: where["experiment_id"] = uid @@ -314,9 +383,13 @@ def delete_trials(self, experiment=None, uid=None, where=None): return count.rowcount def retrieve_result(self, trial, **kwargs): + """Updates the results array""" + new_trial = self.get_trial(trial) + trial.results = new_trial.results return trial def get_trial(self, trial=None, uid=None, experiment_uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) with Session(self.engine) as session: @@ -329,11 +402,10 @@ def get_trial(self, trial=None, uid=None, experiment_uid=None): return OrionTrial(**self._to_trial(trial)) def update_trials(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" uid = get_uid(experiment, uid) - if where is None: - where = dict() - + where = self._get_query(where) where["experiment_id"] = uid query = self._to_query(Trial, where) @@ -351,21 +423,16 @@ def update_trials(self, experiment=None, uid=None, where=None, **kwargs): def update_trial( self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs ): - + """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - if where is None: - where = dict() + where = self._get_query(where) # THIS IS NOT THE UNIQUE ID OF THE TRIAL where["id"] = trial_uid where["experiment_id"] = experiment_uid query = self._to_query(Trial, where) - print() - print(query) - print() - with Session(self.engine) as session: stmt = select(Trial).where(*query) trial = session.scalars(stmt).one() @@ -376,6 +443,7 @@ def update_trial( return trial def fetch_lost_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" heartbeat = orion.core.config.worker.heartbeat threshold = datetime.datetime.utcnow() - datetime.timedelta( seconds=heartbeat * 5 @@ -390,6 +458,7 @@ def fetch_lost_trials(self, experiment): return session.scalars(stmt).all() def push_trial_results(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" with Session(self.engine) as session: stmt = select(Trial).where( Trial.experiment_id == trial.experiment, @@ -403,6 +472,7 @@ def push_trial_results(self, trial): return trial def set_trial_status(self, trial, status, heartbeat=None, was=None): + """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" heartbeat = heartbeat or datetime.datetime.utcnow() was = was or trial.status @@ -430,6 +500,7 @@ def set_trial_status(self, trial, status, heartbeat=None, was=None): raise FailedUpdate() def fetch_pending_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" with Session(self.engine) as session: stmt = select(Trial).where( Trial.status.in_(("interrupted", "new", "suspended")), @@ -460,6 +531,7 @@ def _reserve_trial_postgre(self, experiment): return trial def reserve_trial(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" if False: return self._reserve_trial_postgre(experiment) @@ -501,6 +573,7 @@ def reserve_trial(self, experiment): return None def fetch_trials_by_status(self, experiment, status): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" with Session(self.engine) as session: stmt = select(Trial).where( Trial.status == status and Trial.experiment_id == experiment._id @@ -511,6 +584,7 @@ def fetch_trials_by_status(self, experiment, status): ] def fetch_noncompleted_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" with Session(self.engine) as session: stmt = select(Trial).where( Trial.status != "completed", @@ -519,6 +593,7 @@ def fetch_noncompleted_trials(self, experiment): return session.scalars(stmt).all() def count_completed_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" with Session(self.engine) as session: return ( session.query(Trial) @@ -530,6 +605,7 @@ def count_completed_trials(self, experiment): ) def count_broken_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" with Session(self.engine) as session: return ( session.query(Trial) @@ -562,6 +638,7 @@ def update_heartbeat(self, trial): # Algorithm # ========= def initialize_algorithm_lock(self, experiment_id, algorithm_config): + """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" with Session(self.engine) as session: algo = Algo( experiment_id=experiment_id, @@ -573,8 +650,10 @@ def initialize_algorithm_lock(self, experiment_id, algorithm_config): session.add(algo) session.commit() - def release_algorithm_lock(self, experiment=None, _id=None, new_state=None): - _id = get_uid(experiment, _id) + def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): + """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" + + uid = get_uid(experiment, uid) values = dict( locked=0, @@ -584,17 +663,27 @@ def release_algorithm_lock(self, experiment=None, _id=None, new_state=None): values["state"] = pickle.dumps(new_state) with Session(self.engine) as session: - update(Algo).where(Algo.experiment_id == _id and Algo.locked == 1).values( - **values + stmt = ( + update(Algo) + .where( + Algo.experiment_id == uid, + Algo.locked == 1, + ) + .values(**values) ) + session.execute(stmt) + session.commit() def get_algorithm_lock_info(self, experiment=None, uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" - _id = get_uid(experiment, uid) + uid = get_uid(experiment, uid) with Session(self.engine) as session: stmt = select(Algo).where(Algo.experiment_id == uid) - algo = session.scalar(stmt).one() + algo = session.scalar(stmt) + + if algo is None: + return None return LockedAlgorithmState( state=pickle.loads(algo.state) if algo.state is not None else None, @@ -608,9 +697,11 @@ def delete_algorithm_lock(self, experiment=None, uid=None): with Session(self.engine) as session: stmt = delete(Algo).where(Algo.experiment_id == uid) - session.execute(stmt) + cursor = session.execute(stmt) session.commit() + return cursor.rowcount + def _acquire_algorithm_lock_postgre( self, experiment=None, uid=None, timeout=60, retry_interval=1 ): @@ -628,43 +719,86 @@ def _acquire_algorithm_lock_postgre( session.commit() return algo - @contextlib.contextmanager - def acquire_algorithm_lock( - self, experiment=None, _id=None, timeout=60, retry_interval=1 + def _acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=60, retry_interval=1 ): - _id = get_uid(experiment, _id) + uid = get_uid(experiment, uid) + algo_state_lock = None + start = time.perf_counter() with Session(self.engine) as session: - now = datetime.datetime.utcnow() + while algo_state_lock is None and time.perf_counter() - start < timeout: + now = datetime.datetime.utcnow() - stmt = ( - update(Algo) - .where(Algo.experiment_id == _id, Algo.locked == 0) - .values(locked=1, heartbeat=now) - ) + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=now) + ) - session.execute(stmt) - session.commit() + cursor = session.execute(stmt) + session.commit() - stmt = select(Algo).where(Algo.experiment_id == _id, Algo.locked == 1) - algo = session.scalar(stmt) + if cursor.rowcount == 0: + time.sleep(retry_interval) + else: + stmt = select(Algo).where( + Algo.experiment_id == uid, Algo.locked == 1 + ) + algo_state_lock = session.scalar(stmt) + break - if algo is None or algo.heartbeat != now: - yield None - return + if algo_state_lock is None: + raise LockAcquisitionTimeout() - algo_state = LockedAlgorithmState( - state=pickle.loads(algo.state) if algo.state is not None else None, - configuration=algo.configuration, + if algo_state_lock.state is not None: + state = pickle.loads(algo_state_lock.state) + else: + state = None + + return LockedAlgorithmState( + state=state, + configuration=algo_state_lock.configuration, locked=True, ) - yield algo_state + @contextlib.contextmanager + def acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + """See :func:`orion.storage.base.BaseStorageProtocol.acquire_algorithm_lock`""" + locked_algo_state = self._acquire_algorithm_lock( + experiment, uid, timeout, retry_interval + ) - self.release_algorithm_lock(_id, new_state=algo_state.state) + try: + yield locked_algo_state + except Exception: + # Reset algo to state fetched lock time + locked_algo_state.reset() + raise + finally: + uid = get_uid(experiment, uid) + self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) # Utilities # ========= + def _get_query(self, query): + if query is None: + query = dict() + + query["owner_id"] = self.user_id + return query + + def _selection(self, table, selection): + selected = [] + + for k, v in selection.items(): + if hasattr(table, k) and v: + selected.append(getattr(table, k)) + + return selected + def _set_from_dict(self, obj, data, rest=None): data = deepcopy(data) meta = dict() @@ -717,7 +851,6 @@ def _to_experiment(self, experiment): rest = {} exp.update(rest) - return exp def _to_trial(self, trial): diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index f98686984..f705de2b7 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -143,6 +143,8 @@ def _set_tables(self): exp = self.storage.fetch_experiments(dict(name=exp["name"]))[0] self.expname_to_uid[exp["name"]] = exp["_id"] + self.storage.initialize_algorithm_lock(exp["_id"], exp.get("algorithms")) + for trial in self._trials: exp_id = self.expname_to_uid.get(trial["experiment"], None) diff --git a/tests/conftest.py b/tests/conftest.py index ca74830d8..656a20b69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ #!/usr/bin/env python """Common fixtures and utils for unittests and functional tests.""" -import getpass import os import numpy @@ -10,6 +9,7 @@ import orion.core import orion.core.utils.backward as backward +import orion.core.utils.compat as getpass from orion.algo.base import BaseAlgorithm from orion.algo.space import Space from orion.core.io import resolve_config diff --git a/tests/functional/commands/test_insert_command.py b/tests/functional/commands/test_insert_command.py index b93dd3d2b..bcabb4478 100644 --- a/tests/functional/commands/test_insert_command.py +++ b/tests/functional/commands/test_insert_command.py @@ -5,6 +5,7 @@ import pytest import orion.core.cli +import orion.core.utils.compat as getpass def get_user_corneau(): @@ -15,7 +16,7 @@ def get_user_corneau(): def test_insert_invalid_experiment(storage, monkeypatch, capsys): """Test the insertion of an invalid experiment""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) returncode = orion.core.cli.main( [ @@ -40,7 +41,7 @@ def test_insert_invalid_experiment(storage, monkeypatch, capsys): def test_insert_single_trial(storage, monkeypatch, script_path): """Try to insert a single trial""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ @@ -74,7 +75,7 @@ def test_insert_single_trial(storage, monkeypatch, script_path): def test_insert_single_trial_default_value(storage, monkeypatch): """Try to insert a single trial using a default value""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ @@ -106,7 +107,7 @@ def test_insert_single_trial_default_value(storage, monkeypatch): def test_insert_with_no_default_value(monkeypatch): """Try to insert a single trial by omitting a namespace with no default value""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -127,7 +128,7 @@ def test_insert_with_no_default_value(monkeypatch): def test_insert_with_incorrect_namespace(monkeypatch): """Try to insert a single trial with a namespace not inside the experiment space""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -149,7 +150,7 @@ def test_insert_with_incorrect_namespace(monkeypatch): def test_insert_with_outside_bound_value(monkeypatch): """Try to insert a single trial with value outside the distribution's interval""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -171,7 +172,7 @@ def test_insert_with_outside_bound_value(monkeypatch): def test_insert_two_hyperparameters(storage, monkeypatch): """Try to insert a single trial with two hyperparameters""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr("getpass.getuser", get_user_corneau) + monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ "insert", diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index 3433e3c40..432601c32 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -4,6 +4,7 @@ import random import time import traceback +from collections import OrderedDict from multiprocessing import Pool import matplotlib.pyplot as plt @@ -14,6 +15,114 @@ from orion.core.utils.exceptions import ReservationTimeout DB_FILE = "stress.pkl" +SQLITE_FILE = "db.sqlite" + +ADDRESS = "192.168.0.16" + +# +# Create the stress test user +# +# MongoDB +# +# mongosh +# > use admin +# > db.createUser({ +# user: "user", +# pwd: "pass", +# roles: [ +# {role: 'readWrite', db: 'stress'}, +# ] +# }) +# +# PostgreSQL -- DO NOT USE THIS IN PROD - TESTING ONLY +# +# # Switch to the user running the database +# sudo su postgres +# +# # open an interactive connection to the server +# psql +# > CREATE USER username WITH PASSWORD 'pass'; +# > CREATE ROLE orion_database_admin; +# > CREATE ROLE orion_database_user LOGIN; +# > GRANT orion_database_user, orion_database_user TO username; +# > +# > GRANT pg_write_all_data, pg_read_all_data TO username; +# > CREATE DATABASE stress OWNER orion_database_admin; +# \q +# +# > + + +BACKENDS_CONFIGS = OrderedDict( + [ + # ("pickleddb", {"type": "legacy", "database": {"type": "pickleddb", "host": DB_FILE}}), + # ("sqlite", {"type": "sqlalchemy", "uri": f"sqlite:///{SQLITE_FILE}"}), + ( + "postgresql", + { + "type": "sqlalchemy", + "uri": f"postgresql://username:pass@{ADDRESS}/stress", + }, + ), + ( + "mongodb", + { + "type": "legacy", + "database": { + "type": "mongodb", + "name": "stress", + "host": f"mongodb://user:pass@{ADDRESS}", + }, + }, + ), + ] +) + + +def cleanup_storage(backend): + if backend == "pickleddb": + if os.path.exists(DB_FILE): + os.remove(DB_FILE) + + elif backend == "sqlite": + if os.path.exists(SQLITE_FILE): + os.remove(SQLITE_FILE) + + elif backend == "postgresql": + import sqlalchemy + from sqlalchemy.orm import Session + + from orion.storage.sql import get_tables + + engine = sqlalchemy.create_engine( + f"postgresql://username:pass@{ADDRESS}/stress", + echo=True, + future=True, + ) + + # if the tables are missing just skip + for table in get_tables(): + try: + with Session(engine) as session: + session.execute(f"DROP TABLE {table.__tablename__} CASCADE;") + session.commit() + except: + traceback.print_exc() + + elif backend == "mongodb": + client = MongoClient( + host=ADDRESS, username="user", password="pass", authSource="stress" + ) + database = client.stress + database.experiments.drop() + database.lying_trials.drop() + database.trials.drop() + database.workers.drop() + database.resources.drop() + client.close() + + else: + raise RuntimeError("You need to cleam your backend") def f(x, worker): @@ -40,14 +149,7 @@ def get_experiment(storage, space_type, size): This defines `max_trials`, and the size of the search space (`uniform(0, size)`). """ - if storage == "pickleddb": - storage_config = {"type": "pickleddb", "host": DB_FILE} - elif storage == "mongodb": - storage_config = { - "type": "mongodb", - "name": "stress", - "host": "mongodb://user:pass@localhost", - } + storage_config = BACKENDS_CONFIGS[storage] discrete = space_type == "discrete" high = size # * 2 @@ -58,7 +160,7 @@ def get_experiment(storage, space_type, size): max_trials=size, max_idle_time=60 * 5, algorithms={"random": {"seed": None if space_type == "real" else 1}}, - storage={"type": "legacy", "database": storage_config}, + storage=storage_config, ) @@ -130,18 +232,7 @@ def stress_test(storage, space_type, workers, size): List of all trials at the end of the stress test """ - if storage == "pickleddb": - if os.path.exists(DB_FILE): - os.remove(DB_FILE) - elif storage == "mongodb": - client = MongoClient(username="user", password="pass", authSource="stress") - database = client.stress - database.experiments.drop() - database.lying_trials.drop() - database.trials.drop() - database.workers.drop() - database.resources.drop() - client.close() + cleanup_storage(storage) print("Worker | Point") @@ -170,18 +261,7 @@ def stress_test(storage, space_type, workers, size): trials = experiment.fetch_trials() - if storage == "pickleddb": - os.remove(DB_FILE) - elif storage == "mongodb": - client = MongoClient(username="user", password="pass", authSource="stress") - database = client.stress - database.experiments.drop() - database.lying_trials.drop() - database.trials.drop() - database.workers.drop() - database.resources.drop() - client.close() - + cleanup_storage(storage) return trials @@ -246,8 +326,10 @@ def benchmark(workers, size): """ results = {} - for backend in ["mongodb", "pickleddb"]: + for backend in BACKENDS_CONFIGS.keys(): for space_type in ["discrete", "real", "real-seeded"]: + print(backend, space_type) + trials = stress_test(backend, space_type, workers, size) results[(backend, space_type)] = get_timestamps(trials, size, space_type) @@ -274,7 +356,7 @@ def main(): results[workers] = benchmark(workers, size) - for backend in ["mongodb", "pickleddb"]: + for backend in BACKENDS_CONFIGS.keys(): for space_type in ["discrete", "real", "real-seeded"]: x, y = results[workers][(backend, space_type)] axis[i].plot(x, y, label=f"{backend}-{space_type}") diff --git a/tests/stress/requirements.txt b/tests/stress/requirements.txt index 6ccafc3f9..931f5f7cd 100644 --- a/tests/stress/requirements.txt +++ b/tests/stress/requirements.txt @@ -1 +1,4 @@ matplotlib +psycopg2 +readline +sqlalchemy diff --git a/tests/unittests/core/conftest.py b/tests/unittests/core/conftest.py index 947b8921d..5894aa0a9 100644 --- a/tests/unittests/core/conftest.py +++ b/tests/unittests/core/conftest.py @@ -2,7 +2,6 @@ """Common fixtures and utils for tests.""" import copy -import getpass import os import pytest @@ -10,6 +9,7 @@ import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward +import orion.core.utils.compat as getpass from orion.algo.space import Categorical, Integer, Real, Space from orion.core.evc import conflicts from orion.core.io.convert import JSONConverter, YAMLConverter diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index fbffadfc8..fb7b9e4c4 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -29,9 +29,10 @@ log.setLevel(logging.WARNING) storage_backends = [ - None, - dict(type="sqlalchemy", uri="sqlite://"), -] # defaults to legacy with PickleDB + None, # defaults to legacy with PickleDB + dict(type="sqlalchemy", uri="sqlite:///${file}"), # Temporary file + dict(type="sqlalchemy", uri="sqlite://"), # In-memory +] if not HAS_TRACK: log.warning("Track is not tested because: %s!", REASON) @@ -781,6 +782,14 @@ def test_update_heartbeat(self, storage): def test_serializable(self, storage): """Test storage can be serialized""" + if ( + storage + and storage["type"] == "sqlalchemy" + and storage["uri"] == "sqlite://" + ): + # Cannot serialize an in-memory database + return + with OrionState( experiments=[base_experiment], trials=generate_trials(), storage=storage ) as cfg: From 88a2be0cd0f9f32f475395d29ec15471885d8dcf Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 31 Aug 2022 11:15:46 -0400 Subject: [PATCH 06/25] Tweaks --- src/orion/core/evc/experiment.py | 580 ++++++++-------- src/orion/core/io/database/pickleddb.py | 619 ++++++++--------- src/orion/core/utils/compat.py | 29 +- src/orion/storage/sql.py | 153 +++-- tests/stress/client/stress_experiment.py | 821 ++++++++++++----------- 5 files changed, 1150 insertions(+), 1052 deletions(-) diff --git a/src/orion/core/evc/experiment.py b/src/orion/core/evc/experiment.py index 40b940106..2245e3502 100644 --- a/src/orion/core/evc/experiment.py +++ b/src/orion/core/evc/experiment.py @@ -1,293 +1,287 @@ -# pylint:disable=protected-access -""" -Experiment node for EVC -======================= - -Experiment nodes connecting experiments to the EVC tree - -The experiments are connected to one another through the experiment nodes. The former can be created -standalone without an EVC tree. When connected to an `ExperimentNode`, the experiments gain access -to trials of other experiments by using method `ExperimentNode.fetch_trials`. - -Helper functions are provided to fetch trials keeping the tree structure. Those can be helpful when -analyzing an EVC tree. - -""" -import functools -import logging - -from orion.core.utils.tree import TreeNode - -log = logging.getLogger(__name__) - - -class ExperimentNode(TreeNode): - """Experiment node to connect experiments to EVC tree. - - The node carries an experiment in attribute `item`. The node can be instantiated only using the - name of the experiment. The experiment will be created lazily on access to `node.item`. - - Attributes - ---------- - name: str - Name of the experiment - item: None or :class:`orion.core.worker.experiment.Experiment` - None if the experiment is not initialized yet. When initializing lazily, it creates an - `Experiment` in read only mode. - - .. seealso:: - - :py:class:`orion.core.utils.tree.TreeNode` for tree-specific attributes and methods. - - """ - - __slots__ = ( - "name", - "version", - "_no_parent_lookup", - "_no_children_lookup", - "storage", - ) + TreeNode.__slots__ - - def __init__( - self, - name, - version, - experiment=None, - parent=None, - children=tuple(), - storage=None, - ): - """Initialize experiment node with item, experiment, parent and children - - .. seealso:: - :class:`orion.core.utils.tree.TreeNode` for information about the attributes - """ - super().__init__(experiment, parent, children) - self.name = name - self.version = version - - self._no_parent_lookup = True - self._no_children_lookup = True - self.storage = storage or experiment._storage - - @property - def item(self): - """Get the experiment associated to the node - - Note that accessing `item` may trigger the lazy initialization of the experiment if it was - not done already. - """ - if self._item is None: - # TODO: Find another way around the circular import - from orion.core.io import experiment_builder - - self._item = experiment_builder.load( - name=self.name, version=self.version, storage=self.storage - ) - self._item._node = self - - return self._item - - @property - def parent(self): - """Get parent of the experiment, None if no parent - - .. note:: - - The instantiation of an EVC tree is lazy, which means accessing the parent of a node - may trigger a call to database to build this parent live. - - """ - if self._parent is None and self._no_parent_lookup: - parent_id = self.item.refers.get("parent_id") - - if parent_id is None: - return self._parent - - self._no_parent_lookup = False - query = {"_id": parent_id} - selection = {"name": 1, "version": 1} - experiments = self.storage.fetch_experiments(query, selection) - - if experiments: - parent = experiments[0] - exp_node = ExperimentNode( - name=parent["name"], - version=parent.get("version", 1), - storage=self.storage, - ) - self.set_parent(exp_node) - - return self._parent - - @property - def children(self): - """Get children of the experiment, empty list if no children - - .. note:: - - The instantiation of an EVC tree is lazy, which means accessing the children of a node - may trigger a call to database to build those children live. - - """ - if self._no_children_lookup: - self._children = [] - self._no_children_lookup = False - query = {"refers.parent_id": self.item.id} - selection = {"name": 1, "version": 1} - experiments = self.storage.fetch_experiments(query, selection) - for child in experiments: - self.add_children( - ExperimentNode( - name=child["name"], - version=child.get("version", 1), - storage=self.storage, - ) - ) - - return self._children - - @property - def adapter(self): - """Get the adapter of the experiment with respect to its parent""" - return self.item.refers["adapter"] - - @property - def tree_name(self): - """Return a formatted name of the Node for a tree pretty-print.""" - if self.item is not None: - return f"{self.name}-v{self.item.version}" - - return self.name - - def fetch_lost_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_lost_trials") - - def fetch_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_trials") - - def fetch_pending_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_pending_trials") - - def fetch_noncompleted_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_noncompleted_trials") - - def fetch_trials_by_status(self, status): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_trials_by_status", status=status) - - def recurvise_fetch(self, fun_name, *args, **kwargs): - """Fetch trials recursively in the EVC tree using the fetch function `fun_name`. - - Parameters - ---------- - fun_name: callable - Function name to call to fetch trials. The function must be an attribute of - :class:`orion.core.worker.experiment.Experiment` - - *args: - Positional arguments to pass to `fun_name`. - - **kwargs - Keyword arguments to pass to `fun_name`. - - """ - - def retrieve_trials(node, parent_or_children): - """Retrieve the trials of a node/experiment.""" - fun = getattr(node.item, fun_name) - # with_evc_tree needs to be False here or we will have an infinite loop - trials = fun(*args, with_evc_tree=False, **kwargs) - return dict(trials=trials, experiment=node.item), parent_or_children - - # get the trials of the parents - parent_trials = None - if self.parent is not None: - parent_trials = self.parent.map(retrieve_trials, self.parent.parent) - - # get the trials of the children - children_trials = self.map(retrieve_trials, self.children) - children_trials.set_parent(parent_trials) - - adapt_trials(children_trials) - - return sum((node.item["trials"] for node in children_trials.root), []) - - -def _adapt_parent_trials(node, parent_trials_node, ids): - """Adapt trials from the parent recursively - - .. note:: - - To call with node.map(fct, node.parent) to connect with parents - - """ - # Ids from children are passed to prioritized them if they are also present in parent nodes. - node_ids = { - trial.compute_trial_hash(trial, ignore_lie=True) - for trial in node.item["trials"] - } | ids - if parent_trials_node is not None: - adapter = node.item["experiment"].refers["adapter"] - for parent in parent_trials_node.root: - parent.item["trials"] = adapter.forward(parent.item["trials"]) - - # if trial is in current exp, filter out - parent.item["trials"] = [ - trial - for trial in parent.item["trials"] - if trial.compute_trial_hash( - trial, ignore_lie=True, ignore_experiment=True - ) - not in node_ids - ] - - return node.item, parent_trials_node - - -def _adapt_children_trials(node, children_trials_nodes): - """Adapt trials from the children recursively - - .. note:: - - To call with node.map(fct, node.children) to connect with children - - """ - ids = { - trial.compute_trial_hash(trial, ignore_lie=True) - for trial in node.item["trials"] - } - - for child in children_trials_nodes: - adapter = child.item["experiment"].refers["adapter"] - for subchild in child: # Includes child itself - subchild.item["trials"] = adapter.backward(subchild.item["trials"]) - - # if trial is in current node, filter out - subchild.item["trials"] = [ - trial - for trial in subchild.item["trials"] - if trial.compute_trial_hash( - trial, ignore_lie=True, ignore_experiment=True - ) - not in ids - ] - - return node.item, children_trials_nodes - - -def adapt_trials(trials_tree): - """Adapt trials recursively so that they are all compatible with current experiment.""" - trials_tree.map(_adapt_children_trials, trials_tree.children) - ids = set() - for child in trials_tree.children: - for trial in child.item["trials"]: - ids.add(trial.compute_trial_hash(trial, ignore_lie=True)) - trials_tree.map( - functools.partial(_adapt_parent_trials, ids=ids), trials_tree.parent - ) +# pylint:disable=protected-access +""" +Experiment node for EVC +======================= + +Experiment nodes connecting experiments to the EVC tree + +The experiments are connected to one another through the experiment nodes. The former can be created +standalone without an EVC tree. When connected to an `ExperimentNode`, the experiments gain access +to trials of other experiments by using method `ExperimentNode.fetch_trials`. + +Helper functions are provided to fetch trials keeping the tree structure. Those can be helpful when +analyzing an EVC tree. + +""" +import functools +import logging + +from orion.core.utils.tree import TreeNode + +log = logging.getLogger(__name__) + + +class ExperimentNode(TreeNode): + """Experiment node to connect experiments to EVC tree. + + The node carries an experiment in attribute `item`. The node can be instantiated only using the + name of the experiment. The experiment will be created lazily on access to `node.item`. + + Attributes + ---------- + name: str + Name of the experiment + item: None or :class:`orion.core.worker.experiment.Experiment` + None if the experiment is not initialized yet. When initializing lazily, it creates an + `Experiment` in read only mode. + + .. seealso:: + + :py:class:`orion.core.utils.tree.TreeNode` for tree-specific attributes and methods. + + """ + + __slots__ = ( + "name", + "version", + "_no_parent_lookup", + "_no_children_lookup", + "storage", + ) + TreeNode.__slots__ + + def __init__( + self, + name, + version, + experiment=None, + parent=None, + children=tuple(), + storage=None, + ): + """Initialize experiment node with item, experiment, parent and children + + .. seealso:: + :class:`orion.core.utils.tree.TreeNode` for information about the attributes + """ + super().__init__(experiment, parent, children) + self.name = name + self.version = version + + self._no_parent_lookup = True + self._no_children_lookup = True + self.storage = storage or experiment._storage + + @property + def item(self): + """Get the experiment associated to the node + + Note that accessing `item` may trigger the lazy initialization of the experiment if it was + not done already. + """ + if self._item is None: + # TODO: Find another way around the circular import + from orion.core.io import experiment_builder + + self._item = experiment_builder.load( + name=self.name, version=self.version, storage=self.storage + ) + self._item._node = self + + return self._item + + @property + def parent(self): + """Get parent of the experiment, None if no parent + + .. note:: + + The instantiation of an EVC tree is lazy, which means accessing the parent of a node + may trigger a call to database to build this parent live. + + """ + if self._parent is None and self._no_parent_lookup: + self._no_parent_lookup = False + query = {"_id": self.item.refers.get("parent_id")} + selection = {"name": 1, "version": 1} + experiments = self.storage.fetch_experiments(query, selection) + + if experiments: + parent = experiments[0] + exp_node = ExperimentNode( + name=parent["name"], + version=parent.get("version", 1), + storage=self.storage, + ) + self.set_parent(exp_node) + return self._parent + + @property + def children(self): + """Get children of the experiment, empty list if no children + + .. note:: + + The instantiation of an EVC tree is lazy, which means accessing the children of a node + may trigger a call to database to build those children live. + + """ + if self._no_children_lookup: + self._children = [] + self._no_children_lookup = False + query = {"refers.parent_id": self.item.id} + selection = {"name": 1, "version": 1} + experiments = self.storage.fetch_experiments(query, selection) + for child in experiments: + self.add_children( + ExperimentNode( + name=child["name"], + version=child.get("version", 1), + storage=self.storage, + ) + ) + + return self._children + + @property + def adapter(self): + """Get the adapter of the experiment with respect to its parent""" + return self.item.refers["adapter"] + + @property + def tree_name(self): + """Return a formatted name of the Node for a tree pretty-print.""" + if self.item is not None: + return f"{self.name}-v{self.item.version}" + + return self.name + + def fetch_lost_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_lost_trials") + + def fetch_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_trials") + + def fetch_pending_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_pending_trials") + + def fetch_noncompleted_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_noncompleted_trials") + + def fetch_trials_by_status(self, status): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_trials_by_status", status=status) + + def recurvise_fetch(self, fun_name, *args, **kwargs): + """Fetch trials recursively in the EVC tree using the fetch function `fun_name`. + + Parameters + ---------- + fun_name: callable + Function name to call to fetch trials. The function must be an attribute of + :class:`orion.core.worker.experiment.Experiment` + + *args: + Positional arguments to pass to `fun_name`. + + **kwargs + Keyword arguments to pass to `fun_name`. + + """ + + def retrieve_trials(node, parent_or_children): + """Retrieve the trials of a node/experiment.""" + fun = getattr(node.item, fun_name) + # with_evc_tree needs to be False here or we will have an infinite loop + trials = fun(*args, with_evc_tree=False, **kwargs) + return dict(trials=trials, experiment=node.item), parent_or_children + + # get the trials of the parents + parent_trials = None + if self.parent is not None: + parent_trials = self.parent.map(retrieve_trials, self.parent.parent) + + # get the trials of the children + children_trials = self.map(retrieve_trials, self.children) + children_trials.set_parent(parent_trials) + + adapt_trials(children_trials) + + return sum((node.item["trials"] for node in children_trials.root), []) + + +def _adapt_parent_trials(node, parent_trials_node, ids): + """Adapt trials from the parent recursively + + .. note:: + + To call with node.map(fct, node.parent) to connect with parents + + """ + # Ids from children are passed to prioritized them if they are also present in parent nodes. + node_ids = { + trial.compute_trial_hash(trial, ignore_lie=True) + for trial in node.item["trials"] + } | ids + if parent_trials_node is not None: + adapter = node.item["experiment"].refers["adapter"] + for parent in parent_trials_node.root: + parent.item["trials"] = adapter.forward(parent.item["trials"]) + + # if trial is in current exp, filter out + parent.item["trials"] = [ + trial + for trial in parent.item["trials"] + if trial.compute_trial_hash( + trial, ignore_lie=True, ignore_experiment=True + ) + not in node_ids + ] + + return node.item, parent_trials_node + + +def _adapt_children_trials(node, children_trials_nodes): + """Adapt trials from the children recursively + + .. note:: + + To call with node.map(fct, node.children) to connect with children + + """ + ids = { + trial.compute_trial_hash(trial, ignore_lie=True) + for trial in node.item["trials"] + } + + for child in children_trials_nodes: + adapter = child.item["experiment"].refers["adapter"] + for subchild in child: # Includes child itself + subchild.item["trials"] = adapter.backward(subchild.item["trials"]) + + # if trial is in current node, filter out + subchild.item["trials"] = [ + trial + for trial in subchild.item["trials"] + if trial.compute_trial_hash( + trial, ignore_lie=True, ignore_experiment=True + ) + not in ids + ] + + return node.item, children_trials_nodes + + +def adapt_trials(trials_tree): + """Adapt trials recursively so that they are all compatible with current experiment.""" + trials_tree.map(_adapt_children_trials, trials_tree.children) + ids = set() + for child in trials_tree.children: + for trial in child.item["trials"]: + ids.add(trial.compute_trial_hash(trial, ignore_lie=True)) + trials_tree.map( + functools.partial(_adapt_parent_trials, ids=ids), trials_tree.parent + ) diff --git a/src/orion/core/io/database/pickleddb.py b/src/orion/core/io/database/pickleddb.py index 2286b928d..f882c18a9 100644 --- a/src/orion/core/io/database/pickleddb.py +++ b/src/orion/core/io/database/pickleddb.py @@ -1,309 +1,310 @@ -""" -Pickled Database -================ - -Implement permanent version of :class:`orion.core.io.database.ephemeraldb.EphemeralDB`. - -""" - -import logging -import os -import pickle -from contextlib import contextmanager -from pickle import PicklingError - -import psutil -from filelock import FileLock, SoftFileLock, Timeout - -import orion.core -from orion.core.io.database import Database, DatabaseTimeout -from orion.core.io.database.ephemeraldb import EphemeralDB - -log = logging.getLogger(__name__) - -DEFAULT_HOST = os.path.join(orion.core.DIRS.user_data_dir, "orion", "orion_db.pkl") - -TIMEOUT_ERROR_MESSAGE = """\ -Could not acquire lock for PickledDB after {} seconds. - -This is likely due to one or many of the following scenarios: - -1. There is a large amount of workers and many simultaneous queries. This typically occurs - when the task to optimize is short (few minutes). Try to reduce the amount of workers - at least below 50. - -2. The database is growing large with thousands of trials and many experiments. - If so, you can use a different PickleDB (different file, that is, different `host`) - for each experiment separately to alleviate this issue. - -3. The filesystem is slow. Parallel filesystems on HPC often suffer from - large pool of users generating frequent I/O. In this case try using a separate - partition that may be less affected. - -If you cannot solve the issues listed above that are causing timeouts, you -may need to setup the MongoDB backend for better performance. -See https://orion.readthedocs.io/en/stable/install/database.html -""" - - -def find_unpickable_doc(dict_of_dict): - """Look for a dictionary that cannot be pickled.""" - for name, collection in dict_of_dict.items(): - documents = collection.find() - - for doc in documents: - try: - pickle.dumps(doc) - - except (PicklingError, AttributeError): - return name, doc - - return None, None - - -def find_unpickable_field(doc): - """Look for a field in a dictionary that cannot be pickled""" - if not isinstance(doc, dict): - doc = doc.to_dict() - - for k, v in doc.items(): - try: - pickle.dumps(v) - - except (PicklingError, AttributeError): - return k, v - - return None, None - - -# pylint: disable=too-many-public-methods -class PickledDB(Database): - """Pickled EphemeralDB to support permanancy and concurrency - - This is a very simple and inefficient implementation of a permanent database on disk for Oríon. - The data is loaded from disk for every operation, and every operation is protected with a - filelock. - - Parameters - ---------- - host: str - File path to save pickled ephemeraldb. Default is {user data dir}/orion/orion_db.pkl ex: - $HOME/.local/share/orion/orion_db.pkl - timeout: int - Maximum number of seconds to wait for the lock before raising DatabaseTimeout. - Default is 60. - - """ - - # pylint: disable=unused-argument - def __init__(self, host="", timeout=60, *args, **kwargs): - if host == "": - host = DEFAULT_HOST - super().__init__(host) - - self.host = os.path.abspath(host) - - self.timeout = timeout - - if os.path.dirname(host): - os.makedirs(os.path.dirname(host), exist_ok=True) - - def __repr__(self) -> str: - return f"{type(self).__qualname__}(host={self.host}, timeout={self.timeout})" - - @property - def is_connected(self): - """Return true, always.""" - return True - - def initiate_connection(self): - """Do nothing""" - - def close_connection(self): - """Do nothing""" - - def ensure_index(self, collection_name, keys, unique=False): - """Create given indexes if they do not already exist in database. - - Indexes are only created if `unique` is True. - """ - with self.locked_database() as database: - database.ensure_index(collection_name, keys, unique=unique) - - def index_information(self, collection_name): - """Return dict of names and sorting order of indexes""" - with self.locked_database(write=False) as database: - return database.index_information(collection_name) - - def drop_index(self, collection_name, name): - """Remove index from the database""" - with self.locked_database() as database: - return database.drop_index(collection_name, name) - - def write(self, collection_name, data, query=None): - """Write new information to a collection. Perform insert or update. - - .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. - - """ - with self.locked_database() as database: - return database.write(collection_name, data, query=query) - - def read(self, collection_name, query=None, selection=None): - """Read a collection and return a value according to the query. - - .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. - - """ - with self.locked_database(write=False) as database: - return database.read(collection_name, query=query, selection=selection) - - def read_and_write(self, collection_name, query, data, selection=None): - """Read a collection's document and update the found document. - - Returns the updated document, or None if nothing found. - - .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for - argument documentation. - - """ - with self.locked_database() as database: - return database.read_and_write( - collection_name, query=query, data=data, selection=selection - ) - - def count(self, collection_name, query=None): - """Count the number of documents in a collection which match the `query`. - - .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. - - """ - with self.locked_database(write=False) as database: - return database.count(collection_name, query=query) - - def remove(self, collection_name, query): - """Delete from a collection document[s] which match the `query`. - - .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. - - """ - with self.locked_database() as database: - return database.remove(collection_name, query=query) - - def _get_database(self): - """Read fresh DB state from pickled file""" - if not os.path.exists(self.host): - return EphemeralDB() - - with open(self.host, "rb") as f: - data = f.read() - if not data: - database = EphemeralDB() - else: - database = pickle.loads(data) - - return database - - def _dump_database(self, database): - """Write pickled DB on disk""" - tmp_file = self.host + ".tmp" - - try: - with open(tmp_file, "wb") as f: - pickle.dump(database, f) - - except (PicklingError, AttributeError): - # pylint: disable=protected-access - collection, doc = find_unpickable_doc(database._db) - log.error( - "Document in (collection: %s) is not pickable\ndoc: %s", - collection, - doc.to_dict() if hasattr(doc, "to_dict") else str(doc), - ) - - key, value = find_unpickable_field(doc) - log.error("because (value %s) in (field: %s) is not pickable", value, key) - raise - - os.rename(tmp_file, self.host) - - @contextmanager - def locked_database(self, write=True): - """Lock database file during wrapped operation call.""" - lock = _create_lock(self.host + ".lock") - - try: - with lock.acquire(timeout=self.timeout): - database = self._get_database() - - yield database - - if write: - self._dump_database(database) - except Timeout as e: - raise DatabaseTimeout(TIMEOUT_ERROR_MESSAGE.format(self.timeout)) from e - - @classmethod - def get_defaults(cls): - """Get database arguments needed to create a database instance. - - .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` - for argument documentation. - - """ - return {"host": DEFAULT_HOST} - - -local_file_systems = ["ext2", "ext3", "ext4", "ntfs"] - - -def _fs_support_globalflock(file_system): - if file_system.fstype == "lustre": - return ("flock" in file_system.opts) and ("localflock" not in file_system.opts) - - elif file_system.fstype == "beegfs": - return "tuneUseGlobalFileLocks" in file_system.opts - - elif file_system.fstype == "gpfs": - return True - - elif file_system.fstype == "nfs": - return False - - return file_system.fstype in local_file_systems - - -def _find_mount_point(path): - """Finds the mount point used to access `path`.""" - path = os.path.abspath(path) - while not os.path.ismount(path): - path = os.path.dirname(path) - - return path - - -def _get_fs(path): - """Gets info about the filesystem on which `path` lives.""" - mount = _find_mount_point(path) - - for file_system in psutil.disk_partitions(True): - if file_system.mountpoint == mount: - return file_system - - return None - - -def _create_lock(path): - """Create lock based on file system capabilities - - Determine if we can rely on the fcntl module for locking files. - Otherwise, fallback on using the directory creation atomicity as a locking mechanism. - """ - file_system = _get_fs(path) - - if _fs_support_globalflock(file_system): - log.debug("Using flock.") - return FileLock(path) - else: - log.debug("Cluster does not support flock. Falling back to softfilelock.") - return SoftFileLock(path) +""" +Pickled Database +================ + +Implement permanent version of :class:`orion.core.io.database.ephemeraldb.EphemeralDB`. + +""" + +import logging +import os +import pickle +from contextlib import contextmanager +from pickle import PicklingError + +import psutil +from filelock import FileLock, SoftFileLock, Timeout + +import orion.core +from orion.core.io.database import Database, DatabaseTimeout +from orion.core.io.database.ephemeraldb import EphemeralDB +from orion.core.utils.compat import replace + +log = logging.getLogger(__name__) + +DEFAULT_HOST = os.path.join(orion.core.DIRS.user_data_dir, "orion", "orion_db.pkl") + +TIMEOUT_ERROR_MESSAGE = """\ +Could not acquire lock for PickledDB after {} seconds. + +This is likely due to one or many of the following scenarios: + +1. There is a large amount of workers and many simultaneous queries. This typically occurs + when the task to optimize is short (few minutes). Try to reduce the amount of workers + at least below 50. + +2. The database is growing large with thousands of trials and many experiments. + If so, you can use a different PickleDB (different file, that is, different `host`) + for each experiment separately to alleviate this issue. + +3. The filesystem is slow. Parallel filesystems on HPC often suffer from + large pool of users generating frequent I/O. In this case try using a separate + partition that may be less affected. + +If you cannot solve the issues listed above that are causing timeouts, you +may need to setup the MongoDB backend for better performance. +See https://orion.readthedocs.io/en/stable/install/database.html +""" + + +def find_unpickable_doc(dict_of_dict): + """Look for a dictionary that cannot be pickled.""" + for name, collection in dict_of_dict.items(): + documents = collection.find() + + for doc in documents: + try: + pickle.dumps(doc) + + except (PicklingError, AttributeError): + return name, doc + + return None, None + + +def find_unpickable_field(doc): + """Look for a field in a dictionary that cannot be pickled""" + if not isinstance(doc, dict): + doc = doc.to_dict() + + for k, v in doc.items(): + try: + pickle.dumps(v) + + except (PicklingError, AttributeError): + return k, v + + return None, None + + +# pylint: disable=too-many-public-methods +class PickledDB(Database): + """Pickled EphemeralDB to support permanancy and concurrency + + This is a very simple and inefficient implementation of a permanent database on disk for Oríon. + The data is loaded from disk for every operation, and every operation is protected with a + filelock. + + Parameters + ---------- + host: str + File path to save pickled ephemeraldb. Default is {user data dir}/orion/orion_db.pkl ex: + $HOME/.local/share/orion/orion_db.pkl + timeout: int + Maximum number of seconds to wait for the lock before raising DatabaseTimeout. + Default is 60. + + """ + + # pylint: disable=unused-argument + def __init__(self, host="", timeout=60, *args, **kwargs): + if host == "": + host = DEFAULT_HOST + super().__init__(host) + + self.host = os.path.abspath(host) + + self.timeout = timeout + + if os.path.dirname(host): + os.makedirs(os.path.dirname(host), exist_ok=True) + + def __repr__(self) -> str: + return f"{type(self).__qualname__}(host={self.host}, timeout={self.timeout})" + + @property + def is_connected(self): + """Return true, always.""" + return True + + def initiate_connection(self): + """Do nothing""" + + def close_connection(self): + """Do nothing""" + + def ensure_index(self, collection_name, keys, unique=False): + """Create given indexes if they do not already exist in database. + + Indexes are only created if `unique` is True. + """ + with self.locked_database() as database: + database.ensure_index(collection_name, keys, unique=unique) + + def index_information(self, collection_name): + """Return dict of names and sorting order of indexes""" + with self.locked_database(write=False) as database: + return database.index_information(collection_name) + + def drop_index(self, collection_name, name): + """Remove index from the database""" + with self.locked_database() as database: + return database.drop_index(collection_name, name) + + def write(self, collection_name, data, query=None): + """Write new information to a collection. Perform insert or update. + + .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. + + """ + with self.locked_database() as database: + return database.write(collection_name, data, query=query) + + def read(self, collection_name, query=None, selection=None): + """Read a collection and return a value according to the query. + + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. + + """ + with self.locked_database(write=False) as database: + return database.read(collection_name, query=query, selection=selection) + + def read_and_write(self, collection_name, query, data, selection=None): + """Read a collection's document and update the found document. + + Returns the updated document, or None if nothing found. + + .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for + argument documentation. + + """ + with self.locked_database() as database: + return database.read_and_write( + collection_name, query=query, data=data, selection=selection + ) + + def count(self, collection_name, query=None): + """Count the number of documents in a collection which match the `query`. + + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. + + """ + with self.locked_database(write=False) as database: + return database.count(collection_name, query=query) + + def remove(self, collection_name, query): + """Delete from a collection document[s] which match the `query`. + + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. + + """ + with self.locked_database() as database: + return database.remove(collection_name, query=query) + + def _get_database(self): + """Read fresh DB state from pickled file""" + if not os.path.exists(self.host): + return EphemeralDB() + + with open(self.host, "rb") as f: + data = f.read() + if not data: + database = EphemeralDB() + else: + database = pickle.loads(data) + + return database + + def _dump_database(self, database): + """Write pickled DB on disk""" + tmp_file = self.host + ".tmp" + + try: + with open(tmp_file, "wb") as f: + pickle.dump(database, f) + + except (PicklingError, AttributeError): + # pylint: disable=protected-access + collection, doc = find_unpickable_doc(database._db) + log.error( + "Document in (collection: %s) is not pickable\ndoc: %s", + collection, + doc.to_dict() if hasattr(doc, "to_dict") else str(doc), + ) + + key, value = find_unpickable_field(doc) + log.error("because (value %s) in (field: %s) is not pickable", value, key) + raise + + replace(tmp_file, self.host) + + @contextmanager + def locked_database(self, write=True): + """Lock database file during wrapped operation call.""" + lock = _create_lock(self.host + ".lock") + + try: + with lock.acquire(timeout=self.timeout): + database = self._get_database() + + yield database + + if write: + self._dump_database(database) + except Timeout as e: + raise DatabaseTimeout(TIMEOUT_ERROR_MESSAGE.format(self.timeout)) from e + + @classmethod + def get_defaults(cls): + """Get database arguments needed to create a database instance. + + .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` + for argument documentation. + + """ + return {"host": DEFAULT_HOST} + + +local_file_systems = ["ext2", "ext3", "ext4", "ntfs"] + + +def _fs_support_globalflock(file_system): + if file_system.fstype == "lustre": + return ("flock" in file_system.opts) and ("localflock" not in file_system.opts) + + elif file_system.fstype == "beegfs": + return "tuneUseGlobalFileLocks" in file_system.opts + + elif file_system.fstype == "gpfs": + return True + + elif file_system.fstype == "nfs": + return False + + return file_system.fstype in local_file_systems + + +def _find_mount_point(path): + """Finds the mount point used to access `path`.""" + path = os.path.abspath(path) + while not os.path.ismount(path): + path = os.path.dirname(path) + + return path + + +def _get_fs(path): + """Gets info about the filesystem on which `path` lives.""" + mount = _find_mount_point(path) + + for file_system in psutil.disk_partitions(True): + if file_system.mountpoint == mount: + return file_system + + return None + + +def _create_lock(path): + """Create lock based on file system capabilities + + Determine if we can rely on the fcntl module for locking files. + Otherwise, fallback on using the directory creation atomicity as a locking mechanism. + """ + file_system = _get_fs(path) + + if _fs_support_globalflock(file_system): + log.debug("Using flock.") + return FileLock(path) + else: + log.debug("Cluster does not support flock. Falling back to softfilelock.") + return SoftFileLock(path) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index 61cf7ba23..67770f0ca 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -1,30 +1,53 @@ """Windows compatibility utilities""" import os +import time def getuser(): """getpass use pwd which is UNIX only""" - if os.name == 'nt': + if os.name == "nt": return os.getlogin() import getpass + return getpass.getuser() class _readline: def set_completer_delims(*args, **kwargs): """Fake method for windows""" - pass def get_readline(): """Fake readline interface, readline is UNIX only""" - if os.name == 'nt': + if os.name == "nt": return _readline import readline + return readline readline = get_readline() + + +def replace(old, new, tries=3, sleep=0.01): + """Windows file replacing is more strict than linux""" + if os.name != "nt": + os.replace(old, new) + return + + # if the file is open already windows will raise permission error + # even if the lock was free + exception = None + for _ in range(tries): + try: + os.replace(old, new) + return + except PermissionError as exc: + time.sleep(sleep) + exception = exc + + if exception: + raise exception diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 0ab682803..369b26583 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -60,7 +60,7 @@ class User(Base): __tablename__ = "users" _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30)) + name = Column(String(30), unique=True) token = Column(String(32)) created_at = Column(DateTime) last_seen = Column(DateTime) @@ -79,6 +79,7 @@ class Experiment(Base): algorithms = Column(JSON) remaining = Column(JSON) space = Column(JSON) + parent_id = Column(Integer) __table_args__ = ( UniqueConstraint('name', 'owner_id', name='_one_name_per_owner'), @@ -101,7 +102,7 @@ class Trial(Base): parent = Column(Integer, ForeignKey("trials._id"), nullable=True) params = Column(JSON) worker = Column(JSON) - submit_time = Column(String(30)) + submit_time = Column(DateTime) exp_working_dir = Column(String(30)) id = Column(String(30)) @@ -174,7 +175,7 @@ def __init__(self, uri, token=None, **kwargs): # engine_from_config self.engine = sqlalchemy.create_engine( uri, - echo=True, + echo=False, future=True, json_serializer=to_json, json_deserializer=from_json, @@ -189,29 +190,44 @@ def __init__(self, uri, token=None, **kwargs): self._connect(token) def _connect(self, token): + name = getuser() + + user = self._find_user(name, token) + + if user is None: + user = self._create_user(name) + + assert user is not None + + self.user_id = user._id + self.user = user + self.token = user.token + + def _find_user(self, name, token) -> User: + query = [User.name == name] if token is not None and token != "": - with Session(self.engine) as session: - stmt = select(User).where(User.token == self.token) - self.user = session.scalars(stmt).one() + query.append(User.token == token) - self.user_id = self.user._id - else: - # Local database, create a default user - user = getuser() - now = datetime.datetime.utcnow() + with Session(self.engine) as session: + stmt = select(User).where(*query) - with Session(self.engine) as session: - self.user = User( - name=user, - token=uuid.uuid5(uuid.NAMESPACE_OID, user).hex, - created_at=now, - last_seen=now, - ) - session.add(self.user) - session.commit() + return session.execute(stmt).scalar() + + def _create_user(self, name) -> User: + now = datetime.datetime.utcnow() - assert self.user._id > 0 - self.user_id = self.user._id + with Session(self.engine) as session: + user = User( + name=name, + token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, + created_at=now, + last_seen=now, + ) + session.add(user) + session.commit() + + assert user._id > 0 + return user def __getstate__(self): return dict( @@ -244,6 +260,11 @@ def create_experiment(self, config): version=0, ) + if "refers" in config: + ref = config.get("refers") + if "parent_id" in ref: + config["parent_id"] = ref.pop("parent_id") + cpy["meta"] = cpy.pop("metadata") self._set_from_dict(experiment, cpy, "remaining") @@ -252,6 +273,9 @@ def create_experiment(self, config): session.refresh(experiment) config.update(self._to_experiment(experiment)) + + # Alreadyc reate the algo lock as well + self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) except DBAPIError: raise DuplicateKeyError() @@ -268,6 +292,9 @@ def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" uid = get_uid(experiment, uid) + if where and "refers.parent_id" in where: + where["parent_id"] = where.pop("refers.parent_id") + where = self._get_query(where) if uid is not None: @@ -308,6 +335,9 @@ def _fetch_experiments_with_select(self, query, selection=None): def fetch_experiments(self, query, selection=None): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" + if "refers.parent_id" in query: + query["parent_id"] = query.pop("refers.parent_id") + if selection: return self._fetch_experiments_with_select(query, selection) @@ -320,7 +350,6 @@ def fetch_experiments(self, query, selection=None): experiments = session.scalars(stmt).all() r = [self._to_experiment(exp) for exp in experiments] - print("RESULT", r) return r # Benchmarks @@ -332,16 +361,18 @@ def fetch_trials(self, experiment=None, uid=None, where=None): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" uid = get_uid(experiment, uid) - where = self._get_query(where) + query = self._get_query(where) if uid is not None: - where["experiment_id"] = uid + query["experiment_id"] = uid - query = self._to_query(Trial, where) + query = self._to_query(Trial, query) with Session(self.engine) as session: stmt = select(Trial).where(*query) - return session.scalars(stmt).all() + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(t)) for t in results] def register_trial(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" @@ -359,6 +390,7 @@ def register_trial(self, trial): session.commit() session.refresh(db_trial) + trial.id_override = db_trial._id return OrionTrial(**self._to_trial(db_trial)) except DBAPIError: @@ -384,8 +416,6 @@ def delete_trials(self, experiment=None, uid=None, where=None): def retrieve_result(self, trial, **kwargs): """Updates the results array""" - new_trial = self.get_trial(trial) - trial.results = new_trial.results return trial def get_trial(self, trial=None, uid=None, experiment_uid=None): @@ -440,7 +470,7 @@ def update_trial( self._set_from_dict(trial, kwargs) session.commit() - return trial + return OrionTrial(*self._to_trial(trial)) def fetch_lost_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" @@ -455,21 +485,32 @@ def fetch_lost_trials(self, experiment): Trial.status == "reserved", Trial.heartbeat < threshold, ) - return session.scalars(stmt).all() + results = session.scalars(stmt).all() + + return [OrionTrial(*self._to_trial(t)) for t in results] def push_trial_results(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" + + log.debug("push trial to storage") + original = trial + config = trial.to_dict() + + # Don't need to set that one + config.pop("experiment") + with Session(self.engine) as session: stmt = select(Trial).where( - Trial.experiment_id == trial.experiment, - Trial._id == trial.id, + # Trial.experiment_id == trial.experiment, + # Trial.id == trial.id, + Trial._id == trial.id_override, Trial.status == "reserved", ) trial = session.scalars(stmt).one() - self._set_from_dict(trial, trial.to_dict()) + self._set_from_dict(trial, config) session.commit() - return trial + return original def set_trial_status(self, trial, status, heartbeat=None, was=None): """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" @@ -506,7 +547,10 @@ def fetch_pending_trials(self, experiment): Trial.status.in_(("interrupted", "new", "suspended")), Trial.experiment_id == experiment._id, ) - return session.scalars(stmt).all() + results = session.scalars(stmt).all() + trials = OrionTrial.build([self._to_trial(t) for t in results]) + + return trials def _reserve_trial_postgre(self, experiment): now = datetime.datetime.utcnow() @@ -528,13 +572,14 @@ def _reserve_trial_postgre(self, experiment): .returning() ) trial = session.scalar(stmt) - return trial + return OrionTrial(**self._to_trial(trial)) def reserve_trial(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" if False: return self._reserve_trial_postgre(experiment) + log.debug("reserve trial") now = datetime.datetime.utcnow() with Session(self.engine) as session: @@ -542,6 +587,7 @@ def reserve_trial(self, experiment): Trial.status.in_(("interrupted", "new", "suspended")), Trial.experiment_id == experiment._id, ) + try: trial = session.scalars(stmt).one() except NoResultFound: @@ -552,7 +598,7 @@ def reserve_trial(self, experiment): update(Trial) .where( Trial.status == trial.status, - Trial.experiment_id == experiment._id, + Trial._id == trial._id, ) .values( status="reserved", @@ -561,14 +607,13 @@ def reserve_trial(self, experiment): ) ) - session.execute(stmt) - - stmt = select(Trial).where(Trial.experiment_id == experiment._id) - trial = session.scalars(stmt).one() + result = session.execute(stmt) # time needs to match, could have been reserved by another worker - if trial.status == "reserved" and trial.heartbeat == now: - return trial + if result.rowcount == 1: + session.commit() + session.refresh(trial) + return OrionTrial(**self._to_trial(trial)) return None @@ -576,12 +621,11 @@ def fetch_trials_by_status(self, experiment, status): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" with Session(self.engine) as session: stmt = select(Trial).where( - Trial.status == status and Trial.experiment_id == experiment._id + Trial.status == status, Trial.experiment_id == experiment._id ) - return [ - OrionTrial(**self._to_trial(trial)) - for trial in session.scalars(stmt).all() - ] + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(trial)) for trial in results] def fetch_noncompleted_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" @@ -590,7 +634,9 @@ def fetch_noncompleted_trials(self, experiment): Trial.status != "completed", Trial.experiment_id == experiment._id, ) - return session.scalars(stmt).all() + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(trial)) for trial in results] def count_completed_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" @@ -642,7 +688,7 @@ def initialize_algorithm_lock(self, experiment_id, algorithm_config): with Session(self.engine) as session: algo = Algo( experiment_id=experiment_id, - owner_id=self.user._id, + owner_id=self.user_id, configuration=algorithm_config, locked=0, heartbeat=datetime.datetime.utcnow(), @@ -720,7 +766,7 @@ def _acquire_algorithm_lock_postgre( return algo def _acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=60, retry_interval=1 + self, experiment=None, uid=None, timeout=1, retry_interval=1 ): uid = get_uid(experiment, uid) algo_state_lock = None @@ -772,12 +818,14 @@ def acquire_algorithm_lock( ) try: + log.debug("lock algo") yield locked_algo_state except Exception: # Reset algo to state fetched lock time locked_algo_state.reset() raise finally: + log.debug("unlock algo") uid = get_uid(experiment, uid) self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) @@ -819,6 +867,7 @@ def _set_from_dict(self, obj, data, rest=None): if meta: log.warning("Data was discarded %s", meta) + assert False def _to_query(self, table, where): query = [] diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index 432601c32..37c7d0e74 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -1,395 +1,426 @@ -#!/usr/bin/env python -"""Perform a stress tests on python API.""" -import os -import random -import time -import traceback -from collections import OrderedDict -from multiprocessing import Pool - -import matplotlib.pyplot as plt -from pymongo import MongoClient - -from orion.client import create_experiment -from orion.core.io.database import DatabaseTimeout -from orion.core.utils.exceptions import ReservationTimeout - -DB_FILE = "stress.pkl" -SQLITE_FILE = "db.sqlite" - -ADDRESS = "192.168.0.16" - -# -# Create the stress test user -# -# MongoDB -# -# mongosh -# > use admin -# > db.createUser({ -# user: "user", -# pwd: "pass", -# roles: [ -# {role: 'readWrite', db: 'stress'}, -# ] -# }) -# -# PostgreSQL -- DO NOT USE THIS IN PROD - TESTING ONLY -# -# # Switch to the user running the database -# sudo su postgres -# -# # open an interactive connection to the server -# psql -# > CREATE USER username WITH PASSWORD 'pass'; -# > CREATE ROLE orion_database_admin; -# > CREATE ROLE orion_database_user LOGIN; -# > GRANT orion_database_user, orion_database_user TO username; -# > -# > GRANT pg_write_all_data, pg_read_all_data TO username; -# > CREATE DATABASE stress OWNER orion_database_admin; -# \q -# -# > - - -BACKENDS_CONFIGS = OrderedDict( - [ - # ("pickleddb", {"type": "legacy", "database": {"type": "pickleddb", "host": DB_FILE}}), - # ("sqlite", {"type": "sqlalchemy", "uri": f"sqlite:///{SQLITE_FILE}"}), - ( - "postgresql", - { - "type": "sqlalchemy", - "uri": f"postgresql://username:pass@{ADDRESS}/stress", - }, - ), - ( - "mongodb", - { - "type": "legacy", - "database": { - "type": "mongodb", - "name": "stress", - "host": f"mongodb://user:pass@{ADDRESS}", - }, - }, - ), - ] -) - - -def cleanup_storage(backend): - if backend == "pickleddb": - if os.path.exists(DB_FILE): - os.remove(DB_FILE) - - elif backend == "sqlite": - if os.path.exists(SQLITE_FILE): - os.remove(SQLITE_FILE) - - elif backend == "postgresql": - import sqlalchemy - from sqlalchemy.orm import Session - - from orion.storage.sql import get_tables - - engine = sqlalchemy.create_engine( - f"postgresql://username:pass@{ADDRESS}/stress", - echo=True, - future=True, - ) - - # if the tables are missing just skip - for table in get_tables(): - try: - with Session(engine) as session: - session.execute(f"DROP TABLE {table.__tablename__} CASCADE;") - session.commit() - except: - traceback.print_exc() - - elif backend == "mongodb": - client = MongoClient( - host=ADDRESS, username="user", password="pass", authSource="stress" - ) - database = client.stress - database.experiments.drop() - database.lying_trials.drop() - database.trials.drop() - database.workers.drop() - database.resources.drop() - client.close() - - else: - raise RuntimeError("You need to cleam your backend") - - -def f(x, worker): - """Sleep and return objective equal to param""" - print(f"{worker: 6d} {x: 5f}") - time.sleep(max(0, random.gauss(1, 0.2))) - return [dict(name="objective", value=x, type="objective")] - - -def get_experiment(storage, space_type, size): - """Create an experiment or load from DB if already existing - - Parameters - ---------- - storage: str - Can be `pickleddb` or `mongodb`. A default configuration is used for each. - space_type: str - Can be one of - - `discrete` Search space is discrete and limited to `max_trials` - - `real-seeded` Search space is continuous and algos is seeded, leading to many race - conditions while algos are sampling the same points in parallel, or - - `real` Search space is real and algo is not seeded, leading to very few race conditions. - size: int - This defines `max_trials`, and the size of the search space (`uniform(0, size)`). - - """ - storage_config = BACKENDS_CONFIGS[storage] - - discrete = space_type == "discrete" - high = size # * 2 - - return create_experiment( - "stress-test", - space={"x": f"uniform(0, {high}, discrete={discrete})"}, - max_trials=size, - max_idle_time=60 * 5, - algorithms={"random": {"seed": None if space_type == "real" else 1}}, - storage=storage_config, - ) - - -def worker(worker_id, storage, space_type, size): - """Run trials until experiment is done - - Parameters - ---------- - worker_id: int - ID of the worker. This is used to distinguish logs from different workers. - storage: str - See `get_experiment`. - space_type: str - See `get_experiment`. - size: int - See `get_experiment`. - - """ - try: - experiment = get_experiment(storage, space_type, size) - - assert experiment.version == 1, experiment.version - - print(f"{worker_id: 6d} enters") - - num_trials = 0 - while not experiment.is_done: - try: - trial = experiment.suggest() - except ReservationTimeout: - trial - None - - if trial is None: - break - - results = f(trial.params["x"], worker_id) - num_trials += 1 - experiment.observe(trial, results=results) - - print(f"{worker_id: 6d} leaves | is done? {experiment.is_done}") - except DatabaseTimeout as e: - print(f"{worker_id: 6d} timeouts and leaves") - return num_trials - except Exception as e: - print(f"{worker_id: 6d} crashes") - traceback.print_exc() - return None - - return num_trials - - -def stress_test(storage, space_type, workers, size): - """Spawn workers and run stress test with verifications - - Parameters - ---------- - storage: str - See `get_experiment`. - space_type: str - See `get_experiment`. - workers: int - Number of workers to run in parallel. - size: int - See `get_experiment`. - - Returns - ------- - `list` of `orion.core.worker.trial.Trial` - List of all trials at the end of the stress test - - """ - cleanup_storage(storage) - - print("Worker | Point") - - with Pool(workers) as p: - results = p.starmap( - worker, - zip( - range(workers), - [storage] * workers, - [space_type] * workers, - [size] * workers, - ), - ) - - assert ( - None not in results - ), "A worker crashed unexpectedly. See logs for the error messages." - assert all(n > 0 for n in results), "A worker could not execute any trial." - - if space_type in ["discrete", "real-seeded"]: - assert sum(results) == size, results - else: - assert sum(results) >= size, results - - experiment = get_experiment(storage, space_type, size) - - trials = experiment.fetch_trials() - - cleanup_storage(storage) - return trials - - -def get_timestamps(trials, size, space_type): - """Get start timestamps of the trials - - Parameters - ---------- - trials: `list` of `orion.core.worker.trial.Trial` - List of all trials at the end of the stress test - space_type: str - See `get_experiment`. - size: int - See `get_experiment`. - - Returns - ------- - (`list`, `list`) - Where rval[0] is start timestamp and rval[1] is the index of the trial. - For instance the i-th trial timestamp is rval[0][rval[1].index(i)]. - - """ - hparams = set() - x = [] - y = [] - - start_time = None - for i, trial in enumerate(trials): - hparams.add(trial.params["x"]) - assert trial.objective.value == trial.params["x"] - if start_time is None: - start_time = trial.submit_time - x.append((trial.submit_time - start_time).total_seconds()) - y.append(i) - - if space_type in ["discrete", "real-seeded"]: - assert len(hparams) == size - else: - assert len(hparams) >= size - - return x[:size], y[:size] - - -def benchmark(workers, size): - """Get start timestamps of the trials - - Parameters - ---------- - workers: int - see: `stress_test`. - size: int - See `get_experiment`. - - Returns - ------- - dict - Dictionary containing all results of all stress tests. - Each key is (backend, space_type). See `get_experiment` for the supported types - of `backend`s and `space_type`s. Each values result[(backend, space_type)] is - in the form of a (x, y) tuple, where x a the list start timestamps and y is the indexes of - the trials. See `get_timestamps` for more details. - - """ - results = {} - for backend in BACKENDS_CONFIGS.keys(): - for space_type in ["discrete", "real", "real-seeded"]: - print(backend, space_type) - - trials = stress_test(backend, space_type, workers, size) - results[(backend, space_type)] = get_timestamps(trials, size, space_type) - - return results - - -def main(): - """Run all stress tests and render the plot""" - size = 500 - - num_workers = [1, 4, 16, 32, 64, 128] - - fig, axis = plt.subplots( - len(num_workers), - 1, - figsize=(5, 1.8 * len(num_workers)), - gridspec_kw={"hspace": 0.01, "wspace": 0}, - sharex="col", - ) - - results = {} - - for i, workers in enumerate(num_workers): - - results[workers] = benchmark(workers, size) - - for backend in BACKENDS_CONFIGS.keys(): - for space_type in ["discrete", "real", "real-seeded"]: - x, y = results[workers][(backend, space_type)] - axis[i].plot(x, y, label=f"{backend}-{space_type}") - - for i, workers in enumerate(num_workers): - # We pick 'pickleddb' and discrete=True as the reference for the slowest ones - x, y = results[min(num_workers)][("pickleddb", "discrete")] - d_x = max(x) - min(x) - d_y = max(y) - min(y) - if i < len(num_workers) - 1: - axis[i].text( - min(x) + d_x * 0.6, min(y) + d_y * 0.1, f"{workers: 3d} workers" - ) - else: - axis[i].text( - min(x) + d_x * 0.6, min(y) + d_y * 0.7, f"{workers: 3d} workers" - ) - - for i in range(len(num_workers) - 1): - axis[i].spines["top"].set_visible(False) - axis[i].spines["right"].set_visible(False) - - axis[-1].spines["right"].set_visible(False) - axis[-1].spines["top"].set_visible(False) - - axis[-1].set_xlabel("Time (s)") - axis[-1].set_ylabel("Number of trials") - axis[-1].legend() - - plt.subplots_adjust(left=0.15, bottom=0.05, top=1, right=1) - - plt.savefig("test.png") - - -if __name__ == "__main__": - main() +#!/usr/bin/env python +"""Perform a stress tests on python API.""" +import logging +import os +import random +import time +import traceback +from collections import OrderedDict +from contextlib import contextmanager +from multiprocessing import Pool + +import matplotlib.pyplot as plt +from pymongo import MongoClient + +from orion.client import create_experiment +from orion.core.io.database import DatabaseTimeout +from orion.core.utils.exceptions import ReservationTimeout + +DB_FILE = "stress.pkl" +SQLITE_FILE = "db.sqlite" + +ADDRESS = "192.168.0.16" + +NUM_TRIALS = 500 + +NUM_WORKERS = [1, 4, 16, 32, 64, 128] + +LOG_LEVEL = logging.WARNING + +# +# Create the stress test user +# +# MongoDB +# +# mongosh +# > use admin +# > db.createUser({ +# user: "user", +# pwd: "pass", +# roles: [ +# {role: 'readWrite', db: 'stress'}, +# ] +# }) +# +# PostgreSQL -- DO NOT USE THIS IN PROD - TESTING ONLY +# +# # Switch to the user running the database +# sudo su postgres +# +# # open an interactive connection to the server +# psql +# > CREATE USER username WITH PASSWORD 'pass'; +# > CREATE ROLE orion_database_admin; +# > CREATE ROLE orion_database_user LOGIN; +# > GRANT orion_database_user, orion_database_user TO username; +# > +# > GRANT pg_write_all_data, pg_read_all_data TO username; +# > CREATE DATABASE stress OWNER orion_database_admin; +# \q +# +# > \l # list all the database +# > \c stress # Use the datatabase +# > select * from experiments; + + +BACKENDS_CONFIGS = OrderedDict( + [ + ( + "pickleddb", + {"type": "legacy", "database": {"type": "pickleddb", "host": DB_FILE}}, + ), + ("sqlite", {"type": "sqlalchemy", "uri": f"sqlite:///{SQLITE_FILE}"}), + # ( + # "postgresql", + # { + # "type": "sqlalchemy", + # "uri": f"postgresql://username:pass@{ADDRESS}/stress", + # }, + # ), + # ( + # "mongodb", + # { + # "type": "legacy", + # "database": { + # "type": "mongodb", + # "name": "stress", + # "host": f"mongodb://user:pass@{ADDRESS}", + # }, + # }, + # ), + ] +) + + +def cleanup_storage(backend): + if backend == "pickleddb": + if os.path.exists(DB_FILE): + os.remove(DB_FILE) + + elif backend == "sqlite": + if os.path.exists(SQLITE_FILE): + os.remove(SQLITE_FILE) + + elif backend == "postgresql": + import sqlalchemy + from sqlalchemy.orm import Session + + from orion.storage.sql import get_tables + + engine = sqlalchemy.create_engine( + f"postgresql://username:pass@{ADDRESS}/stress", + echo=True, + future=True, + ) + + # if the tables are missing just skip + for table in get_tables(): + try: + with Session(engine) as session: + session.execute(f"DROP TABLE {table.__tablename__} CASCADE;") + session.commit() + except: + traceback.print_exc() + + elif backend == "mongodb": + client = MongoClient( + host=ADDRESS, username="user", password="pass", authSource="stress" + ) + database = client.stress + database.experiments.drop() + database.lying_trials.drop() + database.trials.drop() + database.workers.drop() + database.resources.drop() + client.close() + + else: + raise RuntimeError("You need to cleam your backend") + + +def f(x, worker): + """Sleep and return objective equal to param""" + time.sleep(max(0, random.gauss(1, 0.2))) + return [dict(name="objective", value=x, type="objective")] + + +def get_experiment(storage, space_type, size): + """Create an experiment or load from DB if already existing + + Parameters + ---------- + storage: str + Can be `pickleddb` or `mongodb`. A default configuration is used for each. + space_type: str + Can be one of + - `discrete` Search space is discrete and limited to `max_trials` + - `real-seeded` Search space is continuous and algos is seeded, leading to many race + conditions while algos are sampling the same points in parallel, or + - `real` Search space is real and algo is not seeded, leading to very few race conditions. + size: int + This defines `max_trials`, and the size of the search space (`uniform(0, size)`). + + """ + storage_config = BACKENDS_CONFIGS[storage] + + discrete = space_type == "discrete" + high = size # * 2 + + return create_experiment( + "stress-test", + space={"x": f"uniform(0, {high}, discrete={discrete})"}, + max_trials=size, + max_idle_time=60 * 5, + algorithms={"random": {"seed": None if space_type == "real" else 1}}, + storage=storage_config, + ) + + +def worker(worker_id, storage, space_type, size): + """Run trials until experiment is done + + Parameters + ---------- + worker_id: int + ID of the worker. This is used to distinguish logs from different workers. + storage: str + See `get_experiment`. + space_type: str + See `get_experiment`. + size: int + See `get_experiment`. + + """ + try: + experiment = get_experiment(storage, space_type, size) + + assert experiment.version == 1, experiment.version + + print(f"{worker_id: 6d} enters") + + num_trials = 0 + while not experiment.is_done: + try: + trial = experiment.suggest() + except ReservationTimeout: + trial - None + + if trial is None: + break + + x = trial.params["x"] + results = f(x, worker_id) + + num_trials += 1 + print(f" - {worker_id: 6d} {num_trials: 6d} {x: 5f}") + experiment.observe(trial, results=results) + + print(f"{worker_id: 6d} leaves | is done? {experiment.is_done}") + except DatabaseTimeout as e: + print(f"{worker_id: 6d} timeouts and leaves") + return num_trials + except Exception as e: + print(f"{worker_id: 6d} crashes") + traceback.print_exc() + return None + + return num_trials + + +@contextmanager +def always_clean(storage): + cleanup_storage(storage) + yield + # cleanup_storage(storage) + + +def stress_test(storage, space_type, workers, size): + """Spawn workers and run stress test with verifications + + Parameters + ---------- + storage: str + See `get_experiment`. + space_type: str + See `get_experiment`. + workers: int + Number of workers to run in parallel. + size: int + See `get_experiment`. + + Returns + ------- + `list` of `orion.core.worker.trial.Trial` + List of all trials at the end of the stress test + + """ + print("Worker | Point") + + with Pool(workers) as p: + results = p.starmap( + worker, + zip( + range(workers), + [storage] * workers, + [space_type] * workers, + [size] * workers, + ), + ) + + assert ( + None not in results + ), "A worker crashed unexpectedly. See logs for the error messages." + assert all(n > 0 for n in results), "A worker could not execute any trial." + + if space_type in ["discrete", "real-seeded"]: + assert sum(results) == size, results + else: + assert sum(results) >= size, results + + experiment = get_experiment(storage, space_type, size) + + trials = experiment.fetch_trials() + + return trials + + +def get_timestamps(trials, size, space_type): + """Get start timestamps of the trials + + Parameters + ---------- + trials: `list` of `orion.core.worker.trial.Trial` + List of all trials at the end of the stress test + space_type: str + See `get_experiment`. + size: int + See `get_experiment`. + + Returns + ------- + (`list`, `list`) + Where rval[0] is start timestamp and rval[1] is the index of the trial. + For instance the i-th trial timestamp is rval[0][rval[1].index(i)]. + + """ + hparams = set() + x = [] + y = [] + + start_time = None + for i, trial in enumerate(trials): + hparams.add(trial.params["x"]) + + assert trial.objective.value == trial.params["x"] + + if start_time is None: + start_time = trial.submit_time + + x.append((trial.submit_time - start_time).total_seconds()) + y.append(i) + + if space_type in ["discrete", "real-seeded"]: + assert len(hparams) == size, f"{len(hparams)} == {size}" + else: + assert len(hparams) >= size + + return x[:size], y[:size] + + +def benchmark(workers, size): + """Get start timestamps of the trials + + Parameters + ---------- + workers: int + see: `stress_test`. + size: int + See `get_experiment`. + + Returns + ------- + dict + Dictionary containing all results of all stress tests. + Each key is (backend, space_type). See `get_experiment` for the supported types + of `backend`s and `space_type`s. Each values result[(backend, space_type)] is + in the form of a (x, y) tuple, where x a the list start timestamps and y is the indexes of + the trials. See `get_timestamps` for more details. + + """ + results = {} + for backend in BACKENDS_CONFIGS.keys(): + for space_type in ["discrete", "real", "real-seeded"]: + print(backend, space_type) + + # Initialize the storage once before parallel work + get_experiment(backend, space_type, size) + + with always_clean(backend): + trials = stress_test(backend, space_type, workers, size) + + results[(backend, space_type)] = get_timestamps( + trials, size, space_type + ) + + return results + + +def main(): + """Run all stress tests and render the plot""" + size = NUM_TRIALS + + logging.basicConfig(level=LOG_LEVEL) + + num_workers = NUM_WORKERS + + fig, axis = plt.subplots( + len(num_workers), + 1, + figsize=(5, 1.8 * len(num_workers)), + gridspec_kw={"hspace": 0.01, "wspace": 0}, + sharex="col", + ) + + results = {} + + for i, workers in enumerate(num_workers): + + results[workers] = benchmark(workers, size) + + for backend in BACKENDS_CONFIGS.keys(): + for space_type in ["discrete", "real", "real-seeded"]: + x, y = results[workers][(backend, space_type)] + axis[i].plot(x, y, label=f"{backend}-{space_type}") + + for i, workers in enumerate(num_workers): + # We pick 'pickleddb' and discrete=True as the reference for the slowest ones + x, y = results[min(num_workers)][("pickleddb", "discrete")] + d_x = max(x) - min(x) + d_y = max(y) - min(y) + if i < len(num_workers) - 1: + axis[i].text( + min(x) + d_x * 0.6, min(y) + d_y * 0.1, f"{workers: 3d} workers" + ) + else: + axis[i].text( + min(x) + d_x * 0.6, min(y) + d_y * 0.7, f"{workers: 3d} workers" + ) + + for i in range(len(num_workers) - 1): + axis[i].spines["top"].set_visible(False) + axis[i].spines["right"].set_visible(False) + + axis[-1].spines["right"].set_visible(False) + axis[-1].spines["top"].set_visible(False) + + axis[-1].set_xlabel("Time (s)") + axis[-1].set_ylabel("Number of trials") + axis[-1].legend() + + plt.subplots_adjust(left=0.15, bottom=0.05, top=1, right=1) + + plt.savefig("test.png") + + +if __name__ == "__main__": + main() From a0bbe077ce78327167549a8645223953f171086c Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 31 Aug 2022 16:41:52 -0400 Subject: [PATCH 07/25] Fix stress test --- src/orion/core/utils/compat.py | 9 +- src/orion/storage/sql.py | 44 +++++---- tests/stress/client/stress_experiment.py | 117 +++++++++++++++++------ 3 files changed, 121 insertions(+), 49 deletions(-) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index 67770f0ca..f24a7b932 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -35,11 +35,16 @@ def get_readline(): def replace(old, new, tries=3, sleep=0.01): """Windows file replacing is more strict than linux""" if os.name != "nt": - os.replace(old, new) + # Rename on UNIX is practically atomic + # so we use that + os.rename(old, new) return + # Rename raise an exception on windows if the file exists + # so we have to use replace + # # if the file is open already windows will raise permission error - # even if the lock was free + # even if the lock was free, waiting a bit usually fix the issue exception = None for _ in range(tries): try: diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 369b26583..c5dc247da 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -182,7 +182,12 @@ def __init__(self, uri, token=None, **kwargs): ) # Create the schema - Base.metadata.create_all(self.engine) + # sqlite3 can fail on table if it already exist + # the doc says it shouldnt but it does + try: + Base.metadata.create_all(self.engine) + except DBAPIError: + pass self.token = token self.user_id = None @@ -214,20 +219,23 @@ def _find_user(self, name, token) -> User: return session.execute(stmt).scalar() def _create_user(self, name) -> User: - now = datetime.datetime.utcnow() + try: + now = datetime.datetime.utcnow() - with Session(self.engine) as session: - user = User( - name=name, - token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, - created_at=now, - last_seen=now, - ) - session.add(user) - session.commit() + with Session(self.engine) as session: + user = User( + name=name, + token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, + created_at=now, + last_seen=now, + ) + session.add(user) + session.commit() - assert user._id > 0 - return user + assert user._id > 0 + return user + except DBAPIError: + return self._find_user(name, self.token) def __getstate__(self): return dict( @@ -583,9 +591,13 @@ def reserve_trial(self, experiment): now = datetime.datetime.utcnow() with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, + stmt = ( + select(Trial) + .where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + .limit(1) ) try: diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index 37c7d0e74..f2ed7fb35 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -14,19 +14,30 @@ from orion.client import create_experiment from orion.core.io.database import DatabaseTimeout -from orion.core.utils.exceptions import ReservationTimeout +from orion.core.utils.exceptions import ( + CompletedExperiment, + ReservationRaceCondition, + ReservationTimeout, + WaitingForTrials, +) DB_FILE = "stress.pkl" SQLITE_FILE = "db.sqlite" ADDRESS = "192.168.0.16" -NUM_TRIALS = 500 +NUM_TRIALS = 1000 -NUM_WORKERS = [1, 4, 16, 32, 64, 128] +NUM_WORKERS = [32, 64] LOG_LEVEL = logging.WARNING +SPACE = ["discrete", "real", "real-seeded"] +SPACE = ["real-seeded"] + +# raw_worker or runner_worker +METHOD = "runner_worker" + # # Create the stress test user # @@ -65,11 +76,11 @@ BACKENDS_CONFIGS = OrderedDict( [ + ("sqlite", {"type": "sqlalchemy", "uri": f"sqlite:///{SQLITE_FILE}"}), ( "pickleddb", {"type": "legacy", "database": {"type": "pickleddb", "host": DB_FILE}}, ), - ("sqlite", {"type": "sqlalchemy", "uri": f"sqlite:///{SQLITE_FILE}"}), # ( # "postgresql", # { @@ -138,9 +149,9 @@ def cleanup_storage(backend): raise RuntimeError("You need to cleam your backend") -def f(x, worker): +def f(x, worker=-1): """Sleep and return objective equal to param""" - time.sleep(max(0, random.gauss(1, 0.2))) + time.sleep(max(0, random.gauss(0.1, 1))) return [dict(name="objective", value=x, type="objective")] @@ -164,7 +175,7 @@ def get_experiment(storage, space_type, size): storage_config = BACKENDS_CONFIGS[storage] discrete = space_type == "discrete" - high = size # * 2 + high = size * 2 return create_experiment( "stress-test", @@ -176,7 +187,7 @@ def get_experiment(storage, space_type, size): ) -def worker(worker_id, storage, space_type, size): +def raw_worker(worker_id, storage, space_type, size, pool_size): """Run trials until experiment is done Parameters @@ -201,7 +212,13 @@ def worker(worker_id, storage, space_type, size): num_trials = 0 while not experiment.is_done: try: - trial = experiment.suggest() + trial = experiment.suggest(pool_size=pool_size) + except WaitingForTrials: + continue + except CompletedExperiment: + continue + except ReservationRaceCondition: + continue except ReservationTimeout: trial - None @@ -212,15 +229,15 @@ def worker(worker_id, storage, space_type, size): results = f(x, worker_id) num_trials += 1 - print(f" - {worker_id: 6d} {num_trials: 6d} {x: 5f}") + print(f"\r - {worker_id: 6d} {num_trials: 6d} {x: 5.0f}", end="") experiment.observe(trial, results=results) - print(f"{worker_id: 6d} leaves | is done? {experiment.is_done}") + print(f"\n{worker_id: 6d} leaves | is done? {experiment.is_done}") except DatabaseTimeout as e: - print(f"{worker_id: 6d} timeouts and leaves") + print(f"\n{worker_id: 6d} timeouts and leaves") return num_trials except Exception as e: - print(f"{worker_id: 6d} crashes") + print(f"\n{worker_id: 6d} crashes") traceback.print_exc() return None @@ -231,10 +248,10 @@ def worker(worker_id, storage, space_type, size): def always_clean(storage): cleanup_storage(storage) yield - # cleanup_storage(storage) + cleanup_storage(storage) -def stress_test(storage, space_type, workers, size): +def stress_test_raw_worker(storage, space_type, workers, size, pool_size): """Spawn workers and run stress test with verifications Parameters @@ -258,12 +275,13 @@ def stress_test(storage, space_type, workers, size): with Pool(workers) as p: results = p.starmap( - worker, + raw_worker, zip( range(workers), [storage] * workers, [space_type] * workers, [size] * workers, + [pool_size] * workers, ), ) @@ -272,18 +290,43 @@ def stress_test(storage, space_type, workers, size): ), "A worker crashed unexpectedly. See logs for the error messages." assert all(n > 0 for n in results), "A worker could not execute any trial." - if space_type in ["discrete", "real-seeded"]: - assert sum(results) == size, results - else: - assert sum(results) >= size, results + assert sum(results) >= size, f"sum({results}) = {sum(results)} != {size}" experiment = get_experiment(storage, space_type, size) - trials = experiment.fetch_trials() + trials = experiment.fetch_trials_by_status("completed") return trials +def stress_test_runner(storage, space_type, workers, size, pool_size): + """Spawn workers and run stress test with verifications + + Parameters + ---------- + storage: str + See `get_experiment`. + space_type: str + See `get_experiment`. + workers: int + Number of workers to run in parallel. + size: int + See `get_experiment`. + + Returns + ------- + `list` of `orion.core.worker.trial.Trial` + List of all trials at the end of the stress test + + """ + + experiment = get_experiment(storage, space_type, size) + + experiment.workon(fct=f, n_workers=workers, pool_size=pool_size, max_trials=size) + + return experiment.fetch_trials() + + def get_timestamps(trials, size, space_type): """Get start timestamps of the trials @@ -307,10 +350,16 @@ def get_timestamps(trials, size, space_type): x = [] y = [] + empty_trial = [] + start_time = None for i, trial in enumerate(trials): hparams.add(trial.params["x"]) + if trial.objective is None: + empty_trial.append(trial) + continue + assert trial.objective.value == trial.params["x"] if start_time is None: @@ -319,15 +368,12 @@ def get_timestamps(trials, size, space_type): x.append((trial.submit_time - start_time).total_seconds()) y.append(i) - if space_type in ["discrete", "real-seeded"]: - assert len(hparams) == size, f"{len(hparams)} == {size}" - else: - assert len(hparams) >= size - + print(f"Found empty trials {empty_trial}") + assert len(hparams) >= size, f"{len(hparams)} == {size}" return x[:size], y[:size] -def benchmark(workers, size): +def benchmark(workers, size, pool_size): """Get start timestamps of the trials Parameters @@ -348,15 +394,24 @@ def benchmark(workers, size): """ results = {} + + stres_test_method = None + if METHOD == "raw_worker": + stres_test_method = stress_test_raw_worker + else: + stres_test_method = stress_test_runner + for backend in BACKENDS_CONFIGS.keys(): - for space_type in ["discrete", "real", "real-seeded"]: + for space_type in SPACE: print(backend, space_type) # Initialize the storage once before parallel work get_experiment(backend, space_type, size) with always_clean(backend): - trials = stress_test(backend, space_type, workers, size) + trials = stres_test_method( + backend, space_type, workers, size, pool_size + ) results[(backend, space_type)] = get_timestamps( trials, size, space_type @@ -385,10 +440,10 @@ def main(): for i, workers in enumerate(num_workers): - results[workers] = benchmark(workers, size) + results[workers] = benchmark(workers, size, pool_size=workers) for backend in BACKENDS_CONFIGS.keys(): - for space_type in ["discrete", "real", "real-seeded"]: + for space_type in SPACE: x, y = results[workers][(backend, space_type)] axis[i].plot(x, y, label=f"{backend}-{space_type}") From d597140056f123c3af51cbe6751465b225315e14 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Tue, 6 Sep 2022 12:58:51 -0400 Subject: [PATCH 08/25] - --- src/orion/client/runner.py | 998 +++++++++++------------ src/orion/storage/sql.py | 4 +- tests/stress/client/stress_experiment.py | 24 +- 3 files changed, 518 insertions(+), 508 deletions(-) diff --git a/src/orion/client/runner.py b/src/orion/client/runner.py index f25e6f498..f1eb27a3d 100644 --- a/src/orion/client/runner.py +++ b/src/orion/client/runner.py @@ -1,499 +1,499 @@ -# pylint:disable=too-many-arguments -# pylint:disable=too-many-instance-attributes -""" -Runner -====== - -Executes the optimization process -""" -from __future__ import annotations - -import logging -import os -import shutil -import signal -import time -import typing -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Callable - -import orion.core -from orion.core.utils import backward -from orion.core.utils.exceptions import ( - BrokenExperiment, - CompletedExperiment, - InvalidResult, - LazyWorkers, - ReservationRaceCondition, - WaitingForTrials, -) -from orion.core.utils.flatten import flatten, unflatten -from orion.core.worker.consumer import ExecutionError -from orion.core.worker.trial import AlreadyReleased -from orion.executor.base import AsyncException, AsyncResult -from orion.storage.base import LockAcquisitionTimeout - -if typing.TYPE_CHECKING: - from orion.client.experiment import ExperimentClient - from orion.core.worker.trial import Trial - -log = logging.getLogger(__name__) - - -class Protected: - """Prevent a signal to be raised during the execution of some code""" - - def __init__(self): - self.signal_received = None - self.handlers = {} - self.start = 0 - self.delayed = 0 - self.signal_installed = False - - def __enter__(self): - """Override the signal handlers with our delayed handler""" - self.signal_received = False - - try: - self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler) - self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler) - self.signal_installed = True - - except ValueError: # ValueError: signal only works in main thread - log.warning( - "SIGINT/SIGTERM protection hooks could not be installed because " - "Runner is executing inside a thread/subprocess, results could get lost " - "on interruptions" - ) - - return self - - def handler(self, sig, frame): - """Register the received signal for later""" - log.warning("Delaying signal %d to finish operations", sig) - log.warning( - "Press CTRL-C again to terminate the program now (You may lose results)" - ) - - self.start = time.time() - - self.signal_received = (sig, frame) - - # if CTRL-C is pressed again the original handlers will handle it - # and make the program stop - self.restore_handlers() - - def restore_handlers(self): - """Restore old signal handlers""" - if not self.signal_installed: - return - - signal.signal(signal.SIGINT, self.handlers[signal.SIGINT]) - signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM]) - - def stop_now(self): - """Raise the delayed signal if any or restore the old signal handlers""" - - if not self.signal_received: - self.restore_handlers() - - else: - self.delayed = time.time() - self.start - - log.warning("Termination was delayed by %.4f s", self.delayed) - handler = self.handlers[self.signal_received[0]] - - if callable(handler): - handler(*self.signal_received) - - def __exit__(self, *args): - self.stop_now() - - -def _optimize(trial, fct, trial_arg, **kwargs): - """Execute a trial on a worker""" - - kwargs.update(flatten(trial.params)) - - if trial_arg: - kwargs[trial_arg] = trial - - return fct(**unflatten(kwargs)) - - -def delayed_exception(exception: Exception): - """Raise exception when called...""" - raise exception - - -@dataclass -class _Stat: - sample: int = 0 - scatter: int = 0 - gather: int = 0 - - @contextmanager - def time(self, name): - """Measure elapsed time of a given block""" - start = time.time() - yield - total = time.time() - start - - value = getattr(self, name) - setattr(self, name, value + total) - - def report(self): - """Show the elapsed time of different blocks""" - lines = [ - f"Sample {self.sample:7.4f}", - f"Scatter {self.scatter:7.4f}", - f"Gather {self.gather:7.4f}", - ] - return "\n".join(lines) - - -def prepare_trial_working_dir( - experiment_client: ExperimentClient, trial: Trial -) -> None: - """Prepare working directory of a trial. - - This will create a working directory based on ``trial.working_dir`` if not already existing. If - the trial has a parent, the ``working_dir`` of the parent will be copied to the ``working_dir`` - of the current trial. - - Parameters - ---------- - experiment_client: orion.client.experiment.ExperimentClient - The experiment client being executed. - trial: orion.core.worker.trial.Trial - The trial that will be executed. - - Raises - ------ - ``ValueError`` - If the parent is not found in the storage of ``experiment_client``. - - """ - backward.ensure_trial_working_dir(experiment_client, trial) - - # TODO: Test that this works when resuming a trial. - if os.path.exists(trial.working_dir): - return - - if trial.parent: - parent_trial = experiment_client.get_trial(uid=trial.parent) - if parent_trial is None: - raise ValueError( - f"Parent id {trial.parent} not available in storage. (From trial {trial.id})" - ) - shutil.copytree(parent_trial.working_dir, trial.working_dir) - else: - os.makedirs(trial.working_dir) - - -class Runner: - """Run the optimization process given the current executor""" - - def __init__( - self, - client: ExperimentClient, - fct: Callable, - pool_size: int, - idle_timeout: int, - max_trials_per_worker: int, - max_broken: int, - trial_arg: str, - on_error: Callable[[ExperimentClient, Exception, int], bool] | None = None, - prepare_trial: Callable[ - [ExperimentClient, Trial], None - ] = prepare_trial_working_dir, - interrupt_signal_code: int | None = None, - gather_timeout: float = 0.01, - n_workers: int | None = None, - **kwargs, - ): - self.client = client - self.fct = fct - self.batch_size = pool_size - self.max_trials_per_worker = max_trials_per_worker - self.max_broken = max_broken - self.trial_arg = trial_arg - self.on_error = on_error - self.prepare_trial = prepare_trial - self.kwargs = kwargs - - self.gather_timeout = gather_timeout - self.idle_timeout = idle_timeout - - self.worker_broken_trials = 0 - self.trials = 0 - self.futures = [] - self.pending_trials = {} - self.stat = _Stat() - self.n_worker_override = n_workers - - if interrupt_signal_code is None: - interrupt_signal_code = orion.core.config.worker.interrupt_signal_code - - self.interrupt_signal_code = interrupt_signal_code - - @property - def free_worker(self): - """Returns the number of free worker""" - n_workers = self.client.executor.n_workers - - if self.n_worker_override is not None: - n_workers = self.n_worker_override - - return max(n_workers - len(self.pending_trials), 0) - - @property - def is_done(self): - """Returns true if the experiment has finished.""" - return self.client.is_done - - @property - def is_broken(self): - """Returns true if the experiment is broken""" - return self.worker_broken_trials >= self.max_broken - - @property - def has_remaining(self) -> bool: - """Returns true if the worker can still pick up work""" - return self.max_trials_per_worker - self.trials > 0 - - @property - def is_idle(self): - """Returns true if none of the workers are running a trial""" - return len(self.pending_trials) <= 0 - - @property - def is_running(self): - """Returns true if we are still running trials.""" - return len(self.pending_trials) > 0 or (self.has_remaining and not self.is_done) - - def run(self): - """Run the optimizing process until completion. - - Returns - ------- - the total number of trials processed - - """ - idle_start = time.time() - idle_end = 0 - idle_time = 0 - - while self.is_running: - try: - - # Protected will prevent Keyboard interrupts from - # happening in the middle of the scatter-gather process - # that we can be sure that completed trials are observed - with Protected(): - - # Get new trials for our free workers - with self.stat.time("sample"): - new_trials = self.sample() - - # Scatter the new trials to our free workers - with self.stat.time("scatter"): - scattered = self.scatter(new_trials) - - # Gather the results of the workers that have finished - with self.stat.time("gather"): - gathered = self.gather() - - if scattered == 0 and gathered == 0 and self.is_idle: - idle_end = time.time() - idle_time += idle_end - idle_start - idle_start = idle_end - - log.debug(f"Workers have been idle for {idle_time:.2f} s") - else: - idle_start = time.time() - idle_time = 0 - - if self.is_idle and idle_time > self.idle_timeout: - msg = f"Workers have been idle for {idle_time:.2f} s" - - if self.has_remaining and not self.is_done: - msg = ( - f"{msg}; worker has leg room (has_remaining: {self.has_remaining})" - f" and optimization is not done (is_done: {self.is_done})" - ) - - raise LazyWorkers(msg) - - except KeyboardInterrupt: - self._release_all() - raise - except: - self._release_all() - raise - - return self.trials - - def should_sample(self): - """Check if more trials could be generated""" - - if self.free_worker <= 0 or (self.is_broken or self.is_done): - return 0 - - pending = len(self.pending_trials) + self.trials - remains = self.max_trials_per_worker - pending - - n_trial = min(self.free_worker, remains) - should_sample_more = self.free_worker > 0 and remains > 0 - - return int(should_sample_more) * n_trial - - def sample(self): - """Sample new trials for all free workers""" - n_trial = self.should_sample() - - if n_trial > 0: - # the producer does the job of limiting the number of new trials - # already no need to worry about it - # NB: suggest reserve the trial already - new_trials = self._suggest_trials(n_trial) - log.debug(f"Sampled {len(new_trials)} new configs") - return new_trials - - return [] - - def scatter(self, new_trials): - """Schedule new trials to be computed""" - new_futures = [] - for trial in new_trials: - try: - self.prepare_trial(self.client, trial) - prepared = True - # pylint:disable=broad-except - except Exception as e: - future = self.client.executor.submit(delayed_exception, e) - prepared = False - - if prepared: - future = self.client.executor.submit( - _optimize, trial, self.fct, self.trial_arg, **self.kwargs - ) - - self.pending_trials[future] = trial - new_futures.append(future) - - self.futures.extend(new_futures) - if new_futures: - log.debug("Scheduled new trials") - return len(new_futures) - - def gather(self): - """Gather the results from each worker asynchronously""" - results = self.client.executor.async_get( - self.futures, timeout=self.gather_timeout - ) - - to_be_raised = None - if results: - log.debug(f"Gathered new results {len(results)}") - # register the results - # NOTE: For Ptera instrumentation - trials = 0 # pylint:disable=unused-variable - for result in results: - trial = self.pending_trials.pop(result.future) - - if isinstance(result, AsyncResult): - try: - # NB: observe release the trial already - self.client.observe(trial, result.value) - self.trials += 1 - # NOTE: For Ptera instrumentation - trials = self.trials # pylint:disable=unused-variable - except InvalidResult as exception: - # stop the optimization process if we received `InvalidResult` - # as all the trials are assumed to be returning those - to_be_raised = exception - self.client.release(trial, status="broken") - - if isinstance(result, AsyncException): - if ( - isinstance(result.exception, ExecutionError) - and result.exception.return_code == self.interrupt_signal_code - ): - to_be_raised = KeyboardInterrupt() - self.client.release(trial, status="interrupted") - continue - - # Regular exception, might be caused by the chosen hyperparameters - # themselves rather than the code in particular (like Out of Memory error - # for big batch sizes) - exception = result.exception - self.worker_broken_trials += 1 - self.client.release(trial, status="broken") - - if self.on_error is None or self.on_error( - self, trial, exception, self.worker_broken_trials - ): - log.error(result.traceback) - - else: - log.error(str(exception)) - log.debug(result.traceback) - - # if we receive too many broken trials, it might indicate the user script - # is broken, stop the experiment and let the user investigate - if self.is_broken: - to_be_raised = BrokenExperiment( - "Worker has reached broken trials threshold" - ) - - if to_be_raised is not None: - log.debug("Runner was interrupted") - self._release_all() - raise to_be_raised - - return len(results) - - def _release_all(self): - """Release all the trials that were reserved by this runner. - This is only called during exception handling to avoid retaining trials - that cannot be retrieved anymore - - """ - # Sanity check - for _, trial in self.pending_trials.items(): - try: - self.client.release(trial, status="interrupted") - except AlreadyReleased: - pass - - self.pending_trials = {} - - def _suggest_trials(self, count): - """Suggest a bunch of trials to be dispatched to the workers""" - trials = [] - for _ in range(count): - try: - batch_size = count if self.batch_size == 0 else self.batch_size - trial = self.client.suggest(pool_size=batch_size) - trials.append(trial) - - # non critical errors - except WaitingForTrials: - log.debug("Runner cannot sample because WaitingForTrials") - break - - except ReservationRaceCondition: - log.debug("Runner cannot sample because ReservationRaceCondition") - break - - except LockAcquisitionTimeout: - log.debug("Runner cannot sample because LockAcquisitionTimeout") - break - - except CompletedExperiment: - log.debug("Runner cannot sample because CompletedExperiment") - break - - return trials +# pylint:disable=too-many-arguments +# pylint:disable=too-many-instance-attributes +""" +Runner +====== + +Executes the optimization process +""" +from __future__ import annotations + +import logging +import os +import shutil +import signal +import time +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable + +import orion.core +from orion.core.utils import backward +from orion.core.utils.exceptions import ( + BrokenExperiment, + CompletedExperiment, + InvalidResult, + LazyWorkers, + ReservationRaceCondition, + WaitingForTrials, +) +from orion.core.utils.flatten import flatten, unflatten +from orion.core.worker.consumer import ExecutionError +from orion.core.worker.trial import AlreadyReleased +from orion.executor.base import AsyncException, AsyncResult +from orion.storage.base import LockAcquisitionTimeout + +if typing.TYPE_CHECKING: + from orion.client.experiment import ExperimentClient + from orion.core.worker.trial import Trial + +log = logging.getLogger(__name__) + + +class Protected: + """Prevent a signal to be raised during the execution of some code""" + + def __init__(self): + self.signal_received = None + self.handlers = {} + self.start = 0 + self.delayed = 0 + self.signal_installed = False + + def __enter__(self): + """Override the signal handlers with our delayed handler""" + self.signal_received = False + + try: + self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler) + self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler) + self.signal_installed = True + + except ValueError: # ValueError: signal only works in main thread + log.warning( + "SIGINT/SIGTERM protection hooks could not be installed because " + "Runner is executing inside a thread/subprocess, results could get lost " + "on interruptions" + ) + + return self + + def handler(self, sig, frame): + """Register the received signal for later""" + log.warning("Delaying signal %d to finish operations", sig) + log.warning( + "Press CTRL-C again to terminate the program now (You may lose results)" + ) + + self.start = time.time() + + self.signal_received = (sig, frame) + + # if CTRL-C is pressed again the original handlers will handle it + # and make the program stop + self.restore_handlers() + + def restore_handlers(self): + """Restore old signal handlers""" + if not self.signal_installed: + return + + signal.signal(signal.SIGINT, self.handlers[signal.SIGINT]) + signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM]) + + def stop_now(self): + """Raise the delayed signal if any or restore the old signal handlers""" + + if not self.signal_received: + self.restore_handlers() + + else: + self.delayed = time.time() - self.start + + log.warning("Termination was delayed by %.4f s", self.delayed) + handler = self.handlers[self.signal_received[0]] + + if callable(handler): + handler(*self.signal_received) + + def __exit__(self, *args): + self.stop_now() + + +def _optimize(trial, fct, trial_arg, **kwargs): + """Execute a trial on a worker""" + + kwargs.update(flatten(trial.params)) + + if trial_arg: + kwargs[trial_arg] = trial + + return fct(**unflatten(kwargs)) + + +def delayed_exception(exception: Exception): + """Raise exception when called...""" + raise exception + + +@dataclass +class _Stat: + sample: int = 0 + scatter: int = 0 + gather: int = 0 + + @contextmanager + def time(self, name): + """Measure elapsed time of a given block""" + start = time.time() + yield + total = time.time() - start + + value = getattr(self, name) + setattr(self, name, value + total) + + def report(self): + """Show the elapsed time of different blocks""" + lines = [ + f"Sample {self.sample:7.4f}", + f"Scatter {self.scatter:7.4f}", + f"Gather {self.gather:7.4f}", + ] + return "\n".join(lines) + + +def prepare_trial_working_dir( + experiment_client: ExperimentClient, trial: Trial +) -> None: + """Prepare working directory of a trial. + + This will create a working directory based on ``trial.working_dir`` if not already existing. If + the trial has a parent, the ``working_dir`` of the parent will be copied to the ``working_dir`` + of the current trial. + + Parameters + ---------- + experiment_client: orion.client.experiment.ExperimentClient + The experiment client being executed. + trial: orion.core.worker.trial.Trial + The trial that will be executed. + + Raises + ------ + ``ValueError`` + If the parent is not found in the storage of ``experiment_client``. + + """ + backward.ensure_trial_working_dir(experiment_client, trial) + + # TODO: Test that this works when resuming a trial. + if os.path.exists(trial.working_dir): + return + + if trial.parent: + parent_trial = experiment_client.get_trial(uid=trial.parent) + if parent_trial is None: + raise ValueError( + f"Parent id {trial.parent} not available in storage. (From trial {trial.id})" + ) + shutil.copytree(parent_trial.working_dir, trial.working_dir) + else: + os.makedirs(trial.working_dir) + + +class Runner: + """Run the optimization process given the current executor""" + + def __init__( + self, + client: ExperimentClient, + fct: Callable, + pool_size: int, + idle_timeout: int, + max_trials_per_worker: int, + max_broken: int, + trial_arg: str, + on_error: Callable[[ExperimentClient, Exception, int], bool] | None = None, + prepare_trial: Callable[ + [ExperimentClient, Trial], None + ] = prepare_trial_working_dir, + interrupt_signal_code: int | None = None, + gather_timeout: float = 0.01, + n_workers: int | None = None, + **kwargs, + ): + self.client = client + self.fct = fct + self.batch_size = pool_size + self.max_trials_per_worker = max_trials_per_worker + self.max_broken = max_broken + self.trial_arg = trial_arg + self.on_error = on_error + self.prepare_trial = prepare_trial + self.kwargs = kwargs + + self.gather_timeout = gather_timeout + self.idle_timeout = idle_timeout + + self.worker_broken_trials = 0 + self.trials = 0 + self.futures = [] + self.pending_trials = {} + self.stat = _Stat() + self.n_worker_override = n_workers + + if interrupt_signal_code is None: + interrupt_signal_code = orion.core.config.worker.interrupt_signal_code + + self.interrupt_signal_code = interrupt_signal_code + + @property + def free_worker(self): + """Returns the number of free worker""" + n_workers = self.client.executor.n_workers + + if self.n_worker_override is not None: + n_workers = self.n_worker_override + + return max(n_workers - len(self.pending_trials), 0) + + @property + def is_done(self): + """Returns true if the experiment has finished.""" + return self.client.is_done + + @property + def is_broken(self): + """Returns true if the experiment is broken""" + return self.worker_broken_trials >= self.max_broken + + @property + def has_remaining(self) -> bool: + """Returns true if the worker can still pick up work""" + return self.max_trials_per_worker - self.trials > 0 + + @property + def is_idle(self): + """Returns true if none of the workers are running a trial""" + return len(self.pending_trials) <= 0 + + @property + def is_running(self): + """Returns true if we are still running trials.""" + return len(self.pending_trials) > 0 or (self.has_remaining and not self.is_done) + + def run(self): + """Run the optimizing process until completion. + + Returns + ------- + the total number of trials processed + + """ + idle_start = time.time() + idle_end = 0 + idle_time = 0 + + while self.is_running: + try: + + # Protected will prevent Keyboard interrupts from + # happening in the middle of the scatter-gather process + # that we can be sure that completed trials are observed + with Protected(): + + # Get new trials for our free workers + with self.stat.time("sample"): + new_trials = self.sample() + + # Scatter the new trials to our free workers + with self.stat.time("scatter"): + scattered = self.scatter(new_trials) + + # Gather the results of the workers that have finished + with self.stat.time("gather"): + gathered = self.gather() + + if scattered == 0 and gathered == 0 and self.is_idle: + idle_end = time.time() + idle_time += idle_end - idle_start + idle_start = idle_end + + log.debug(f"Workers have been idle for {idle_time:.2f} s") + else: + idle_start = time.time() + idle_time = 0 + + if self.is_idle and idle_time > self.idle_timeout: + msg = f"Workers have been idle for {idle_time:.2f} s" + + if self.has_remaining and not self.is_done: + msg = ( + f"{msg}; worker has leg room (has_remaining: {self.has_remaining})" + f" and optimization is not done (is_done: {self.is_done})" + ) + + raise LazyWorkers(msg) + + except KeyboardInterrupt: + self._release_all() + raise + except: + self._release_all() + raise + + return self.trials + + def should_sample(self): + """Check if more trials could be generated""" + + if self.free_worker <= 0 or (self.is_broken or self.is_done): + return 0 + + pending = len(self.pending_trials) + self.trials + remains = self.max_trials_per_worker - pending + + n_trial = min(self.free_worker, remains) + should_sample_more = self.free_worker > 0 and remains > 0 + + return int(should_sample_more) * n_trial + + def sample(self): + """Sample new trials for all free workers""" + n_trial = self.should_sample() + + if n_trial > 0: + # the producer does the job of limiting the number of new trials + # already no need to worry about it + # NB: suggest reserve the trial already + new_trials = self._suggest_trials(n_trial) + log.debug(f"Sampled {len(new_trials)} new configs") + return new_trials + + return [] + + def scatter(self, new_trials): + """Schedule new trials to be computed""" + new_futures = [] + for trial in new_trials: + try: + self.prepare_trial(self.client, trial) + prepared = True + # pylint:disable=broad-except + except Exception as e: + future = self.client.executor.submit(delayed_exception, e) + prepared = False + + if prepared: + future = self.client.executor.submit( + _optimize, trial, self.fct, self.trial_arg, **self.kwargs + ) + + self.pending_trials[future] = trial + new_futures.append(future) + + self.futures.extend(new_futures) + if new_futures: + log.debug("Scheduled new trials") + return len(new_futures) + + def gather(self): + """Gather the results from each worker asynchronously""" + results = self.client.executor.async_get( + self.futures, timeout=self.gather_timeout + ) + + to_be_raised = None + if results: + log.debug(f"Gathered new results {len(results)}") + # register the results + # NOTE: For Ptera instrumentation + trials = 0 # pylint:disable=unused-variable + for result in results: + trial = self.pending_trials.pop(result.future) + + if isinstance(result, AsyncResult): + try: + # NB: observe release the trial already + self.client.observe(trial, result.value) + self.trials += 1 + # NOTE: For Ptera instrumentation + trials = self.trials # pylint:disable=unused-variable + except InvalidResult as exception: + # stop the optimization process if we received `InvalidResult` + # as all the trials are assumed to be returning those + to_be_raised = exception + self.client.release(trial, status="broken") + + if isinstance(result, AsyncException): + if ( + isinstance(result.exception, ExecutionError) + and result.exception.return_code == self.interrupt_signal_code + ): + to_be_raised = KeyboardInterrupt() + self.client.release(trial, status="interrupted") + continue + + # Regular exception, might be caused by the chosen hyperparameters + # themselves rather than the code in particular (like Out of Memory error + # for big batch sizes) + exception = result.exception + self.worker_broken_trials += 1 + self.client.release(trial, status="broken") + + if self.on_error is None or self.on_error( + self, trial, exception, self.worker_broken_trials + ): + log.error(result.traceback) + + else: + log.error(str(exception)) + log.debug(result.traceback) + + # if we receive too many broken trials, it might indicate the user script + # is broken, stop the experiment and let the user investigate + if self.is_broken: + to_be_raised = BrokenExperiment( + "Worker has reached broken trials threshold" + ) + + if to_be_raised is not None: + log.debug("Runner was interrupted") + self._release_all() + raise to_be_raised + + return len(results) + + def _release_all(self): + """Release all the trials that were reserved by this runner. + This is only called during exception handling to avoid retaining trials + that cannot be retrieved anymore + + """ + # Sanity check + for _, trial in self.pending_trials.items(): + try: + self.client.release(trial, status="interrupted") + except AlreadyReleased: + pass + + self.pending_trials = {} + + def _suggest_trials(self, count): + """Suggest a bunch of trials to be dispatched to the workers""" + trials = [] + for _ in range(count): + try: + batch_size = count if self.batch_size == 0 else self.batch_size + trial = self.client.suggest(pool_size=batch_size) + trials.append(trial) + + # non critical errors + except WaitingForTrials: + log.debug("Runner cannot sample because WaitingForTrials") + break + + except ReservationRaceCondition: + log.debug("Runner cannot sample because ReservationRaceCondition") + break + + except LockAcquisitionTimeout: + log.debug("Runner cannot sample because LockAcquisitionTimeout") + break + + except CompletedExperiment: + log.debug("Runner cannot sample because CompletedExperiment") + break + + return trials diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index c5dc247da..617236a62 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -478,7 +478,7 @@ def update_trial( self._set_from_dict(trial, kwargs) session.commit() - return OrionTrial(*self._to_trial(trial)) + return OrionTrial(**self._to_trial(trial)) def fetch_lost_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" @@ -495,7 +495,7 @@ def fetch_lost_trials(self, experiment): ) results = session.scalars(stmt).all() - return [OrionTrial(*self._to_trial(t)) for t in results] + return [OrionTrial(**self._to_trial(t)) for t in results] def push_trial_results(self, trial): """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index f2ed7fb35..f31d4472f 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Perform a stress tests on python API.""" import logging +import multiprocessing import os import random import time @@ -26,14 +27,17 @@ ADDRESS = "192.168.0.16" -NUM_TRIALS = 1000 +NUM_TRIALS = 500 -NUM_WORKERS = [32, 64] +NUM_WORKERS = [1, 16, 32, 64] + +# int or 'workers +POOL_SIZE = 0 LOG_LEVEL = logging.WARNING -SPACE = ["discrete", "real", "real-seeded"] -SPACE = ["real-seeded"] +# SPACE = ["discrete", "real", "real-seeded"] +SPACE = ["discrete"] # raw_worker or runner_worker METHOD = "runner_worker" @@ -151,7 +155,10 @@ def cleanup_storage(backend): def f(x, worker=-1): """Sleep and return objective equal to param""" - time.sleep(max(0, random.gauss(0.1, 1))) + time.sleep(max(0, random.gauss(1, 0.2))) + + print(f'\r {x:5.2f}', end='') + return [dict(name="objective", value=x, type="objective")] @@ -175,7 +182,7 @@ def get_experiment(storage, space_type, size): storage_config = BACKENDS_CONFIGS[storage] discrete = space_type == "discrete" - high = size * 2 + high = size # * 2 return create_experiment( "stress-test", @@ -439,8 +446,11 @@ def main(): results = {} for i, workers in enumerate(num_workers): + pool_size = POOL_SIZE + if POOL_SIZE == 'worker': + pool_size = workers - results[workers] = benchmark(workers, size, pool_size=workers) + results[workers] = benchmark(workers, size, pool_size=pool_size) for backend in BACKENDS_CONFIGS.keys(): for space_type in SPACE: From 26fd7ce922adda3438a21c15820244a5c9d4fc33 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Mon, 21 Nov 2022 13:53:59 -0500 Subject: [PATCH 09/25] - --- src/orion/core/io/experiment_builder.py | 2 +- src/orion/storage/sql.py | 96 ++++++++++++------------ tests/stress/client/stress_experiment.py | 7 +- 3 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 999eb1a65..696521055 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -94,8 +94,8 @@ from orion.core.io.experiment_branch_builder import ExperimentBranchBuilder from orion.core.io.interactive_commands.branching_prompt import BranchingPrompt from orion.core.io.space_builder import SpaceBuilder -from orion.core.utils.compat import getuser from orion.core.utils import backward +from orion.core.utils.compat import getuser from orion.core.utils.exceptions import ( BranchingEvent, NoConfigurationError, diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 617236a62..e8d549ce0 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -54,87 +54,87 @@ def compile_binary_postgresql(type_, compiler, **kw): return "BYTEA" -# fmt: off class User(Base): """Defines the User table""" + __tablename__ = "users" - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30), unique=True) - token = Column(String(32)) - created_at = Column(DateTime) - last_seen = Column(DateTime) + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30), unique=True) + token = Column(String(32)) + created_at = Column(DateTime) + last_seen = Column(DateTime) class Experiment(Base): """Defines the Experiment table""" + __tablename__ = "experiments" - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30)) - meta = Column(JSON) # metadata field is reserved - version = Column(Integer) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - datetime = Column(DateTime) - algorithms = Column(JSON) - remaining = Column(JSON) - space = Column(JSON) - parent_id = Column(Integer) + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30)) + meta = Column(JSON) # metadata field is reserved + version = Column(Integer) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + datetime = Column(DateTime) + algorithms = Column(JSON) + remaining = Column(JSON) + space = Column(JSON) + parent_id = Column(Integer) __table_args__ = ( - UniqueConstraint('name', 'owner_id', name='_one_name_per_owner'), - Index('idx_experiment_name_version', 'name', 'version'), + UniqueConstraint("name", "owner_id", name="_one_name_per_owner"), + Index("idx_experiment_name_version", "name", "version"), ) class Trial(Base): """Defines the Trial table""" + __tablename__ = "trials" - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - status = Column(String(30)) - results = Column(JSON) - start_time = Column(DateTime) - end_time = Column(DateTime) - heartbeat = Column(DateTime) - parent = Column(Integer, ForeignKey("trials._id"), nullable=True) - params = Column(JSON) - worker = Column(JSON) - submit_time = Column(DateTime) + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + status = Column(String(30)) + results = Column(JSON) + start_time = Column(DateTime) + end_time = Column(DateTime) + heartbeat = Column(DateTime) + parent = Column(Integer, ForeignKey("trials._id"), nullable=True) + params = Column(JSON) + worker = Column(JSON) + submit_time = Column(DateTime) exp_working_dir = Column(String(30)) - id = Column(String(30)) + id = Column(String(30)) __table_args__ = ( - UniqueConstraint('experiment_id', 'id', name='_one_trial_hash_per_experiment'), - Index('idx_trial_experiment_id', 'experiment_id'), - Index('idx_trial_status', 'status'), + UniqueConstraint("experiment_id", "id", name="_one_trial_hash_per_experiment"), + Index("idx_trial_experiment_id", "experiment_id"), + Index("idx_trial_status", "status"), # Can't put an index on json # Index('idx_trial_results', 'results'), - Index('idx_trial_start_time', 'start_time'), - Index('idx_trial_end_time', 'end_time'), + Index("idx_trial_start_time", "start_time"), + Index("idx_trial_end_time", "end_time"), ) class Algo(Base): """Defines the Algo table""" + __tablename__ = "algo" # it is one algo per experiment so we could set experiment_id as the primary key # and make it a 1-1 relation - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - configuration = Column(JSON) - locked = Column(Integer) - state = Column(BINARY) - heartbeat = Column(DateTime) + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + configuration = Column(JSON) + locked = Column(Integer) + state = Column(BINARY) + heartbeat = Column(DateTime) - __table_args__ = ( - Index('idx_algo_experiment_id', 'experiment_id'), - ) -# fmt: on + __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) def get_tables(): @@ -183,7 +183,7 @@ def __init__(self, uri, token=None, **kwargs): # Create the schema # sqlite3 can fail on table if it already exist - # the doc says it shouldnt but it does + # the doc says it shouldn't but it does try: Base.metadata.create_all(self.engine) except DBAPIError: diff --git a/tests/stress/client/stress_experiment.py b/tests/stress/client/stress_experiment.py index f31d4472f..f1f8c8d32 100644 --- a/tests/stress/client/stress_experiment.py +++ b/tests/stress/client/stress_experiment.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Perform a stress tests on python API.""" import logging -import multiprocessing import os import random import time @@ -157,7 +156,7 @@ def f(x, worker=-1): """Sleep and return objective equal to param""" time.sleep(max(0, random.gauss(1, 0.2))) - print(f'\r {x:5.2f}', end='') + print(f"\r {x:5.2f}", end="") return [dict(name="objective", value=x, type="objective")] @@ -182,7 +181,7 @@ def get_experiment(storage, space_type, size): storage_config = BACKENDS_CONFIGS[storage] discrete = space_type == "discrete" - high = size # * 2 + high = size # * 2 return create_experiment( "stress-test", @@ -447,7 +446,7 @@ def main(): for i, workers in enumerate(num_workers): pool_size = POOL_SIZE - if POOL_SIZE == 'worker': + if POOL_SIZE == "worker": pool_size = workers results[workers] = benchmark(workers, size, pool_size=pool_size) From 0f28bc94baa5c7bbe99deec2ab0125abd059eb2e Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Mon, 21 Nov 2022 13:58:48 -0500 Subject: [PATCH 10/25] - --- src/orion/core/utils/compat.py | 3 +++ tests/conftest.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index f24a7b932..6a8bb1454 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -3,6 +3,7 @@ import time +# pylint: no-method-argument def getuser(): """getpass use pwd which is UNIX only""" @@ -14,6 +15,8 @@ def getuser(): return getpass.getuser() + +# pylint: too-few-public-methods class _readline: def set_completer_delims(*args, **kwargs): """Fake method for windows""" diff --git a/tests/conftest.py b/tests/conftest.py index 2c4058f0c..dd9fa68bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ """Common fixtures and utils for unittests and functional tests.""" from __future__ import annotations -import getpass import os from typing import Any From 6fd8641ec20e20a709f27c8e8cde62434dc5a6f1 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Mon, 21 Nov 2022 14:33:27 -0500 Subject: [PATCH 11/25] - --- src/orion/core/utils/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index 6a8bb1454..92d57f5bc 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -3,7 +3,6 @@ import time -# pylint: no-method-argument def getuser(): """getpass use pwd which is UNIX only""" @@ -18,6 +17,7 @@ def getuser(): # pylint: too-few-public-methods class _readline: + # pylint: no-method-argument def set_completer_delims(*args, **kwargs): """Fake method for windows""" From 3d2fe01abced5cece2eea3d12cdbf9518a6a430e Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Mon, 21 Nov 2022 14:35:30 -0500 Subject: [PATCH 12/25] - --- src/orion/core/utils/compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index 92d57f5bc..21912b247 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -15,9 +15,9 @@ def getuser(): -# pylint: too-few-public-methods +# pylint: disable=too-few-public-methods class _readline: - # pylint: no-method-argument + # pylint: disable=no-method-argument def set_completer_delims(*args, **kwargs): """Fake method for windows""" From 8cdae1de8ef5ad17dc28a889ea1a9ae977f0d159 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Mon, 21 Nov 2022 14:47:39 -0500 Subject: [PATCH 13/25] - --- docs/src/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/conf.py b/docs/src/conf.py index c06b60a0a..ea48f9790 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -213,7 +213,7 @@ # -- Autodoc configuration ----------------------------------------------- -autodoc_mock_imports = ["_version", "utils._appdirs", "nevergrad", "torch"] +autodoc_mock_imports = ["_version", "utils._appdirs", "nevergrad", "torch", "sqlalchemy"] # -- Gallery configuration ----------------------------------------------- From 8e1f796f4ee82a38f507e691871726474672fdd1 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 11:33:11 -0500 Subject: [PATCH 14/25] - --- docs/src/conf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/src/conf.py b/docs/src/conf.py index ea48f9790..f0fd9f812 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -213,7 +213,13 @@ # -- Autodoc configuration ----------------------------------------------- -autodoc_mock_imports = ["_version", "utils._appdirs", "nevergrad", "torch", "sqlalchemy"] +autodoc_mock_imports = [ + "_version", + "utils._appdirs", + "nevergrad", + "torch", + "sqlalchemy", +] # -- Gallery configuration ----------------------------------------------- From 76ae866e881a3673f383677e138303ff2fbba83d Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 11:34:10 -0500 Subject: [PATCH 15/25] - --- src/orion/core/utils/compat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/orion/core/utils/compat.py b/src/orion/core/utils/compat.py index 21912b247..9c8e16045 100644 --- a/src/orion/core/utils/compat.py +++ b/src/orion/core/utils/compat.py @@ -14,7 +14,6 @@ def getuser(): return getpass.getuser() - # pylint: disable=too-few-public-methods class _readline: # pylint: disable=no-method-argument From 8772d1c80b67862a1e4bd08633159b9249b8423c Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 11:55:54 -0500 Subject: [PATCH 16/25] - --- src/orion/storage/sql.py | 1440 +++++++++++++++++++------------------- 1 file changed, 726 insertions(+), 714 deletions(-) diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index e8d549ce0..24c676f99 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -6,29 +6,34 @@ import uuid from copy import deepcopy -import sqlalchemy - # Use MongoDB json serializer from bson.json_util import dumps as to_json from bson.json_util import loads as from_json -from sqlalchemy import ( - BINARY, - JSON, - Column, - DateTime, - ForeignKey, - Index, - Integer, - String, - UniqueConstraint, - delete, - select, - update, -) -from sqlalchemy.exc import DBAPIError, NoResultFound -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.orm import Session, declarative_base +IMPORT_ERROR = None +try: + from sqlalchemy import ( + BINARY, + JSON, + Column, + DateTime, + ForeignKey, + Index, + Integer, + String, + UniqueConstraint, + delete, + select, + update, + ) + import sqlalchemy + from sqlalchemy.exc import DBAPIError, NoResultFound + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.orm import Session, declarative_base + +except ImportError as err: + IMPORT_ERROR = err + import orion.core from orion.core.io.database import DuplicateKeyError from orion.core.utils.compat import getuser @@ -43,880 +48,887 @@ get_uid, ) -log = logging.getLogger(__name__) -Base = declarative_base() +if IMPORT_ERROR is not None: + class SQLAlchemy(BaseStorageProtocol): + def __init__(self, uri, token=None, **kwargs): + raise IMPORT_ERROR +else: + log = logging.getLogger(__name__) -@compiles(BINARY, "postgresql") -def compile_binary_postgresql(type_, compiler, **kw): - """Postgresql does not know about Binary type we should byte array instead""" - return "BYTEA" + Base = declarative_base() -class User(Base): - """Defines the User table""" + @compiles(BINARY, "postgresql") + def compile_binary_postgresql(type_, compiler, **kw): + """Postgresql does not know about Binary type we should byte array instead""" + return "BYTEA" - __tablename__ = "users" - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30), unique=True) - token = Column(String(32)) - created_at = Column(DateTime) - last_seen = Column(DateTime) + class User(Base): + """Defines the User table""" + __tablename__ = "users" -class Experiment(Base): - """Defines the Experiment table""" + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30), unique=True) + token = Column(String(32)) + created_at = Column(DateTime) + last_seen = Column(DateTime) - __tablename__ = "experiments" - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30)) - meta = Column(JSON) # metadata field is reserved - version = Column(Integer) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - datetime = Column(DateTime) - algorithms = Column(JSON) - remaining = Column(JSON) - space = Column(JSON) - parent_id = Column(Integer) + class Experiment(Base): + """Defines the Experiment table""" - __table_args__ = ( - UniqueConstraint("name", "owner_id", name="_one_name_per_owner"), - Index("idx_experiment_name_version", "name", "version"), - ) + __tablename__ = "experiments" + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30)) + meta = Column(JSON) # metadata field is reserved + version = Column(Integer) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + datetime = Column(DateTime) + algorithms = Column(JSON) + remaining = Column(JSON) + space = Column(JSON) + parent_id = Column(Integer) -class Trial(Base): - """Defines the Trial table""" - - __tablename__ = "trials" - - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - status = Column(String(30)) - results = Column(JSON) - start_time = Column(DateTime) - end_time = Column(DateTime) - heartbeat = Column(DateTime) - parent = Column(Integer, ForeignKey("trials._id"), nullable=True) - params = Column(JSON) - worker = Column(JSON) - submit_time = Column(DateTime) - exp_working_dir = Column(String(30)) - id = Column(String(30)) - - __table_args__ = ( - UniqueConstraint("experiment_id", "id", name="_one_trial_hash_per_experiment"), - Index("idx_trial_experiment_id", "experiment_id"), - Index("idx_trial_status", "status"), - # Can't put an index on json - # Index('idx_trial_results', 'results'), - Index("idx_trial_start_time", "start_time"), - Index("idx_trial_end_time", "end_time"), - ) + __table_args__ = ( + UniqueConstraint("name", "owner_id", name="_one_name_per_owner"), + Index("idx_experiment_name_version", "name", "version"), + ) -class Algo(Base): - """Defines the Algo table""" - - __tablename__ = "algo" - - # it is one algo per experiment so we could set experiment_id as the primary key - # and make it a 1-1 relation - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - configuration = Column(JSON) - locked = Column(Integer) - state = Column(BINARY) - heartbeat = Column(DateTime) - - __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) - - -def get_tables(): - return [User, Experiment, Trial, Algo, User] - - -class SQLAlchemy(BaseStorageProtocol): # noqa: F811 - """Implement a generic protocol to allow Orion to communicate using - different storage backend - - Parameters - ---------- - uri: str - PostgreSQL backend to use for storage; the format is as follow - `protocol://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]]` - - """ - - def __init__(self, uri, token=None, **kwargs): - # dialect+driver://username:password@host:port/database - # - # postgresql://scott:tiger@localhost/mydatabase - # postgresql+psycopg2://scott:tiger@localhost/mydatabase - # postgresql+pg8000://scott:tiger@localhost/mydatabase - # - # mysql://scott:tiger@localhost/foo - # mysql+mysqldb://scott:tiger@localhost/foo - # mysql+pymysql://scott:tiger@localhost/foo - # - # sqlite:///foo.db # relative - # sqlite:////foo.db # absolute - # sqlite:// # in memory - - self.uri = uri - if uri == "": - uri = "sqlite://" - - # engine_from_config - self.engine = sqlalchemy.create_engine( - uri, - echo=False, - future=True, - json_serializer=to_json, - json_deserializer=from_json, + class Trial(Base): + """Defines the Trial table""" + + __tablename__ = "trials" + + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + status = Column(String(30)) + results = Column(JSON) + start_time = Column(DateTime) + end_time = Column(DateTime) + heartbeat = Column(DateTime) + parent = Column(Integer, ForeignKey("trials._id"), nullable=True) + params = Column(JSON) + worker = Column(JSON) + submit_time = Column(DateTime) + exp_working_dir = Column(String(30)) + id = Column(String(30)) + + __table_args__ = ( + UniqueConstraint("experiment_id", "id", name="_one_trial_hash_per_experiment"), + Index("idx_trial_experiment_id", "experiment_id"), + Index("idx_trial_status", "status"), + # Can't put an index on json + # Index('idx_trial_results', 'results'), + Index("idx_trial_start_time", "start_time"), + Index("idx_trial_end_time", "end_time"), ) - # Create the schema - # sqlite3 can fail on table if it already exist - # the doc says it shouldn't but it does - try: - Base.metadata.create_all(self.engine) - except DBAPIError: - pass - self.token = token - self.user_id = None - self.user = None - self._connect(token) - - def _connect(self, token): - name = getuser() + class Algo(Base): + """Defines the Algo table""" + + __tablename__ = "algo" + + # it is one algo per experiment so we could set experiment_id as the primary key + # and make it a 1-1 relation + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + configuration = Column(JSON) + locked = Column(Integer) + state = Column(BINARY) + heartbeat = Column(DateTime) + + __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) + + + def get_tables(): + return [User, Experiment, Trial, Algo, User] + + + class SQLAlchemy(BaseStorageProtocol): # noqa: F811 + """Implement a generic protocol to allow Orion to communicate using + different storage backend + + Parameters + ---------- + uri: str + PostgreSQL backend to use for storage; the format is as follow + `protocol://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]]` + + """ + + def __init__(self, uri, token=None, **kwargs): + # dialect+driver://username:password@host:port/database + # + # postgresql://scott:tiger@localhost/mydatabase + # postgresql+psycopg2://scott:tiger@localhost/mydatabase + # postgresql+pg8000://scott:tiger@localhost/mydatabase + # + # mysql://scott:tiger@localhost/foo + # mysql+mysqldb://scott:tiger@localhost/foo + # mysql+pymysql://scott:tiger@localhost/foo + # + # sqlite:///foo.db # relative + # sqlite:////foo.db # absolute + # sqlite:// # in memory + + self.uri = uri + if uri == "": + uri = "sqlite://" + + # engine_from_config + self.engine = sqlalchemy.create_engine( + uri, + echo=False, + future=True, + json_serializer=to_json, + json_deserializer=from_json, + ) - user = self._find_user(name, token) + # Create the schema + # sqlite3 can fail on table if it already exist + # the doc says it shouldn't but it does + try: + Base.metadata.create_all(self.engine) + except DBAPIError: + pass - if user is None: - user = self._create_user(name) + self.token = token + self.user_id = None + self.user = None + self._connect(token) - assert user is not None + def _connect(self, token): + name = getuser() - self.user_id = user._id - self.user = user - self.token = user.token + user = self._find_user(name, token) - def _find_user(self, name, token) -> User: - query = [User.name == name] - if token is not None and token != "": - query.append(User.token == token) + if user is None: + user = self._create_user(name) - with Session(self.engine) as session: - stmt = select(User).where(*query) + assert user is not None - return session.execute(stmt).scalar() + self.user_id = user._id + self.user = user + self.token = user.token - def _create_user(self, name) -> User: - try: - now = datetime.datetime.utcnow() + def _find_user(self, name, token) -> User: + query = [User.name == name] + if token is not None and token != "": + query.append(User.token == token) with Session(self.engine) as session: - user = User( - name=name, - token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, - created_at=now, - last_seen=now, - ) - session.add(user) - session.commit() + stmt = select(User).where(*query) - assert user._id > 0 - return user - except DBAPIError: - return self._find_user(name, self.token) + return session.execute(stmt).scalar() - def __getstate__(self): - return dict( - uri=self.uri, - token=self.token, - ) + def _create_user(self, name) -> User: + try: + now = datetime.datetime.utcnow() - def __setstate__(self, state): - self.uri = state["uri"] - self.token = state["token"] - self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) + with Session(self.engine) as session: + user = User( + name=name, + token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, + created_at=now, + last_seen=now, + ) + session.add(user) + session.commit() + + assert user._id > 0 + return user + except DBAPIError: + return self._find_user(name, self.token) + + def __getstate__(self): + return dict( + uri=self.uri, + token=self.token, + ) - if self.uri == "sqlite://" or self.uri == "": - log.warning("You are serializing an in-memory database, data will be lost") - Base.metadata.create_all(self.engine) + def __setstate__(self, state): + self.uri = state["uri"] + self.token = state["token"] + self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) - self._connect(self.token) + if self.uri == "sqlite://" or self.uri == "": + log.warning("You are serializing an in-memory database, data will be lost") + Base.metadata.create_all(self.engine) - # Experiment Operations - # ===================== + self._connect(self.token) - def create_experiment(self, config): - """Insert a new experiment inside the database""" - cpy = deepcopy(config) + # Experiment Operations + # ===================== - try: - with Session(self.engine) as session: - experiment = Experiment( - owner_id=self.user_id, - version=0, - ) + def create_experiment(self, config): + """Insert a new experiment inside the database""" + cpy = deepcopy(config) - if "refers" in config: - ref = config.get("refers") - if "parent_id" in ref: - config["parent_id"] = ref.pop("parent_id") + try: + with Session(self.engine) as session: + experiment = Experiment( + owner_id=self.user_id, + version=0, + ) - cpy["meta"] = cpy.pop("metadata") - self._set_from_dict(experiment, cpy, "remaining") + if "refers" in config: + ref = config.get("refers") + if "parent_id" in ref: + config["parent_id"] = ref.pop("parent_id") - session.add(experiment) - session.commit() + cpy["meta"] = cpy.pop("metadata") + self._set_from_dict(experiment, cpy, "remaining") - session.refresh(experiment) - config.update(self._to_experiment(experiment)) + session.add(experiment) + session.commit() - # Alreadyc reate the algo lock as well - self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) - except DBAPIError: - raise DuplicateKeyError() + session.refresh(experiment) + config.update(self._to_experiment(experiment)) - def delete_experiment(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" - uid = get_uid(experiment, uid) + # Alreadyc reate the algo lock as well + self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) + except DBAPIError: + raise DuplicateKeyError() - with Session(self.engine) as session: - stmt = delete(Experiment).where(Experiment._id == uid) - session.execute(stmt) - session.commit() + def delete_experiment(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" + uid = get_uid(experiment, uid) - def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): - """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" - uid = get_uid(experiment, uid) + with Session(self.engine) as session: + stmt = delete(Experiment).where(Experiment._id == uid) + session.execute(stmt) + session.commit() - if where and "refers.parent_id" in where: - where["parent_id"] = where.pop("refers.parent_id") + def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" + uid = get_uid(experiment, uid) - where = self._get_query(where) + if where and "refers.parent_id" in where: + where["parent_id"] = where.pop("refers.parent_id") - if uid is not None: - where["_id"] = uid + where = self._get_query(where) - query = self._to_query(Experiment, where) + if uid is not None: + where["_id"] = uid - with Session(self.engine) as session: - stmt = select(Experiment).where(*query) - experiment = session.scalars(stmt).one() + query = self._to_query(Experiment, where) - metadata = kwargs.pop("metadata", dict()) - self._set_from_dict(experiment, kwargs, "remaining") - experiment.meta.update(metadata) + with Session(self.engine) as session: + stmt = select(Experiment).where(*query) + experiment = session.scalars(stmt).one() - session.commit() + metadata = kwargs.pop("metadata", dict()) + self._set_from_dict(experiment, kwargs, "remaining") + experiment.meta.update(metadata) - def _fetch_experiments_with_select(self, query, selection=None): - query = self._get_query(query) + session.commit() - where = self._to_query(Experiment, query) + def _fetch_experiments_with_select(self, query, selection=None): + query = self._get_query(query) - with Session(self.engine) as session: - columns = self._selection(Experiment, selection) - stmt = select(columns).where(*where) + where = self._to_query(Experiment, query) - rows = session.execute(stmt).all() + with Session(self.engine) as session: + columns = self._selection(Experiment, selection) + stmt = select(columns).where(*where) - results = [] + rows = session.execute(stmt).all() - for row in rows: - obj = dict() - for value, k in zip(row, columns): - obj[str(k).split(".")[-1]] = value - results.append(obj) + results = [] - return results + for row in rows: + obj = dict() + for value, k in zip(row, columns): + obj[str(k).split(".")[-1]] = value + results.append(obj) - def fetch_experiments(self, query, selection=None): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" - if "refers.parent_id" in query: - query["parent_id"] = query.pop("refers.parent_id") + return results - if selection: - return self._fetch_experiments_with_select(query, selection) + def fetch_experiments(self, query, selection=None): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" + if "refers.parent_id" in query: + query["parent_id"] = query.pop("refers.parent_id") - query = self._get_query(query) - where = self._to_query(Experiment, query) + if selection: + return self._fetch_experiments_with_select(query, selection) - with Session(self.engine) as session: - stmt = select(Experiment).where(*where) + query = self._get_query(query) + where = self._to_query(Experiment, query) - experiments = session.scalars(stmt).all() + with Session(self.engine) as session: + stmt = select(Experiment).where(*where) - r = [self._to_experiment(exp) for exp in experiments] - return r + experiments = session.scalars(stmt).all() - # Benchmarks - # ========== + r = [self._to_experiment(exp) for exp in experiments] + return r - # Trials - # ====== - def fetch_trials(self, experiment=None, uid=None, where=None): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" - uid = get_uid(experiment, uid) + # Benchmarks + # ========== - query = self._get_query(where) + # Trials + # ====== + def fetch_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" + uid = get_uid(experiment, uid) - if uid is not None: - query["experiment_id"] = uid + query = self._get_query(where) - query = self._to_query(Trial, query) + if uid is not None: + query["experiment_id"] = uid - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - results = session.scalars(stmt).all() + query = self._to_query(Trial, query) - return [OrionTrial(**self._to_trial(t)) for t in results] + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + results = session.scalars(stmt).all() - def register_trial(self, trial): - """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" - config = trial.to_dict() + return [OrionTrial(**self._to_trial(t)) for t in results] - try: - with Session(self.engine) as session: - experiment_id = config.pop("experiment", None) + def register_trial(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" + config = trial.to_dict() - db_trial = Trial(experiment_id=experiment_id, owner_id=self.user_id) + try: + with Session(self.engine) as session: + experiment_id = config.pop("experiment", None) - self._set_from_dict(db_trial, config) + db_trial = Trial(experiment_id=experiment_id, owner_id=self.user_id) - session.add(db_trial) - session.commit() + self._set_from_dict(db_trial, config) - session.refresh(db_trial) - trial.id_override = db_trial._id + session.add(db_trial) + session.commit() - return OrionTrial(**self._to_trial(db_trial)) - except DBAPIError: - raise DuplicateKeyError() + session.refresh(db_trial) + trial.id_override = db_trial._id - def delete_trials(self, experiment=None, uid=None, where=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" - uid = get_uid(experiment, uid) + return OrionTrial(**self._to_trial(db_trial)) + except DBAPIError: + raise DuplicateKeyError() - where = self._get_query(where) + def delete_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" + uid = get_uid(experiment, uid) - if uid is not None: - where["experiment_id"] = uid + where = self._get_query(where) - query = self._to_query(Trial, where) + if uid is not None: + where["experiment_id"] = uid - with Session(self.engine) as session: - stmt = delete(Trial).where(*query) - count = session.execute(stmt) - session.commit() + query = self._to_query(Trial, where) - return count.rowcount + with Session(self.engine) as session: + stmt = delete(Trial).where(*query) + count = session.execute(stmt) + session.commit() - def retrieve_result(self, trial, **kwargs): - """Updates the results array""" - return trial + return count.rowcount - def get_trial(self, trial=None, uid=None, experiment_uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) + def retrieve_result(self, trial, **kwargs): + """Updates the results array""" + return trial - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.experiment_id == experiment_uid, - Trial.id == trial_uid, - ) - trial = session.scalars(stmt).one() + def get_trial(self, trial=None, uid=None, experiment_uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - return OrionTrial(**self._to_trial(trial)) + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.experiment_id == experiment_uid, + Trial.id == trial_uid, + ) + trial = session.scalars(stmt).one() - def update_trials(self, experiment=None, uid=None, where=None, **kwargs): - """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" - uid = get_uid(experiment, uid) + return OrionTrial(**self._to_trial(trial)) - where = self._get_query(where) - where["experiment_id"] = uid - query = self._to_query(Trial, where) + def update_trials(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" + uid = get_uid(experiment, uid) - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - trials = session.scalars(stmt).all() + where = self._get_query(where) + where["experiment_id"] = uid + query = self._to_query(Trial, where) - for trial in trials: - self._set_from_dict(trial, kwargs) + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + trials = session.scalars(stmt).all() - session.commit() + for trial in trials: + self._set_from_dict(trial, kwargs) - return len(trials) + session.commit() - def update_trial( - self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs - ): - """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) + return len(trials) - where = self._get_query(where) + def update_trial( + self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs + ): + """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - # THIS IS NOT THE UNIQUE ID OF THE TRIAL - where["id"] = trial_uid - where["experiment_id"] = experiment_uid - query = self._to_query(Trial, where) + where = self._get_query(where) - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - trial = session.scalars(stmt).one() + # THIS IS NOT THE UNIQUE ID OF THE TRIAL + where["id"] = trial_uid + where["experiment_id"] = experiment_uid + query = self._to_query(Trial, where) - self._set_from_dict(trial, kwargs) - session.commit() + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + trial = session.scalars(stmt).one() - return OrionTrial(**self._to_trial(trial)) + self._set_from_dict(trial, kwargs) + session.commit() - def fetch_lost_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" - heartbeat = orion.core.config.worker.heartbeat - threshold = datetime.datetime.utcnow() - datetime.timedelta( - seconds=heartbeat * 5 - ) + return OrionTrial(**self._to_trial(trial)) - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.experiment_id == experiment._id, - Trial.status == "reserved", - Trial.heartbeat < threshold, + def fetch_lost_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" + heartbeat = orion.core.config.worker.heartbeat + threshold = datetime.datetime.utcnow() - datetime.timedelta( + seconds=heartbeat * 5 ) - results = session.scalars(stmt).all() - return [OrionTrial(**self._to_trial(t)) for t in results] - - def push_trial_results(self, trial): - """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.experiment_id == experiment._id, + Trial.status == "reserved", + Trial.heartbeat < threshold, + ) + results = session.scalars(stmt).all() - log.debug("push trial to storage") - original = trial - config = trial.to_dict() + return [OrionTrial(**self._to_trial(t)) for t in results] - # Don't need to set that one - config.pop("experiment") + def push_trial_results(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - # Trial.experiment_id == trial.experiment, - # Trial.id == trial.id, - Trial._id == trial.id_override, - Trial.status == "reserved", - ) - trial = session.scalars(stmt).one() - self._set_from_dict(trial, config) - session.commit() + log.debug("push trial to storage") + original = trial + config = trial.to_dict() - return original + # Don't need to set that one + config.pop("experiment") - def set_trial_status(self, trial, status, heartbeat=None, was=None): - """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" - heartbeat = heartbeat or datetime.datetime.utcnow() - was = was or trial.status + with Session(self.engine) as session: + stmt = select(Trial).where( + # Trial.experiment_id == trial.experiment, + # Trial.id == trial.id, + Trial._id == trial.id_override, + Trial.status == "reserved", + ) + trial = session.scalars(stmt).one() + self._set_from_dict(trial, config) + session.commit() - validate_status(status) - validate_status(was) + return original - query = [ - Trial.id == trial.id, - Trial.experiment_id == trial.experiment, - Trial.status == was, - ] + def set_trial_status(self, trial, status, heartbeat=None, was=None): + """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" + heartbeat = heartbeat or datetime.datetime.utcnow() + was = was or trial.status - values = dict(status=status) - if heartbeat: - values["heartbeat"] = heartbeat + validate_status(status) + validate_status(was) - with Session(self.engine) as session: - stmt = update(Trial).where(*query).values(**values) - result = session.execute(stmt) - session.commit() + query = [ + Trial.id == trial.id, + Trial.experiment_id == trial.experiment, + Trial.status == was, + ] - if result.rowcount == 1: - trial.status = status - else: - raise FailedUpdate() - - def fetch_pending_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, - ) - results = session.scalars(stmt).all() - trials = OrionTrial.build([self._to_trial(t) for t in results]) + values = dict(status=status) + if heartbeat: + values["heartbeat"] = heartbeat - return trials + with Session(self.engine) as session: + stmt = update(Trial).where(*query).values(**values) + result = session.execute(stmt) + session.commit() - def _reserve_trial_postgre(self, experiment): - now = datetime.datetime.utcnow() + if result.rowcount == 1: + trial.status = status + else: + raise FailedUpdate() - with Session(self.engine) as session: - # In PostgrerSQL we can do single query - stmt = ( - update(Trial) - .where( + def fetch_pending_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" + with Session(self.engine) as session: + stmt = select(Trial).where( Trial.status.in_(("interrupted", "new", "suspended")), Trial.experiment_id == experiment._id, ) - .values( - status="reserved", - start_time=now, - heartbeat=now, - ) - .limit(1) - .returning() - ) - trial = session.scalar(stmt) - return OrionTrial(**self._to_trial(trial)) + results = session.scalars(stmt).all() + trials = OrionTrial.build([self._to_trial(t) for t in results]) - def reserve_trial(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" - if False: - return self._reserve_trial_postgre(experiment) + return trials - log.debug("reserve trial") - now = datetime.datetime.utcnow() + def _reserve_trial_postgre(self, experiment): + now = datetime.datetime.utcnow() - with Session(self.engine) as session: - stmt = ( - select(Trial) - .where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, + with Session(self.engine) as session: + # In PostgrerSQL we can do single query + stmt = ( + update(Trial) + .where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + .limit(1) + .returning() ) - .limit(1) - ) + trial = session.scalar(stmt) + return OrionTrial(**self._to_trial(trial)) - try: - trial = session.scalars(stmt).one() - except NoResultFound: - return None + def reserve_trial(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" + if False: + return self._reserve_trial_postgre(experiment) - # Update the trial iff the status has not been changed yet - stmt = ( - update(Trial) - .where( - Trial.status == trial.status, - Trial._id == trial._id, - ) - .values( - status="reserved", - start_time=now, - heartbeat=now, + log.debug("reserve trial") + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + stmt = ( + select(Trial) + .where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + .limit(1) ) - ) - result = session.execute(stmt) + try: + trial = session.scalars(stmt).one() + except NoResultFound: + return None - # time needs to match, could have been reserved by another worker - if result.rowcount == 1: - session.commit() - session.refresh(trial) - return OrionTrial(**self._to_trial(trial)) + # Update the trial iff the status has not been changed yet + stmt = ( + update(Trial) + .where( + Trial.status == trial.status, + Trial._id == trial._id, + ) + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + ) - return None + result = session.execute(stmt) - def fetch_trials_by_status(self, experiment, status): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status == status, Trial.experiment_id == experiment._id - ) - results = session.scalars(stmt).all() + # time needs to match, could have been reserved by another worker + if result.rowcount == 1: + session.commit() + session.refresh(trial) + return OrionTrial(**self._to_trial(trial)) - return [OrionTrial(**self._to_trial(trial)) for trial in results] + return None - def fetch_noncompleted_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status != "completed", - Trial.experiment_id == experiment._id, - ) - results = session.scalars(stmt).all() + def fetch_trials_by_status(self, experiment, status): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status == status, Trial.experiment_id == experiment._id + ) + results = session.scalars(stmt).all() - return [OrionTrial(**self._to_trial(trial)) for trial in results] + return [OrionTrial(**self._to_trial(trial)) for trial in results] - def count_completed_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" - with Session(self.engine) as session: - return ( - session.query(Trial) - .filter( - Trial.status == "completed", + def fetch_noncompleted_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status != "completed", Trial.experiment_id == experiment._id, ) - .count() - ) + results = session.scalars(stmt).all() - def count_broken_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" - with Session(self.engine) as session: - return ( - session.query(Trial) - .filter( - Trial.status == "broken", - Trial.experiment_id == experiment._id, - ) - .count() - ) + return [OrionTrial(**self._to_trial(trial)) for trial in results] - def update_heartbeat(self, trial): - """Update trial's heartbeat""" + def count_completed_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" + with Session(self.engine) as session: + return ( + session.query(Trial) + .filter( + Trial.status == "completed", + Trial.experiment_id == experiment._id, + ) + .count() + ) - with Session(self.engine) as session: - stmt = ( - update(Trial) - .where( - Trial._id == trial.id_override, - Trial.status == "reserved", + def count_broken_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" + with Session(self.engine) as session: + return ( + session.query(Trial) + .filter( + Trial.status == "broken", + Trial.experiment_id == experiment._id, + ) + .count() ) - .values(heartbeat=datetime.datetime.utcnow()) - ) - cursor = session.execute(stmt) - session.commit() - - if cursor.rowcount <= 0: - raise FailedUpdate() - - # Algorithm - # ========= - def initialize_algorithm_lock(self, experiment_id, algorithm_config): - """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" - with Session(self.engine) as session: - algo = Algo( - experiment_id=experiment_id, - owner_id=self.user_id, - configuration=algorithm_config, - locked=0, - heartbeat=datetime.datetime.utcnow(), - ) - session.add(algo) - session.commit() + def update_heartbeat(self, trial): + """Update trial's heartbeat""" - def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): - """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" + with Session(self.engine) as session: + stmt = ( + update(Trial) + .where( + Trial._id == trial.id_override, + Trial.status == "reserved", + ) + .values(heartbeat=datetime.datetime.utcnow()) + ) - uid = get_uid(experiment, uid) + cursor = session.execute(stmt) + session.commit() - values = dict( - locked=0, - heartbeat=datetime.datetime.utcnow(), - ) - if new_state is not None: - values["state"] = pickle.dumps(new_state) - - with Session(self.engine) as session: - stmt = ( - update(Algo) - .where( - Algo.experiment_id == uid, - Algo.locked == 1, - ) - .values(**values) - ) - session.execute(stmt) - session.commit() + if cursor.rowcount <= 0: + raise FailedUpdate() - def get_algorithm_lock_info(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" - uid = get_uid(experiment, uid) + # Algorithm + # ========= + def initialize_algorithm_lock(self, experiment_id, algorithm_config): + """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" + with Session(self.engine) as session: + algo = Algo( + experiment_id=experiment_id, + owner_id=self.user_id, + configuration=algorithm_config, + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + session.add(algo) + session.commit() - with Session(self.engine) as session: - stmt = select(Algo).where(Algo.experiment_id == uid) - algo = session.scalar(stmt) + def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): + """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" - if algo is None: - return None + uid = get_uid(experiment, uid) - return LockedAlgorithmState( - state=pickle.loads(algo.state) if algo.state is not None else None, - configuration=algo.configuration, - locked=algo.locked, - ) + values = dict( + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + if new_state is not None: + values["state"] = pickle.dumps(new_state) - def delete_algorithm_lock(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_algorithm_lock`""" - uid = get_uid(experiment, uid) + with Session(self.engine) as session: + stmt = ( + update(Algo) + .where( + Algo.experiment_id == uid, + Algo.locked == 1, + ) + .values(**values) + ) + session.execute(stmt) + session.commit() - with Session(self.engine) as session: - stmt = delete(Algo).where(Algo.experiment_id == uid) - cursor = session.execute(stmt) - session.commit() + def get_algorithm_lock_info(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" + uid = get_uid(experiment, uid) - return cursor.rowcount + with Session(self.engine) as session: + stmt = select(Algo).where(Algo.experiment_id == uid) + algo = session.scalar(stmt) - def _acquire_algorithm_lock_postgre( - self, experiment=None, uid=None, timeout=60, retry_interval=1 - ): - with Session(self.engine) as session: - now = datetime.datetime.utcnow() + if algo is None: + return None - stmt = ( - update(Algo) - .where(Algo.experiment_id == uid, Algo.locked == 0) - .values(locked=1, heartbeat=now) - .returning() + return LockedAlgorithmState( + state=pickle.loads(algo.state) if algo.state is not None else None, + configuration=algo.configuration, + locked=algo.locked, ) - algo = session.scalar(stmt).one() - session.commit() - return algo + def delete_algorithm_lock(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_algorithm_lock`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = delete(Algo).where(Algo.experiment_id == uid) + cursor = session.execute(stmt) + session.commit() - def _acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=1, retry_interval=1 - ): - uid = get_uid(experiment, uid) - algo_state_lock = None - start = time.perf_counter() + return cursor.rowcount - with Session(self.engine) as session: - while algo_state_lock is None and time.perf_counter() - start < timeout: + def _acquire_algorithm_lock_postgre( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + with Session(self.engine) as session: now = datetime.datetime.utcnow() stmt = ( update(Algo) .where(Algo.experiment_id == uid, Algo.locked == 0) .values(locked=1, heartbeat=now) + .returning() ) - cursor = session.execute(stmt) + algo = session.scalar(stmt).one() session.commit() + return algo - if cursor.rowcount == 0: - time.sleep(retry_interval) - else: - stmt = select(Algo).where( - Algo.experiment_id == uid, Algo.locked == 1 - ) - algo_state_lock = session.scalar(stmt) - break - - if algo_state_lock is None: - raise LockAcquisitionTimeout() - - if algo_state_lock.state is not None: - state = pickle.loads(algo_state_lock.state) - else: - state = None - - return LockedAlgorithmState( - state=state, - configuration=algo_state_lock.configuration, - locked=True, - ) - - @contextlib.contextmanager - def acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=60, retry_interval=1 - ): - """See :func:`orion.storage.base.BaseStorageProtocol.acquire_algorithm_lock`""" - locked_algo_state = self._acquire_algorithm_lock( - experiment, uid, timeout, retry_interval - ) - - try: - log.debug("lock algo") - yield locked_algo_state - except Exception: - # Reset algo to state fetched lock time - locked_algo_state.reset() - raise - finally: - log.debug("unlock algo") + def _acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=1, retry_interval=1 + ): uid = get_uid(experiment, uid) - self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) + algo_state_lock = None + start = time.perf_counter() - # Utilities - # ========= - def _get_query(self, query): - if query is None: - query = dict() + with Session(self.engine) as session: + while algo_state_lock is None and time.perf_counter() - start < timeout: + now = datetime.datetime.utcnow() - query["owner_id"] = self.user_id - return query + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=now) + ) - def _selection(self, table, selection): - selected = [] + cursor = session.execute(stmt) + session.commit() - for k, v in selection.items(): - if hasattr(table, k) and v: - selected.append(getattr(table, k)) + if cursor.rowcount == 0: + time.sleep(retry_interval) + else: + stmt = select(Algo).where( + Algo.experiment_id == uid, Algo.locked == 1 + ) + algo_state_lock = session.scalar(stmt) + break - return selected + if algo_state_lock is None: + raise LockAcquisitionTimeout() - def _set_from_dict(self, obj, data, rest=None): - data = deepcopy(data) - meta = dict() - while data: - k, v = data.popitem() + if algo_state_lock.state is not None: + state = pickle.loads(algo_state_lock.state) + else: + state = None - if v is None: - continue + return LockedAlgorithmState( + state=state, + configuration=algo_state_lock.configuration, + locked=True, + ) - if hasattr(obj, k): - setattr(obj, k, v) - else: - meta[k] = v + @contextlib.contextmanager + def acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + """See :func:`orion.storage.base.BaseStorageProtocol.acquire_algorithm_lock`""" + locked_algo_state = self._acquire_algorithm_lock( + experiment, uid, timeout, retry_interval + ) - if meta and rest: - setattr(obj, rest, meta) - return + try: + log.debug("lock algo") + yield locked_algo_state + except Exception: + # Reset algo to state fetched lock time + locked_algo_state.reset() + raise + finally: + log.debug("unlock algo") + uid = get_uid(experiment, uid) + self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) + + # Utilities + # ========= + def _get_query(self, query): + if query is None: + query = dict() + + query["owner_id"] = self.user_id + return query + + def _selection(self, table, selection): + selected = [] + + for k, v in selection.items(): + if hasattr(table, k) and v: + selected.append(getattr(table, k)) + + return selected + + def _set_from_dict(self, obj, data, rest=None): + data = deepcopy(data) + meta = dict() + while data: + k, v = data.popitem() + + if v is None: + continue + + if hasattr(obj, k): + setattr(obj, k, v) + else: + meta[k] = v - if meta: - log.warning("Data was discarded %s", meta) - assert False + if meta and rest: + setattr(obj, rest, meta) + return - def _to_query(self, table, where): - query = [] + if meta: + log.warning("Data was discarded %s", meta) + assert False - for k, v in where.items(): - if hasattr(table, k): - query.append(getattr(table, k) == v) - else: - log.warning("constrained ignored %s = %s", k, v) - - return query - - def _to_experiment(self, experiment): - exp = deepcopy(experiment.__dict__) - exp["metadata"] = exp.pop("meta", {}) - exp.pop("_sa_instance_state") - exp.pop("owner_id") - exp.pop("datetime") - - none_keys = [] - for k, v in exp.items(): - if v is None: - none_keys.append(k) - - for k in none_keys: - exp.pop(k) - - rest = exp.pop("remaining", {}) - if rest is None: - rest = {} - - exp.update(rest) - return exp - - def _to_trial(self, trial): - trial = deepcopy(trial.__dict__) - trial.pop("_sa_instance_state") - trial["experiment"] = trial.pop("experiment_id") - trial.pop("owner_id") - return trial + def _to_query(self, table, where): + query = [] + + for k, v in where.items(): + if hasattr(table, k): + query.append(getattr(table, k) == v) + else: + log.warning("constrained ignored %s = %s", k, v) + + return query + + def _to_experiment(self, experiment): + exp = deepcopy(experiment.__dict__) + exp["metadata"] = exp.pop("meta", {}) + exp.pop("_sa_instance_state") + exp.pop("owner_id") + exp.pop("datetime") + + none_keys = [] + for k, v in exp.items(): + if v is None: + none_keys.append(k) + + for k in none_keys: + exp.pop(k) + + rest = exp.pop("remaining", {}) + if rest is None: + rest = {} + + exp.update(rest) + return exp + + def _to_trial(self, trial): + trial = deepcopy(trial.__dict__) + trial.pop("_sa_instance_state") + trial["experiment"] = trial.pop("experiment_id") + trial.pop("owner_id") + return trial From 969da8603d7e386ea75c4abc98efcd624655a59c Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 14:30:47 -0500 Subject: [PATCH 17/25] - --- src/orion/storage/sql.py | 928 +--------------------------------- src/orion/storage/sql_impl.py | 925 +++++++++++++++++++++++++++++++++ 2 files changed, 930 insertions(+), 923 deletions(-) create mode 100644 src/orion/storage/sql_impl.py diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index 24c676f99..cd028b318 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -1,934 +1,16 @@ -import contextlib -import datetime -import logging -import pickle -import time -import uuid -from copy import deepcopy - -# Use MongoDB json serializer -from bson.json_util import dumps as to_json -from bson.json_util import loads as from_json - IMPORT_ERROR = None try: - from sqlalchemy import ( - BINARY, - JSON, - Column, - DateTime, - ForeignKey, - Index, - Integer, - String, - UniqueConstraint, - delete, - select, - update, - ) - import sqlalchemy - from sqlalchemy.exc import DBAPIError, NoResultFound - from sqlalchemy.ext.compiler import compiles - from sqlalchemy.orm import Session, declarative_base - + from orion.storage.sql_impl import SQLAlchemy as SQLAlchemyImpl except ImportError as err: IMPORT_ERROR = err - -import orion.core -from orion.core.io.database import DuplicateKeyError -from orion.core.utils.compat import getuser -from orion.core.worker.trial import Trial as OrionTrial -from orion.core.worker.trial import validate_status -from orion.storage.base import ( - BaseStorageProtocol, - FailedUpdate, - LockAcquisitionTimeout, - LockedAlgorithmState, - get_trial_uid_and_exp, - get_uid, -) if IMPORT_ERROR is not None: - class SQLAlchemy(BaseStorageProtocol): + from orion.storage.base import BaseStorageProtocol + + class SQLAlchemy(BaseStorageProtocol): def __init__(self, uri, token=None, **kwargs): raise IMPORT_ERROR else: - log = logging.getLogger(__name__) - - Base = declarative_base() - - - @compiles(BINARY, "postgresql") - def compile_binary_postgresql(type_, compiler, **kw): - """Postgresql does not know about Binary type we should byte array instead""" - return "BYTEA" - - - class User(Base): - """Defines the User table""" - - __tablename__ = "users" - - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30), unique=True) - token = Column(String(32)) - created_at = Column(DateTime) - last_seen = Column(DateTime) - - - class Experiment(Base): - """Defines the Experiment table""" - - __tablename__ = "experiments" - - _id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(30)) - meta = Column(JSON) # metadata field is reserved - version = Column(Integer) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - datetime = Column(DateTime) - algorithms = Column(JSON) - remaining = Column(JSON) - space = Column(JSON) - parent_id = Column(Integer) - - __table_args__ = ( - UniqueConstraint("name", "owner_id", name="_one_name_per_owner"), - Index("idx_experiment_name_version", "name", "version"), - ) - - - class Trial(Base): - """Defines the Trial table""" - - __tablename__ = "trials" - - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - status = Column(String(30)) - results = Column(JSON) - start_time = Column(DateTime) - end_time = Column(DateTime) - heartbeat = Column(DateTime) - parent = Column(Integer, ForeignKey("trials._id"), nullable=True) - params = Column(JSON) - worker = Column(JSON) - submit_time = Column(DateTime) - exp_working_dir = Column(String(30)) - id = Column(String(30)) - - __table_args__ = ( - UniqueConstraint("experiment_id", "id", name="_one_trial_hash_per_experiment"), - Index("idx_trial_experiment_id", "experiment_id"), - Index("idx_trial_status", "status"), - # Can't put an index on json - # Index('idx_trial_results', 'results'), - Index("idx_trial_start_time", "start_time"), - Index("idx_trial_end_time", "end_time"), - ) - - - class Algo(Base): - """Defines the Algo table""" - - __tablename__ = "algo" - - # it is one algo per experiment so we could set experiment_id as the primary key - # and make it a 1-1 relation - _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) - owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) - configuration = Column(JSON) - locked = Column(Integer) - state = Column(BINARY) - heartbeat = Column(DateTime) - - __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) - - - def get_tables(): - return [User, Experiment, Trial, Algo, User] - - - class SQLAlchemy(BaseStorageProtocol): # noqa: F811 - """Implement a generic protocol to allow Orion to communicate using - different storage backend - - Parameters - ---------- - uri: str - PostgreSQL backend to use for storage; the format is as follow - `protocol://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]]` - - """ - - def __init__(self, uri, token=None, **kwargs): - # dialect+driver://username:password@host:port/database - # - # postgresql://scott:tiger@localhost/mydatabase - # postgresql+psycopg2://scott:tiger@localhost/mydatabase - # postgresql+pg8000://scott:tiger@localhost/mydatabase - # - # mysql://scott:tiger@localhost/foo - # mysql+mysqldb://scott:tiger@localhost/foo - # mysql+pymysql://scott:tiger@localhost/foo - # - # sqlite:///foo.db # relative - # sqlite:////foo.db # absolute - # sqlite:// # in memory - - self.uri = uri - if uri == "": - uri = "sqlite://" - - # engine_from_config - self.engine = sqlalchemy.create_engine( - uri, - echo=False, - future=True, - json_serializer=to_json, - json_deserializer=from_json, - ) - - # Create the schema - # sqlite3 can fail on table if it already exist - # the doc says it shouldn't but it does - try: - Base.metadata.create_all(self.engine) - except DBAPIError: - pass - - self.token = token - self.user_id = None - self.user = None - self._connect(token) - - def _connect(self, token): - name = getuser() - - user = self._find_user(name, token) - - if user is None: - user = self._create_user(name) - - assert user is not None - - self.user_id = user._id - self.user = user - self.token = user.token - - def _find_user(self, name, token) -> User: - query = [User.name == name] - if token is not None and token != "": - query.append(User.token == token) - - with Session(self.engine) as session: - stmt = select(User).where(*query) - - return session.execute(stmt).scalar() - - def _create_user(self, name) -> User: - try: - now = datetime.datetime.utcnow() - - with Session(self.engine) as session: - user = User( - name=name, - token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, - created_at=now, - last_seen=now, - ) - session.add(user) - session.commit() - - assert user._id > 0 - return user - except DBAPIError: - return self._find_user(name, self.token) - - def __getstate__(self): - return dict( - uri=self.uri, - token=self.token, - ) - - def __setstate__(self, state): - self.uri = state["uri"] - self.token = state["token"] - self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) - - if self.uri == "sqlite://" or self.uri == "": - log.warning("You are serializing an in-memory database, data will be lost") - Base.metadata.create_all(self.engine) - - self._connect(self.token) - - # Experiment Operations - # ===================== - - def create_experiment(self, config): - """Insert a new experiment inside the database""" - cpy = deepcopy(config) - - try: - with Session(self.engine) as session: - experiment = Experiment( - owner_id=self.user_id, - version=0, - ) - - if "refers" in config: - ref = config.get("refers") - if "parent_id" in ref: - config["parent_id"] = ref.pop("parent_id") - - cpy["meta"] = cpy.pop("metadata") - self._set_from_dict(experiment, cpy, "remaining") - - session.add(experiment) - session.commit() - - session.refresh(experiment) - config.update(self._to_experiment(experiment)) - - # Alreadyc reate the algo lock as well - self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) - except DBAPIError: - raise DuplicateKeyError() - - def delete_experiment(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" - uid = get_uid(experiment, uid) - - with Session(self.engine) as session: - stmt = delete(Experiment).where(Experiment._id == uid) - session.execute(stmt) - session.commit() - - def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): - """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" - uid = get_uid(experiment, uid) - - if where and "refers.parent_id" in where: - where["parent_id"] = where.pop("refers.parent_id") - - where = self._get_query(where) - - if uid is not None: - where["_id"] = uid - - query = self._to_query(Experiment, where) - - with Session(self.engine) as session: - stmt = select(Experiment).where(*query) - experiment = session.scalars(stmt).one() - - metadata = kwargs.pop("metadata", dict()) - self._set_from_dict(experiment, kwargs, "remaining") - experiment.meta.update(metadata) - - session.commit() - - def _fetch_experiments_with_select(self, query, selection=None): - query = self._get_query(query) - - where = self._to_query(Experiment, query) - - with Session(self.engine) as session: - columns = self._selection(Experiment, selection) - stmt = select(columns).where(*where) - - rows = session.execute(stmt).all() - - results = [] - - for row in rows: - obj = dict() - for value, k in zip(row, columns): - obj[str(k).split(".")[-1]] = value - results.append(obj) - - return results - - def fetch_experiments(self, query, selection=None): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" - if "refers.parent_id" in query: - query["parent_id"] = query.pop("refers.parent_id") - - if selection: - return self._fetch_experiments_with_select(query, selection) - - query = self._get_query(query) - where = self._to_query(Experiment, query) - - with Session(self.engine) as session: - stmt = select(Experiment).where(*where) - - experiments = session.scalars(stmt).all() - - r = [self._to_experiment(exp) for exp in experiments] - return r - - # Benchmarks - # ========== - - # Trials - # ====== - def fetch_trials(self, experiment=None, uid=None, where=None): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" - uid = get_uid(experiment, uid) - - query = self._get_query(where) - - if uid is not None: - query["experiment_id"] = uid - - query = self._to_query(Trial, query) - - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - results = session.scalars(stmt).all() - - return [OrionTrial(**self._to_trial(t)) for t in results] - - def register_trial(self, trial): - """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" - config = trial.to_dict() - - try: - with Session(self.engine) as session: - experiment_id = config.pop("experiment", None) - - db_trial = Trial(experiment_id=experiment_id, owner_id=self.user_id) - - self._set_from_dict(db_trial, config) - - session.add(db_trial) - session.commit() - - session.refresh(db_trial) - trial.id_override = db_trial._id - - return OrionTrial(**self._to_trial(db_trial)) - except DBAPIError: - raise DuplicateKeyError() - - def delete_trials(self, experiment=None, uid=None, where=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" - uid = get_uid(experiment, uid) - - where = self._get_query(where) - - if uid is not None: - where["experiment_id"] = uid - - query = self._to_query(Trial, where) - - with Session(self.engine) as session: - stmt = delete(Trial).where(*query) - count = session.execute(stmt) - session.commit() - - return count.rowcount - - def retrieve_result(self, trial, **kwargs): - """Updates the results array""" - return trial - - def get_trial(self, trial=None, uid=None, experiment_uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.experiment_id == experiment_uid, - Trial.id == trial_uid, - ) - trial = session.scalars(stmt).one() - - return OrionTrial(**self._to_trial(trial)) - - def update_trials(self, experiment=None, uid=None, where=None, **kwargs): - """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" - uid = get_uid(experiment, uid) - - where = self._get_query(where) - where["experiment_id"] = uid - query = self._to_query(Trial, where) - - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - trials = session.scalars(stmt).all() - - for trial in trials: - self._set_from_dict(trial, kwargs) - - session.commit() - - return len(trials) - - def update_trial( - self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs - ): - """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) - - where = self._get_query(where) - - # THIS IS NOT THE UNIQUE ID OF THE TRIAL - where["id"] = trial_uid - where["experiment_id"] = experiment_uid - query = self._to_query(Trial, where) - - with Session(self.engine) as session: - stmt = select(Trial).where(*query) - trial = session.scalars(stmt).one() - - self._set_from_dict(trial, kwargs) - session.commit() - - return OrionTrial(**self._to_trial(trial)) - - def fetch_lost_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" - heartbeat = orion.core.config.worker.heartbeat - threshold = datetime.datetime.utcnow() - datetime.timedelta( - seconds=heartbeat * 5 - ) - - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.experiment_id == experiment._id, - Trial.status == "reserved", - Trial.heartbeat < threshold, - ) - results = session.scalars(stmt).all() - - return [OrionTrial(**self._to_trial(t)) for t in results] - - def push_trial_results(self, trial): - """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" - - log.debug("push trial to storage") - original = trial - config = trial.to_dict() - - # Don't need to set that one - config.pop("experiment") - - with Session(self.engine) as session: - stmt = select(Trial).where( - # Trial.experiment_id == trial.experiment, - # Trial.id == trial.id, - Trial._id == trial.id_override, - Trial.status == "reserved", - ) - trial = session.scalars(stmt).one() - self._set_from_dict(trial, config) - session.commit() - - return original - - def set_trial_status(self, trial, status, heartbeat=None, was=None): - """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" - heartbeat = heartbeat or datetime.datetime.utcnow() - was = was or trial.status - - validate_status(status) - validate_status(was) - - query = [ - Trial.id == trial.id, - Trial.experiment_id == trial.experiment, - Trial.status == was, - ] - - values = dict(status=status) - if heartbeat: - values["heartbeat"] = heartbeat - - with Session(self.engine) as session: - stmt = update(Trial).where(*query).values(**values) - result = session.execute(stmt) - session.commit() - - if result.rowcount == 1: - trial.status = status - else: - raise FailedUpdate() - - def fetch_pending_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, - ) - results = session.scalars(stmt).all() - trials = OrionTrial.build([self._to_trial(t) for t in results]) - - return trials - - def _reserve_trial_postgre(self, experiment): - now = datetime.datetime.utcnow() - - with Session(self.engine) as session: - # In PostgrerSQL we can do single query - stmt = ( - update(Trial) - .where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, - ) - .values( - status="reserved", - start_time=now, - heartbeat=now, - ) - .limit(1) - .returning() - ) - trial = session.scalar(stmt) - return OrionTrial(**self._to_trial(trial)) - - def reserve_trial(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" - if False: - return self._reserve_trial_postgre(experiment) - - log.debug("reserve trial") - now = datetime.datetime.utcnow() - - with Session(self.engine) as session: - stmt = ( - select(Trial) - .where( - Trial.status.in_(("interrupted", "new", "suspended")), - Trial.experiment_id == experiment._id, - ) - .limit(1) - ) - - try: - trial = session.scalars(stmt).one() - except NoResultFound: - return None - - # Update the trial iff the status has not been changed yet - stmt = ( - update(Trial) - .where( - Trial.status == trial.status, - Trial._id == trial._id, - ) - .values( - status="reserved", - start_time=now, - heartbeat=now, - ) - ) - - result = session.execute(stmt) - - # time needs to match, could have been reserved by another worker - if result.rowcount == 1: - session.commit() - session.refresh(trial) - return OrionTrial(**self._to_trial(trial)) - - return None - - def fetch_trials_by_status(self, experiment, status): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status == status, Trial.experiment_id == experiment._id - ) - results = session.scalars(stmt).all() - - return [OrionTrial(**self._to_trial(trial)) for trial in results] - - def fetch_noncompleted_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" - with Session(self.engine) as session: - stmt = select(Trial).where( - Trial.status != "completed", - Trial.experiment_id == experiment._id, - ) - results = session.scalars(stmt).all() - - return [OrionTrial(**self._to_trial(trial)) for trial in results] - - def count_completed_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" - with Session(self.engine) as session: - return ( - session.query(Trial) - .filter( - Trial.status == "completed", - Trial.experiment_id == experiment._id, - ) - .count() - ) - - def count_broken_trials(self, experiment): - """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" - with Session(self.engine) as session: - return ( - session.query(Trial) - .filter( - Trial.status == "broken", - Trial.experiment_id == experiment._id, - ) - .count() - ) - - def update_heartbeat(self, trial): - """Update trial's heartbeat""" - - with Session(self.engine) as session: - stmt = ( - update(Trial) - .where( - Trial._id == trial.id_override, - Trial.status == "reserved", - ) - .values(heartbeat=datetime.datetime.utcnow()) - ) - - cursor = session.execute(stmt) - session.commit() - - if cursor.rowcount <= 0: - raise FailedUpdate() - - # Algorithm - # ========= - def initialize_algorithm_lock(self, experiment_id, algorithm_config): - """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" - with Session(self.engine) as session: - algo = Algo( - experiment_id=experiment_id, - owner_id=self.user_id, - configuration=algorithm_config, - locked=0, - heartbeat=datetime.datetime.utcnow(), - ) - session.add(algo) - session.commit() - - def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): - """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" - - uid = get_uid(experiment, uid) - - values = dict( - locked=0, - heartbeat=datetime.datetime.utcnow(), - ) - if new_state is not None: - values["state"] = pickle.dumps(new_state) - - with Session(self.engine) as session: - stmt = ( - update(Algo) - .where( - Algo.experiment_id == uid, - Algo.locked == 1, - ) - .values(**values) - ) - session.execute(stmt) - session.commit() - - def get_algorithm_lock_info(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" - uid = get_uid(experiment, uid) - - with Session(self.engine) as session: - stmt = select(Algo).where(Algo.experiment_id == uid) - algo = session.scalar(stmt) - - if algo is None: - return None - - return LockedAlgorithmState( - state=pickle.loads(algo.state) if algo.state is not None else None, - configuration=algo.configuration, - locked=algo.locked, - ) - - def delete_algorithm_lock(self, experiment=None, uid=None): - """See :func:`orion.storage.base.BaseStorageProtocol.delete_algorithm_lock`""" - uid = get_uid(experiment, uid) - - with Session(self.engine) as session: - stmt = delete(Algo).where(Algo.experiment_id == uid) - cursor = session.execute(stmt) - session.commit() - - return cursor.rowcount - - def _acquire_algorithm_lock_postgre( - self, experiment=None, uid=None, timeout=60, retry_interval=1 - ): - with Session(self.engine) as session: - now = datetime.datetime.utcnow() - - stmt = ( - update(Algo) - .where(Algo.experiment_id == uid, Algo.locked == 0) - .values(locked=1, heartbeat=now) - .returning() - ) - - algo = session.scalar(stmt).one() - session.commit() - return algo - - def _acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=1, retry_interval=1 - ): - uid = get_uid(experiment, uid) - algo_state_lock = None - start = time.perf_counter() - - with Session(self.engine) as session: - while algo_state_lock is None and time.perf_counter() - start < timeout: - now = datetime.datetime.utcnow() - - stmt = ( - update(Algo) - .where(Algo.experiment_id == uid, Algo.locked == 0) - .values(locked=1, heartbeat=now) - ) - - cursor = session.execute(stmt) - session.commit() - - if cursor.rowcount == 0: - time.sleep(retry_interval) - else: - stmt = select(Algo).where( - Algo.experiment_id == uid, Algo.locked == 1 - ) - algo_state_lock = session.scalar(stmt) - break - - if algo_state_lock is None: - raise LockAcquisitionTimeout() - - if algo_state_lock.state is not None: - state = pickle.loads(algo_state_lock.state) - else: - state = None - - return LockedAlgorithmState( - state=state, - configuration=algo_state_lock.configuration, - locked=True, - ) - - @contextlib.contextmanager - def acquire_algorithm_lock( - self, experiment=None, uid=None, timeout=60, retry_interval=1 - ): - """See :func:`orion.storage.base.BaseStorageProtocol.acquire_algorithm_lock`""" - locked_algo_state = self._acquire_algorithm_lock( - experiment, uid, timeout, retry_interval - ) - - try: - log.debug("lock algo") - yield locked_algo_state - except Exception: - # Reset algo to state fetched lock time - locked_algo_state.reset() - raise - finally: - log.debug("unlock algo") - uid = get_uid(experiment, uid) - self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) - - # Utilities - # ========= - def _get_query(self, query): - if query is None: - query = dict() - - query["owner_id"] = self.user_id - return query - - def _selection(self, table, selection): - selected = [] - - for k, v in selection.items(): - if hasattr(table, k) and v: - selected.append(getattr(table, k)) - - return selected - - def _set_from_dict(self, obj, data, rest=None): - data = deepcopy(data) - meta = dict() - while data: - k, v = data.popitem() - - if v is None: - continue - - if hasattr(obj, k): - setattr(obj, k, v) - else: - meta[k] = v - - if meta and rest: - setattr(obj, rest, meta) - return - - if meta: - log.warning("Data was discarded %s", meta) - assert False - - def _to_query(self, table, where): - query = [] - - for k, v in where.items(): - if hasattr(table, k): - query.append(getattr(table, k) == v) - else: - log.warning("constrained ignored %s = %s", k, v) - - return query - - def _to_experiment(self, experiment): - exp = deepcopy(experiment.__dict__) - exp["metadata"] = exp.pop("meta", {}) - exp.pop("_sa_instance_state") - exp.pop("owner_id") - exp.pop("datetime") - - none_keys = [] - for k, v in exp.items(): - if v is None: - none_keys.append(k) - - for k in none_keys: - exp.pop(k) - - rest = exp.pop("remaining", {}) - if rest is None: - rest = {} - - exp.update(rest) - return exp - - def _to_trial(self, trial): - trial = deepcopy(trial.__dict__) - trial.pop("_sa_instance_state") - trial["experiment"] = trial.pop("experiment_id") - trial.pop("owner_id") - return trial + SQLAlchemy = SQLAlchemyImpl diff --git a/src/orion/storage/sql_impl.py b/src/orion/storage/sql_impl.py new file mode 100644 index 000000000..02c7d06c9 --- /dev/null +++ b/src/orion/storage/sql_impl.py @@ -0,0 +1,925 @@ +import contextlib +import datetime +import logging +import pickle +import time +import uuid +from copy import deepcopy + +# Use MongoDB json serializer +from bson.json_util import dumps as to_json +from bson.json_util import loads as from_json + +import sqlalchemy +from sqlalchemy import ( + BINARY, + JSON, + Column, + DateTime, + ForeignKey, + Index, + Integer, + String, + UniqueConstraint, + delete, + select, + update, +) +from sqlalchemy.exc import DBAPIError, NoResultFound +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm import Session, declarative_base + +import orion.core +from orion.core.io.database import DuplicateKeyError +from orion.core.utils.compat import getuser +from orion.core.worker.trial import Trial as OrionTrial +from orion.core.worker.trial import validate_status +from orion.storage.base import ( + BaseStorageProtocol, + FailedUpdate, + LockAcquisitionTimeout, + LockedAlgorithmState, + get_trial_uid_and_exp, + get_uid, +) + +log = logging.getLogger(__name__) + +Base = declarative_base() + +@compiles(BINARY, "postgresql") +def compile_binary_postgresql(type_, compiler, **kw): + """Postgresql does not know about Binary type we should byte array instead""" + return "BYTEA" + +class User(Base): + """Defines the User table""" + + __tablename__ = "users" + + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30), unique=True) + token = Column(String(32)) + created_at = Column(DateTime) + last_seen = Column(DateTime) + +class Experiment(Base): + """Defines the Experiment table""" + + __tablename__ = "experiments" + + _id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(30)) + meta = Column(JSON) # metadata field is reserved + version = Column(Integer) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + datetime = Column(DateTime) + algorithms = Column(JSON) + remaining = Column(JSON) + space = Column(JSON) + parent_id = Column(Integer) + + __table_args__ = ( + UniqueConstraint("name", "owner_id", name="_one_name_per_owner"), + Index("idx_experiment_name_version", "name", "version"), + ) + +class Trial(Base): + """Defines the Trial table""" + + __tablename__ = "trials" + + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + status = Column(String(30)) + results = Column(JSON) + start_time = Column(DateTime) + end_time = Column(DateTime) + heartbeat = Column(DateTime) + parent = Column(Integer, ForeignKey("trials._id"), nullable=True) + params = Column(JSON) + worker = Column(JSON) + submit_time = Column(DateTime) + exp_working_dir = Column(String(30)) + id = Column(String(30)) + + __table_args__ = ( + UniqueConstraint( + "experiment_id", "id", name="_one_trial_hash_per_experiment" + ), + Index("idx_trial_experiment_id", "experiment_id"), + Index("idx_trial_status", "status"), + # Can't put an index on json + # Index('idx_trial_results', 'results'), + Index("idx_trial_start_time", "start_time"), + Index("idx_trial_end_time", "end_time"), + ) + +class Algo(Base): + """Defines the Algo table""" + + __tablename__ = "algo" + + # it is one algo per experiment so we could set experiment_id as the primary key + # and make it a 1-1 relation + _id = Column(Integer, primary_key=True, autoincrement=True) + experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) + configuration = Column(JSON) + locked = Column(Integer) + state = Column(BINARY) + heartbeat = Column(DateTime) + + __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) + +def get_tables(): + return [User, Experiment, Trial, Algo, User] + +class SQLAlchemy(BaseStorageProtocol): # noqa: F811 + """Implement a generic protocol to allow Orion to communicate using + different storage backend + + Parameters + ---------- + uri: str + PostgreSQL backend to use for storage; the format is as follow + `protocol://[username:password@]host1[:port1][,...hostN[:portN]]][/[database][?options]]` + + """ + + def __init__(self, uri, token=None, **kwargs): + # dialect+driver://username:password@host:port/database + # + # postgresql://scott:tiger@localhost/mydatabase + # postgresql+psycopg2://scott:tiger@localhost/mydatabase + # postgresql+pg8000://scott:tiger@localhost/mydatabase + # + # mysql://scott:tiger@localhost/foo + # mysql+mysqldb://scott:tiger@localhost/foo + # mysql+pymysql://scott:tiger@localhost/foo + # + # sqlite:///foo.db # relative + # sqlite:////foo.db # absolute + # sqlite:// # in memory + + self.uri = uri + if uri == "": + uri = "sqlite://" + + # engine_from_config + self.engine = sqlalchemy.create_engine( + uri, + echo=False, + future=True, + json_serializer=to_json, + json_deserializer=from_json, + ) + + # Create the schema + # sqlite3 can fail on table if it already exist + # the doc says it shouldn't but it does + try: + Base.metadata.create_all(self.engine) + except DBAPIError: + pass + + self.token = token + self.user_id = None + self.user = None + self._connect(token) + + def _connect(self, token): + name = getuser() + + user = self._find_user(name, token) + + if user is None: + user = self._create_user(name) + + assert user is not None + + self.user_id = user._id + self.user = user + self.token = user.token + + def _find_user(self, name, token) -> User: + query = [User.name == name] + if token is not None and token != "": + query.append(User.token == token) + + with Session(self.engine) as session: + stmt = select(User).where(*query) + + return session.execute(stmt).scalar() + + def _create_user(self, name) -> User: + try: + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + user = User( + name=name, + token=uuid.uuid5(uuid.NAMESPACE_OID, name).hex, + created_at=now, + last_seen=now, + ) + session.add(user) + session.commit() + + assert user._id > 0 + return user + except DBAPIError: + return self._find_user(name, self.token) + + def __getstate__(self): + return dict( + uri=self.uri, + token=self.token, + ) + + def __setstate__(self, state): + self.uri = state["uri"] + self.token = state["token"] + self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) + + if self.uri == "sqlite://" or self.uri == "": + log.warning( + "You are serializing an in-memory database, data will be lost" + ) + Base.metadata.create_all(self.engine) + + self._connect(self.token) + + # Experiment Operations + # ===================== + + def create_experiment(self, config): + """Insert a new experiment inside the database""" + cpy = deepcopy(config) + + try: + with Session(self.engine) as session: + experiment = Experiment( + owner_id=self.user_id, + version=0, + ) + + if "refers" in config: + ref = config.get("refers") + if "parent_id" in ref: + config["parent_id"] = ref.pop("parent_id") + + cpy["meta"] = cpy.pop("metadata") + self._set_from_dict(experiment, cpy, "remaining") + + session.add(experiment) + session.commit() + + session.refresh(experiment) + config.update(self._to_experiment(experiment)) + + # Alreadyc reate the algo lock as well + self.initialize_algorithm_lock( + config["_id"], config.get("algorithms", {}) + ) + except DBAPIError: + raise DuplicateKeyError() + + def delete_experiment(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_experiment`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = delete(Experiment).where(Experiment._id == uid) + session.execute(stmt) + session.commit() + + def update_experiment(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_experiment`""" + uid = get_uid(experiment, uid) + + if where and "refers.parent_id" in where: + where["parent_id"] = where.pop("refers.parent_id") + + where = self._get_query(where) + + if uid is not None: + where["_id"] = uid + + query = self._to_query(Experiment, where) + + with Session(self.engine) as session: + stmt = select(Experiment).where(*query) + experiment = session.scalars(stmt).one() + + metadata = kwargs.pop("metadata", dict()) + self._set_from_dict(experiment, kwargs, "remaining") + experiment.meta.update(metadata) + + session.commit() + + def _fetch_experiments_with_select(self, query, selection=None): + query = self._get_query(query) + + where = self._to_query(Experiment, query) + + with Session(self.engine) as session: + columns = self._selection(Experiment, selection) + stmt = select(columns).where(*where) + + rows = session.execute(stmt).all() + + results = [] + + for row in rows: + obj = dict() + for value, k in zip(row, columns): + obj[str(k).split(".")[-1]] = value + results.append(obj) + + return results + + def fetch_experiments(self, query, selection=None): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_experiments`""" + if "refers.parent_id" in query: + query["parent_id"] = query.pop("refers.parent_id") + + if selection: + return self._fetch_experiments_with_select(query, selection) + + query = self._get_query(query) + where = self._to_query(Experiment, query) + + with Session(self.engine) as session: + stmt = select(Experiment).where(*where) + + experiments = session.scalars(stmt).all() + + r = [self._to_experiment(exp) for exp in experiments] + return r + + # Benchmarks + # ========== + + # Trials + # ====== + def fetch_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials`""" + uid = get_uid(experiment, uid) + + query = self._get_query(where) + + if uid is not None: + query["experiment_id"] = uid + + query = self._to_query(Trial, query) + + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(t)) for t in results] + + def register_trial(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.register_trial`""" + config = trial.to_dict() + + try: + with Session(self.engine) as session: + experiment_id = config.pop("experiment", None) + + db_trial = Trial(experiment_id=experiment_id, owner_id=self.user_id) + + self._set_from_dict(db_trial, config) + + session.add(db_trial) + session.commit() + + session.refresh(db_trial) + trial.id_override = db_trial._id + + return OrionTrial(**self._to_trial(db_trial)) + except DBAPIError: + raise DuplicateKeyError() + + def delete_trials(self, experiment=None, uid=None, where=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_trials`""" + uid = get_uid(experiment, uid) + + where = self._get_query(where) + + if uid is not None: + where["experiment_id"] = uid + + query = self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = delete(Trial).where(*query) + count = session.execute(stmt) + session.commit() + + return count.rowcount + + def retrieve_result(self, trial, **kwargs): + """Updates the results array""" + return trial + + def get_trial(self, trial=None, uid=None, experiment_uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" + trial_uid, experiment_uid = get_trial_uid_and_exp( + trial, uid, experiment_uid + ) + + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.experiment_id == experiment_uid, + Trial.id == trial_uid, + ) + trial = session.scalars(stmt).one() + + return OrionTrial(**self._to_trial(trial)) + + def update_trials(self, experiment=None, uid=None, where=None, **kwargs): + """See :func:`orion.storage.base.BaseStorageProtocol.update_trials`""" + uid = get_uid(experiment, uid) + + where = self._get_query(where) + where["experiment_id"] = uid + query = self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + trials = session.scalars(stmt).all() + + for trial in trials: + self._set_from_dict(trial, kwargs) + + session.commit() + + return len(trials) + + def update_trial( + self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs + ): + """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" + trial_uid, experiment_uid = get_trial_uid_and_exp( + trial, uid, experiment_uid + ) + + where = self._get_query(where) + + # THIS IS NOT THE UNIQUE ID OF THE TRIAL + where["id"] = trial_uid + where["experiment_id"] = experiment_uid + query = self._to_query(Trial, where) + + with Session(self.engine) as session: + stmt = select(Trial).where(*query) + trial = session.scalars(stmt).one() + + self._set_from_dict(trial, kwargs) + session.commit() + + return OrionTrial(**self._to_trial(trial)) + + def fetch_lost_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" + heartbeat = orion.core.config.worker.heartbeat + threshold = datetime.datetime.utcnow() - datetime.timedelta( + seconds=heartbeat * 5 + ) + + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.experiment_id == experiment._id, + Trial.status == "reserved", + Trial.heartbeat < threshold, + ) + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(t)) for t in results] + + def push_trial_results(self, trial): + """See :func:`orion.storage.base.BaseStorageProtocol.push_trial_results`""" + + log.debug("push trial to storage") + original = trial + config = trial.to_dict() + + # Don't need to set that one + config.pop("experiment") + + with Session(self.engine) as session: + stmt = select(Trial).where( + # Trial.experiment_id == trial.experiment, + # Trial.id == trial.id, + Trial._id == trial.id_override, + Trial.status == "reserved", + ) + trial = session.scalars(stmt).one() + self._set_from_dict(trial, config) + session.commit() + + return original + + def set_trial_status(self, trial, status, heartbeat=None, was=None): + """See :func:`orion.storage.base.BaseStorageProtocol.set_trial_status`""" + heartbeat = heartbeat or datetime.datetime.utcnow() + was = was or trial.status + + validate_status(status) + validate_status(was) + + query = [ + Trial.id == trial.id, + Trial.experiment_id == trial.experiment, + Trial.status == was, + ] + + values = dict(status=status) + if heartbeat: + values["heartbeat"] = heartbeat + + with Session(self.engine) as session: + stmt = update(Trial).where(*query).values(**values) + result = session.execute(stmt) + session.commit() + + if result.rowcount == 1: + trial.status = status + else: + raise FailedUpdate() + + def fetch_pending_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_pending_trials`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + results = session.scalars(stmt).all() + trials = OrionTrial.build([self._to_trial(t) for t in results]) + + return trials + + def _reserve_trial_postgre(self, experiment): + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + # In PostgrerSQL we can do single query + stmt = ( + update(Trial) + .where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + .limit(1) + .returning() + ) + trial = session.scalar(stmt) + return OrionTrial(**self._to_trial(trial)) + + def reserve_trial(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.reserve_trial`""" + if False: + return self._reserve_trial_postgre(experiment) + + log.debug("reserve trial") + now = datetime.datetime.utcnow() + + with Session(self.engine) as session: + stmt = ( + select(Trial) + .where( + Trial.status.in_(("interrupted", "new", "suspended")), + Trial.experiment_id == experiment._id, + ) + .limit(1) + ) + + try: + trial = session.scalars(stmt).one() + except NoResultFound: + return None + + # Update the trial iff the status has not been changed yet + stmt = ( + update(Trial) + .where( + Trial.status == trial.status, + Trial._id == trial._id, + ) + .values( + status="reserved", + start_time=now, + heartbeat=now, + ) + ) + + result = session.execute(stmt) + + # time needs to match, could have been reserved by another worker + if result.rowcount == 1: + session.commit() + session.refresh(trial) + return OrionTrial(**self._to_trial(trial)) + + return None + + def fetch_trials_by_status(self, experiment, status): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_trials_by_status`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status == status, Trial.experiment_id == experiment._id + ) + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(trial)) for trial in results] + + def fetch_noncompleted_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.fetch_noncompleted_trials`""" + with Session(self.engine) as session: + stmt = select(Trial).where( + Trial.status != "completed", + Trial.experiment_id == experiment._id, + ) + results = session.scalars(stmt).all() + + return [OrionTrial(**self._to_trial(trial)) for trial in results] + + def count_completed_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_completed_trials`""" + with Session(self.engine) as session: + return ( + session.query(Trial) + .filter( + Trial.status == "completed", + Trial.experiment_id == experiment._id, + ) + .count() + ) + + def count_broken_trials(self, experiment): + """See :func:`orion.storage.base.BaseStorageProtocol.count_broken_trials`""" + with Session(self.engine) as session: + return ( + session.query(Trial) + .filter( + Trial.status == "broken", + Trial.experiment_id == experiment._id, + ) + .count() + ) + + def update_heartbeat(self, trial): + """Update trial's heartbeat""" + + with Session(self.engine) as session: + stmt = ( + update(Trial) + .where( + Trial._id == trial.id_override, + Trial.status == "reserved", + ) + .values(heartbeat=datetime.datetime.utcnow()) + ) + + cursor = session.execute(stmt) + session.commit() + + if cursor.rowcount <= 0: + raise FailedUpdate() + + # Algorithm + # ========= + def initialize_algorithm_lock(self, experiment_id, algorithm_config): + """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" + with Session(self.engine) as session: + algo = Algo( + experiment_id=experiment_id, + owner_id=self.user_id, + configuration=algorithm_config, + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + session.add(algo) + session.commit() + + def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): + """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" + + uid = get_uid(experiment, uid) + + values = dict( + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + if new_state is not None: + values["state"] = pickle.dumps(new_state) + + with Session(self.engine) as session: + stmt = ( + update(Algo) + .where( + Algo.experiment_id == uid, + Algo.locked == 1, + ) + .values(**values) + ) + session.execute(stmt) + session.commit() + + def get_algorithm_lock_info(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.get_algorithm_lock_info`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = select(Algo).where(Algo.experiment_id == uid) + algo = session.scalar(stmt) + + if algo is None: + return None + + return LockedAlgorithmState( + state=pickle.loads(algo.state) if algo.state is not None else None, + configuration=algo.configuration, + locked=algo.locked, + ) + + def delete_algorithm_lock(self, experiment=None, uid=None): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_algorithm_lock`""" + uid = get_uid(experiment, uid) + + with Session(self.engine) as session: + stmt = delete(Algo).where(Algo.experiment_id == uid) + cursor = session.execute(stmt) + session.commit() + + return cursor.rowcount + + def _acquire_algorithm_lock_postgre( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + with Session(self.engine) as session: + now = datetime.datetime.utcnow() + + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=now) + .returning() + ) + + algo = session.scalar(stmt).one() + session.commit() + return algo + + def _acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=1, retry_interval=1 + ): + uid = get_uid(experiment, uid) + algo_state_lock = None + start = time.perf_counter() + + with Session(self.engine) as session: + while algo_state_lock is None and time.perf_counter() - start < timeout: + now = datetime.datetime.utcnow() + + stmt = ( + update(Algo) + .where(Algo.experiment_id == uid, Algo.locked == 0) + .values(locked=1, heartbeat=now) + ) + + cursor = session.execute(stmt) + session.commit() + + if cursor.rowcount == 0: + time.sleep(retry_interval) + else: + stmt = select(Algo).where( + Algo.experiment_id == uid, Algo.locked == 1 + ) + algo_state_lock = session.scalar(stmt) + break + + if algo_state_lock is None: + raise LockAcquisitionTimeout() + + if algo_state_lock.state is not None: + state = pickle.loads(algo_state_lock.state) + else: + state = None + + return LockedAlgorithmState( + state=state, + configuration=algo_state_lock.configuration, + locked=True, + ) + + @contextlib.contextmanager + def acquire_algorithm_lock( + self, experiment=None, uid=None, timeout=60, retry_interval=1 + ): + """See :func:`orion.storage.base.BaseStorageProtocol.acquire_algorithm_lock`""" + locked_algo_state = self._acquire_algorithm_lock( + experiment, uid, timeout, retry_interval + ) + + try: + log.debug("lock algo") + yield locked_algo_state + except Exception: + # Reset algo to state fetched lock time + locked_algo_state.reset() + raise + finally: + log.debug("unlock algo") + uid = get_uid(experiment, uid) + self.release_algorithm_lock(uid=uid, new_state=locked_algo_state.state) + + # Utilities + # ========= + def _get_query(self, query): + if query is None: + query = dict() + + query["owner_id"] = self.user_id + return query + + def _selection(self, table, selection): + selected = [] + + for k, v in selection.items(): + if hasattr(table, k) and v: + selected.append(getattr(table, k)) + + return selected + + def _set_from_dict(self, obj, data, rest=None): + data = deepcopy(data) + meta = dict() + while data: + k, v = data.popitem() + + if v is None: + continue + + if hasattr(obj, k): + setattr(obj, k, v) + else: + meta[k] = v + + if meta and rest: + setattr(obj, rest, meta) + return + + if meta: + log.warning("Data was discarded %s", meta) + assert False + + def _to_query(self, table, where): + query = [] + + for k, v in where.items(): + if hasattr(table, k): + query.append(getattr(table, k) == v) + else: + log.warning("constrained ignored %s = %s", k, v) + + return query + + def _to_experiment(self, experiment): + exp = deepcopy(experiment.__dict__) + exp["metadata"] = exp.pop("meta", {}) + exp.pop("_sa_instance_state") + exp.pop("owner_id") + exp.pop("datetime") + + none_keys = [] + for k, v in exp.items(): + if v is None: + none_keys.append(k) + + for k in none_keys: + exp.pop(k) + + rest = exp.pop("remaining", {}) + if rest is None: + rest = {} + + exp.update(rest) + return exp + + def _to_trial(self, trial): + trial = deepcopy(trial.__dict__) + trial.pop("_sa_instance_state") + trial["experiment"] = trial.pop("experiment_id") + trial.pop("owner_id") + return trial From e706c1ae5630cc89a61e01796540ab49e289bf8c Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 14:36:16 -0500 Subject: [PATCH 18/25] - --- src/orion/storage/sql_impl.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/orion/storage/sql_impl.py b/src/orion/storage/sql_impl.py index 02c7d06c9..a56754828 100644 --- a/src/orion/storage/sql_impl.py +++ b/src/orion/storage/sql_impl.py @@ -6,11 +6,11 @@ import uuid from copy import deepcopy +import sqlalchemy + # Use MongoDB json serializer from bson.json_util import dumps as to_json from bson.json_util import loads as from_json - -import sqlalchemy from sqlalchemy import ( BINARY, JSON, @@ -42,16 +42,18 @@ get_trial_uid_and_exp, get_uid, ) - + log = logging.getLogger(__name__) Base = declarative_base() + @compiles(BINARY, "postgresql") def compile_binary_postgresql(type_, compiler, **kw): """Postgresql does not know about Binary type we should byte array instead""" return "BYTEA" + class User(Base): """Defines the User table""" @@ -63,6 +65,7 @@ class User(Base): created_at = Column(DateTime) last_seen = Column(DateTime) + class Experiment(Base): """Defines the Experiment table""" @@ -84,6 +87,7 @@ class Experiment(Base): Index("idx_experiment_name_version", "name", "version"), ) + class Trial(Base): """Defines the Trial table""" @@ -105,9 +109,7 @@ class Trial(Base): id = Column(String(30)) __table_args__ = ( - UniqueConstraint( - "experiment_id", "id", name="_one_trial_hash_per_experiment" - ), + UniqueConstraint("experiment_id", "id", name="_one_trial_hash_per_experiment"), Index("idx_trial_experiment_id", "experiment_id"), Index("idx_trial_status", "status"), # Can't put an index on json @@ -116,6 +118,7 @@ class Trial(Base): Index("idx_trial_end_time", "end_time"), ) + class Algo(Base): """Defines the Algo table""" @@ -133,9 +136,11 @@ class Algo(Base): __table_args__ = (Index("idx_algo_experiment_id", "experiment_id"),) + def get_tables(): return [User, Experiment, Trial, Algo, User] + class SQLAlchemy(BaseStorageProtocol): # noqa: F811 """Implement a generic protocol to allow Orion to communicate using different storage backend @@ -244,9 +249,7 @@ def __setstate__(self, state): self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) if self.uri == "sqlite://" or self.uri == "": - log.warning( - "You are serializing an in-memory database, data will be lost" - ) + log.warning("You are serializing an in-memory database, data will be lost") Base.metadata.create_all(self.engine) self._connect(self.token) @@ -280,9 +283,7 @@ def create_experiment(self, config): config.update(self._to_experiment(experiment)) # Alreadyc reate the algo lock as well - self.initialize_algorithm_lock( - config["_id"], config.get("algorithms", {}) - ) + self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) except DBAPIError: raise DuplicateKeyError() @@ -427,9 +428,7 @@ def retrieve_result(self, trial, **kwargs): def get_trial(self, trial=None, uid=None, experiment_uid=None): """See :func:`orion.storage.base.BaseStorageProtocol.get_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp( - trial, uid, experiment_uid - ) + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) with Session(self.engine) as session: stmt = select(Trial).where( @@ -463,9 +462,7 @@ def update_trial( self, trial=None, uid=None, experiment_uid=None, where=None, **kwargs ): """See :func:`orion.storage.base.BaseStorageProtocol.update_trial`""" - trial_uid, experiment_uid = get_trial_uid_and_exp( - trial, uid, experiment_uid - ) + trial_uid, experiment_uid = get_trial_uid_and_exp(trial, uid, experiment_uid) where = self._get_query(where) From 766194859329cdc17e15941aa0ea9f03540f9f23 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Tue, 22 Nov 2022 16:33:45 -0500 Subject: [PATCH 19/25] - --- src/orion/storage/sql.py | 5 ++++- tests/unittests/storage/test_storage.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/orion/storage/sql.py b/src/orion/storage/sql.py index cd028b318..d715d017f 100644 --- a/src/orion/storage/sql.py +++ b/src/orion/storage/sql.py @@ -1,8 +1,11 @@ IMPORT_ERROR = None try: from orion.storage.sql_impl import SQLAlchemy as SQLAlchemyImpl -except ImportError as err: + + HAS_SQLALCHEMY = True +except ModuleNotFoundError as err: IMPORT_ERROR = err + HAS_SQLALCHEMY = False if IMPORT_ERROR is not None: diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index fb7b9e4c4..8ff9ef617 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -22,6 +22,7 @@ setup_storage, ) from orion.storage.legacy import Legacy +from orion.storage.sql import HAS_SQLALCHEMY from orion.storage.track import HAS_TRACK, REASON from orion.testing import OrionState, base_experiment @@ -30,10 +31,19 @@ storage_backends = [ None, # defaults to legacy with PickleDB - dict(type="sqlalchemy", uri="sqlite:///${file}"), # Temporary file - dict(type="sqlalchemy", uri="sqlite://"), # In-memory ] +if not HAS_SQLALCHEMY: + log.warning("Track is not tested because: %s!", REASON) +else: + storage_backends.extend( + [ + dict(type="sqlalchemy", uri="sqlite:///${file}"), # Temporary file + dict(type="sqlalchemy", uri="sqlite://"), # In-memory + ] + ) + + if not HAS_TRACK: log.warning("Track is not tested because: %s!", REASON) else: From 7eade6cdd14829f555bf4eef19ca415f8de0933f Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 11:25:20 -0500 Subject: [PATCH 20/25] - --- src/orion/client/runner.py | 998 +++++++++--------- src/orion/core/evc/experiment.py | 574 +++++----- src/orion/core/io/database/pickleddb.py | 620 +++++------ src/orion/core/io/experiment_builder.py | 6 +- src/orion/core/io/resolve_config.py | 8 +- src/orion/storage/sql_impl.py | 4 +- src/orion/testing/__init__.py | 4 +- tests/conftest.py | 28 +- .../commands/test_insert_command.py | 21 +- .../configuration/test_all_options.py | 10 + tests/unittests/core/conftest.py | 19 - tests/unittests/core/io/orion_config.yaml | 4 + .../core/io/test_experiment_builder.py | 4 +- .../unittests/core/io/test_resolve_config.py | 6 +- .../unittests/core/worker/test_experiment.py | 2 +- tests/unittests/core/worker/test_producer.py | 3 +- tests/unittests/storage/test_legacy.py | 5 +- tests/unittests/storage/test_storage.py | 2 +- 18 files changed, 1171 insertions(+), 1147 deletions(-) diff --git a/src/orion/client/runner.py b/src/orion/client/runner.py index f1eb27a3d..f25e6f498 100644 --- a/src/orion/client/runner.py +++ b/src/orion/client/runner.py @@ -1,499 +1,499 @@ -# pylint:disable=too-many-arguments -# pylint:disable=too-many-instance-attributes -""" -Runner -====== - -Executes the optimization process -""" -from __future__ import annotations - -import logging -import os -import shutil -import signal -import time -import typing -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Callable - -import orion.core -from orion.core.utils import backward -from orion.core.utils.exceptions import ( - BrokenExperiment, - CompletedExperiment, - InvalidResult, - LazyWorkers, - ReservationRaceCondition, - WaitingForTrials, -) -from orion.core.utils.flatten import flatten, unflatten -from orion.core.worker.consumer import ExecutionError -from orion.core.worker.trial import AlreadyReleased -from orion.executor.base import AsyncException, AsyncResult -from orion.storage.base import LockAcquisitionTimeout - -if typing.TYPE_CHECKING: - from orion.client.experiment import ExperimentClient - from orion.core.worker.trial import Trial - -log = logging.getLogger(__name__) - - -class Protected: - """Prevent a signal to be raised during the execution of some code""" - - def __init__(self): - self.signal_received = None - self.handlers = {} - self.start = 0 - self.delayed = 0 - self.signal_installed = False - - def __enter__(self): - """Override the signal handlers with our delayed handler""" - self.signal_received = False - - try: - self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler) - self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler) - self.signal_installed = True - - except ValueError: # ValueError: signal only works in main thread - log.warning( - "SIGINT/SIGTERM protection hooks could not be installed because " - "Runner is executing inside a thread/subprocess, results could get lost " - "on interruptions" - ) - - return self - - def handler(self, sig, frame): - """Register the received signal for later""" - log.warning("Delaying signal %d to finish operations", sig) - log.warning( - "Press CTRL-C again to terminate the program now (You may lose results)" - ) - - self.start = time.time() - - self.signal_received = (sig, frame) - - # if CTRL-C is pressed again the original handlers will handle it - # and make the program stop - self.restore_handlers() - - def restore_handlers(self): - """Restore old signal handlers""" - if not self.signal_installed: - return - - signal.signal(signal.SIGINT, self.handlers[signal.SIGINT]) - signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM]) - - def stop_now(self): - """Raise the delayed signal if any or restore the old signal handlers""" - - if not self.signal_received: - self.restore_handlers() - - else: - self.delayed = time.time() - self.start - - log.warning("Termination was delayed by %.4f s", self.delayed) - handler = self.handlers[self.signal_received[0]] - - if callable(handler): - handler(*self.signal_received) - - def __exit__(self, *args): - self.stop_now() - - -def _optimize(trial, fct, trial_arg, **kwargs): - """Execute a trial on a worker""" - - kwargs.update(flatten(trial.params)) - - if trial_arg: - kwargs[trial_arg] = trial - - return fct(**unflatten(kwargs)) - - -def delayed_exception(exception: Exception): - """Raise exception when called...""" - raise exception - - -@dataclass -class _Stat: - sample: int = 0 - scatter: int = 0 - gather: int = 0 - - @contextmanager - def time(self, name): - """Measure elapsed time of a given block""" - start = time.time() - yield - total = time.time() - start - - value = getattr(self, name) - setattr(self, name, value + total) - - def report(self): - """Show the elapsed time of different blocks""" - lines = [ - f"Sample {self.sample:7.4f}", - f"Scatter {self.scatter:7.4f}", - f"Gather {self.gather:7.4f}", - ] - return "\n".join(lines) - - -def prepare_trial_working_dir( - experiment_client: ExperimentClient, trial: Trial -) -> None: - """Prepare working directory of a trial. - - This will create a working directory based on ``trial.working_dir`` if not already existing. If - the trial has a parent, the ``working_dir`` of the parent will be copied to the ``working_dir`` - of the current trial. - - Parameters - ---------- - experiment_client: orion.client.experiment.ExperimentClient - The experiment client being executed. - trial: orion.core.worker.trial.Trial - The trial that will be executed. - - Raises - ------ - ``ValueError`` - If the parent is not found in the storage of ``experiment_client``. - - """ - backward.ensure_trial_working_dir(experiment_client, trial) - - # TODO: Test that this works when resuming a trial. - if os.path.exists(trial.working_dir): - return - - if trial.parent: - parent_trial = experiment_client.get_trial(uid=trial.parent) - if parent_trial is None: - raise ValueError( - f"Parent id {trial.parent} not available in storage. (From trial {trial.id})" - ) - shutil.copytree(parent_trial.working_dir, trial.working_dir) - else: - os.makedirs(trial.working_dir) - - -class Runner: - """Run the optimization process given the current executor""" - - def __init__( - self, - client: ExperimentClient, - fct: Callable, - pool_size: int, - idle_timeout: int, - max_trials_per_worker: int, - max_broken: int, - trial_arg: str, - on_error: Callable[[ExperimentClient, Exception, int], bool] | None = None, - prepare_trial: Callable[ - [ExperimentClient, Trial], None - ] = prepare_trial_working_dir, - interrupt_signal_code: int | None = None, - gather_timeout: float = 0.01, - n_workers: int | None = None, - **kwargs, - ): - self.client = client - self.fct = fct - self.batch_size = pool_size - self.max_trials_per_worker = max_trials_per_worker - self.max_broken = max_broken - self.trial_arg = trial_arg - self.on_error = on_error - self.prepare_trial = prepare_trial - self.kwargs = kwargs - - self.gather_timeout = gather_timeout - self.idle_timeout = idle_timeout - - self.worker_broken_trials = 0 - self.trials = 0 - self.futures = [] - self.pending_trials = {} - self.stat = _Stat() - self.n_worker_override = n_workers - - if interrupt_signal_code is None: - interrupt_signal_code = orion.core.config.worker.interrupt_signal_code - - self.interrupt_signal_code = interrupt_signal_code - - @property - def free_worker(self): - """Returns the number of free worker""" - n_workers = self.client.executor.n_workers - - if self.n_worker_override is not None: - n_workers = self.n_worker_override - - return max(n_workers - len(self.pending_trials), 0) - - @property - def is_done(self): - """Returns true if the experiment has finished.""" - return self.client.is_done - - @property - def is_broken(self): - """Returns true if the experiment is broken""" - return self.worker_broken_trials >= self.max_broken - - @property - def has_remaining(self) -> bool: - """Returns true if the worker can still pick up work""" - return self.max_trials_per_worker - self.trials > 0 - - @property - def is_idle(self): - """Returns true if none of the workers are running a trial""" - return len(self.pending_trials) <= 0 - - @property - def is_running(self): - """Returns true if we are still running trials.""" - return len(self.pending_trials) > 0 or (self.has_remaining and not self.is_done) - - def run(self): - """Run the optimizing process until completion. - - Returns - ------- - the total number of trials processed - - """ - idle_start = time.time() - idle_end = 0 - idle_time = 0 - - while self.is_running: - try: - - # Protected will prevent Keyboard interrupts from - # happening in the middle of the scatter-gather process - # that we can be sure that completed trials are observed - with Protected(): - - # Get new trials for our free workers - with self.stat.time("sample"): - new_trials = self.sample() - - # Scatter the new trials to our free workers - with self.stat.time("scatter"): - scattered = self.scatter(new_trials) - - # Gather the results of the workers that have finished - with self.stat.time("gather"): - gathered = self.gather() - - if scattered == 0 and gathered == 0 and self.is_idle: - idle_end = time.time() - idle_time += idle_end - idle_start - idle_start = idle_end - - log.debug(f"Workers have been idle for {idle_time:.2f} s") - else: - idle_start = time.time() - idle_time = 0 - - if self.is_idle and idle_time > self.idle_timeout: - msg = f"Workers have been idle for {idle_time:.2f} s" - - if self.has_remaining and not self.is_done: - msg = ( - f"{msg}; worker has leg room (has_remaining: {self.has_remaining})" - f" and optimization is not done (is_done: {self.is_done})" - ) - - raise LazyWorkers(msg) - - except KeyboardInterrupt: - self._release_all() - raise - except: - self._release_all() - raise - - return self.trials - - def should_sample(self): - """Check if more trials could be generated""" - - if self.free_worker <= 0 or (self.is_broken or self.is_done): - return 0 - - pending = len(self.pending_trials) + self.trials - remains = self.max_trials_per_worker - pending - - n_trial = min(self.free_worker, remains) - should_sample_more = self.free_worker > 0 and remains > 0 - - return int(should_sample_more) * n_trial - - def sample(self): - """Sample new trials for all free workers""" - n_trial = self.should_sample() - - if n_trial > 0: - # the producer does the job of limiting the number of new trials - # already no need to worry about it - # NB: suggest reserve the trial already - new_trials = self._suggest_trials(n_trial) - log.debug(f"Sampled {len(new_trials)} new configs") - return new_trials - - return [] - - def scatter(self, new_trials): - """Schedule new trials to be computed""" - new_futures = [] - for trial in new_trials: - try: - self.prepare_trial(self.client, trial) - prepared = True - # pylint:disable=broad-except - except Exception as e: - future = self.client.executor.submit(delayed_exception, e) - prepared = False - - if prepared: - future = self.client.executor.submit( - _optimize, trial, self.fct, self.trial_arg, **self.kwargs - ) - - self.pending_trials[future] = trial - new_futures.append(future) - - self.futures.extend(new_futures) - if new_futures: - log.debug("Scheduled new trials") - return len(new_futures) - - def gather(self): - """Gather the results from each worker asynchronously""" - results = self.client.executor.async_get( - self.futures, timeout=self.gather_timeout - ) - - to_be_raised = None - if results: - log.debug(f"Gathered new results {len(results)}") - # register the results - # NOTE: For Ptera instrumentation - trials = 0 # pylint:disable=unused-variable - for result in results: - trial = self.pending_trials.pop(result.future) - - if isinstance(result, AsyncResult): - try: - # NB: observe release the trial already - self.client.observe(trial, result.value) - self.trials += 1 - # NOTE: For Ptera instrumentation - trials = self.trials # pylint:disable=unused-variable - except InvalidResult as exception: - # stop the optimization process if we received `InvalidResult` - # as all the trials are assumed to be returning those - to_be_raised = exception - self.client.release(trial, status="broken") - - if isinstance(result, AsyncException): - if ( - isinstance(result.exception, ExecutionError) - and result.exception.return_code == self.interrupt_signal_code - ): - to_be_raised = KeyboardInterrupt() - self.client.release(trial, status="interrupted") - continue - - # Regular exception, might be caused by the chosen hyperparameters - # themselves rather than the code in particular (like Out of Memory error - # for big batch sizes) - exception = result.exception - self.worker_broken_trials += 1 - self.client.release(trial, status="broken") - - if self.on_error is None or self.on_error( - self, trial, exception, self.worker_broken_trials - ): - log.error(result.traceback) - - else: - log.error(str(exception)) - log.debug(result.traceback) - - # if we receive too many broken trials, it might indicate the user script - # is broken, stop the experiment and let the user investigate - if self.is_broken: - to_be_raised = BrokenExperiment( - "Worker has reached broken trials threshold" - ) - - if to_be_raised is not None: - log.debug("Runner was interrupted") - self._release_all() - raise to_be_raised - - return len(results) - - def _release_all(self): - """Release all the trials that were reserved by this runner. - This is only called during exception handling to avoid retaining trials - that cannot be retrieved anymore - - """ - # Sanity check - for _, trial in self.pending_trials.items(): - try: - self.client.release(trial, status="interrupted") - except AlreadyReleased: - pass - - self.pending_trials = {} - - def _suggest_trials(self, count): - """Suggest a bunch of trials to be dispatched to the workers""" - trials = [] - for _ in range(count): - try: - batch_size = count if self.batch_size == 0 else self.batch_size - trial = self.client.suggest(pool_size=batch_size) - trials.append(trial) - - # non critical errors - except WaitingForTrials: - log.debug("Runner cannot sample because WaitingForTrials") - break - - except ReservationRaceCondition: - log.debug("Runner cannot sample because ReservationRaceCondition") - break - - except LockAcquisitionTimeout: - log.debug("Runner cannot sample because LockAcquisitionTimeout") - break - - except CompletedExperiment: - log.debug("Runner cannot sample because CompletedExperiment") - break - - return trials +# pylint:disable=too-many-arguments +# pylint:disable=too-many-instance-attributes +""" +Runner +====== + +Executes the optimization process +""" +from __future__ import annotations + +import logging +import os +import shutil +import signal +import time +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable + +import orion.core +from orion.core.utils import backward +from orion.core.utils.exceptions import ( + BrokenExperiment, + CompletedExperiment, + InvalidResult, + LazyWorkers, + ReservationRaceCondition, + WaitingForTrials, +) +from orion.core.utils.flatten import flatten, unflatten +from orion.core.worker.consumer import ExecutionError +from orion.core.worker.trial import AlreadyReleased +from orion.executor.base import AsyncException, AsyncResult +from orion.storage.base import LockAcquisitionTimeout + +if typing.TYPE_CHECKING: + from orion.client.experiment import ExperimentClient + from orion.core.worker.trial import Trial + +log = logging.getLogger(__name__) + + +class Protected: + """Prevent a signal to be raised during the execution of some code""" + + def __init__(self): + self.signal_received = None + self.handlers = {} + self.start = 0 + self.delayed = 0 + self.signal_installed = False + + def __enter__(self): + """Override the signal handlers with our delayed handler""" + self.signal_received = False + + try: + self.handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self.handler) + self.handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self.handler) + self.signal_installed = True + + except ValueError: # ValueError: signal only works in main thread + log.warning( + "SIGINT/SIGTERM protection hooks could not be installed because " + "Runner is executing inside a thread/subprocess, results could get lost " + "on interruptions" + ) + + return self + + def handler(self, sig, frame): + """Register the received signal for later""" + log.warning("Delaying signal %d to finish operations", sig) + log.warning( + "Press CTRL-C again to terminate the program now (You may lose results)" + ) + + self.start = time.time() + + self.signal_received = (sig, frame) + + # if CTRL-C is pressed again the original handlers will handle it + # and make the program stop + self.restore_handlers() + + def restore_handlers(self): + """Restore old signal handlers""" + if not self.signal_installed: + return + + signal.signal(signal.SIGINT, self.handlers[signal.SIGINT]) + signal.signal(signal.SIGTERM, self.handlers[signal.SIGTERM]) + + def stop_now(self): + """Raise the delayed signal if any or restore the old signal handlers""" + + if not self.signal_received: + self.restore_handlers() + + else: + self.delayed = time.time() - self.start + + log.warning("Termination was delayed by %.4f s", self.delayed) + handler = self.handlers[self.signal_received[0]] + + if callable(handler): + handler(*self.signal_received) + + def __exit__(self, *args): + self.stop_now() + + +def _optimize(trial, fct, trial_arg, **kwargs): + """Execute a trial on a worker""" + + kwargs.update(flatten(trial.params)) + + if trial_arg: + kwargs[trial_arg] = trial + + return fct(**unflatten(kwargs)) + + +def delayed_exception(exception: Exception): + """Raise exception when called...""" + raise exception + + +@dataclass +class _Stat: + sample: int = 0 + scatter: int = 0 + gather: int = 0 + + @contextmanager + def time(self, name): + """Measure elapsed time of a given block""" + start = time.time() + yield + total = time.time() - start + + value = getattr(self, name) + setattr(self, name, value + total) + + def report(self): + """Show the elapsed time of different blocks""" + lines = [ + f"Sample {self.sample:7.4f}", + f"Scatter {self.scatter:7.4f}", + f"Gather {self.gather:7.4f}", + ] + return "\n".join(lines) + + +def prepare_trial_working_dir( + experiment_client: ExperimentClient, trial: Trial +) -> None: + """Prepare working directory of a trial. + + This will create a working directory based on ``trial.working_dir`` if not already existing. If + the trial has a parent, the ``working_dir`` of the parent will be copied to the ``working_dir`` + of the current trial. + + Parameters + ---------- + experiment_client: orion.client.experiment.ExperimentClient + The experiment client being executed. + trial: orion.core.worker.trial.Trial + The trial that will be executed. + + Raises + ------ + ``ValueError`` + If the parent is not found in the storage of ``experiment_client``. + + """ + backward.ensure_trial_working_dir(experiment_client, trial) + + # TODO: Test that this works when resuming a trial. + if os.path.exists(trial.working_dir): + return + + if trial.parent: + parent_trial = experiment_client.get_trial(uid=trial.parent) + if parent_trial is None: + raise ValueError( + f"Parent id {trial.parent} not available in storage. (From trial {trial.id})" + ) + shutil.copytree(parent_trial.working_dir, trial.working_dir) + else: + os.makedirs(trial.working_dir) + + +class Runner: + """Run the optimization process given the current executor""" + + def __init__( + self, + client: ExperimentClient, + fct: Callable, + pool_size: int, + idle_timeout: int, + max_trials_per_worker: int, + max_broken: int, + trial_arg: str, + on_error: Callable[[ExperimentClient, Exception, int], bool] | None = None, + prepare_trial: Callable[ + [ExperimentClient, Trial], None + ] = prepare_trial_working_dir, + interrupt_signal_code: int | None = None, + gather_timeout: float = 0.01, + n_workers: int | None = None, + **kwargs, + ): + self.client = client + self.fct = fct + self.batch_size = pool_size + self.max_trials_per_worker = max_trials_per_worker + self.max_broken = max_broken + self.trial_arg = trial_arg + self.on_error = on_error + self.prepare_trial = prepare_trial + self.kwargs = kwargs + + self.gather_timeout = gather_timeout + self.idle_timeout = idle_timeout + + self.worker_broken_trials = 0 + self.trials = 0 + self.futures = [] + self.pending_trials = {} + self.stat = _Stat() + self.n_worker_override = n_workers + + if interrupt_signal_code is None: + interrupt_signal_code = orion.core.config.worker.interrupt_signal_code + + self.interrupt_signal_code = interrupt_signal_code + + @property + def free_worker(self): + """Returns the number of free worker""" + n_workers = self.client.executor.n_workers + + if self.n_worker_override is not None: + n_workers = self.n_worker_override + + return max(n_workers - len(self.pending_trials), 0) + + @property + def is_done(self): + """Returns true if the experiment has finished.""" + return self.client.is_done + + @property + def is_broken(self): + """Returns true if the experiment is broken""" + return self.worker_broken_trials >= self.max_broken + + @property + def has_remaining(self) -> bool: + """Returns true if the worker can still pick up work""" + return self.max_trials_per_worker - self.trials > 0 + + @property + def is_idle(self): + """Returns true if none of the workers are running a trial""" + return len(self.pending_trials) <= 0 + + @property + def is_running(self): + """Returns true if we are still running trials.""" + return len(self.pending_trials) > 0 or (self.has_remaining and not self.is_done) + + def run(self): + """Run the optimizing process until completion. + + Returns + ------- + the total number of trials processed + + """ + idle_start = time.time() + idle_end = 0 + idle_time = 0 + + while self.is_running: + try: + + # Protected will prevent Keyboard interrupts from + # happening in the middle of the scatter-gather process + # that we can be sure that completed trials are observed + with Protected(): + + # Get new trials for our free workers + with self.stat.time("sample"): + new_trials = self.sample() + + # Scatter the new trials to our free workers + with self.stat.time("scatter"): + scattered = self.scatter(new_trials) + + # Gather the results of the workers that have finished + with self.stat.time("gather"): + gathered = self.gather() + + if scattered == 0 and gathered == 0 and self.is_idle: + idle_end = time.time() + idle_time += idle_end - idle_start + idle_start = idle_end + + log.debug(f"Workers have been idle for {idle_time:.2f} s") + else: + idle_start = time.time() + idle_time = 0 + + if self.is_idle and idle_time > self.idle_timeout: + msg = f"Workers have been idle for {idle_time:.2f} s" + + if self.has_remaining and not self.is_done: + msg = ( + f"{msg}; worker has leg room (has_remaining: {self.has_remaining})" + f" and optimization is not done (is_done: {self.is_done})" + ) + + raise LazyWorkers(msg) + + except KeyboardInterrupt: + self._release_all() + raise + except: + self._release_all() + raise + + return self.trials + + def should_sample(self): + """Check if more trials could be generated""" + + if self.free_worker <= 0 or (self.is_broken or self.is_done): + return 0 + + pending = len(self.pending_trials) + self.trials + remains = self.max_trials_per_worker - pending + + n_trial = min(self.free_worker, remains) + should_sample_more = self.free_worker > 0 and remains > 0 + + return int(should_sample_more) * n_trial + + def sample(self): + """Sample new trials for all free workers""" + n_trial = self.should_sample() + + if n_trial > 0: + # the producer does the job of limiting the number of new trials + # already no need to worry about it + # NB: suggest reserve the trial already + new_trials = self._suggest_trials(n_trial) + log.debug(f"Sampled {len(new_trials)} new configs") + return new_trials + + return [] + + def scatter(self, new_trials): + """Schedule new trials to be computed""" + new_futures = [] + for trial in new_trials: + try: + self.prepare_trial(self.client, trial) + prepared = True + # pylint:disable=broad-except + except Exception as e: + future = self.client.executor.submit(delayed_exception, e) + prepared = False + + if prepared: + future = self.client.executor.submit( + _optimize, trial, self.fct, self.trial_arg, **self.kwargs + ) + + self.pending_trials[future] = trial + new_futures.append(future) + + self.futures.extend(new_futures) + if new_futures: + log.debug("Scheduled new trials") + return len(new_futures) + + def gather(self): + """Gather the results from each worker asynchronously""" + results = self.client.executor.async_get( + self.futures, timeout=self.gather_timeout + ) + + to_be_raised = None + if results: + log.debug(f"Gathered new results {len(results)}") + # register the results + # NOTE: For Ptera instrumentation + trials = 0 # pylint:disable=unused-variable + for result in results: + trial = self.pending_trials.pop(result.future) + + if isinstance(result, AsyncResult): + try: + # NB: observe release the trial already + self.client.observe(trial, result.value) + self.trials += 1 + # NOTE: For Ptera instrumentation + trials = self.trials # pylint:disable=unused-variable + except InvalidResult as exception: + # stop the optimization process if we received `InvalidResult` + # as all the trials are assumed to be returning those + to_be_raised = exception + self.client.release(trial, status="broken") + + if isinstance(result, AsyncException): + if ( + isinstance(result.exception, ExecutionError) + and result.exception.return_code == self.interrupt_signal_code + ): + to_be_raised = KeyboardInterrupt() + self.client.release(trial, status="interrupted") + continue + + # Regular exception, might be caused by the chosen hyperparameters + # themselves rather than the code in particular (like Out of Memory error + # for big batch sizes) + exception = result.exception + self.worker_broken_trials += 1 + self.client.release(trial, status="broken") + + if self.on_error is None or self.on_error( + self, trial, exception, self.worker_broken_trials + ): + log.error(result.traceback) + + else: + log.error(str(exception)) + log.debug(result.traceback) + + # if we receive too many broken trials, it might indicate the user script + # is broken, stop the experiment and let the user investigate + if self.is_broken: + to_be_raised = BrokenExperiment( + "Worker has reached broken trials threshold" + ) + + if to_be_raised is not None: + log.debug("Runner was interrupted") + self._release_all() + raise to_be_raised + + return len(results) + + def _release_all(self): + """Release all the trials that were reserved by this runner. + This is only called during exception handling to avoid retaining trials + that cannot be retrieved anymore + + """ + # Sanity check + for _, trial in self.pending_trials.items(): + try: + self.client.release(trial, status="interrupted") + except AlreadyReleased: + pass + + self.pending_trials = {} + + def _suggest_trials(self, count): + """Suggest a bunch of trials to be dispatched to the workers""" + trials = [] + for _ in range(count): + try: + batch_size = count if self.batch_size == 0 else self.batch_size + trial = self.client.suggest(pool_size=batch_size) + trials.append(trial) + + # non critical errors + except WaitingForTrials: + log.debug("Runner cannot sample because WaitingForTrials") + break + + except ReservationRaceCondition: + log.debug("Runner cannot sample because ReservationRaceCondition") + break + + except LockAcquisitionTimeout: + log.debug("Runner cannot sample because LockAcquisitionTimeout") + break + + except CompletedExperiment: + log.debug("Runner cannot sample because CompletedExperiment") + break + + return trials diff --git a/src/orion/core/evc/experiment.py b/src/orion/core/evc/experiment.py index 2245e3502..a1e54255f 100644 --- a/src/orion/core/evc/experiment.py +++ b/src/orion/core/evc/experiment.py @@ -1,287 +1,287 @@ -# pylint:disable=protected-access -""" -Experiment node for EVC -======================= - -Experiment nodes connecting experiments to the EVC tree - -The experiments are connected to one another through the experiment nodes. The former can be created -standalone without an EVC tree. When connected to an `ExperimentNode`, the experiments gain access -to trials of other experiments by using method `ExperimentNode.fetch_trials`. - -Helper functions are provided to fetch trials keeping the tree structure. Those can be helpful when -analyzing an EVC tree. - -""" -import functools -import logging - -from orion.core.utils.tree import TreeNode - -log = logging.getLogger(__name__) - - -class ExperimentNode(TreeNode): - """Experiment node to connect experiments to EVC tree. - - The node carries an experiment in attribute `item`. The node can be instantiated only using the - name of the experiment. The experiment will be created lazily on access to `node.item`. - - Attributes - ---------- - name: str - Name of the experiment - item: None or :class:`orion.core.worker.experiment.Experiment` - None if the experiment is not initialized yet. When initializing lazily, it creates an - `Experiment` in read only mode. - - .. seealso:: - - :py:class:`orion.core.utils.tree.TreeNode` for tree-specific attributes and methods. - - """ - - __slots__ = ( - "name", - "version", - "_no_parent_lookup", - "_no_children_lookup", - "storage", - ) + TreeNode.__slots__ - - def __init__( - self, - name, - version, - experiment=None, - parent=None, - children=tuple(), - storage=None, - ): - """Initialize experiment node with item, experiment, parent and children - - .. seealso:: - :class:`orion.core.utils.tree.TreeNode` for information about the attributes - """ - super().__init__(experiment, parent, children) - self.name = name - self.version = version - - self._no_parent_lookup = True - self._no_children_lookup = True - self.storage = storage or experiment._storage - - @property - def item(self): - """Get the experiment associated to the node - - Note that accessing `item` may trigger the lazy initialization of the experiment if it was - not done already. - """ - if self._item is None: - # TODO: Find another way around the circular import - from orion.core.io import experiment_builder - - self._item = experiment_builder.load( - name=self.name, version=self.version, storage=self.storage - ) - self._item._node = self - - return self._item - - @property - def parent(self): - """Get parent of the experiment, None if no parent - - .. note:: - - The instantiation of an EVC tree is lazy, which means accessing the parent of a node - may trigger a call to database to build this parent live. - - """ - if self._parent is None and self._no_parent_lookup: - self._no_parent_lookup = False - query = {"_id": self.item.refers.get("parent_id")} - selection = {"name": 1, "version": 1} - experiments = self.storage.fetch_experiments(query, selection) - - if experiments: - parent = experiments[0] - exp_node = ExperimentNode( - name=parent["name"], - version=parent.get("version", 1), - storage=self.storage, - ) - self.set_parent(exp_node) - return self._parent - - @property - def children(self): - """Get children of the experiment, empty list if no children - - .. note:: - - The instantiation of an EVC tree is lazy, which means accessing the children of a node - may trigger a call to database to build those children live. - - """ - if self._no_children_lookup: - self._children = [] - self._no_children_lookup = False - query = {"refers.parent_id": self.item.id} - selection = {"name": 1, "version": 1} - experiments = self.storage.fetch_experiments(query, selection) - for child in experiments: - self.add_children( - ExperimentNode( - name=child["name"], - version=child.get("version", 1), - storage=self.storage, - ) - ) - - return self._children - - @property - def adapter(self): - """Get the adapter of the experiment with respect to its parent""" - return self.item.refers["adapter"] - - @property - def tree_name(self): - """Return a formatted name of the Node for a tree pretty-print.""" - if self.item is not None: - return f"{self.name}-v{self.item.version}" - - return self.name - - def fetch_lost_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_lost_trials") - - def fetch_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_trials") - - def fetch_pending_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_pending_trials") - - def fetch_noncompleted_trials(self): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_noncompleted_trials") - - def fetch_trials_by_status(self, status): - """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" - return self.recurvise_fetch("fetch_trials_by_status", status=status) - - def recurvise_fetch(self, fun_name, *args, **kwargs): - """Fetch trials recursively in the EVC tree using the fetch function `fun_name`. - - Parameters - ---------- - fun_name: callable - Function name to call to fetch trials. The function must be an attribute of - :class:`orion.core.worker.experiment.Experiment` - - *args: - Positional arguments to pass to `fun_name`. - - **kwargs - Keyword arguments to pass to `fun_name`. - - """ - - def retrieve_trials(node, parent_or_children): - """Retrieve the trials of a node/experiment.""" - fun = getattr(node.item, fun_name) - # with_evc_tree needs to be False here or we will have an infinite loop - trials = fun(*args, with_evc_tree=False, **kwargs) - return dict(trials=trials, experiment=node.item), parent_or_children - - # get the trials of the parents - parent_trials = None - if self.parent is not None: - parent_trials = self.parent.map(retrieve_trials, self.parent.parent) - - # get the trials of the children - children_trials = self.map(retrieve_trials, self.children) - children_trials.set_parent(parent_trials) - - adapt_trials(children_trials) - - return sum((node.item["trials"] for node in children_trials.root), []) - - -def _adapt_parent_trials(node, parent_trials_node, ids): - """Adapt trials from the parent recursively - - .. note:: - - To call with node.map(fct, node.parent) to connect with parents - - """ - # Ids from children are passed to prioritized them if they are also present in parent nodes. - node_ids = { - trial.compute_trial_hash(trial, ignore_lie=True) - for trial in node.item["trials"] - } | ids - if parent_trials_node is not None: - adapter = node.item["experiment"].refers["adapter"] - for parent in parent_trials_node.root: - parent.item["trials"] = adapter.forward(parent.item["trials"]) - - # if trial is in current exp, filter out - parent.item["trials"] = [ - trial - for trial in parent.item["trials"] - if trial.compute_trial_hash( - trial, ignore_lie=True, ignore_experiment=True - ) - not in node_ids - ] - - return node.item, parent_trials_node - - -def _adapt_children_trials(node, children_trials_nodes): - """Adapt trials from the children recursively - - .. note:: - - To call with node.map(fct, node.children) to connect with children - - """ - ids = { - trial.compute_trial_hash(trial, ignore_lie=True) - for trial in node.item["trials"] - } - - for child in children_trials_nodes: - adapter = child.item["experiment"].refers["adapter"] - for subchild in child: # Includes child itself - subchild.item["trials"] = adapter.backward(subchild.item["trials"]) - - # if trial is in current node, filter out - subchild.item["trials"] = [ - trial - for trial in subchild.item["trials"] - if trial.compute_trial_hash( - trial, ignore_lie=True, ignore_experiment=True - ) - not in ids - ] - - return node.item, children_trials_nodes - - -def adapt_trials(trials_tree): - """Adapt trials recursively so that they are all compatible with current experiment.""" - trials_tree.map(_adapt_children_trials, trials_tree.children) - ids = set() - for child in trials_tree.children: - for trial in child.item["trials"]: - ids.add(trial.compute_trial_hash(trial, ignore_lie=True)) - trials_tree.map( - functools.partial(_adapt_parent_trials, ids=ids), trials_tree.parent - ) +# pylint:disable=protected-access +""" +Experiment node for EVC +======================= + +Experiment nodes connecting experiments to the EVC tree + +The experiments are connected to one another through the experiment nodes. The former can be created +standalone without an EVC tree. When connected to an `ExperimentNode`, the experiments gain access +to trials of other experiments by using method `ExperimentNode.fetch_trials`. + +Helper functions are provided to fetch trials keeping the tree structure. Those can be helpful when +analyzing an EVC tree. + +""" +import functools +import logging + +from orion.core.utils.tree import TreeNode + +log = logging.getLogger(__name__) + + +class ExperimentNode(TreeNode): + """Experiment node to connect experiments to EVC tree. + + The node carries an experiment in attribute `item`. The node can be instantiated only using the + name of the experiment. The experiment will be created lazily on access to `node.item`. + + Attributes + ---------- + name: str + Name of the experiment + item: None or :class:`orion.core.worker.experiment.Experiment` + None if the experiment is not initialized yet. When initializing lazily, it creates an + `Experiment` in read only mode. + + .. seealso:: + + :py:class:`orion.core.utils.tree.TreeNode` for tree-specific attributes and methods. + + """ + + __slots__ = ( + "name", + "version", + "_no_parent_lookup", + "_no_children_lookup", + "storage", + ) + TreeNode.__slots__ + + def __init__( + self, + name, + version, + experiment=None, + parent=None, + children=tuple(), + storage=None, + ): + """Initialize experiment node with item, experiment, parent and children + + .. seealso:: + :class:`orion.core.utils.tree.TreeNode` for information about the attributes + """ + super().__init__(experiment, parent, children) + self.name = name + self.version = version + + self._no_parent_lookup = True + self._no_children_lookup = True + self.storage = storage or experiment._storage + + @property + def item(self): + """Get the experiment associated to the node + + Note that accessing `item` may trigger the lazy initialization of the experiment if it was + not done already. + """ + if self._item is None: + # TODO: Find another way around the circular import + from orion.core.io import experiment_builder + + self._item = experiment_builder.load( + name=self.name, version=self.version, storage=self.storage + ) + self._item._node = self + + return self._item + + @property + def parent(self): + """Get parent of the experiment, None if no parent + + .. note:: + + The instantiation of an EVC tree is lazy, which means accessing the parent of a node + may trigger a call to database to build this parent live. + + """ + if self._parent is None and self._no_parent_lookup: + self._no_parent_lookup = False + query = {"_id": self.item.refers.get("parent_id")} + selection = {"name": 1, "version": 1} + experiments = self.storage.fetch_experiments(query, selection) + + if experiments: + parent = experiments[0] + exp_node = ExperimentNode( + name=parent["name"], + version=parent.get("version", 1), + storage=self.storage, + ) + self.set_parent(exp_node) + return self._parent + + @property + def children(self): + """Get children of the experiment, empty list if no children + + .. note:: + + The instantiation of an EVC tree is lazy, which means accessing the children of a node + may trigger a call to database to build those children live. + + """ + if self._no_children_lookup: + self._children = [] + self._no_children_lookup = False + query = {"refers.parent_id": self.item.id} + selection = {"name": 1, "version": 1} + experiments = self.storage.fetch_experiments(query, selection) + for child in experiments: + self.add_children( + ExperimentNode( + name=child["name"], + version=child.get("version", 1), + storage=self.storage, + ) + ) + + return self._children + + @property + def adapter(self): + """Get the adapter of the experiment with respect to its parent""" + return self.item.refers["adapter"] + + @property + def tree_name(self): + """Return a formatted name of the Node for a tree pretty-print.""" + if self.item is not None: + return f"{self.name}-v{self.item.version}" + + return self.name + + def fetch_lost_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_lost_trials") + + def fetch_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_trials") + + def fetch_pending_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_pending_trials") + + def fetch_noncompleted_trials(self): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_noncompleted_trials") + + def fetch_trials_by_status(self, status): + """See :meth:`orion.core.evc.experiment.ExperimentNode.recurvise_fetch`""" + return self.recurvise_fetch("fetch_trials_by_status", status=status) + + def recurvise_fetch(self, fun_name, *args, **kwargs): + """Fetch trials recursively in the EVC tree using the fetch function `fun_name`. + + Parameters + ---------- + fun_name: callable + Function name to call to fetch trials. The function must be an attribute of + :class:`orion.core.worker.experiment.Experiment` + + *args: + Positional arguments to pass to `fun_name`. + + **kwargs + Keyword arguments to pass to `fun_name`. + + """ + + def retrieve_trials(node, parent_or_children): + """Retrieve the trials of a node/experiment.""" + fun = getattr(node.item, fun_name) + # with_evc_tree needs to be False here or we will have an infinite loop + trials = fun(*args, with_evc_tree=False, **kwargs) + return dict(trials=trials, experiment=node.item), parent_or_children + + # get the trials of the parents + parent_trials = None + if self.parent is not None: + parent_trials = self.parent.map(retrieve_trials, self.parent.parent) + + # get the trials of the children + children_trials = self.map(retrieve_trials, self.children) + children_trials.set_parent(parent_trials) + + adapt_trials(children_trials) + + return sum((node.item["trials"] for node in children_trials.root), []) + + +def _adapt_parent_trials(node, parent_trials_node, ids): + """Adapt trials from the parent recursively + + .. note:: + + To call with node.map(fct, node.parent) to connect with parents + + """ + # Ids from children are passed to prioritized them if they are also present in parent nodes. + node_ids = { + trial.compute_trial_hash(trial, ignore_lie=True) + for trial in node.item["trials"] + } | ids + if parent_trials_node is not None: + adapter = node.item["experiment"].refers["adapter"] + for parent in parent_trials_node.root: + parent.item["trials"] = adapter.forward(parent.item["trials"]) + + # if trial is in current exp, filter out + parent.item["trials"] = [ + trial + for trial in parent.item["trials"] + if trial.compute_trial_hash( + trial, ignore_lie=True, ignore_experiment=True + ) + not in node_ids + ] + + return node.item, parent_trials_node + + +def _adapt_children_trials(node, children_trials_nodes): + """Adapt trials from the children recursively + + .. note:: + + To call with node.map(fct, node.children) to connect with children + + """ + ids = { + trial.compute_trial_hash(trial, ignore_lie=True) + for trial in node.item["trials"] + } + + for child in children_trials_nodes: + adapter = child.item["experiment"].refers["adapter"] + for subchild in child: # Includes child itself + subchild.item["trials"] = adapter.backward(subchild.item["trials"]) + + # if trial is in current node, filter out + subchild.item["trials"] = [ + trial + for trial in subchild.item["trials"] + if trial.compute_trial_hash( + trial, ignore_lie=True, ignore_experiment=True + ) + not in ids + ] + + return node.item, children_trials_nodes + + +def adapt_trials(trials_tree): + """Adapt trials recursively so that they are all compatible with current experiment.""" + trials_tree.map(_adapt_children_trials, trials_tree.children) + ids = set() + for child in trials_tree.children: + for trial in child.item["trials"]: + ids.add(trial.compute_trial_hash(trial, ignore_lie=True)) + trials_tree.map( + functools.partial(_adapt_parent_trials, ids=ids), trials_tree.parent + ) diff --git a/src/orion/core/io/database/pickleddb.py b/src/orion/core/io/database/pickleddb.py index f882c18a9..66f0feab0 100644 --- a/src/orion/core/io/database/pickleddb.py +++ b/src/orion/core/io/database/pickleddb.py @@ -1,310 +1,310 @@ -""" -Pickled Database -================ - -Implement permanent version of :class:`orion.core.io.database.ephemeraldb.EphemeralDB`. - -""" - -import logging -import os -import pickle -from contextlib import contextmanager -from pickle import PicklingError - -import psutil -from filelock import FileLock, SoftFileLock, Timeout - -import orion.core -from orion.core.io.database import Database, DatabaseTimeout -from orion.core.io.database.ephemeraldb import EphemeralDB -from orion.core.utils.compat import replace - -log = logging.getLogger(__name__) - -DEFAULT_HOST = os.path.join(orion.core.DIRS.user_data_dir, "orion", "orion_db.pkl") - -TIMEOUT_ERROR_MESSAGE = """\ -Could not acquire lock for PickledDB after {} seconds. - -This is likely due to one or many of the following scenarios: - -1. There is a large amount of workers and many simultaneous queries. This typically occurs - when the task to optimize is short (few minutes). Try to reduce the amount of workers - at least below 50. - -2. The database is growing large with thousands of trials and many experiments. - If so, you can use a different PickleDB (different file, that is, different `host`) - for each experiment separately to alleviate this issue. - -3. The filesystem is slow. Parallel filesystems on HPC often suffer from - large pool of users generating frequent I/O. In this case try using a separate - partition that may be less affected. - -If you cannot solve the issues listed above that are causing timeouts, you -may need to setup the MongoDB backend for better performance. -See https://orion.readthedocs.io/en/stable/install/database.html -""" - - -def find_unpickable_doc(dict_of_dict): - """Look for a dictionary that cannot be pickled.""" - for name, collection in dict_of_dict.items(): - documents = collection.find() - - for doc in documents: - try: - pickle.dumps(doc) - - except (PicklingError, AttributeError): - return name, doc - - return None, None - - -def find_unpickable_field(doc): - """Look for a field in a dictionary that cannot be pickled""" - if not isinstance(doc, dict): - doc = doc.to_dict() - - for k, v in doc.items(): - try: - pickle.dumps(v) - - except (PicklingError, AttributeError): - return k, v - - return None, None - - -# pylint: disable=too-many-public-methods -class PickledDB(Database): - """Pickled EphemeralDB to support permanancy and concurrency - - This is a very simple and inefficient implementation of a permanent database on disk for Oríon. - The data is loaded from disk for every operation, and every operation is protected with a - filelock. - - Parameters - ---------- - host: str - File path to save pickled ephemeraldb. Default is {user data dir}/orion/orion_db.pkl ex: - $HOME/.local/share/orion/orion_db.pkl - timeout: int - Maximum number of seconds to wait for the lock before raising DatabaseTimeout. - Default is 60. - - """ - - # pylint: disable=unused-argument - def __init__(self, host="", timeout=60, *args, **kwargs): - if host == "": - host = DEFAULT_HOST - super().__init__(host) - - self.host = os.path.abspath(host) - - self.timeout = timeout - - if os.path.dirname(host): - os.makedirs(os.path.dirname(host), exist_ok=True) - - def __repr__(self) -> str: - return f"{type(self).__qualname__}(host={self.host}, timeout={self.timeout})" - - @property - def is_connected(self): - """Return true, always.""" - return True - - def initiate_connection(self): - """Do nothing""" - - def close_connection(self): - """Do nothing""" - - def ensure_index(self, collection_name, keys, unique=False): - """Create given indexes if they do not already exist in database. - - Indexes are only created if `unique` is True. - """ - with self.locked_database() as database: - database.ensure_index(collection_name, keys, unique=unique) - - def index_information(self, collection_name): - """Return dict of names and sorting order of indexes""" - with self.locked_database(write=False) as database: - return database.index_information(collection_name) - - def drop_index(self, collection_name, name): - """Remove index from the database""" - with self.locked_database() as database: - return database.drop_index(collection_name, name) - - def write(self, collection_name, data, query=None): - """Write new information to a collection. Perform insert or update. - - .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. - - """ - with self.locked_database() as database: - return database.write(collection_name, data, query=query) - - def read(self, collection_name, query=None, selection=None): - """Read a collection and return a value according to the query. - - .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. - - """ - with self.locked_database(write=False) as database: - return database.read(collection_name, query=query, selection=selection) - - def read_and_write(self, collection_name, query, data, selection=None): - """Read a collection's document and update the found document. - - Returns the updated document, or None if nothing found. - - .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for - argument documentation. - - """ - with self.locked_database() as database: - return database.read_and_write( - collection_name, query=query, data=data, selection=selection - ) - - def count(self, collection_name, query=None): - """Count the number of documents in a collection which match the `query`. - - .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. - - """ - with self.locked_database(write=False) as database: - return database.count(collection_name, query=query) - - def remove(self, collection_name, query): - """Delete from a collection document[s] which match the `query`. - - .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. - - """ - with self.locked_database() as database: - return database.remove(collection_name, query=query) - - def _get_database(self): - """Read fresh DB state from pickled file""" - if not os.path.exists(self.host): - return EphemeralDB() - - with open(self.host, "rb") as f: - data = f.read() - if not data: - database = EphemeralDB() - else: - database = pickle.loads(data) - - return database - - def _dump_database(self, database): - """Write pickled DB on disk""" - tmp_file = self.host + ".tmp" - - try: - with open(tmp_file, "wb") as f: - pickle.dump(database, f) - - except (PicklingError, AttributeError): - # pylint: disable=protected-access - collection, doc = find_unpickable_doc(database._db) - log.error( - "Document in (collection: %s) is not pickable\ndoc: %s", - collection, - doc.to_dict() if hasattr(doc, "to_dict") else str(doc), - ) - - key, value = find_unpickable_field(doc) - log.error("because (value %s) in (field: %s) is not pickable", value, key) - raise - - replace(tmp_file, self.host) - - @contextmanager - def locked_database(self, write=True): - """Lock database file during wrapped operation call.""" - lock = _create_lock(self.host + ".lock") - - try: - with lock.acquire(timeout=self.timeout): - database = self._get_database() - - yield database - - if write: - self._dump_database(database) - except Timeout as e: - raise DatabaseTimeout(TIMEOUT_ERROR_MESSAGE.format(self.timeout)) from e - - @classmethod - def get_defaults(cls): - """Get database arguments needed to create a database instance. - - .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` - for argument documentation. - - """ - return {"host": DEFAULT_HOST} - - -local_file_systems = ["ext2", "ext3", "ext4", "ntfs"] - - -def _fs_support_globalflock(file_system): - if file_system.fstype == "lustre": - return ("flock" in file_system.opts) and ("localflock" not in file_system.opts) - - elif file_system.fstype == "beegfs": - return "tuneUseGlobalFileLocks" in file_system.opts - - elif file_system.fstype == "gpfs": - return True - - elif file_system.fstype == "nfs": - return False - - return file_system.fstype in local_file_systems - - -def _find_mount_point(path): - """Finds the mount point used to access `path`.""" - path = os.path.abspath(path) - while not os.path.ismount(path): - path = os.path.dirname(path) - - return path - - -def _get_fs(path): - """Gets info about the filesystem on which `path` lives.""" - mount = _find_mount_point(path) - - for file_system in psutil.disk_partitions(True): - if file_system.mountpoint == mount: - return file_system - - return None - - -def _create_lock(path): - """Create lock based on file system capabilities - - Determine if we can rely on the fcntl module for locking files. - Otherwise, fallback on using the directory creation atomicity as a locking mechanism. - """ - file_system = _get_fs(path) - - if _fs_support_globalflock(file_system): - log.debug("Using flock.") - return FileLock(path) - else: - log.debug("Cluster does not support flock. Falling back to softfilelock.") - return SoftFileLock(path) +""" +Pickled Database +================ + +Implement permanent version of :class:`orion.core.io.database.ephemeraldb.EphemeralDB`. + +""" + +import logging +import os +import pickle +from contextlib import contextmanager +from pickle import PicklingError + +import psutil +from filelock import FileLock, SoftFileLock, Timeout + +import orion.core +import orion.core.utils.compat +from orion.core.io.database import Database, DatabaseTimeout +from orion.core.io.database.ephemeraldb import EphemeralDB + +log = logging.getLogger(__name__) + +DEFAULT_HOST = os.path.join(orion.core.DIRS.user_data_dir, "orion", "orion_db.pkl") + +TIMEOUT_ERROR_MESSAGE = """\ +Could not acquire lock for PickledDB after {} seconds. + +This is likely due to one or many of the following scenarios: + +1. There is a large amount of workers and many simultaneous queries. This typically occurs + when the task to optimize is short (few minutes). Try to reduce the amount of workers + at least below 50. + +2. The database is growing large with thousands of trials and many experiments. + If so, you can use a different PickleDB (different file, that is, different `host`) + for each experiment separately to alleviate this issue. + +3. The filesystem is slow. Parallel filesystems on HPC often suffer from + large pool of users generating frequent I/O. In this case try using a separate + partition that may be less affected. + +If you cannot solve the issues listed above that are causing timeouts, you +may need to setup the MongoDB backend for better performance. +See https://orion.readthedocs.io/en/stable/install/database.html +""" + + +def find_unpickable_doc(dict_of_dict): + """Look for a dictionary that cannot be pickled.""" + for name, collection in dict_of_dict.items(): + documents = collection.find() + + for doc in documents: + try: + pickle.dumps(doc) + + except (PicklingError, AttributeError): + return name, doc + + return None, None + + +def find_unpickable_field(doc): + """Look for a field in a dictionary that cannot be pickled""" + if not isinstance(doc, dict): + doc = doc.to_dict() + + for k, v in doc.items(): + try: + pickle.dumps(v) + + except (PicklingError, AttributeError): + return k, v + + return None, None + + +# pylint: disable=too-many-public-methods +class PickledDB(Database): + """Pickled EphemeralDB to support permanancy and concurrency + + This is a very simple and inefficient implementation of a permanent database on disk for Oríon. + The data is loaded from disk for every operation, and every operation is protected with a + filelock. + + Parameters + ---------- + host: str + File path to save pickled ephemeraldb. Default is {user data dir}/orion/orion_db.pkl ex: + $HOME/.local/share/orion/orion_db.pkl + timeout: int + Maximum number of seconds to wait for the lock before raising DatabaseTimeout. + Default is 60. + + """ + + # pylint: disable=unused-argument + def __init__(self, host="", timeout=60, *args, **kwargs): + if host == "": + host = DEFAULT_HOST + super().__init__(host) + + self.host = os.path.abspath(host) + + self.timeout = timeout + + if os.path.dirname(host): + os.makedirs(os.path.dirname(host), exist_ok=True) + + def __repr__(self) -> str: + return f"{type(self).__qualname__}(host={self.host}, timeout={self.timeout})" + + @property + def is_connected(self): + """Return true, always.""" + return True + + def initiate_connection(self): + """Do nothing""" + + def close_connection(self): + """Do nothing""" + + def ensure_index(self, collection_name, keys, unique=False): + """Create given indexes if they do not already exist in database. + + Indexes are only created if `unique` is True. + """ + with self.locked_database() as database: + database.ensure_index(collection_name, keys, unique=unique) + + def index_information(self, collection_name): + """Return dict of names and sorting order of indexes""" + with self.locked_database(write=False) as database: + return database.index_information(collection_name) + + def drop_index(self, collection_name, name): + """Remove index from the database""" + with self.locked_database() as database: + return database.drop_index(collection_name, name) + + def write(self, collection_name, data, query=None): + """Write new information to a collection. Perform insert or update. + + .. seealso:: :meth:`orion.core.io.database.Database.write` for argument documentation. + + """ + with self.locked_database() as database: + return database.write(collection_name, data, query=query) + + def read(self, collection_name, query=None, selection=None): + """Read a collection and return a value according to the query. + + .. seealso:: :meth:`orion.core.io.database.Database.read` for argument documentation. + + """ + with self.locked_database(write=False) as database: + return database.read(collection_name, query=query, selection=selection) + + def read_and_write(self, collection_name, query, data, selection=None): + """Read a collection's document and update the found document. + + Returns the updated document, or None if nothing found. + + .. seealso:: :meth:`orion.core.io.database.Database.read_and_write` for + argument documentation. + + """ + with self.locked_database() as database: + return database.read_and_write( + collection_name, query=query, data=data, selection=selection + ) + + def count(self, collection_name, query=None): + """Count the number of documents in a collection which match the `query`. + + .. seealso:: :meth:`orion.core.io.database.Database.count` for argument documentation. + + """ + with self.locked_database(write=False) as database: + return database.count(collection_name, query=query) + + def remove(self, collection_name, query): + """Delete from a collection document[s] which match the `query`. + + .. seealso:: :meth:`orion.core.io.database.Database.remove` for argument documentation. + + """ + with self.locked_database() as database: + return database.remove(collection_name, query=query) + + def _get_database(self): + """Read fresh DB state from pickled file""" + if not os.path.exists(self.host): + return EphemeralDB() + + with open(self.host, "rb") as f: + data = f.read() + if not data: + database = EphemeralDB() + else: + database = pickle.loads(data) + + return database + + def _dump_database(self, database): + """Write pickled DB on disk""" + tmp_file = self.host + ".tmp" + + try: + with open(tmp_file, "wb") as f: + pickle.dump(database, f) + + except (PicklingError, AttributeError): + # pylint: disable=protected-access + collection, doc = find_unpickable_doc(database._db) + log.error( + "Document in (collection: %s) is not pickable\ndoc: %s", + collection, + doc.to_dict() if hasattr(doc, "to_dict") else str(doc), + ) + + key, value = find_unpickable_field(doc) + log.error("because (value %s) in (field: %s) is not pickable", value, key) + raise + + orion.core.utils.compat.replace(tmp_file, self.host) + + @contextmanager + def locked_database(self, write=True): + """Lock database file during wrapped operation call.""" + lock = _create_lock(self.host + ".lock") + + try: + with lock.acquire(timeout=self.timeout): + database = self._get_database() + + yield database + + if write: + self._dump_database(database) + except Timeout as e: + raise DatabaseTimeout(TIMEOUT_ERROR_MESSAGE.format(self.timeout)) from e + + @classmethod + def get_defaults(cls): + """Get database arguments needed to create a database instance. + + .. seealso:: :meth:`orion.core.io.database.Database.get_defaults` + for argument documentation. + + """ + return {"host": DEFAULT_HOST} + + +local_file_systems = ["ext2", "ext3", "ext4", "ntfs"] + + +def _fs_support_globalflock(file_system): + if file_system.fstype == "lustre": + return ("flock" in file_system.opts) and ("localflock" not in file_system.opts) + + elif file_system.fstype == "beegfs": + return "tuneUseGlobalFileLocks" in file_system.opts + + elif file_system.fstype == "gpfs": + return True + + elif file_system.fstype == "nfs": + return False + + return file_system.fstype in local_file_systems + + +def _find_mount_point(path): + """Finds the mount point used to access `path`.""" + path = os.path.abspath(path) + while not os.path.ismount(path): + path = os.path.dirname(path) + + return path + + +def _get_fs(path): + """Gets info about the filesystem on which `path` lives.""" + mount = _find_mount_point(path) + + for file_system in psutil.disk_partitions(True): + if file_system.mountpoint == mount: + return file_system + + return None + + +def _create_lock(path): + """Create lock based on file system capabilities + + Determine if we can rely on the fcntl module for locking files. + Otherwise, fallback on using the directory creation atomicity as a locking mechanism. + """ + file_system = _get_fs(path) + + if _fs_support_globalflock(file_system): + log.debug("Using flock.") + return FileLock(path) + else: + log.debug("Cluster does not support flock. Falling back to softfilelock.") + return SoftFileLock(path) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 696521055..72fa18710 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -84,6 +84,7 @@ from typing import Any, TypeVar import orion.core +import orion.core.utils.compat from orion.algo.base import BaseAlgorithm, algo_factory from orion.algo.space import Space from orion.core.evc.adapters import BaseAdapter @@ -95,7 +96,6 @@ from orion.core.io.interactive_commands.branching_prompt import BranchingPrompt from orion.core.io.space_builder import SpaceBuilder from orion.core.utils import backward -from orion.core.utils.compat import getuser from orion.core.utils.exceptions import ( BranchingEvent, NoConfigurationError, @@ -934,9 +934,11 @@ def _default(v: T | None, default: V) -> T | V: knowledge_base=knowledge_base, ) + username = orion.core.utils.compat.getuser() + max_broken = _default(max_broken, orion.core.config.experiment.max_broken) working_dir = _default(working_dir, orion.core.config.experiment.working_dir) - metadata = _default(metadata, {"user": _default(user, getuser())}) + metadata = _default(metadata, {"user": _default(user, username)}) refers = _default(refers, dict(parent_id=None, root_id=None, adapter=[])) refers["adapter"] = _instantiate_adapters(refers.get("adapter", [])) # type: ignore diff --git a/src/orion/core/io/resolve_config.py b/src/orion/core/io/resolve_config.py index c9c87ea1c..5478edac1 100644 --- a/src/orion/core/io/resolve_config.py +++ b/src/orion/core/io/resolve_config.py @@ -14,8 +14,8 @@ import orion import orion.core +import orion.core.utils.compat from orion.core.io.orion_cmdline_parser import OrionCmdlineParser -from orion.core.utils.compat import getuser from orion.core.utils.flatten import unflatten @@ -267,7 +267,8 @@ def fetch_env_vars(): def fetch_metadata(user=None, user_args=None, user_script_config=None): """Infer rest information about the process + versioning""" - metadata = {"user": user if user else getuser()} + username = orion.core.utils.compat.getuser() + metadata = {"user": user if user else username} metadata["orion_version"] = orion.core.__version__ @@ -300,7 +301,8 @@ def fetch_metadata(user=None, user_args=None, user_script_config=None): def update_metadata(metadata): """Update information about the process + versioning""" - metadata.setdefault("user", getuser()) + username = orion.core.utils.compat.getuser() + metadata.setdefault("user", username) metadata["orion_version"] = orion.core.__version__ if not metadata.get("user_args"): diff --git a/src/orion/storage/sql_impl.py b/src/orion/storage/sql_impl.py index a56754828..e2853af11 100644 --- a/src/orion/storage/sql_impl.py +++ b/src/orion/storage/sql_impl.py @@ -30,8 +30,8 @@ from sqlalchemy.orm import Session, declarative_base import orion.core +import orion.core.utils.compat from orion.core.io.database import DuplicateKeyError -from orion.core.utils.compat import getuser from orion.core.worker.trial import Trial as OrionTrial from orion.core.worker.trial import validate_status from orion.storage.base import ( @@ -195,7 +195,7 @@ def __init__(self, uri, token=None, **kwargs): self._connect(token) def _connect(self, token): - name = getuser() + name = orion.core.utils.compat.getuser() user = self._find_user(name, token) diff --git a/src/orion/testing/__init__.py b/src/orion/testing/__init__.py index 175f50ade..e2056d660 100644 --- a/src/orion/testing/__init__.py +++ b/src/orion/testing/__init__.py @@ -29,7 +29,7 @@ "user": "default_user", "user_script": "abc", "priors": {"x": "uniform(0, 10)"}, - "datetime": "2017-11-23T02:00:00", + "datetime": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "orion_version": "XYZ", }, "algorithms": {"random": {"seed": 1}}, @@ -39,7 +39,7 @@ "experiment": "default_name", "status": "new", # new, reserved, suspended, completed, broken "worker": None, - "submit_time": "2017-11-23T02:00:00", + "submit_time": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "start_time": None, "end_time": None, "heartbeat": None, diff --git a/tests/conftest.py b/tests/conftest.py index dd9fa68bf..55cad798b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ import orion.core import orion.core.utils.backward as backward -import orion.core.utils.compat as getpass +import orion.core.utils.compat from orion.algo.base import BaseAlgorithm from orion.algo.space import Space from orion.core.io import resolve_config @@ -364,7 +364,31 @@ def fixed_dictionary(user_script): @pytest.fixture() def with_user_userxyz(monkeypatch): """Make ``getpass.getuser()`` return ``'userxyz'``.""" - monkeypatch.setattr(getpass, "getuser", lambda: "userxyz") + monkeypatch.setattr(orion.core.utils.compat, "getuser", lambda: "userxyz") + + +@pytest.fixture() +def with_user_tsirif(monkeypatch): + """Make ``getpass.getuser()`` return ``'tsirif'``.""" + monkeypatch.setattr(orion.core.utils.compat, "getuser", lambda: "tsirif") + + +@pytest.fixture() +def with_user_bouthilx(monkeypatch): + """Make ``getpass.getuser()`` return ``'bouthilx'``.""" + monkeypatch.setattr(orion.core.utils.compat, "getuser", lambda: "bouthilx") + + +@pytest.fixture() +def with_user_dendi(monkeypatch): + """Make ``getpass.getuser()`` return ``'dendi'``.""" + monkeypatch.setattr(orion.core.utils.compat, "getuser", lambda: "dendi") + + +@pytest.fixture() +def with_user_corneau(monkeypatch): + """Make ``getpass.getuser()`` return ``'corneau'``.""" + monkeypatch.setattr(orion.core.utils.compat, "getuser", lambda: "corneau") @pytest.fixture() diff --git a/tests/functional/commands/test_insert_command.py b/tests/functional/commands/test_insert_command.py index bcabb4478..69ad8a199 100644 --- a/tests/functional/commands/test_insert_command.py +++ b/tests/functional/commands/test_insert_command.py @@ -5,7 +5,6 @@ import pytest import orion.core.cli -import orion.core.utils.compat as getpass def get_user_corneau(): @@ -13,10 +12,10 @@ def get_user_corneau(): return "corneau" +@pytest.mark.usefixtures("with_user_corneau") def test_insert_invalid_experiment(storage, monkeypatch, capsys): """Test the insertion of an invalid experiment""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) returncode = orion.core.cli.main( [ @@ -37,11 +36,10 @@ def test_insert_invalid_experiment(storage, monkeypatch, capsys): assert "Error: No experiment with given name 'dumb_experiment'" -@pytest.mark.usefixtures("only_experiments_db", "version_XYZ") +@pytest.mark.usefixtures("only_experiments_db", "version_XYZ", "with_user_corneau") def test_insert_single_trial(storage, monkeypatch, script_path): """Try to insert a single trial""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ @@ -71,11 +69,10 @@ def test_insert_single_trial(storage, monkeypatch, script_path): assert trial.params["/x"] == 1 -@pytest.mark.usefixtures("only_experiments_db", "version_XYZ") +@pytest.mark.usefixtures("only_experiments_db", "version_XYZ", "with_user_corneau") def test_insert_single_trial_default_value(storage, monkeypatch): """Try to insert a single trial using a default value""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ @@ -103,11 +100,10 @@ def test_insert_single_trial_default_value(storage, monkeypatch): assert trial.params["/x"] == 1 -@pytest.mark.usefixtures("only_experiments_db") +@pytest.mark.usefixtures("only_experiments_db", "with_user_corneau") def test_insert_with_no_default_value(monkeypatch): """Try to insert a single trial by omitting a namespace with no default value""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -124,11 +120,10 @@ def test_insert_with_no_default_value(monkeypatch): assert "Dimension /x is unspecified and has no default value" in str(exc_info.value) -@pytest.mark.usefixtures("only_experiments_db") +@pytest.mark.usefixtures("only_experiments_db", "with_user_corneau") def test_insert_with_incorrect_namespace(monkeypatch): """Try to insert a single trial with a namespace not inside the experiment space""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -146,11 +141,10 @@ def test_insert_with_incorrect_namespace(monkeypatch): assert "Found namespace outside of experiment space : /p" in str(exc_info.value) -@pytest.mark.usefixtures("only_experiments_db") +@pytest.mark.usefixtures("only_experiments_db", "with_user_corneau") def test_insert_with_outside_bound_value(monkeypatch): """Try to insert a single trial with value outside the distribution's interval""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) with pytest.raises(ValueError) as exc_info: orion.core.cli.main( @@ -168,11 +162,10 @@ def test_insert_with_outside_bound_value(monkeypatch): assert "Value 100 is outside of" in str(exc_info.value) -@pytest.mark.usefixtures("only_experiments_db", "version_XYZ") +@pytest.mark.usefixtures("only_experiments_db", "version_XYZ", "with_user_corneau") def test_insert_two_hyperparameters(storage, monkeypatch): """Try to insert a single trial with two hyperparameters""" monkeypatch.chdir(os.path.dirname(os.path.abspath(__file__))) - monkeypatch.setattr(getpass, "getuser", get_user_corneau) orion.core.cli.main( [ "insert", diff --git a/tests/functional/configuration/test_all_options.py b/tests/functional/configuration/test_all_options.py index ef8d9bf01..d4f0371c9 100644 --- a/tests/functional/configuration/test_all_options.py +++ b/tests/functional/configuration/test_all_options.py @@ -186,6 +186,8 @@ class TestStorage(ConfigurationTestSuite): config = { "storage": { "type": "legacy", + "uri": "config", + "token": "config", "database": { "name": "test_name", "type": "pickleddb", @@ -197,6 +199,8 @@ class TestStorage(ConfigurationTestSuite): env_vars = { "ORION_STORAGE_TYPE": "legacy", + "ORION_STORAGE_TOKEN": "env", + "ORION_STORAGE_URI": "env", "ORION_DB_NAME": "test_env_var_name", "ORION_DB_TYPE": "pickleddb", "ORION_DB_ADDRESS": "${tmp_path}/there.pkl", @@ -206,6 +210,8 @@ class TestStorage(ConfigurationTestSuite): local = { "storage": { "type": "legacy", + "uri": "local", + "token": "local", "database": {"type": "pickleddb", "host": "${tmp_path}/local.pkl"}, } } @@ -242,6 +248,8 @@ def check_env_var_config(self, tmp_path, monkeypatch): assert orion.core.config.storage.to_dict() == { "type": self.env_vars["ORION_STORAGE_TYPE"], + "uri": self.env_vars["ORION_STORAGE_URI"], + "token": self.env_vars["ORION_STORAGE_TOKEN"], "database": { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], @@ -277,6 +285,8 @@ def check_local_config(self, tmp_path, conf_file, monkeypatch): assert orion.core.config.storage.to_dict() == { "type": self.env_vars["ORION_STORAGE_TYPE"], + "uri": self.env_vars["ORION_STORAGE_URI"], + "token": self.env_vars["ORION_STORAGE_TOKEN"], "database": { "name": self.env_vars["ORION_DB_NAME"], "type": self.env_vars["ORION_DB_TYPE"], diff --git a/tests/unittests/core/conftest.py b/tests/unittests/core/conftest.py index 5894aa0a9..3cfbc4bd6 100644 --- a/tests/unittests/core/conftest.py +++ b/tests/unittests/core/conftest.py @@ -9,7 +9,6 @@ import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward -import orion.core.utils.compat as getpass from orion.algo.space import Categorical, Integer, Real, Space from orion.core.evc import conflicts from orion.core.io.convert import JSONConverter, YAMLConverter @@ -138,24 +137,6 @@ def fixed_suggestion(fixed_suggestion_value, space): return format_trials.tuple_to_trial(fixed_suggestion_value, space) -@pytest.fixture() -def with_user_tsirif(monkeypatch): - """Make ``getpass.getuser()`` return ``'tsirif'``.""" - monkeypatch.setattr(getpass, "getuser", lambda: "tsirif") - - -@pytest.fixture() -def with_user_bouthilx(monkeypatch): - """Make ``getpass.getuser()`` return ``'bouthilx'``.""" - monkeypatch.setattr(getpass, "getuser", lambda: "bouthilx") - - -@pytest.fixture() -def with_user_dendi(monkeypatch): - """Make ``getpass.getuser()`` return ``'dendi'``.""" - monkeypatch.setattr(getpass, "getuser", lambda: "dendi") - - dendi_exp_config = dict( name="supernaedo2-dendi", space={ diff --git a/tests/unittests/core/io/orion_config.yaml b/tests/unittests/core/io/orion_config.yaml index a8c0dc1be..e0dc5624a 100644 --- a/tests/unittests/core/io/orion_config.yaml +++ b/tests/unittests/core/io/orion_config.yaml @@ -10,3 +10,7 @@ database: type: 'pickleddb' name: 'orion_test' host: '${FILE}' + +storage: + uri: sqlite:// + token: tok diff --git a/tests/unittests/core/io/test_experiment_builder.py b/tests/unittests/core/io/test_experiment_builder.py index 5611eecec..3449bd741 100644 --- a/tests/unittests/core/io/test_experiment_builder.py +++ b/tests/unittests/core/io/test_experiment_builder.py @@ -189,11 +189,13 @@ def test_get_cmd_config(raw_config): assert local_config["max_broken"] == 5 assert local_config["name"] == "voila_voici" assert local_config["storage"] == { + "token": "tok", + "uri": "sqlite://", "database": { "host": "${FILE}", "name": "orion_test", "type": "pickleddb", - } + }, } assert local_config["metadata"] == {"orion_version": "XYZ", "user": "tsirif"} diff --git a/tests/unittests/core/io/test_resolve_config.py b/tests/unittests/core/io/test_resolve_config.py index c78fc9483..3654cf5cd 100644 --- a/tests/unittests/core/io/test_resolve_config.py +++ b/tests/unittests/core/io/test_resolve_config.py @@ -209,11 +209,13 @@ def test_fetch_config(raw_config): config = resolve_config.fetch_config({"config": raw_config}) assert config.pop("storage") == { + "token": "tok", + "uri": "sqlite://", "database": { "host": "${FILE}", "name": "orion_test", "type": "pickleddb", - } + }, } assert config.pop("experiment") == { @@ -240,6 +242,8 @@ def mocked_config(file_object): storage_config = config.pop("storage") database_config = storage_config.pop("database") assert storage_config.pop("type") == orion.core.config.storage.type + assert storage_config.pop("uri") == orion.core.config.storage.uri + assert storage_config.pop("token") == orion.core.config.storage.token assert storage_config == {} assert database_config.pop("host") == orion.core.config.storage.database.host diff --git a/tests/unittests/core/worker/test_experiment.py b/tests/unittests/core/worker/test_experiment.py index 710f61225..de16611ec 100644 --- a/tests/unittests/core/worker/test_experiment.py +++ b/tests/unittests/core/worker/test_experiment.py @@ -119,7 +119,7 @@ def _generate(obj, *args, value): "experiment": 0, "status": "new", # new, reserved, suspended, completed, broken "worker": None, - "submit_time": "2017-11-23T02:00:00", + "submit_time": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "start_time": None, "end_time": None, "heartbeat": None, diff --git a/tests/unittests/core/worker/test_producer.py b/tests/unittests/core/worker/test_producer.py index 454c422cb..75b84e199 100644 --- a/tests/unittests/core/worker/test_producer.py +++ b/tests/unittests/core/worker/test_producer.py @@ -2,6 +2,7 @@ """Collection of tests for :mod:`orion.core.worker.producer`.""" import contextlib import copy +import datetime import threading import time @@ -31,7 +32,7 @@ def update_algorithm(producer): "metadata": { "user": "default_user", "user_script": "abc", - "datetime": "2017-11-23T02:00:00", + "datetime": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "orion_version": "XYZ", }, "algorithms": { diff --git a/tests/unittests/storage/test_legacy.py b/tests/unittests/storage/test_legacy.py index 24481abd7..91a269cf6 100644 --- a/tests/unittests/storage/test_legacy.py +++ b/tests/unittests/storage/test_legacy.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Collection of tests for :mod:`orion.storage`.""" import copy +import datetime import logging import os @@ -22,7 +23,7 @@ "metadata": { "user": "default_user", "user_script": "abc", - "datetime": "2017-11-23T02:00:00", + "datetime": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), }, } @@ -30,7 +31,7 @@ "experiment": "default_name", "status": "new", # new, reserved, suspended, completed, broken "worker": None, - "submit_time": "2017-11-23T02:00:00", + "submit_time": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "start_time": None, "end_time": None, "heartbeat": None, diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index 8ff9ef617..599be6558 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -54,7 +54,7 @@ "experiment": "default_name", "status": "new", # new, reserved, suspended, completed, broken "worker": None, - "submit_time": "2017-11-23T02:00:00", + "submit_time": datetime.datetime.fromisoformat("2017-11-23T02:00:00"), "start_time": None, "end_time": None, "heartbeat": None, From 66c3c9b958ea74efb8ff857251f6fe85280450c0 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 12:19:06 -0500 Subject: [PATCH 21/25] Update algo lock --- src/orion/storage/sql_impl.py | 58 +++++++++++++++++++++++++---------- src/orion/testing/__init__.py | 19 ++++++------ src/orion/testing/state.py | 5 +++ 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/src/orion/storage/sql_impl.py b/src/orion/storage/sql_impl.py index e2853af11..ea0b53fad 100644 --- a/src/orion/storage/sql_impl.py +++ b/src/orion/storage/sql_impl.py @@ -127,7 +127,9 @@ class Algo(Base): # it is one algo per experiment so we could set experiment_id as the primary key # and make it a 1-1 relation _id = Column(Integer, primary_key=True, autoincrement=True) - experiment_id = Column(Integer, ForeignKey("experiments._id"), nullable=False) + experiment_id = Column( + Integer, ForeignKey("experiments._id"), nullable=False, unique=True + ) owner_id = Column(Integer, ForeignKey("users._id"), nullable=False) configuration = Column(JSON) locked = Column(Integer) @@ -173,13 +175,7 @@ def __init__(self, uri, token=None, **kwargs): uri = "sqlite://" # engine_from_config - self.engine = sqlalchemy.create_engine( - uri, - echo=False, - future=True, - json_serializer=to_json, - json_deserializer=from_json, - ) + self.engine = self._create_engine(uri) # Create the schema # sqlite3 can fail on table if it already exist @@ -194,6 +190,16 @@ def __init__(self, uri, token=None, **kwargs): self.user = None self._connect(token) + @staticmethod + def _create_engine(uri): + return sqlalchemy.create_engine( + uri, + echo=False, + future=True, + json_serializer=to_json, + json_deserializer=from_json, + ) + def _connect(self, token): name = orion.core.utils.compat.getuser() @@ -246,7 +252,7 @@ def __getstate__(self): def __setstate__(self, state): self.uri = state["uri"] self.token = state["token"] - self.engine = sqlalchemy.create_engine(self.uri, echo=True, future=True) + self.engine = self._create_engine(self.uri) if self.uri == "sqlite://" or self.uri == "": log.warning("You are serializing an in-memory database, data will be lost") @@ -278,12 +284,11 @@ def create_experiment(self, config): session.add(experiment) session.commit() - session.refresh(experiment) config.update(self._to_experiment(experiment)) - # Alreadyc reate the algo lock as well - self.initialize_algorithm_lock(config["_id"], config.get("algorithms", {})) + # Already create the algo lock as well + self._insert_algorithm_lock(config["_id"], config.get("algorithms", {})) except DBAPIError: raise DuplicateKeyError() @@ -478,7 +483,7 @@ def update_trial( self._set_from_dict(trial, kwargs) session.commit() - return OrionTrial(**self._to_trial(trial)) + return 1 # OrionTrial(**self._to_trial(trial)) def fetch_lost_trials(self, experiment): """See :func:`orion.storage.base.BaseStorageProtocol.fetch_lost_trials`""" @@ -695,8 +700,7 @@ def update_heartbeat(self, trial): # Algorithm # ========= - def initialize_algorithm_lock(self, experiment_id, algorithm_config): - """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" + def _insert_algorithm_lock(self, experiment_id, algorithm_config): with Session(self.engine) as session: algo = Algo( experiment_id=experiment_id, @@ -705,9 +709,29 @@ def initialize_algorithm_lock(self, experiment_id, algorithm_config): locked=0, heartbeat=datetime.datetime.utcnow(), ) + session.add(algo) session.commit() + def initialize_algorithm_lock(self, experiment_id, algorithm_config): + """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" + with Session(self.engine) as session: + stmt = ( + update(Algo) + .where( + Algo.experiment_id == experiment_id, + Algo.owner_id == self.user_id, + ) + .values( + configuration=algorithm_config, + locked=0, + heartbeat=datetime.datetime.utcnow(), + ) + ) + + session.execute(stmt) + session.commit() + def release_algorithm_lock(self, experiment=None, uid=None, new_state=None): """See :func:`orion.storage.base.BaseStorageProtocol.release_algorithm_lock`""" @@ -859,8 +883,8 @@ def _selection(self, table, selection): return selected - def _set_from_dict(self, obj, data, rest=None): - data = deepcopy(data) + def _set_from_dict(self, obj, argdata, rest=None): + data = deepcopy(argdata) meta = dict() while data: k, v = data.popitem() diff --git a/src/orion/testing/__init__.py b/src/orion/testing/__init__.py index e2056d660..f3f1bc231 100644 --- a/src/orion/testing/__init__.py +++ b/src/orion/testing/__init__.py @@ -22,6 +22,16 @@ from orion.serving.webapi import WebApi from orion.testing.state import OrionState + +class MockDatetime(datetime.datetime): + """Fake Datetime""" + + @classmethod + def utcnow(cls): + """Return our random/fixed datetime""" + return default_datetime() + + base_experiment = { "name": "default_name", "version": 0, @@ -246,15 +256,6 @@ def falcon_client(exp_config=None, trial_config=None, statuses=None): yield cfg, experiment, exp_client, falcon_client -class MockDatetime(datetime.datetime): - """Fake Datetime""" - - @classmethod - def utcnow(cls): - """Return our random/fixed datetime""" - return default_datetime() - - @contextlib.contextmanager def mocked_datetime(monkeypatch): """Make ``datetime.datetime.utcnow()`` return an arbitrary date.""" diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index f705de2b7..1c81ccd4b 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -136,6 +136,7 @@ def add_trials(self, *trials): def _set_tables(self): self.trials = [] self.lies = [] + exp_ids = set() for exp in self._experiments: self.storage.create_experiment(exp) @@ -143,6 +144,10 @@ def _set_tables(self): exp = self.storage.fetch_experiments(dict(name=exp["name"]))[0] self.expname_to_uid[exp["name"]] = exp["_id"] + if exp["_id"] in exp_ids: + raise RuntimeError("Duplicate experiment during setup") + + exp_ids.add(exp["_id"]) self.storage.initialize_algorithm_lock(exp["_id"], exp.get("algorithms")) for trial in self._trials: From 17856de9e86aaf77fe1c4ad3f3eb9f648bc6e1f4 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 13:41:27 -0500 Subject: [PATCH 22/25] Replace MockDatetime by freezegun --- src/orion/testing/__init__.py | 24 +++-------------- tests/conftest.py | 9 ++++--- .../core/worker/test_experiment_functional.py | 6 +++-- tests/requirements.txt | 3 ++- tests/unittests/core/conftest.py | 26 +++++++++---------- 5 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/orion/testing/__init__.py b/src/orion/testing/__init__.py index f3f1bc231..62f7a6095 100644 --- a/src/orion/testing/__init__.py +++ b/src/orion/testing/__init__.py @@ -23,13 +23,9 @@ from orion.testing.state import OrionState -class MockDatetime(datetime.datetime): - """Fake Datetime""" - - @classmethod - def utcnow(cls): - """Return our random/fixed datetime""" - return default_datetime() +def default_datetime(): + """Return default datetime""" + return datetime.datetime(1903, 4, 25, 0, 0, 0) base_experiment = { @@ -58,11 +54,6 @@ def utcnow(cls): } -def default_datetime(): - """Return default datetime""" - return datetime.datetime(1903, 4, 25, 0, 0, 0) - - all_status = ["completed", "broken", "reserved", "interrupted", "suspended", "new"] @@ -256,15 +247,6 @@ def falcon_client(exp_config=None, trial_config=None, statuses=None): yield cfg, experiment, exp_client, falcon_client -@contextlib.contextmanager -def mocked_datetime(monkeypatch): - """Make ``datetime.datetime.utcnow()`` return an arbitrary date.""" - with monkeypatch.context() as m: - m.setattr(datetime, "datetime", MockDatetime) - - yield MockDatetime - - class AssertNewFile: def __init__(self, filename): self.filename = filename diff --git a/tests/conftest.py b/tests/conftest.py index 55cad798b..b3cb8f493 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,14 @@ """Common fixtures and utils for unittests and functional tests.""" from __future__ import annotations +import datetime import os from typing import Any import numpy import pytest import yaml +from freezegun import freeze_time from pymongo import MongoClient import orion.core @@ -23,7 +25,8 @@ # So that assert messages show up in tests defined outside testing suite. pytest.register_assert_rewrite("orion.testing") -from orion.testing import OrionState, mocked_datetime + +from orion.testing import OrionState, default_datetime def pytest_addoption(parser): @@ -394,8 +397,8 @@ def with_user_corneau(monkeypatch): @pytest.fixture() def random_dt(monkeypatch): """Make ``datetime.datetime.utcnow()`` return an arbitrary date.""" - with mocked_datetime(monkeypatch) as datetime: - yield datetime.utcnow() + with freeze_time(default_datetime()): + yield datetime.datetime.utcnow() @pytest.fixture(scope="function") diff --git a/tests/functional/core/worker/test_experiment_functional.py b/tests/functional/core/worker/test_experiment_functional.py index 1e6737c2a..c875a18bd 100644 --- a/tests/functional/core/worker/test_experiment_functional.py +++ b/tests/functional/core/worker/test_experiment_functional.py @@ -2,10 +2,12 @@ """Collection of functional tests for :mod:`orion.core.worker.experiment`.""" import logging +from freezegun import freeze_time + from orion.client import build_experiment, get_experiment from orion.core.io.database import DuplicateKeyError from orion.core.worker.trial import Trial -from orion.testing import mocked_datetime +from orion.testing import default_datetime from orion.testing.evc import ( build_child_experiment, build_root_experiment, @@ -146,7 +148,7 @@ def test_fix_lost_trials_in_evc(storage, monkeypatch): `fix_lost_trials` is tested more carefully in experiment's unit-tests (without the EVC). """ - with disable_duplication(monkeypatch), mocked_datetime(monkeypatch): + with disable_duplication(monkeypatch), freeze_time(default_datetime()): build_evc_tree(list(range(5))) for exp_name in ["root", "parent", "experiment", "child", "grand-child"]: diff --git a/tests/requirements.txt b/tests/requirements.txt index 4ebee8d7d..c6d70f8cf 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -6,4 +6,5 @@ pytest-lazy-fixture pytest-custom_exit_code scikit-learn ptera >= 1.1.0 -selenium \ No newline at end of file +selenium +freezegun diff --git a/tests/unittests/core/conftest.py b/tests/unittests/core/conftest.py index 3cfbc4bd6..1303d12ee 100644 --- a/tests/unittests/core/conftest.py +++ b/tests/unittests/core/conftest.py @@ -2,6 +2,7 @@ """Common fixtures and utils for tests.""" import copy +import datetime import os import pytest @@ -15,7 +16,6 @@ from orion.core.io.space_builder import DimensionBuilder from orion.core.utils import format_trials from orion.core.worker.trial import Trial -from orion.testing import MockDatetime TEST_DIR = os.path.dirname(os.path.abspath(__file__)) YAML_SAMPLE = os.path.join(TEST_DIR, "sample_config.yml") @@ -167,9 +167,9 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "completed", "worker": 12512301, - "submit_time": MockDatetime(2017, 11, 22, 23), + "submit_time": datetime.datetime(2017, 11, 22, 23), "start_time": None, - "end_time": MockDatetime(2017, 11, 22, 23), + "end_time": datetime.datetime(2017, 11, 22, 23), "results": [{"name": None, "type": "objective", "value": 3}], "params": [ {"name": "/decoding_layer", "type": "categorical", "value": "rnn"}, @@ -180,9 +180,9 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "completed", "worker": 23415151, - "submit_time": MockDatetime(2017, 11, 23, 0), + "submit_time": datetime.datetime(2017, 11, 23, 0), "start_time": None, - "end_time": MockDatetime(2017, 11, 23, 0), + "end_time": datetime.datetime(2017, 11, 23, 0), "results": [ {"name": "yolo", "type": "objective", "value": 10}, {"name": "contra", "type": "constraint", "value": 1.2}, @@ -201,9 +201,9 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "completed", "worker": 1251231, - "submit_time": MockDatetime(2017, 11, 22, 23), + "submit_time": datetime.datetime(2017, 11, 22, 23), "start_time": None, - "end_time": MockDatetime(2017, 11, 22, 22), + "end_time": datetime.datetime(2017, 11, 22, 22), "results": [ {"name": None, "type": "objective", "value": 2}, {"name": "naedw_grad", "type": "gradient", "value": [-0.1, 2]}, @@ -217,7 +217,7 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "new", "worker": None, - "submit_time": MockDatetime(2017, 11, 23, 1), + "submit_time": datetime.datetime(2017, 11, 23, 1), "start_time": None, "end_time": None, "results": [{"name": None, "type": "objective", "value": None}], @@ -230,7 +230,7 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "new", "worker": None, - "submit_time": MockDatetime(2017, 11, 23, 2), + "submit_time": datetime.datetime(2017, 11, 23, 2), "start_time": None, "end_time": None, "results": [{"name": None, "type": "objective", "value": None}], @@ -247,8 +247,8 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "interrupted", "worker": None, - "submit_time": MockDatetime(2017, 11, 23, 3), - "start_time": MockDatetime(2017, 11, 23, 3), + "submit_time": datetime.datetime(2017, 11, 23, 3), + "start_time": datetime.datetime(2017, 11, 23, 3), "end_time": None, "results": [{"name": None, "type": "objective", "value": None}], "params": [ @@ -264,8 +264,8 @@ def fixed_suggestion(fixed_suggestion_value, space): { "status": "suspended", "worker": None, - "submit_time": MockDatetime(2017, 11, 23, 4), - "start_time": MockDatetime(2017, 11, 23, 4), + "submit_time": datetime.datetime(2017, 11, 23, 4), + "start_time": datetime.datetime(2017, 11, 23, 4), "end_time": None, "results": [{"name": None, "type": "objective", "value": None}], "params": [ From 9d5707b7127e21cbf42a8ad73eef999fdd985b02 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 13:52:24 -0500 Subject: [PATCH 23/25] - --- tests/unittests/storage/test_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/storage/test_storage.py b/tests/unittests/storage/test_storage.py index 599be6558..f4342fe21 100644 --- a/tests/unittests/storage/test_storage.py +++ b/tests/unittests/storage/test_storage.py @@ -34,7 +34,7 @@ ] if not HAS_SQLALCHEMY: - log.warning("Track is not tested because: %s!", REASON) + log.warning("SQLAlchemy is not tested because it is not installed") else: storage_backends.extend( [ From 91cc462a641d0826735f57d0eb15fd5d77acaf86 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 13:56:22 -0500 Subject: [PATCH 24/25] - --- src/orion/core/io/experiment_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/orion/core/io/experiment_builder.py b/src/orion/core/io/experiment_builder.py index 72fa18710..8caa71fcb 100644 --- a/src/orion/core/io/experiment_builder.py +++ b/src/orion/core/io/experiment_builder.py @@ -864,7 +864,7 @@ def _branch_experiment(self, experiment, conflicts, version, branching_arguments return self.create_experiment(mode="x", **config) - # pylint: disable=too-many-arguments + # pylint: disable=too-many-arguments,too-many-locals def create_experiment( self, name: str, From 702791fbe0451598d8124f4dbdd68b8f63601665 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Wed, 23 Nov 2022 15:36:52 -0500 Subject: [PATCH 25/25] - --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51c8400cb..fbf5e6a2a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -145,7 +145,7 @@ jobs: test_no_extras: needs: [pre-commit, pretest] - runs-on: [ubuntu-latest, windows-latest] + runs-on: [ubuntu-latest] steps: - uses: actions/checkout@v1 - name: Set up Python 3.9