Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement framewise encoding/decoding in LTX Video VAE #10488

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
141 changes: 98 additions & 43 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,8 @@ def __init__(

# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
self.use_framewise_encoding = False
self.use_framewise_decoding = False
self.use_framewise_encoding = True
self.use_framewise_decoding = True
rootonchair marked this conversation as resolved.
Show resolved Hide resolved

# This can be configured based on the amount of GPU memory available.
# `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
Expand All @@ -1010,10 +1010,12 @@ def __init__(
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16

# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
Expand All @@ -1023,8 +1025,10 @@ def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
Expand All @@ -1046,8 +1050,10 @@ def enable_tiling(
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames

def disable_tiling(self) -> None:
r"""
Expand All @@ -1073,18 +1079,13 @@ def disable_slicing(self) -> None:
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
enc = self.encoder(x)

return enc

Expand Down Expand Up @@ -1121,19 +1122,15 @@ def _decode(
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio

if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)

if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)

if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z, temb)
dec = self.decoder(z, temb)

if not return_dict:
return (dec,)
Expand Down Expand Up @@ -1189,6 +1186,14 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
)
return b

def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b

def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.

Expand Down Expand Up @@ -1217,17 +1222,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)

row.append(time)
rows.append(row)
Expand Down Expand Up @@ -1283,17 +1280,7 @@ def tiled_decode(
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)

row.append(time)
rows.append(row)
Expand All @@ -1318,6 +1305,74 @@ def tiled_decode(

return DecoderOutput(sample=dec)

def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1

tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames

row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)

result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])

enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
return enc

def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames

row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)

result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_num_frames)
tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
result_row.append(tile)
else:
result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])

dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.Tensor,
Expand All @@ -1334,5 +1389,5 @@ def forward(
z = posterior.mode()
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return (dec.sample,)
return dec
31 changes: 31 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,34 @@ def test_outputs_equivalence(self):
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)

inputs_dict.update({"return_dict": False})

torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)

torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]

self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
Loading