Skip to content

Commit

Permalink
[Import] Add missing settings / Correct some dummy imports (#5036)
Browse files Browse the repository at this point in the history
* [Import] Add missing settings

* up

* up

* up
  • Loading branch information
patrickvonplaten committed Sep 14, 2023
1 parent 7512fc4 commit d172404
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 26 deletions.
54 changes: 29 additions & 25 deletions src/diffusers/pipelines/stable_diffusion_safe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,26 @@ class SafetyConfig(object):

_dummy_objects = {}
_additional_imports = {}
_import_structure = {
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_import_structure = {}

_additional_imports.update({"SafetyConfig": SafetyConfig})

try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)


if TYPE_CHECKING:
try:
Expand All @@ -70,25 +83,16 @@ class SafetyConfig(object):
from .safety_checker import SafeStableDiffusionSafetyChecker

else:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))

else:
import sys
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/text_to_video_synthesis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/vq_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/wuerstchen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline

else:
import sys

Expand All @@ -51,3 +50,6 @@
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

0 comments on commit d172404

Please sign in to comment.