diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index c2eb7fd2a705..2e2eb4c12912 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -198,6 +199,8 @@ def __init__( self.gradient_checkpointing = False + self.enable_teacache = False + def forward( self, hidden_states: torch.Tensor, @@ -305,27 +308,77 @@ def forward( # [B, 1, 1, N], for broadcasting across attention heads attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) - if torch.is_grad_enabled() and self.gradient_checkpointing: - for block in self.transformer_blocks: - hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + if self.enable_teacache: + hidden_states_ = hidden_states.clone() + temb_ = temb.clone() + modulated_inp = self.transformer_blocks[0].norm1(hidden_states_, emb=temb_)[0] + + if self.cnt == 0 or self.cnt == self.num_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + curr_rel_l1 = ( + ( + (modulated_inp - self.previous_modulated_input).abs().mean() + / self.previous_modulated_input.abs().mean() + ) + .cpu() + .item() ) - for block in self.single_transformer_blocks: - hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb - ) + self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) - else: - for block in self.transformer_blocks: - hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb - ) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 - for block in self.single_transformer_blocks: - hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb - ) + self.previous_modulated_input = modulated_inp + self.cnt += 1 + + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc: + hidden_states += self.previous_residual + else: + ori_hidden_states = hidden_states.clone() + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + self.previous_residual = hidden_states - ori_hidden_states + + else: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) hidden_states = hidden_states[:, -original_context_length:] hidden_states = self.norm_out(hidden_states, temb) @@ -398,6 +451,17 @@ def _pad_rotary_emb( freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0) return freqs_cos, freqs_sin + def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): + self.enable_teacache = enable_teacache + self.cnt = 0 + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.coeffs = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02] + self.teacache_rescale_func = np.poly1d(self.coeffs) + def _pad_for_3d_conv(x, kernel_size): if isinstance(x, (tuple, list)): diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py index ddb79925a7fe..9e9b1247a1bb 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py @@ -114,3 +114,24 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideoFramepackTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_teacache_initialization(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + custom_num_steps = 50 + custom_thresh = 0.1 + + model.initialize_teacache(enable_teacache=True, num_steps=custom_num_steps, rel_l1_thresh=custom_thresh) + + self.assertTrue(model.enable_teacache) + self.assertEqual(model.num_steps, custom_num_steps) + self.assertEqual(model.rel_l1_thresh, custom_thresh) + + self.assertEqual(model.cnt, 0) + self.assertEqual(model.accumulated_rel_l1_distance, 0) + self.assertIsNone(model.previous_modulated_input) + self.assertIsNone(model.previous_residual) + self.assertTrue(hasattr(model, "coeffs")) + self.assertTrue(hasattr(model, "teacache_rescale_func"))