From 802e6622b07641a90d19a4e044bd137c60b6c878 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 10 Jun 2024 10:18:42 -0500 Subject: [PATCH] Refactor metadata loading --- rubicon_ml/repository/base.py | 208 ++++++++++++++++------------------ 1 file changed, 99 insertions(+), 109 deletions(-) diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 160231a0..3dad5d60 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -3,7 +3,8 @@ import tempfile import warnings from datetime import datetime -from typing import List, Optional +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, TypeVar from zipfile import ZipFile import fsspec @@ -13,6 +14,17 @@ from rubicon_ml.exceptions import RubiconException from rubicon_ml.repository.utils import json, slugify +Domains = TypeVar( + "Domains", + domain.Project, + domain.Experiment, + domain.Artifact, + domain.Dataframe, + domain.Feature, + domain.Metric, + domain.Parameter, +) + class BaseRepository: """The base repository defines all the shared interactions @@ -42,34 +54,34 @@ def __init__(self, root_dir: str, **storage_options): # --- Filesystem Helpers --- - def _cat(self, path): + def _cat(self, path: str): """Returns the contents of the file at `path`.""" return self.filesystem.cat(path) - def _cat_paths(self, metadata_paths): + def _cat_paths(self, metadata_paths: List[str]) -> Dict[str, Any]: """Cat `metadata_paths` to get the list of files to include. Ignore FileNotFoundErrors to avoid misc file errors, like hidden dotfiles. """ - files = [] + files = {} for path, metadata in self.filesystem.cat(metadata_paths, on_error="return").items(): if isinstance(metadata, FileNotFoundError): warning = f"{path} not found. Was this file unintentionally created?" warnings.warn(warning) else: - files.append(metadata) + files[path] = metadata return files - def _exists(self, path): + def _exists(self, path: str) -> bool: """Returns True if a file exists at `path`, False otherwise.""" return self.filesystem.exists(path) - def _glob(self, globstring): + def _glob(self, globstring: str): """Returns the names of the files matching `globstring`.""" return self.filesystem.glob(globstring, detail=True) - def _ls_directories_only(self, path): + def _ls_directories_only(self, path: str) -> List[str]: """Returns the names of all the directories at path `path`.""" directories = [ os.path.join(p.get("name"), "metadata.json") @@ -79,14 +91,14 @@ def _ls_directories_only(self, path): return directories - def _ls(self, path): + def _ls(self, path: str): return self.filesystem.ls(path) - def _mkdir(self, dirpath): + def _mkdir(self, dirpath: str): """Creates a directory `dirpath` with parents.""" return self.filesystem.mkdirs(dirpath, exist_ok=True) - def _modified(self, path): + def _modified(self, path: str): return self.filesystem.modified(path) def _persist_bytes(self, bytes_data, path): @@ -125,6 +137,33 @@ def _rm(self, path): """Recursively remove all files at `path`.""" return self.filesystem.rm(path, recursive=True) + def _load_metadata_files(self, metadata_root: str, domain_type: type[Domains]) -> List[Domains]: + """Load metadata files from the given root directory and return a list of domain objects.""" + # find all directories, prepare a list of those plus `metadata.yaml` + metadata_paths = self._ls_directories_only(metadata_root) + + loaded_domains = [] + # cat_paths will check for FileNotFoundErrors and skip any missing files + # it loads the contents of the found files + for path, metadata in self._cat_paths(metadata_paths).items(): + try: + metadata_contents = json.loads(metadata) + except JSONDecodeError: + warnings.warn(f"Failed to load metadata for {domain_type.__name__} at {path}") + continue + + try: + loaded_domain = domain_type(**metadata_contents) + except TypeError: + warnings.warn(f"Failed to load {domain_type.__name__} from metadata at {path}") + continue + + loaded_domains.append(loaded_domain) + + loaded_domains.sort(key=lambda d: d.created_at) + + return loaded_domains + # -------- Projects -------- def _get_project_metadata_path(self, project_name): @@ -148,7 +187,7 @@ def create_project(self, project): self._persist_domain(project, project_metadata_path) - def get_project(self, project_name): + def get_project(self, project_name: str) -> domain.Project: """Retrieve a project from the configured filesystem. Parameters @@ -170,7 +209,7 @@ def get_project(self, project_name): return domain.Project(**project) - def get_projects(self): + def get_projects(self) -> List[domain.Project]: """Get the list of projects from the filesystem. Returns @@ -178,17 +217,7 @@ def get_projects(self): list of rubicon.domain.Project The list of projects from the filesystem. """ - try: - project_metadata_paths = self._ls_directories_only(self.root_dir) - projects = [ - domain.Project(**json.loads(metadata)) - for metadata in self._cat_paths(project_metadata_paths) - ] - projects.sort(key=lambda p: p.created_at) - except FileNotFoundError: - return [] - - return projects + return self._load_metadata_files(self.root_dir, domain.Project) # ------ Experiments ------- @@ -220,7 +249,7 @@ def create_experiment(self, experiment): self._persist_domain(experiment, experiment_metadata_path) - def get_experiment(self, project_name, experiment_id): + def get_experiment(self, project_name: str, experiment_id: str) -> domain.Experiment: """Retrieve an experiment from the configured filesystem. Parameters @@ -244,7 +273,7 @@ def get_experiment(self, project_name, experiment_id): return domain.Experiment(**experiment) - def get_experiments(self, project_name): + def get_experiments(self, project_name: str) -> List[domain.Experiment]: """Retrieve all experiments from the configured filesystem that belong to the project with name `project_name`. @@ -261,18 +290,7 @@ def get_experiments(self, project_name): `project_name`. """ experiment_metadata_root = self._get_experiment_metadata_root(project_name) - - try: - experiment_metadata_paths = self._ls_directories_only(experiment_metadata_root) - experiments = [ - domain.Experiment(**json.loads(metadata)) - for metadata in self._cat_paths(experiment_metadata_paths) - ] - experiments.sort(key=lambda e: e.created_at) - except FileNotFoundError: - return [] - - return experiments + return self._load_metadata_files(experiment_metadata_root, domain.Experiment) # ------- Archiving -------- @@ -330,7 +348,10 @@ def _archive( return zip_archive_filename def _experiments_from_archive( - self, project_name, remote_rubicon_root: str, latest_only: Optional[bool] = False + self, + project_name, + remote_rubicon_root: str, + latest_only: Optional[bool] = False, ): """Retrieve archived experiments into this project's experiments folder. @@ -485,18 +506,7 @@ def get_artifacts_metadata(self, project_name, experiment_id=None): The artifacts logged to the specified object. """ artifact_metadata_root = self._get_artifact_metadata_root(project_name, experiment_id) - - try: - artifact_metadata_paths = self._ls_directories_only(artifact_metadata_root) - artifacts = [ - domain.Artifact(**json.loads(metadata)) - for metadata in self._cat_paths(artifact_metadata_paths) - ] - artifacts.sort(key=lambda a: a.created_at) - except FileNotFoundError: - return [] - - return artifacts + return self._load_metadata_files(artifact_metadata_root, domain.Artifact) def get_artifact_data(self, project_name, artifact_id, experiment_id=None): """Retrieve an artifact's raw data. @@ -672,7 +682,9 @@ def get_dataframe_metadata(self, project_name, dataframe_id, experiment_id=None) return domain.Dataframe(**dataframe) - def get_dataframes_metadata(self, project_name, experiment_id=None): + def get_dataframes_metadata( + self, project_name: str, experiment_id: Optional[str] = None + ) -> List[domain.Dataframe]: """Retrieve all dataframes' metadata from the configured filesystem that belong to the specified object. @@ -692,18 +704,7 @@ def get_dataframes_metadata(self, project_name, experiment_id=None): The dataframes logged to the specified object. """ dataframe_metadata_root = self._get_dataframe_metadata_root(project_name, experiment_id) - - try: - dataframe_metadata_paths = self._ls_directories_only(dataframe_metadata_root) - dataframes = [ - domain.Dataframe(**json.loads(metadata)) - for metadata in self._cat_paths(dataframe_metadata_paths) - ] - dataframes.sort(key=lambda d: d.created_at) - except FileNotFoundError: - return [] - - return dataframes + return self._load_metadata_files(dataframe_metadata_root, domain.Dataframe) def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_type="pandas"): """Retrieve a dataframe's raw data. @@ -803,7 +804,9 @@ def create_feature(self, feature, project_name, experiment_id): self._persist_domain(feature, feature_metadata_path) - def get_feature(self, project_name, experiment_id, feature_name): + def get_feature( + self, project_name: str, experiment_id: str, feature_name: str + ) -> domain.Feature: """Retrieve a feature from the configured filesystem. Parameters @@ -832,7 +835,7 @@ def get_feature(self, project_name, experiment_id, feature_name): return domain.Feature(**feature) - def get_features(self, project_name, experiment_id): + def get_features(self, project_name: str, experiment_id: str) -> List[domain.Feature]: """Retrieve all features from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -852,18 +855,7 @@ def get_features(self, project_name, experiment_id): `experiment_id`. """ feature_metadata_root = self._get_feature_metadata_root(project_name, experiment_id) - - try: - feature_metadata_paths = self._ls_directories_only(feature_metadata_root) - features = [ - domain.Feature(**json.loads(metadata)) - for metadata in self._cat_paths(feature_metadata_paths) - ] - features.sort(key=lambda f: f.created_at) - except FileNotFoundError: - return [] - - return features + return self._load_metadata_files(feature_metadata_root, domain.Feature) # -------- Metrics --------- @@ -905,7 +897,7 @@ def create_metric(self, metric, project_name, experiment_id): self._persist_domain(metric, metric_metadata_path) - def get_metric(self, project_name, experiment_id, metric_name): + def get_metric(self, project_name: str, experiment_id: str, metric_name: str) -> domain.Metric: """Retrieve a metric from the configured filesystem. Parameters @@ -934,7 +926,7 @@ def get_metric(self, project_name, experiment_id, metric_name): return domain.Metric(**metric) - def get_metrics(self, project_name, experiment_id): + def get_metrics(self, project_name: str, experiment_id: str) -> List[domain.Metric]: """Retrieve all metrics from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -954,18 +946,7 @@ def get_metrics(self, project_name, experiment_id): `experiment_id`. """ metric_metadata_root = self._get_metric_metadata_root(project_name, experiment_id) - - try: - metric_metadata_paths = self._ls_directories_only(metric_metadata_root) - metrics = [ - domain.Metric(**json.loads(metadata)) - for metadata in self._cat_paths(metric_metadata_paths) - ] - metrics.sort(key=lambda m: m.created_at) - except FileNotFoundError: - return [] - - return metrics + return self._load_metadata_files(metric_metadata_root, domain.Metric) # ------- Parameters ------- @@ -1035,7 +1016,7 @@ def get_parameter(self, project_name, experiment_id, parameter_name): return domain.Parameter(**parameter) - def get_parameters(self, project_name, experiment_id): + def get_parameters(self, project_name: str, experiment_id: str) -> List[domain.Parameter]: """Retrieve all parameters from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -1055,18 +1036,7 @@ def get_parameters(self, project_name, experiment_id): `experiment_id`. """ parameter_metadata_root = self._get_parameter_metadata_root(project_name, experiment_id) - - try: - parameter_metadata_paths = self._ls_directories_only(parameter_metadata_root) - parameters = [ - domain.Parameter(**json.loads(metadata)) - for metadata in self._cat_paths(parameter_metadata_paths) - ] - parameters.sort(key=lambda p: p.created_at) - except FileNotFoundError: - return [] - - return parameters + return self._load_metadata_files(parameter_metadata_root, domain.Parameter) # ---------- Tags ---------- @@ -1101,7 +1071,12 @@ def _get_tag_metadata_root( return f"{entity_metadata_root}/{entity_identifier}" def add_tags( - self, project_name, tags, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + tags, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Persist tags to the configured filesystem. @@ -1130,7 +1105,12 @@ def add_tags( self._persist_domain({"added_tags": tags}, tag_metadata_path) def remove_tags( - self, project_name, tags, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + tags, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Delete tags from the configured filesystem. @@ -1225,7 +1205,12 @@ def _get_comment_metadata_root( ) def add_comments( - self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + comments, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Persist comments to the configured filesystem. @@ -1254,7 +1239,12 @@ def add_comments( self._persist_domain({"added_comments": comments}, comment_metadata_path) def remove_comments( - self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + comments, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Delete comments from the configured filesystem.