diff --git a/src/penai/config.py b/src/penai/config.py index 1eb2b66..e1ff81c 100644 --- a/src/penai/config.py +++ b/src/penai/config.py @@ -13,6 +13,8 @@ from accsr.remote_storage import RemoteStorage, RemoteStorageConfig, TransactionSummary from openai import OpenAI +from penai.errors import ConfigError + file_dir = os.path.dirname(__file__) if "__file__" in locals() else os.getcwd() top_level_directory: str = os.path.abspath(os.path.join(file_dir, os.pardir, os.pardir)) @@ -122,14 +124,19 @@ def pull_from_remote( dryrun: bool = False, ) -> TransactionSummary: """Pulls from the remote storage using the default storage config.""" - return _default_remote_storage().pull( - remote_path=remote_path, - local_base_dir=top_level_directory, - force=force, - include_regex=include_regex, - exclude_regex=exclude_regex, - dryrun=dryrun, - ) + try: + return _default_remote_storage().pull( + remote_path=remote_path, + local_base_dir=top_level_directory, + force=force, + include_regex=include_regex, + exclude_regex=exclude_regex, + dryrun=dryrun, + ) + except TypeError as e: + raise ConfigError( + "Pulling from remote storage failed. This might be due to missing configuration keys." + ) from e def push_to_remote( diff --git a/src/penai/errors.py b/src/penai/errors.py index 3c930de..44dc9bc 100644 --- a/src/penai/errors.py +++ b/src/penai/errors.py @@ -1,2 +1,6 @@ class FontFetchError(Exception): pass + + +class ConfigError(Exception): + pass diff --git a/src/penai/registries/projects.py b/src/penai/registries/projects.py index ca9f714..69f26cd 100644 --- a/src/penai/registries/projects.py +++ b/src/penai/registries/projects.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from functools import cache -from typing import Literal +from typing import Literal, Self from sensai.util.cache import pickle_cached @@ -93,6 +93,13 @@ class SavedPenpotProject(Enum): WIREFRAMING_KIT = "Wireframing kit" GENERATIVE_VARIATIONS = "Generative variations" + @classmethod + def get_by_name(cls, name: str) -> Self: + for project in cls: + if project.value == name: + return project + raise ValueError(f"Project with name '{name}' not found.") + def get_project_name(self) -> str: return self.value