From effe9d66ebad017d54b47e11c79f347fa7556ecf Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 13 Oct 2022 21:05:17 +0200 Subject: [PATCH] [FlaxStableDiffusionPipeline] fix bug when nsfw is detected (#832) fix nsfw bug --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 6cd678293f5a..7e58d048b4ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -291,7 +291,8 @@ def __call__( # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): - images[i] = np.asarray(images_uint8_casted[i]) + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: