Skip to content

Commit

Permalink
fix(T2V-Turbo): fix the bug of lora save&load in T2V-Turbo (#830)
Browse files Browse the repository at this point in the history
* bugfix num_train_epochs

* replace ops.repeat() with repeat_interleave() or tile()

* fix lora save&load

* bugfix and aligns tile

* bugfix lora
  • Loading branch information
hqkate authored Feb 7, 2025
1 parent 8802b97 commit 8be96ce
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = ops.arange(seq_length)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = seq_ids[None, None, :].tile((batch_size, seq_length, 1)) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def pretrain_internvideo2_6b_patch14_224(config):
dim=-1,
).to(ms.bool)

output = model(ops.rand((4, 3, num_frames, img_size, img_size)), mask.repeat(4, 1))
output = model(ops.rand((4, 3, num_frames, img_size, img_size)), mask.repeat_interleave(4, 1))
print(output[0].shape)
print(output[1].shape)
print(output[2].shape)
Expand Down
2 changes: 1 addition & 1 deletion examples/t2v_turbo/lvdm/modules/encoders/ip_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
)

def construct(self, x):
latents = self.latents.repeat(x.shape[0], 1, 1)
latents = self.latents.tile((x.shape[0], 1, 1))

x = self.proj_in(x)

Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/pipeline/t2v_turbo_ms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _encode_prompt(

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.tile((1, num_videos_per_prompt, 1))
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)

# Don't need to get uncond prompt embedding because of LCM Guided Distillation
Expand Down Expand Up @@ -177,7 +177,7 @@ def __call__(
bs = batch_size * num_videos_per_prompt

# 6. Get Guidance Scale Embedding
w = ms.Tensor(guidance_scale).repeat(bs)
w = ms.Tensor(guidance_scale).tile((bs,))
w_embedding = self.get_w_embedding(w, embedding_dim=256)

# 7. LCM MultiStep Sampling Loop:
Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/pipeline/t2v_turbo_vc2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _encode_prompt(

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.tile((1, num_videos_per_prompt, 1))
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)

# Don't need to get uncond prompt embedding because of LCM Guided Distillation
Expand Down Expand Up @@ -171,7 +171,7 @@ def __call__(
bs = batch_size * num_videos_per_prompt

# 6. Get Guidance Scale Embedding
w = ms.Tensor(guidance_scale).repeat(bs)
w = ms.Tensor(guidance_scale).tile((bs,))
w_embedding = self.get_w_embedding(w, embedding_dim=256)

# 7. LCM MultiStep Sampling Loop:
Expand Down
8 changes: 6 additions & 2 deletions examples/t2v_turbo/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from itertools import groupby
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import numpy as np

import mindspore as ms
import mindspore.common.initializer as init
from mindspore import nn, ops
Expand All @@ -19,6 +21,8 @@
def load_lora_from_pkl(file_path):
with open(file_path, "rb") as file:
loras = pickle.load(file)

loras = [ms.Tensor(lora_weight) if isinstance(lora_weight, np.ndarray) else lora_weight for lora_weight in loras]
return loras


Expand Down Expand Up @@ -571,8 +575,8 @@ def save_lora_weight(
):
weights = []
for _up, _down in extract_lora_ups_down(model, target_replace_module=target_replace_module):
weights.append(_up.weight.value().to(ms.float32))
weights.append(_down.weight.value().to(ms.float32))
weights.append(_up.weight.value().asnumpy())
weights.append(_down.weight.value().asnumpy())

import pickle

Expand Down
4 changes: 2 additions & 2 deletions examples/t2v_turbo/viclip/viclip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def inflate_weight(weight_2d, time_dim, center=True):
logger.info(f"Init center: {center}")
if center:
weight_3d = ops.zeros(*weight_2d.shape)
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_3d.unsqueeze(2).tile((1, 1, time_dim, 1, 1))
middle_idx = time_dim // 2
weight_3d[:, :, middle_idx, :, :] = weight_2d
else:
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
weight_3d = weight_2d.unsqueeze(2).tile((1, 1, time_dim, 1, 1))
weight_3d = weight_3d / time_dim
return weight_3d

Expand Down

0 comments on commit 8be96ce

Please sign in to comment.