Skip to content

Commit

Permalink
SDXL flax (#4254)
Browse files Browse the repository at this point in the history
* support transformer_layers_per block in flax UNet

* add support for text_time additional embeddings to Flax UNet

* rename attention layers for VAE

* add shape asserts when renaming attention layers

* transpose VAE attention layers

* add pipeline flax SDXL code [WIP]

* continue add pipeline flax SDXL code [WIP]

* cleanup

* Working on JIT support

Fixed prompt embedding shapes so they work in parallel mode. Assuming we
always have both text encoders for now, for simplicity.

* Fixing embeddings (untested)

* Remove spurious line

* Shard guidance_scale when jitting.

* Decode images

* Fix sharding

* style

* Refiner UNet can be loaded.

* Refiner / img2img pipeline

* Allow latent outputs from base and latent inputs in refiner

This makes it possible to chain base + refiner without having to use the
vae decoder in the base model, the vae encoder in the refiner, skipping
conversions to/from PIL, and avoiding TPU <-> CPU memory copies.

* Adapt to FlaxCLIPTextModelOutput

* Update Flax XL pipeline to FlaxCLIPTextModelOutput

* make fix-copies

* make style

* add euler scheduler

* Fix import

* Fix copies, comment unused code.

* Fix SDXL Flax imports

* Fix euler discrete begin

* improve init import

* finish

* put discrete euler in init

* fix flax euler

* Fix more

* make style

* correct init

* correct init

* Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline

* correct pipelines

* finish

---------

Co-authored-by: Martin Müller <[email protected]>
Co-authored-by: patil-suraj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
4 people authored Sep 22, 2023
1 parent 2e860e8 commit 3651b14
Show file tree
Hide file tree
Showing 17 changed files with 1,248 additions and 488 deletions.
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@
"FlaxDDIMScheduler",
"FlaxDDPMScheduler",
"FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler",
"FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler",
Expand Down Expand Up @@ -395,6 +396,7 @@
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
]
)

Expand Down Expand Up @@ -673,6 +675,7 @@
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
Expand All @@ -691,6 +694,7 @@
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
)

try:
Expand Down
18 changes: 17 additions & 1 deletion src/diffusers/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,25 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""

# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)

# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T

if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
Expand Down
34 changes: 17 additions & 17 deletions src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,23 @@ def from_pretrained(
"framework": "flax",
}

# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path
model, model_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
# model args
dtype=dtype,
**kwargs,
)
# Load config if we don't provide one
if config is None:
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)

model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)

# Load model
pretrained_path_with_subfolder = (
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -72,7 +73,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
resnets = []
Expand All @@ -213,7 +215,7 @@ def setup(self):
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
Expand Down Expand Up @@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1

def setup(self):
# there is always at least one resnet
Expand All @@ -350,7 +353,7 @@ def setup(self):
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=1,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,6 @@ def forward(
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
Expand Down
67 changes: 65 additions & 2 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
Expand All @@ -127,7 +132,17 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]

def setup(self):
block_out_channels = self.block_out_channels
Expand Down Expand Up @@ -168,6 +183,24 @@ def setup(self):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)

# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)

# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == "text_time":
if self.addition_time_embed_dim is None:
raise ValueError(
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
)
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")

# down
down_blocks = []
output_channel = block_out_channels[0]
Expand All @@ -182,6 +215,7 @@ def setup(self):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
Expand All @@ -207,6 +241,7 @@ def setup(self):
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype,
Expand All @@ -218,6 +253,7 @@ def setup(self):
reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
Expand All @@ -231,6 +267,7 @@ def setup(self):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block,
dropout=self.dropout,
Expand Down Expand Up @@ -269,6 +306,7 @@ def __call__(
sample,
timesteps,
encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
return_dict: bool = True,
Expand Down Expand Up @@ -300,6 +338,31 @@ def __call__(
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)

# additional embeddings
aug_emb = None
if self.addition_embed_type == "text_time":
if added_cond_kwargs is None:
raise ValueError(
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)

t_emb = t_emb + aug_emb if aug_emb is not None else t_emb

# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
Expand Down
Loading

0 comments on commit 3651b14

Please sign in to comment.