Skip to content

Commit

Permalink
Update config logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kklemon committed Sep 25, 2024
1 parent 6466dd7 commit 981dbf9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
23 changes: 15 additions & 8 deletions src/penai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from accsr.remote_storage import RemoteStorage, RemoteStorageConfig, TransactionSummary
from openai import OpenAI

from penai.errors import ConfigError

file_dir = os.path.dirname(__file__) if "__file__" in locals() else os.getcwd()

top_level_directory: str = os.path.abspath(os.path.join(file_dir, os.pardir, os.pardir))
Expand Down Expand Up @@ -122,14 +124,19 @@ def pull_from_remote(
dryrun: bool = False,
) -> TransactionSummary:
"""Pulls from the remote storage using the default storage config."""
return _default_remote_storage().pull(
remote_path=remote_path,
local_base_dir=top_level_directory,
force=force,
include_regex=include_regex,
exclude_regex=exclude_regex,
dryrun=dryrun,
)
try:
return _default_remote_storage().pull(
remote_path=remote_path,
local_base_dir=top_level_directory,
force=force,
include_regex=include_regex,
exclude_regex=exclude_regex,
dryrun=dryrun,
)
except TypeError as e:
raise ConfigError(
"Pulling from remote storage failed. This might be due to missing configuration keys."
) from e


def push_to_remote(
Expand Down
4 changes: 4 additions & 0 deletions src/penai/errors.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class FontFetchError(Exception):
pass


class ConfigError(Exception):
pass
9 changes: 8 additions & 1 deletion src/penai/registries/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import cache
from typing import Literal
from typing import Literal, Self

from sensai.util.cache import pickle_cached

Expand Down Expand Up @@ -93,6 +93,13 @@ class SavedPenpotProject(Enum):
WIREFRAMING_KIT = "Wireframing kit"
GENERATIVE_VARIATIONS = "Generative variations"

@classmethod
def get_by_name(cls, name: str) -> Self:
for project in cls:
if project.value == name:
return project
raise ValueError(f"Project with name '{name}' not found.")

def get_project_name(self) -> str:
return self.value

Expand Down

0 comments on commit 981dbf9

Please sign in to comment.