From d17240457f2b36273db1b877d5e1bea617f94b77 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 14 Sep 2023 12:42:54 +0200 Subject: [PATCH] [Import] Add missing settings / Correct some dummy imports (#5036) * [Import] Add missing settings * up * up * up --- .../stable_diffusion_safe/__init__.py | 54 ++++++++++--------- .../text_to_video_synthesis/__init__.py | 2 + .../pipelines/vq_diffusion/__init__.py | 3 ++ .../pipelines/wuerstchen/__init__.py | 4 +- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py index 67c6ab1f6686..2bab91c5524a 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -50,13 +50,26 @@ class SafetyConfig(object): _dummy_objects = {} _additional_imports = {} -_import_structure = { - "pipeline_output": ["StableDiffusionSafePipelineOutput"], - "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], - "safety_checker": ["StableDiffusionSafetyChecker"], -} +_import_structure = {} + _additional_imports.update({"SafetyConfig": SafetyConfig}) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_output": ["StableDiffusionSafePipelineOutput"], + "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], + "safety_checker": ["StableDiffusionSafetyChecker"], + } + ) + if TYPE_CHECKING: try: @@ -70,25 +83,16 @@ class SafetyConfig(object): from .safety_checker import SafeStableDiffusionSafetyChecker else: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) - - else: - import sys + import sys - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) - for name, value in _additional_imports.items(): - setattr(sys.modules[__name__], name, value) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py index a09a63476b7c..8bc8e407d4f9 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -47,3 +47,5 @@ _import_structure, module_spec=__spec__, ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py index b8fb7f55e8ce..dac43806a51b 100644 --- a/src/diffusers/pipelines/vq_diffusion/__init__.py +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -51,3 +51,6 @@ _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py index 13407f2cd10c..3a6a464aef05 100644 --- a/src/diffusers/pipelines/wuerstchen/__init__.py +++ b/src/diffusers/pipelines/wuerstchen/__init__.py @@ -41,7 +41,6 @@ from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline - else: import sys @@ -51,3 +50,6 @@ _import_structure, module_spec=__spec__, ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value)