Skip to content

Commit

Permalink
modelcards and tensorboard are optional
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jul 21, 2022
1 parent b1b99b5 commit 416749f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,11 @@ def run(self):
extras = {}
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
extras["docs"] = []
extras["training"] = ["tensorboard", "modelcards"]
extras["test"] = [
"pytest",
]
extras["dev"] = extras["quality"] + extras["test"]
extras["dev"] = extras["quality"] + extras["test"] + extras["training"]

install_requires = [
deps["filelock"],
Expand All @@ -174,8 +175,6 @@ def run(self):
deps["requests"],
deps["torch"],
deps["Pillow"],
deps["tensorboard"],
deps["modelcards"],
]

setup(
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami
from modelcards import CardData, ModelCard
from utils import is_modelcards_available


if is_modelcards_available():
from modelcards import CardData, ModelCard

from .utils import logging

Expand Down Expand Up @@ -147,6 +151,12 @@ def push_to_hub(


def create_model_card(args, model_name):
if not is_modelcards_available:
raise ValueError(
"Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
" install the package with `pip install modelcards`."
)

if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
_unidecode_available = False


_modelcards_available = importlib.util.find_spec("modelcards") is not None
try:
_modelcards_version = importlib_metadata.version("modelcards")
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
except importlib_metadata.PackageNotFoundError:
_modelcards_available = False


def is_transformers_available():
return _transformers_available

Expand All @@ -73,6 +81,10 @@ def is_unidecode_available():
return _unidecode_available


def is_modelcards_available():
return _modelcards_available


class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Expand Down

0 comments on commit 416749f

Please sign in to comment.