Skip to content

Teacache implemented into hunyuan video framepack transformer model #11949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -198,6 +199,8 @@ def __init__(

self.gradient_checkpointing = False

self.enable_teacache = False

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -305,27 +308,77 @@ def forward(
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)

if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
if self.enable_teacache:
hidden_states_ = hidden_states.clone()
temb_ = temb.clone()
modulated_inp = self.transformer_blocks[0].norm1(hidden_states_, emb=temb_)[0]

if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
curr_rel_l1 = (
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
self.previous_modulated_input = modulated_inp
self.cnt += 1

if self.cnt == self.num_steps:
self.cnt = 0

if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

self.previous_residual = hidden_states - ori_hidden_states

else:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

hidden_states = hidden_states[:, -original_context_length:]
hidden_states = self.norm_out(hidden_states, temb)
Expand Down Expand Up @@ -398,6 +451,17 @@ def _pad_rotary_emb(
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
return freqs_cos, freqs_sin

def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
self.enable_teacache = enable_teacache
self.cnt = 0
self.num_steps = num_steps
self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.coeffs = [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02]
self.teacache_rescale_func = np.poly1d(self.coeffs)


def _pad_for_3d_conv(x, kernel_size):
if isinstance(x, (tuple, list)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,24 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_teacache_initialization(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

custom_num_steps = 50
custom_thresh = 0.1

model.initialize_teacache(enable_teacache=True, num_steps=custom_num_steps, rel_l1_thresh=custom_thresh)

self.assertTrue(model.enable_teacache)
self.assertEqual(model.num_steps, custom_num_steps)
self.assertEqual(model.rel_l1_thresh, custom_thresh)

self.assertEqual(model.cnt, 0)
self.assertEqual(model.accumulated_rel_l1_distance, 0)
self.assertIsNone(model.previous_modulated_input)
self.assertIsNone(model.previous_residual)
self.assertTrue(hasattr(model, "coeffs"))
self.assertTrue(hasattr(model, "teacache_rescale_func"))