From c5ce9e0f54f2eb16fa29a4944b3bd3517275cb6f Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Wed, 29 Nov 2023 17:03:07 -0800 Subject: [PATCH] Access row elements with `row.x` rather than `row[x]`, compatible with sqlalchemy 2.0. Resolves https://github.com/google/vizier/issues/993 PiperOrigin-RevId: 586493566 --- requirements-client.txt | 8 ------- requirements.txt | 2 +- vizier/__init__.py | 2 +- vizier/_src/service/sql_datastore.py | 31 ++++++++++------------------ 4 files changed, 13 insertions(+), 30 deletions(-) delete mode 100644 requirements-client.txt diff --git a/requirements-client.txt b/requirements-client.txt deleted file mode 100644 index 20e6be8dd..000000000 --- a/requirements-client.txt +++ /dev/null @@ -1,8 +0,0 @@ -attrs==23.1.0 -absl-py>=1.0.0 -numpy>=1.21.5 -protobuf>=3.6 -portpicker>=1.3.1 -grpcio>=1.35.0 -grpcio-tools>=1.35.0 -googleapis-common-protos>=1.56.4 diff --git a/requirements.txt b/requirements.txt index aa76424e3..abf37b257 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ portpicker>=1.3.1 grpcio>=1.35.0 grpcio-tools>=1.35.0 googleapis-common-protos>=1.56.4 -sqlalchemy>=1.4,<=1.4.20 +sqlalchemy>=1.4 diff --git a/vizier/__init__.py b/vizier/__init__.py index 7843c132b..eaa719065 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -23,4 +23,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.1.12" +__version__ = "0.1.13" diff --git a/vizier/_src/service/sql_datastore.py b/vizier/_src/service/sql_datastore.py index f6977da65..47167bc55 100644 --- a/vizier/_src/service/sql_datastore.py +++ b/vizier/_src/service/sql_datastore.py @@ -18,7 +18,7 @@ import collections import threading -from typing import Callable, DefaultDict, Iterable, List, Optional +from typing import Callable, Iterable, List, Optional from absl import logging import sqlalchemy as sqla @@ -124,7 +124,7 @@ def load_study(self, study_name: str) -> study_pb2.Study: row = result.fetchone() if not row: raise NotFoundError('Failed to find study name: %s' % study_name) - return study_pb2.Study.FromString(row['serialized_study']) + return study_pb2.Study.FromString(row.serialized_study) def update_study(self, study: study_pb2.Study) -> resources.StudyResource: study_resource = resources.StudyResource.from_name(study.name) @@ -190,9 +190,7 @@ def list_studies(self, owner_name: str) -> List[study_pb2.Study]: raise NotFoundError('Owner name %s does not exist.' % owner_name) result = self._connection.execute(lq).fetchall() - return [ - study_pb2.Study.FromString(row['serialized_study']) for row in result - ] + return [study_pb2.Study.FromString(row.serialized_study) for row in result] def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: trial_resource = resources.TrialResource.from_name(trial.name) @@ -223,7 +221,7 @@ def get_trial(self, trial_name: str) -> study_pb2.Trial: row = result.fetchone() if not row: raise NotFoundError('Failed to find trial name: %s' % trial_name) - return study_pb2.Trial.FromString(row['serialized_trial']) + return study_pb2.Trial.FromString(row.serialized_trial) def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: trial_resource = resources.TrialResource.from_name(trial.name) @@ -269,9 +267,7 @@ def list_trials(self, study_name: str) -> List[study_pb2.Trial]: raise NotFoundError('Study name %s does not exist.' % study_name) result = self._connection.execute(lq) - return [ - study_pb2.Trial.FromString(row['serialized_trial']) for row in result - ] + return [study_pb2.Trial.FromString(row.serialized_trial) for row in result] def delete_trial(self, trial_name: str) -> None: # Exist query @@ -347,7 +343,7 @@ def get_suggestion_operation( row = result.fetchone() if not row: raise NotFoundError('Failed to find suggest op name: %s' % operation_name) - return operations_pb2.Operation.FromString(row['serialized_op']) + return operations_pb2.Operation.FromString(row.serialized_op) def update_suggestion_operation( self, operation: operations_pb2.Operation @@ -407,8 +403,7 @@ def list_suggestion_operations( result = self._connection.execute(q) all_ops = [ - operations_pb2.Operation.FromString(row['serialized_op']) - for row in result + operations_pb2.Operation.FromString(row.serialized_op) for row in result ] if filter_fn is not None: output_list = [] @@ -495,9 +490,7 @@ def get_early_stopping_operation( raise NotFoundError( 'Failed to find early stopping op name: %s' % operation_name ) - return vizier_oss_pb2.EarlyStoppingOperation.FromString( - row['serialized_op'] - ) + return vizier_oss_pb2.EarlyStoppingOperation.FromString(row.serialized_op) def update_early_stopping_operation( self, operation: vizier_oss_pb2.EarlyStoppingOperation @@ -552,7 +545,7 @@ def update_metadata( row = study_result.fetchone() if not row: raise NotFoundError('No such study:', s_resource.name) - original_study = study_pb2.Study.FromString(row['serialized_study']) + original_study = study_pb2.Study.FromString(row.serialized_study) # Store Study-related metadata into the database. vz.metadata_util.merge_study_metadata( @@ -565,9 +558,7 @@ def update_metadata( self._connection.execute(usq) # Split the trial-related metadata by Trial. - split_metadata: DefaultDict[str, List[datastore.UnitMetadataUpdate]] = ( - collections.defaultdict(list) - ) + split_metadata = collections.defaultdict(list) for md in trial_metadata: split_metadata[md.trial_id].append(md) @@ -583,7 +574,7 @@ def update_metadata( row = trial_result.fetchone() if not row: raise NotFoundError('No such trial:', trial_name) - original_trial = study_pb2.Trial.FromString(row['serialized_trial']) + original_trial = study_pb2.Trial.FromString(row.serialized_trial) # Update Trial. vz.metadata_util.merge_trial_metadata(original_trial, md_list)