diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index c92fdccb8333..525548e7c302 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -45,21 +45,3 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL - -## FlaxModelMixin -[[autodoc]] FlaxModelMixin - -## FlaxUNet2DConditionOutput -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput - -## FlaxUNet2DConditionModel -[[autodoc]] FlaxUNet2DConditionModel - -## FlaxDecoderOutput -[[autodoc]] models.vae_flax.FlaxDecoderOutput - -## FlaxAutoencoderKLOutput -[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput - -## FlaxAutoencoderKL -[[autodoc]] FlaxAutoencoderKL diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 12a6b5c587bc..d2c4c4d408e0 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher To this end, the design of schedulers is such that: - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. -- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists). +- Schedulers are currently by default in PyTorch. ## API diff --git a/setup.py b/setup.py index bce9f5e401d2..8a74789cf794 100644 --- a/setup.py +++ b/setup.py @@ -84,13 +84,10 @@ "datasets", "filelock", "flake8>=3.8.3", - "flax>=0.4.1", "hf-doc-builder>=0.3.0", "huggingface-hub>=0.10.0", "importlib_metadata", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2,<=0.3.6", - "jaxlib>=0.1.65,<=0.3.6", "modelcards>=0.1.4", "numpy", "onnxruntime", @@ -188,15 +185,9 @@ def run(self): "torchvision", "transformers" ) -extras["torch"] = deps_list("torch") - -if os.name == "nt": # windows - extras["flax"] = [] # jax is not supported on windows -else: - extras["flax"] = deps_list("jax", "jaxlib", "flax") extras["dev"] = ( - extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] + extras["quality"] + extras["test"] + extras["training"] + extras["docs"] ) install_requires = [ @@ -207,6 +198,7 @@ def run(self): deps["regex"], deps["requests"], deps["Pillow"], + deps["torch"] ] setup( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1cf64a4a2ebf..219f2d8bf9d1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,5 +1,4 @@ from .utils import ( - is_flax_available, is_inflect_available, is_onnx_available, is_scipy_available, @@ -61,25 +60,3 @@ from .pipelines import StableDiffusionOnnxPipeline else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 - -if is_flax_available(): - from .modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipeline_flax_utils import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) -else: - from .utils.dummy_flax_objects import * # noqa F403 - -if is_flax_available() and is_transformers_available(): - from .pipelines import FlaxStableDiffusionPipeline -else: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 8b10d70a26f7..7ea7a66f19d2 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -8,13 +8,10 @@ "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", - "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.10.0", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", - "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", "onnxruntime": "onnxruntime", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1242ad6fca7f..aff1ec1c57a5 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,14 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_flax_available, is_torch_available +from ..utils import is_torch_available if is_torch_available(): from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel - -if is_flax_available(): - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1c31595fb0cf..3f3df460d62f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -21,6 +21,3 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline - -if is_transformers_available() and is_flax_available(): - from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 615fa404da0b..289c6e1a948a 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available +from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available @dataclass @@ -35,27 +35,3 @@ class StableDiffusionPipelineOutput(BaseOutput): if is_transformers_available() and is_onnx_available(): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline - -if is_transformers_available() and is_flax_available(): - import flax - - @flax.struct.dataclass - class FlaxStableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: List[bool] - - from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState - from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a906c39eb24c..a3c23d0f99ce 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from ..utils import is_flax_available, is_scipy_available, is_torch_available +from ..utils import is_scipy_available, is_torch_available if is_torch_available(): @@ -27,17 +27,6 @@ else: from ..utils.dummy_pt_objects import * # noqa F403 -if is_flax_available(): - from .scheduling_ddim_flax import FlaxDDIMScheduler - from .scheduling_ddpm_flax import FlaxDDPMScheduler - from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler - from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler - from .scheduling_pndm_flax import FlaxPNDMScheduler - from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import FlaxSchedulerMixin -else: - from ..utils.dummy_flax_objects import * # noqa F403 - if is_scipy_available() and is_torch_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler