Skip to content

Commit

Permalink
Merge branch 'main' into sayakpaul-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Jan 11, 2025
2 parents 8657340 + e7db062 commit 80a3a04
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 3 deletions.
3 changes: 3 additions & 0 deletions examples/community/rerender_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
Expand Down
104 changes: 102 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ def __init__(
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448

self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
Expand Down Expand Up @@ -515,6 +518,8 @@ def enable_tiling(
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

def disable_tiling(self) -> None:
r"""
Expand Down Expand Up @@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
return (decoded,)
return DecoderOutput(sample=decoded)

def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, x.shape[2], self.tile_sample_stride_height):
row = []
for j in range(0, x.shape[3], self.tile_sample_stride_width):
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
if (
tile.shape[2] % self.spatial_compression_ratio != 0
or tile.shape[3] % self.spatial_compression_ratio != 0
):
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
tile = F.pad(tile, (0, pad_w, 0, pad_h))
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]

if not return_dict:
return (encoded,)
return EncoderOutput(latent=encoded)

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
batch_size, num_channels, height, width = z.shape

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio

blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=3))

decoded = torch.cat(result_rows, dim=2)

if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
encoded = self.encode(sample, return_dict=False)[0]
Expand Down
29 changes: 29 additions & 0 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,35 @@ def __init__(
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand Down
29 changes: 29 additions & 0 deletions src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,35 @@ def __init__(
)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.num_inference_steps = num_inference_steps

# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
# "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
Expand Down
30 changes: 30 additions & 0 deletions tests/pipelines/sana/test_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass(
"Attention slicing should not affect the inference results",
)

def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()

pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)

# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]

# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]

self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)

# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
Expand Down

0 comments on commit 80a3a04

Please sign in to comment.