Skip to content

Commit

Permalink
[v0.4.0] Temporarily remove Flax modules from the public API (#755)
Browse files Browse the repository at this point in the history
Temporarily remove Flax modules from the public API
  • Loading branch information
anton-l authored Oct 6, 2022
1 parent 9c9462f commit 2e209c3
Show file tree
Hide file tree
Showing 9 changed files with 6 additions and 100 deletions.
18 changes: 0 additions & 18 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,3 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module

## AutoencoderKL
[[autodoc]] AutoencoderKL

## FlaxModelMixin
[[autodoc]] FlaxModelMixin

## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput

## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel

## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput

## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput

## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
2 changes: 1 addition & 1 deletion docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:

- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
- Schedulers are currently by default in PyTorch.


## API
Expand Down
12 changes: 2 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@
"datasets",
"filelock",
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
Expand Down Expand Up @@ -188,15 +185,9 @@ def run(self):
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")

extras["dev"] = (
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
)

install_requires = [
Expand All @@ -207,6 +198,7 @@ def run(self):
deps["regex"],
deps["requests"],
deps["Pillow"],
deps["torch"]
]

setup(
Expand Down
23 changes: 0 additions & 23 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .utils import (
is_flax_available,
is_inflect_available,
is_onnx_available,
is_scipy_available,
Expand Down Expand Up @@ -61,25 +60,3 @@
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
from .utils.dummy_flax_objects import * # noqa F403

if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
3 changes: 0 additions & 3 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
Expand Down
6 changes: 1 addition & 5 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_flax_available, is_torch_available
from ..utils import is_torch_available


if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel

if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
3 changes: 0 additions & 3 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@

if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
26 changes: 1 addition & 25 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL
from PIL import Image

from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available


@dataclass
Expand Down Expand Up @@ -35,27 +35,3 @@ class StableDiffusionPipelineOutput(BaseOutput):

if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline

if is_transformers_available() and is_flax_available():
import flax

@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]

from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
13 changes: 1 addition & 12 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from ..utils import is_flax_available, is_scipy_available, is_torch_available
from ..utils import is_scipy_available, is_torch_available


if is_torch_available():
Expand All @@ -27,17 +27,6 @@
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
else:
from ..utils.dummy_flax_objects import * # noqa F403


if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
Expand Down

0 comments on commit 2e209c3

Please sign in to comment.