Skip to content

Commit

Permalink
More changes throughout to ease multiple repos
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenpardy committed Nov 14, 2023
1 parent dc068fd commit 4700ea9
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 38 deletions.
36 changes: 28 additions & 8 deletions rubicon_ml/client/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union

from rubicon_ml.exceptions import RubiconException

Expand All @@ -21,29 +21,49 @@ class Base:
The config, which injects the repository to use.
"""

def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None):
def __init__(self, domain: DOMAIN_TYPES, config: Optional[Union[Config, List[Config]]] = None):
self._config = config
self._domain = domain

def __str__(self) -> str:
return self._domain.__str__()

def is_auto_git_enabled(self) -> bool:
"""Is git enabled for any of the configs."""
if isinstance(self._config, list):
return any(_config.is_auto_git_enabled for _config in self._config)

if self._config is None:
return False

return self._config.is_auto_git_enabled

def _raise_rubicon_exception(self, exception: Exception):
if len(self.repositories) > 1:
if self.repositories is None or len(self.repositories) > 1:
raise RubiconException("all configured storage backends failed") from exception
else:
raise exception

@property
def repository(self) -> Optional[BaseRepository]:
return self._config.repository if self._config is not None else None
"""Get the repository."""
if self._config is None:
return None

if isinstance(self._config, list):
if len(self._config) > 1:
raise ValueError("More than one repository available. Use `.repositories` instead.")
return self._config[0].repository

return self._config.repository

@property
def repositories(self) -> Optional[List[BaseRepository]]:
"""Get all repositories."""
if self._config is None:
return None

if hasattr(self._config, "repositories"):
return self._config.repositories
else:
return [self._config.repository]
if isinstance(self._config, list):
return [_config.repository for _config in self._config]

return [self._config.repository]
4 changes: 2 additions & 2 deletions rubicon_ml/client/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Project(Base, ArtifactMixin, DataframeMixin, SchemaMixin):
The config, which specifies the underlying repository.
"""

def __init__(self, domain: ProjectDomain, config: Optional[Config] = None):
def __init__(self, domain: ProjectDomain, config: Optional[Union[Config, List[Config]]] = None):
super().__init__(domain, config)

self._domain: ProjectDomain
Expand Down Expand Up @@ -76,7 +76,7 @@ def _create_experiment_domain(
tags,
):
"""Instantiates and returns an experiment domain object."""
if self._config.is_auto_git_enabled:
if self.is_auto_git_enabled:
if branch_name is None:
branch_name = self._get_branch_name()
if commit_hash is None:
Expand Down
20 changes: 13 additions & 7 deletions rubicon_ml/client/rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def _get_github_url(self):
def _create_project_domain(
self,
name: str,
description: str,
github_url: str,
training_metadata: Union[List[Tuple], Tuple],
description: Optional[str],
github_url: Optional[str],
training_metadata: Optional[Union[List[Tuple], Tuple]],
):
"""Instantiates and returns a project domain object."""
if self.config.is_auto_git_enabled and github_url is None:
Expand All @@ -121,7 +121,13 @@ def _create_project_domain(
)

@failsafe
def create_project(self, name, description=None, github_url=None, training_metadata=None):
def create_project(
self,
name: str,
description: Optional[str] = None,
github_url: Optional[str] = None,
training_metadata: Optional[Union[Tuple, List[Tuple]]] = None,
) -> Project:
"""Create a project.
Parameters
Expand All @@ -147,10 +153,10 @@ def create_project(self, name, description=None, github_url=None, training_metad
for repo in self.repositories:
repo.create_project(project)

return Project(project, self.config)
return Project(project, self.configs)

@failsafe
def get_project(self, name=None, id=None):
def get_project(self, name: Optional[str] = None, id: Optional[str] = None) -> Project:
"""Get a project.
Parameters
Expand Down Expand Up @@ -223,7 +229,7 @@ def get_project_as_df(self, name, df_type="pandas", group_by=None):
return project.to_df(df_type=df_type, group_by=None)

@failsafe
def get_or_create_project(self, name, **kwargs):
def get_or_create_project(self, name: str, **kwargs):
"""Get or create a project.
Parameters
Expand Down
12 changes: 11 additions & 1 deletion rubicon_ml/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Union

from rubicon_ml.domain import utils
from rubicon_ml.domain.artifact import Artifact
from rubicon_ml.domain.dataframe import Dataframe
from rubicon_ml.domain.experiment import Experiment
Expand All @@ -12,4 +13,13 @@

DOMAIN_TYPES = Union[Artifact, Dataframe, Experiment, Feature, Metric, Parameter, Project]

__all__ = ["Artifact", "Dataframe", "Experiment", "Feature", "Metric", "Parameter", "Project"]
__all__ = [
"Artifact",
"Dataframe",
"Experiment",
"Feature",
"Metric",
"Parameter",
"Project",
"utils",
]
20 changes: 0 additions & 20 deletions tests/unit/client/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,6 @@ def test_init_memory_repository():
assert config.root_dir == config.repository.root_dir


def test_init_multiple_backend():
config = Config(
composite_config=[
{"persistence": "filesystem", "root_dir": "./local/root", "is_auto_git_enabled": False},
{
"persistence": "filesystem",
"root_dir": "s3://remote/bucket/root",
"is_auto_git_enabled": False,
},
{"persistence": "memory", "root_dir": "./memory/root", "is_auto_git_enabled": False},
]
)

assert hasattr(config, "repositories")

assert isinstance(config.repositories[0], LocalRepository)
assert isinstance(config.repositories[1], S3Repository)
assert isinstance(config.repositories[2], MemoryRepository)


def test_invalid_persistence():
with pytest.raises(ValueError) as e:
Config("invalid")
Expand Down

0 comments on commit 4700ea9

Please sign in to comment.