Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiGPU support #184

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added enhance_a_video/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added hyvideo/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added hyvideo/__pycache__/constants.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler

from ...modules import HYVideoDiffusionTransformer
from comfy.utils import ProgressBar
Expand Down Expand Up @@ -636,6 +637,7 @@ def __call__(
logger.info(f"Sampling {video_length} frames in {latents.shape[2]} latents at {width}x{height} with {len(timesteps)} inference steps")
comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar:
old_pred_original_sample = None # for DPM-solver++
for i, t in enumerate(timesteps):
if self.interrupt:
continue
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
flow_shift: float = 1.0,
reverse: bool = True,
solver: str = "euler",
n_tokens: Optional[int] = None,
Expand All @@ -80,7 +80,7 @@ def __init__(
print("Scheduler config:", self.config)
if not reverse:
sigmas = sigmas.flip(0)
self.shift = shift
self.flow_shift = flow_shift

self.sigmas = sigmas
# the value fed to model
Expand Down Expand Up @@ -184,7 +184,7 @@ def scale_model_input(
return sample

def sd3_time_shift(self, t: torch.Tensor):
return (self.shift * t) / (1 + (self.shift - 1) * t)
return (self.flow_shift * t) / (1 + (self.flow_shift - 1) * t)

def step(
self,
Expand Down
4 changes: 1 addition & 3 deletions hyvideo/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG


from .models import HYVideoDiffusionTransformer
Binary file added hyvideo/modules/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
229 changes: 134 additions & 95 deletions hyvideo/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -671,11 +672,23 @@ def __init__(
get_activation_layer("silu"),
**factory_kwargs,
)
#init block swap variables
self.double_blocks_to_swap = -1
self.single_blocks_to_swap = -1
self.offload_txt_in = False
self.offload_img_in = False

#init TeaCache variables
self.enable_teacache = False
self.cnt = 0
self.num_steps = 0
self.rel_l1_thresh = 0.15
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.last_dimensions = None
self.last_frame_count = None

# thanks @2kpr for the initial block swap code!
def block_swap(self, double_blocks_to_swap, single_blocks_to_swap, offload_txt_in=False, offload_img_in=False):
print(f"Swapping {double_blocks_to_swap + 1} double blocks and {single_blocks_to_swap + 1} single blocks")
Expand Down Expand Up @@ -866,6 +879,30 @@ def forward(
stg_block_idx: int = -1,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:

def _process_double_blocks(img, txt, vec, block_args):
for b, block in enumerate(self.double_blocks):
if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
block.to(self.main_device)

img, txt = block(img, txt, vec, *block_args)

if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
block.to(self.offload_device, non_blocking=True)
return img, txt

def _process_single_blocks(x, vec, txt_seq_len, block_args, stg_mode=None, stg_block_idx=None):
for b, block in enumerate(self.single_blocks):
if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
block.to(self.main_device)

curr_stg_mode = stg_mode if b == stg_block_idx else None
x = block(x, vec, txt_seq_len, *block_args, curr_stg_mode)

if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
block.to(self.offload_device, non_blocking=True)
return x

out = {}
img = x
txt = text_states
Expand All @@ -877,6 +914,21 @@ def forward(
)
set_num_frames(img.shape[2])

current_dims = (ot, oh, ow)

# Check if dimensions changed since last run
if not hasattr(self, 'last_dims') or self.last_dims != current_dims:
# Reset TeaCache state on dimension change
self.cnt = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_residual = None
self.last_dims = current_dims

out = {}
img = x
txt = text_states

# Prepare modulation vectors.
vec = self.time_in(t)

Expand Down Expand Up @@ -931,57 +983,70 @@ def forward(
cu_seqlens_kv = cu_seqlens_q

freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for b, block in enumerate(self.double_blocks):
if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
#print(f"Moving double_block {b} to main device")
block.to(self.main_device)
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
attn_mask
]

img, txt = block(*double_block_args)
if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
#print(f"Moving double_block {b} to offload device")
block.to(self.offload_device, non_blocking=True)

# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for b, block in enumerate(self.single_blocks):
if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
#print(f"Moving single_block {b} to main device")
#mm.soft_empty_cache()
block.to(self.main_device)
curr_stg_mode = stg_mode if b == stg_block_idx else None
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
attn_mask,
curr_stg_mode,
]

x = block(*single_block_args)
if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
#print(f"Moving single_block {b} to offload device")
#mm.soft_empty_cache()
block.to(self.offload_device, non_blocking=True)
block_args = [cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, freqs_cis, attn_mask]

#tea_cache
if self.enable_teacache:
inp = img.clone()
vec_ = vec.clone()
txt_ = txt.clone()
self.double_blocks[0].to(self.main_device)
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
normed_inp = self.double_blocks[0].img_norm1(inp)
modulated_inp = modulate(
normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
)

img = x[:, :img_seq_len, ...]
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp.clone()
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp.clone()
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0

if not should_calc and self.previous_residual is not None:
# Verify tensor dimensions match before adding
if img.shape == self.previous_residual.shape:
img = img + self.previous_residual
else:
should_calc = True # Force recalculation if dimensions don't match

if should_calc:
ori_img = img.clone()
# Pass through DiT blocks
img, txt = _process_double_blocks(img, txt, vec, block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx)

img = x[:, :img_seq_len, ...]
self.previous_residual = img - ori_img
else:
# Pass through DiT blocks
img, txt = _process_double_blocks(img, txt, vec, block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx)
img = x[:, :img_seq_len, ...]

# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
Expand All @@ -1007,52 +1072,26 @@ def unpatchify(self, x, t, h, w):

return imgs

def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts


#################################################################################
# HunyuanVideo Configs #
#################################################################################

HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
# HUNYUAN_VIDEO_CONFIG = {
# "HYVideo-T/2": {
# "mm_double_blocks_depth": 20,
# "mm_single_blocks_depth": 40,
# "rope_dim_list": [16, 56, 56],
# "hidden_size": 3072,
# "heads_num": 24,
# "mlp_width_ratio": 4,
# },
# "HYVideo-T/2-cfgdistill": {
# "mm_double_blocks_depth": 20,
# "mm_single_blocks_depth": 40,
# "rope_dim_list": [16, 56, 56],
# "hidden_size": 3072,
# "heads_num": 24,
# "mlp_width_ratio": 4,
# "guidance_embed": True,
# },
# }
Binary file not shown.
Binary file added hyvideo/utils/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added hyvideo/utils/__pycache__/data_utils.cpython-310.pyc
Binary file not shown.
Binary file added hyvideo/utils/__pycache__/helpers.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added hyvideo/vae/__pycache__/vae.cpython-310.pyc
Binary file not shown.
Loading