From 3edacfb6f90c1c031439dba8a450955d99abad8b Mon Sep 17 00:00:00 2001 From: sonaalthaker Date: Wed, 6 Mar 2024 10:49:31 -0800 Subject: [PATCH] additional tests --- rubicon_ml/client/project.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 4acb0b5c..cf60a859 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -8,7 +8,7 @@ import pandas as pd from rubicon_ml import domain -from rubicon_ml.client import ArtifactMixin, Base, DataframeMixin, Experiment, TagMixin +from rubicon_ml.client import ArtifactMixin, Base, DataframeMixin, Experiment from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException @@ -20,7 +20,7 @@ from rubicon_ml.domain import Project as ProjectDomain -class Project(Base, ArtifactMixin, DataframeMixin, SchemaMixin, TagMixin): +class Project(Base, ArtifactMixin, DataframeMixin, SchemaMixin): """A client project. A `project` is a collection of `experiments`, @@ -74,6 +74,7 @@ def _create_experiment_domain( commit_hash, training_metadata, tags, + comments, ): """Instantiates and returns an experiment domain object.""" if self.is_auto_git_enabled: @@ -94,6 +95,7 @@ def _create_experiment_domain( commit_hash=commit_hash, training_metadata=training_metadata, tags=tags, + comments=comments, ) def _group_experiments(self, experiments: List[Experiment], group_by: Optional[str] = None): @@ -218,6 +220,7 @@ def log_experiment( commit_hash: Optional[str] = None, training_metadata: Optional[Union[Tuple, List[Tuple]]] = None, tags: Optional[List[str]] = None, + comments: Optional[List[str]] = None, ) -> Experiment: """Log a new experiment to this project. @@ -248,6 +251,8 @@ def log_experiment( to differentiate between the type of model or classifier used during the experiment (i.e. `linear regression` or `random forest`). + comments : list of str, optional + Values to comment the experiment with. Returns ------- @@ -260,6 +265,14 @@ def log_experiment( if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") + if comments is None: + comments = [] + + if not isinstance(comments, list) or not all( + [isinstance(comment, str) for comment in comments] + ): + raise ValueError("`comments` must be `list` of type `str`") + experiment = self._create_experiment_domain( name, description, @@ -268,6 +281,7 @@ def log_experiment( commit_hash, training_metadata, tags, + comments, ) for repo in self.repositories: repo.create_experiment(experiment)