From 111200d9e306956c2f93fd0227ffcf5ac5285a29 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Tue, 6 Aug 2024 21:30:05 +0800 Subject: [PATCH 001/122] fix config names --- .../opensora-v1-1/train/train_stage2.yaml | 18 +++++++++--------- .../opensora-v1-1/train/train_stage3.yaml | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml index 661b76b627..093c520c38 100644 --- a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml +++ b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml @@ -43,15 +43,15 @@ epochs: 2000 ckpt_save_interval: 100 mask_ratios: - mask_no: 0.75 - mask_quarter_random: 0.025 - mask_quarter_head: 0.025 - mask_quarter_tail: 0.025 - mask_quarter_head_tail: 0.05 - mask_image_random: 0.025 - mask_image_head: 0.025 - mask_image_tail: 0.025 - mask_image_head_tail: 0.05 + identity: 0.75 + quarter_random: 0.025 + quarter_head: 0.025 + quarter_tail: 0.025 + quarter_head_tail: 0.05 + image_random: 0.025 + image_head: 0.025 + image_tail: 0.025 + image_head_tail: 0.05 bucket_config: # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } diff --git a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml index 8463e37a51..dd085233ce 100644 --- a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml +++ b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml @@ -43,15 +43,15 @@ epochs: 2000 ckpt_save_interval: 100 mask_ratios: - mask_no: 0.75 - mask_quarter_random: 0.025 - mask_quarter_head: 0.025 - mask_quarter_tail: 0.025 - mask_quarter_head_tail: 0.05 - mask_image_random: 0.025 - mask_image_head: 0.025 - mask_image_tail: 0.025 - mask_image_head_tail: 0.05 + identity: 0.75 + quarter_random: 0.025 + quarter_head: 0.025 + quarter_tail: 0.025 + quarter_head_tail: 0.05 + image_random: 0.025 + image_head: 0.025 + image_tail: 0.025 + image_head_tail: 0.05 bucket_config: # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } From 94614d93527f1decc25694fd994d589935c0de01 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 16 Aug 2024 23:22:09 +0800 Subject: [PATCH 002/122] add mem monitor --- .../tools/mem_monitor/monitor.sh | 19 +++++++++++++++++++ .../opensora_hpcai/tools/mem_monitor/plot.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 examples/opensora_hpcai/tools/mem_monitor/monitor.sh create mode 100644 examples/opensora_hpcai/tools/mem_monitor/plot.py diff --git a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh new file mode 100644 index 0000000000..4eda7cb221 --- /dev/null +++ b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# File to store memory usage data +LOG_FILE="memory_usage.log" + +# Clear the log file at the start +echo "Timestamp,Memory_Usage(%)" > $LOG_FILE + +# Monitor memory usage every second for a specified duration +DURATION=60 # Total duration in seconds +INTERVAL=1 # Interval in seconds + +for ((i=0; i> $LOG_FILE + sleep $INTERVAL +done diff --git a/examples/opensora_hpcai/tools/mem_monitor/plot.py b/examples/opensora_hpcai/tools/mem_monitor/plot.py new file mode 100644 index 0000000000..65ed5eba6e --- /dev/null +++ b/examples/opensora_hpcai/tools/mem_monitor/plot.py @@ -0,0 +1,19 @@ +import pandas as pd +import matplotlib.pyplot as plt + +# Load the memory usage data +data = pd.read_csv('memory_usage.log', parse_dates=['Timestamp']) + +# Plotting +plt.figure(figsize=(10, 5)) +plt.plot(data['Timestamp'], data['Memory_Usage(%)'], marker='o') +plt.title('Memory Usage Over Time') +plt.xlabel('Timestamp') +plt.ylabel('Memory Usage (%)') +plt.xticks(rotation=45) +plt.grid() +plt.tight_layout() + +# Save the plot +plt.savefig('memory_usage_plot.png') +plt.show() From f784d36ccac8d2d1b5204c122a5e9b39f013f960 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Mon, 19 Aug 2024 14:54:39 +0800 Subject: [PATCH 003/122] update --- examples/opensora_hpcai/tools/mem_monitor/monitor.sh | 4 ++-- examples/opensora_hpcai/tools/mem_monitor/plot.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh index 4eda7cb221..de080afc55 100644 --- a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh +++ b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh @@ -7,8 +7,8 @@ LOG_FILE="memory_usage.log" echo "Timestamp,Memory_Usage(%)" > $LOG_FILE # Monitor memory usage every second for a specified duration -DURATION=60 # Total duration in seconds -INTERVAL=1 # Interval in seconds +DURATION=6000000000 # Total duration in seconds +INTERVAL=10 # Interval in seconds for ((i=0; i Date: Mon, 19 Aug 2024 19:45:24 +0800 Subject: [PATCH 004/122] update --- .../tools/mem_monitor/monitor.sh | 55 +++++++++++++++---- .../opensora_hpcai/tools/mem_monitor/plot.py | 17 +++--- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh index de080afc55..040480f68c 100644 --- a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh +++ b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh @@ -1,19 +1,50 @@ #!/bin/bash -# File to store memory usage data +# Check if a PID is provided +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +PID=$1 LOG_FILE="memory_usage.log" -# Clear the log file at the start -echo "Timestamp,Memory_Usage(%)" > $LOG_FILE +# Check if the process with the given PID exists +if ! ps -p $PID > /dev/null; then + echo "Process with PID $PID does not exist." + exit 1 +fi + +# Initialize the log file +echo "Timestamp,Memory_Usage_Percentage" > "$LOG_FILE" + +# Monitor memory usage +echo "Monitoring memory usage for PID: $PID. Logging to $LOG_FILE" +echo "Press [CTRL+C] to stop." + +# Loop to continuously monitor memory usage +while true; do + # Get the total memory in KB + TOTAL_MEM=$(grep MemTotal /proc/meminfo | awk '{print $2}') + + # Get the RSS memory of the process in KB + MEMORY_INFO=$(pmap -x $PID | tail -n 1) + RSS_MEMORY=$(echo $MEMORY_INFO | awk '{print $3}') # Get the total RSS memory + + # Calculate memory usage percentage + if [ -n "$RSS_MEMORY" ]; then + MEMORY_USAGE_PERCENTAGE=$(echo "scale=2; ($RSS_MEMORY / $TOTAL_MEM) * 100" | bc) + TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S") + + # Log the timestamp and memory usage percentage + echo "$TIMESTAMP,$MEMORY_USAGE_PERCENTAGE" >> "$LOG_FILE" -# Monitor memory usage every second for a specified duration -DURATION=6000000000 # Total duration in seconds -INTERVAL=10 # Interval in seconds + # Print the memory usage percentage to the console + echo "[$TIMESTAMP] Memory Usage: $MEMORY_USAGE_PERCENTAGE%" + else + echo "Unable to retrieve memory usage for PID $PID." + fi -for ((i=0; i> $LOG_FILE - sleep $INTERVAL + # Sleep for a specified interval (e.g., 1 second) + sleep 10 done diff --git a/examples/opensora_hpcai/tools/mem_monitor/plot.py b/examples/opensora_hpcai/tools/mem_monitor/plot.py index 43b7a68222..fa76dffe47 100644 --- a/examples/opensora_hpcai/tools/mem_monitor/plot.py +++ b/examples/opensora_hpcai/tools/mem_monitor/plot.py @@ -1,19 +1,20 @@ import pandas as pd import matplotlib.pyplot as plt +import sys -# Load the memory usage data -data = pd.read_csv('memory_usage.log', parse_dates=['Timestamp']) +# Read the log file +data = pd.read_csv("memory_usage.log", parse_dates=['Timestamp']) -# Plotting +# Plotting the memory usage plt.figure(figsize=(10, 5)) -plt.plot(data['Timestamp'], data['Memory_Usage(%)']) -plt.title('Memory Usage Over Time') -plt.xlabel('Timestamp') +plt.plot(data['Timestamp'], data['Memory_Usage_Percentage'], label='Memory Usage (%)', color='blue') +plt.title('Memory Usage Percentage Over Time') +plt.xlabel('Time') plt.ylabel('Memory Usage (%)') plt.xticks(rotation=45) +plt.ylim(0, 100) # Set y-axis limits from 0 to 100% plt.grid() +plt.legend() plt.tight_layout() - -# Save the plot plt.savefig('memory_usage_plot.png') plt.show() From 929b7ae6455a2e6743a180693d2a18f9cffaee6f Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Tue, 22 Oct 2024 15:51:42 +0800 Subject: [PATCH 005/122] debug tae attn --- examples/movie_gen/mg/models/tae/modules.py | 668 ++++++++++++++++++++ examples/movie_gen/mg/models/tae/tae.py | 60 ++ 2 files changed, 728 insertions(+) create mode 100644 examples/movie_gen/mg/models/tae/modules.py create mode 100644 examples/movie_gen/mg/models/tae/tae.py diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py new file mode 100644 index 0000000000..28a77013df --- /dev/null +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -0,0 +1,668 @@ +import logging +from packaging import version + +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + +_logger = logging.getLogger(__name__) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def nonlinearity(x): + return x * (ops.sigmoid(x)) + +class GroupNorm5d(nn.GroupNorm): + def construct(self, x): + # x (b c t h w) + x_shape = x.shape + x_ndim = x.ndim + if x_ndim == 5: + # (b c f h w) -> (b c f h*w) + x = ops.reshape(x, (x_shape[0], x_shape[1], x_shape[2], -1)) + + out = super().construct(x) + + if x_ndim == 5: + # (b c f h*w) -> (b c f h w) + out = ops.reshape(out, (x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4])) + + return out + +def Normalize(in_channels, num_groups=32): + if version.parse(ms.__version__) >= version.parse("2.3.1"): + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + else: + return GroupNorm5d(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +def rearrange_in_spatial(x): + # (b t c h w) -> (b*t c h w) + B, T, C, H, W = x.shape + x = ops.reshape(x, (B*T, C, H, W)) + return x + +def rearrange_out_spatial(x, T): + # (b*t c h w) -> (b t c h w) + BT, C, H, W = x.shape + x = ops.reshape(x, (BT//T, T, C, H, W)) + return x + +def rearrange_in_temporal(x): + # (b t c h w) -> (b*h*w c t) + B, T, C, H, W = x.shape + # (b t c h w) -> (b h w c t) + x = ops.transpose(x, (0, 3, 4, 2, 1)) + # (b h w c t) -> (b*h*w c t) + x = ops.reshape(x, (B*H*W, C, T)) + return x + +def rearrange_out_temporal(x, H, W): + # (b*h*w c t) -> (b t c h w) + BHW, C, T = x.shape + # (b*h*w c t) -> (b h w c t) + x = ops.reshape(x, (BHW // (H*W), H, W, C, T)) + # (b h w c t) -> (b t c h w) + x = ops.transpose(x, (0, 4, 3, 1, 2)) + return x + + +class TemporalConv1d(nn.Cell): + r""" + Temporal conv1d with symmetrical replicate padding + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + # pad_mode="SYMMETRIC", + # padding=0, + # dilation=1, + has_bias=True, + **kwargs, + ): + # assert dilation ==1 + assert stride == 1, 'not supported for stride > 1' + # TODO; consider stride + self.pad = nn.Pad(paddings=((2, (kernel_size-1)//2)), mode="SYMMETRIC") + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + + def construct(self, x): + r""" + Inputs: + x: (b c t h w) + Outputs: + (b c t h w) + """ + _, _, _, H, W = x.shape + x = rearrange_in_temporal(x) + + x = self.pad(x) + x = self.conv(x) + + x = rearrange_out_temporal(x, H, W) + + return x + +class Conv2_5d(nn.Cell): + r""" + Conv2.5d, a 2D spatial convolution followed by 1D temporal convolution + """ + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode="valid", + padding=0, + dilation=1, + has_bias=True, + **kwargs, + ): + super().__init__() + assert stride==1 + assert dilation==1 + # spatial conv + self.conv_spat = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias) + # temporal conv + if kernel_size > 1: + self.pad = nn.Pad(paddings=((0, 0), (0, 0), ((kernel_size-1)//2, (kernel_size-1)//2)), mode='SYMMETRIC') + self.use_pad = True + else: + self.use_pad = False + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + + def construct(self, x): + ''' + Parameters: + x: (b c t h w) + Returns: + (b c t h w) + ''' + + B, Ci, T, Hi, Wi = x.shape + # (b c t h w) -> (b t c h w) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + + # spatial conv2d + # (b t c h w) -> (b*t c h w) + x = ops.reshape(x, (B*T, Ci, Hi, Wi)) + + x = self.conv_spat(x) + + # (b*t c h w) -> (b t c h w) + _, Co, Ho, Wo = x.shape + x = ops.reshape(x, (B, T, Co, Ho, Wo)) + + # temporal conv1d + # (b t c h w) -> (b*h*w c t) + x = ops.transpose(x, (0, 3, 4, 2, 1)) # (b t c h w) -> (b h w c t) + x = ops.reshape(x, (B*Ho*Wo, Co, T)) + + if self.use_pad: + x = self.pad(x) + + x = self.conv_temp(x) + + # (b*h*w c t) -> (b t c h w) + _, _, To = x.shape + # (b*h*w c t) -> (b h w c t) + x = ops.reshape(x, (B, Ho, Wo, Co, To)) + # (b h w c t) -> (b c t h w) + x = ops.transpose(x, (0, 3, 4, 1, 2)) + + return x + + +class Upsample(nn.Cell): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + + def construct(self, x): + in_shape = x.shape[-2:] + out_shape = tuple(2 * x for x in in_shape) + x = ops.ResizeNearestNeighbor(out_shape)(x) + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Cell): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, pad_mode="valid", padding=0, has_bias=True + ) + + def construct(self, x): + if self.with_conv: + pad = ((0, 0), (0, 0), (0, 1), (0, 1)) + x = nn.Pad(paddings=pad)(x) + x = self.conv(x) + else: + x = ops.AvgPool(kernel_size=2, stride=2)(x) + return x + + +# used in vae +class ResnetBlock(nn.Cell): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + assert conv_shortcut==False + + self.norm1 = Normalize(in_channels) + self.conv1 = Conv2_5d( + in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True, + ) + + if temb_channels > 0: + self.temb_proj = nn.Dense(temb_channels, out_channels, bias_init="normal") + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(p=dropout) + self.conv2 = Conv2_5d( + out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True, + ) + if self.in_channels != self.out_channels: + # TODO: + self.nin_shortcut = Conv2_5d( + in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True + ) + + def construct(self, x): + # x: (b c t h w) + h = x + h = self.norm1(h) + h = nonlinearity(h) + + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Cell): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + self.bmm = ops.BatchMatMul() + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + + def construct(self, x): + # x (b c t h w) + h_ = x + h_ = self.norm(h_) + + # rearrange to spatial sequence (b c t h w) -> (bt c h w) + T = x.shape[2] + h_ = ops.transpose(h_, (0, 2, 1, 3, 4)) + h_ = ops.reshape(h_, (h_.shape[0]*h_.shape[1], h_.shape[2], h_.shape[3], h_.shape[4] )) + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = ops.reshape(q, (b, c, h * w)) + q = ops.transpose(q, (0, 2, 1)) # b,hw,c + k = ops.reshape(k, (b, c, h * w)) # b,c,hw + w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + w_ = w_ * (int(c) ** (-0.5)) + w_ = ops.Softmax(axis=2)(w_) + + # attend to values + v = ops.reshape(v, (b, c, h * w)) + w_ = ops.transpose(w_, (0, 2, 1)) # b,hw,hw (first hw of k, second of q) + h_ = self.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = ops.reshape(h_, (b, c, h, w)) + + h_ = self.proj_out(h_) + + # rearrange back + h_ = ops.reshape(h_, (b//T, T, c, h, w)) + + return x + h_ + + +class SpatialAttnBlock(nn.Cell): + # rewritten to reduce transpose and reshape ops + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + self.bmm = ops.BatchMatMul() + self.norm = Normalize(in_channels) + # TODO: load pretrained weight defined in Conv2d + self.to_q = nn.Dense(in_channels, in_channels, has_bias=True) + self.to_k = nn.Dense(in_channels, in_channels, has_bias=True) + self.to_v = nn.Dense(in_channels, in_channels, has_bias=True) + self.proj_out = nn.Dense(in_channels, in_channels, has_bias=True) + + self.scale = ms.Tensor(in_channels**(-0.5), dtype=ms.float32) # hidden_dim = in_channels + + def construct(self, x): + # x (b c t h w) + h_ = x + h_ = self.norm(h_) + + # rearrange h_ to (b*t h*w c) + B, C, T, H, W = x.shape + h_ = ops.transpose(h_, (0, 2, 3, 4, 1)) + ops.reshape(h_, (B*T, H*W, C)) + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + k = ops.transpose(k, (0, 2, 1)) # (bt c hw) + m = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + m = m.to(ms.float32) + m = m * self.scale + attn = ops.softmax(m, axis=-1).astype(v.dtype) # (bt nq nk) + + # attend to values (nk = nv) + h_ = self.bmm(attn, v) # (bt nq c) = (bt hw c) + h_ = self.proj_out(h_) + + # rearrange back to input shape + h_ = ops.reshape(h_, (B, T, H, W, C)) + h_ = ops.transpose(h_, (0, 4, 1, 2, 3)) + + return x + h_ + + +class TemporalAttnBlock(nn.Cell): + def __init__(self, in_channels, has_bias=True): + super().__init__() + self.in_channels = in_channels + self.bmm = ops.BatchMatMul() + # TODO: instead of GroupNorm, LayerNorm is better for tiling + self.norm = Normalize(in_channels) + # TODO: use mint.nn.Linear + self.to_q = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.to_k = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.to_v = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.proj_out = nn.Dense(in_channels, in_channels, has_bias=has_bias) + + self.scale = ms.Tensor(in_channels**(-0.5), dtype=ms.float32) # hidden_dim = in_channels + + def construct(self, x): + # x (b c t h w) + h_ = x + # TODO: use LayerNorm for (B N C) instead of GroupNorm for (B C N) + h_ = self.norm(h_) + + # (b c t h w) -> (b*h*w t c) = (B S H) + B, C, T, H, W = h_.shape + h_ = ops.transpose(h_, (0, 3, 4, 2, 1)) + h_ = ops.reshape(h_, (B*H*W, T, C)) + + # projection + q = self.to_q(h_) # (bhw t c) + k = self.to_k(h_) # (bhw t c) = (bhw nk c) + v = self.to_v(h_) # (bhw t c) = (bhw nv c) + + # compute attention + # (B S H) -> (B H S) + k = ops.transpose(k, (0, 2, 1)) # (bhw c t) + m = self.bmm(q, k) # bhw, t, t = (bhw nq nk) + + m = m.to(ms.float32) + m = m * self.scale + attn = ops.softmax(m, axis=-1).astype(v.dtype) + + # attend to values + h_ = self.bmm(attn, v) # (bhw nq c) + h_ = self.proj_out(h_) + + # rearrange back to input shape + h_ = ops.reshape(h_, (B, H, W, T, C)) + h_ = ops.transpose(h_, (0, 4, 3, 1, 2)) + + return x + h_ + +# used in vae +class Encoder(nn.Cell): + # @ms.lazy_inline() + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + attn_resolutions=[], + dropout=0.0, + resamp_with_conv=True, + in_channels=3, + resolution=256, + z_channels=4, # TODO: use 16 + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + # if use_linear_attn: attn_type = "linear" + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = Conv2_5d( + in_channels, ch, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=True, + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.CellList(auto_prefix=False) + for i_level in range(self.num_resolutions): + block = nn.CellList() + attn = nn.CellList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Cell() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + else: + down.downsample = nn.Identity() + curr_res = curr_res // 2 + down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") + self.down.append(down) + + # middle + self.mid = nn.Cell() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + def construct(self, x): + ''' + Args: + x: (b c t h w) + Returns: + (b c t h w) + ''' + # spatial and temporal conv + hs = self.conv_in(x) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs = h + if i_level != self.num_resolutions - 1: + hs = self.down[i_level].downsample(hs) + + # middle + h = hs + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + return h + + +class Decoder(nn.Cell): + # @ms.lazy_inline() + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + # if use_linear_attn: attn_type = "linear" + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + # in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + _logger.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + + # middle + self.mid = nn.Cell() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # upsampling + self.up = nn.CellList(auto_prefix=False) + for i_level in reversed(range(self.num_resolutions)): + block = nn.CellList() + attn = nn.CellList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Cell() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + else: + up.upsample = nn.Identity() + curr_res = curr_res * 2 + up.update_parameters_name(prefix=self.param_prefix + f"up.{i_level}.") + if len(self.up) != 0: + self.up.insert(0, up) + else: + self.up.append(up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) + + def construct(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + i_level = self.num_resolutions + while i_level > 0: + i_level -= 1 + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = ops.tanh(h) + return h diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py new file mode 100644 index 0000000000..8288c66229 --- /dev/null +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -0,0 +1,60 @@ +import mindspore as ms +from mindspore import nn, ops + + +SDXL_CONFIG = { + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, +} + + +class VideoAutoencoder(nn.Cell): + r""" + TAE + + Parameters: + config (`dict`): config dict + pretrained (`str`): checkpoint path + """ + def __init__( + self, + config: dict=SDXL_CONFIG, + pretrained: str=None, + ): + super().__init__() + + # encoder + self.encoder = Encoder(**config) + + # quant and post quant + self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) + self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1, pad_mode="valid", has_bias=True) + + # decoder + self.decoder = Decoder(**config) + + def encode(self, x: ms.Tensor) -> ms.Tensor: + + return x + + def decode(self, x: ms.Tensor) -> ms.Tensor: + + return x + + def construct(self, x: ms.Tensor) -> ms.Tensor: + """ + video reconstruction + + x: (b c t h w) + """ + + return x + From a50a1cf005497b66788d89e04d4a787bf880ed8f Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 25 Oct 2024 10:24:13 +0800 Subject: [PATCH 006/122] update --- .../movie_gen/mg/models/tae/autoencoder_kl.py | 120 +++++ examples/movie_gen/mg/models/tae/modules.py | 151 +++++- examples/movie_gen/mg/models/tae/vae.py | 485 ++++++++++++++++++ examples/movie_gen/tests/test_gn.py | 29 ++ examples/movie_gen/tests/test_tae.py | 137 +++++ 5 files changed, 903 insertions(+), 19 deletions(-) create mode 100644 examples/movie_gen/mg/models/tae/autoencoder_kl.py create mode 100644 examples/movie_gen/mg/models/tae/vae.py create mode 100644 examples/movie_gen/tests/test_gn.py create mode 100644 examples/movie_gen/tests/test_tae.py diff --git a/examples/movie_gen/mg/models/tae/autoencoder_kl.py b/examples/movie_gen/mg/models/tae/autoencoder_kl.py new file mode 100644 index 0000000000..779838b71b --- /dev/null +++ b/examples/movie_gen/mg/models/tae/autoencoder_kl.py @@ -0,0 +1,120 @@ +import mindspore as ms +from mindspore import nn, ops + +from ..layers.operation_selector import get_split_op +from .modules import Decoder, Encoder + +__all__ = ["AutoencoderKL"] + + +class AutoencoderKL(nn.Cell): + def __init__( + self, + ddconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + monitor=None, + use_recompute=False, + sample_deterministic=False, + ): + super().__init__() + self.image_key = image_key + self.sample_deterministic = sample_deterministic + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + # assert ddconfig["double_z"] + self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) + self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1, pad_mode="valid", has_bias=True) + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.exp = ops.Exp() + self.stdnormal = ops.StandardNormal() + self.split = get_split_op() + + if use_recompute: + self.recompute(self.encoder) + self.recompute(self.quant_conv) + self.recompute(self.post_quant_conv) + self.recompute(self.decoder) + + def recompute(self, b): + if not b._has_config_recompute: + b.recompute() + if isinstance(b, nn.CellList): + self.recompute(b[-1]) + else: + b.add_flags(output_no_recompute=True) + + def init_from_ckpt( + self, path, ignore_keys=list(), remove_prefix=["first_stage_model.", "autoencoder.", "spatial_vae.module."] + ): + # TODO: support auto download pretrained checkpoints + sd = ms.load_checkpoint(path) + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + vae_prefix = ["encoder.", "decoder.", "quant_conv.", "post_quant_conv."] + for pname in keys: + is_vae_param = False + for pf in remove_prefix: + if pname.startswith(pf): + sd[pname.replace(pf, "")] = sd.pop(pname) + is_vae_param = True + for pf in vae_prefix: + if pname.startswith(pf): + is_vae_param = True + if not is_vae_param: + sd.pop(pname) + pu, cu = ms.load_param_into_net(self, sd, strict_load=False) + print(f"Net param not loaded : {pu}") + print(f"Checkpoint param not loaded : {cu}") + print(f"Restored from {path}") + + def _encode(self, x): + # return latent distribution, N(mean, logvar) + h = self.encoder(x) + moments = self.quant_conv(h) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + + return mean, logvar + + def sample(self, mean, logvar): + # sample z from latent distribution + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + z = mean + std * self.stdnormal(mean.shape) + + return z + + def encode(self, x): + # embedding, get latent representation z + posterior_mean, posterior_logvar = self._encode(x) + if self.sample_deterministic: + return posterior_mean + z = self.sample(posterior_mean, posterior_logvar) + + return z + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def construct(self, input): + # overall pass, mostly for training + posterior_mean, posterior_logvar = self._encode(input) + z = self.sample(posterior_mean, posterior_logvar) + + recons = self.decode(z) + + return recons, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index 28a77013df..b59cf0c0d1 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -62,7 +62,7 @@ def rearrange_out_spatial(x, T): def rearrange_in_temporal(x): # (b t c h w) -> (b*h*w c t) - B, T, C, H, W = x.shape + B, C, T, H, W = x.shape # (b t c h w) -> (b h w c t) x = ops.transpose(x, (0, 3, 4, 2, 1)) # (b h w c t) -> (b*h*w c t) @@ -98,7 +98,7 @@ def __init__(self, assert stride == 1, 'not supported for stride > 1' # TODO; consider stride self.pad = nn.Pad(paddings=((2, (kernel_size-1)//2)), mode="SYMMETRIC") - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=True) def construct(self, x): r""" @@ -207,7 +207,7 @@ def construct(self, x): return x -class Downsample(nn.Cell): +class SpatialDownsample(nn.Cell): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -217,13 +217,105 @@ def __init__(self, in_channels, with_conv): in_channels, in_channels, kernel_size=3, stride=2, pad_mode="valid", padding=0, has_bias=True ) + self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1))) + def construct(self, x): + # x (b c t h w) + # TODO: reduce transpose and reshape op + B, C, T, H, W = x.shape + x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = ops.reshape(x, (B*T, C, H, W)) + if self.with_conv: - pad = ((0, 0), (0, 0), (0, 1), (0, 1)) - x = nn.Pad(paddings=pad)(x) + x = self.pad(x) x = self.conv(x) else: x = ops.AvgPool(kernel_size=2, stride=2)(x) + + # (bt c h w) -> (b c t h w) + _, Co, Ho, Wo = x.shape + x = ops.reshape(x, (B, T, Co, Ho, Wo)) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + + return x + + +class TemporalDownsample(nn.Cell): + def __init__(self, in_channels): + super().__init__() + self.ks = 3 + self.ch = in_channels + self.conv = nn.Conv1d( + in_channels, in_channels, kernel_size=self.ks, stride=2, pad_mode="valid", padding=0, has_bias=True, bias_init='zeros', + ) + # tail padding, pad with last frame + self.time_pad = self.ks - 1 + self.init_weight() + + def init_weight(self): + w = self.conv.weight + value = np.zeros(tuple(w.shape)) + # TODO: ablate with normal init + for i in range(self.ch): + value[i, i, 0, :] = 1/self.ks # (cout, cin, 1, ks) + w.set_data(ms.Tensor(value, dtype=ms.float32)) + + + def construct(self, x): + # x (b c t h w) + + # -> (bhw c t) + B, C, T, H, W = x.shape + x = ops.transpose(x, (0, 3, 4, 1, 2)) + x = ops.reshape(x, (B*H*W, C, T)) + + # tail padding + last_frame = x[:, :, -1:] + last_frame_pad = ops.cat([last_frame] * self.time_pad, axis=2) + x = ops.concat((x, last_frame_pad), axis=2) + + x = self.conv(x) + + # (bhw c t) -> (b c t h w) + _, Co, To = x.shape + x = ops.reshape(x, (B, H, W, Co, To)) + x = ops.transpose(x, (0, 3, 4, 1, 2)) + + return x + + +class TemporalUpsample(nn.Cell): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init='zeros') + # TODO: init conv weight so that it pass in image mode + self.ch = in_channels + self.init_weight() + + def init_weight(self): + w = self.conv.weight + value = np.zeros(tuple(w.shape)) + # TODO: ablate with normal init + # consider image input, make sure it's the same + for i in range(self.ch): + value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) + w.set_data(ms.Tensor(value, dtype=ms.float32)) + + def construct(self, x): + # x (b c t h w) + x = ops.interpolate(x, scale_factor=(2.0, 1.0, 1.0), mode="nearest") + + # x (b c t h w) -> (bhw c t) + B, C, T, H, W = x.shape + x = ops.transpose(x, (0, 3, 4, 1, 2)) + x = ops.reshape(x, (B*H*W, C, T)) + + x = self.conv(x) + + # x (bhw c t) -> (b c t h w) + x = ops.reshape(x, (B, H, W, C, T)) + x = ops.transpose(x, (0, 3, 4, 1, 2)) + return x @@ -280,7 +372,7 @@ def construct(self, x): return x + h -class AttnBlock(nn.Cell): +class SpatialAttnBlock(nn.Cell): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -291,6 +383,9 @@ def __init__(self, in_channels): self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.hidden_dim = in_channels + self.scale = ms.Tensor(self.hidden_dim**(-0.5), dtype=ms.float32) + def construct(self, x): # x (b c t h w) h_ = x @@ -312,8 +407,9 @@ def construct(self, x): k = ops.reshape(k, (b, c, h * w)) # b,c,hw w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = ops.Softmax(axis=2)(w_) + # TODO: use float32 for softmax + w_ = w_.to(ms.float32) * self.scale + w_ = ops.Softmax(axis=2)(w_).astype(v.dtype) # attend to values v = ops.reshape(v, (b, c, h * w)) @@ -324,12 +420,14 @@ def construct(self, x): h_ = self.proj_out(h_) # rearrange back + # -> (b t c h w) h_ = ops.reshape(h_, (b//T, T, c, h, w)) + h_ = ops.transpose(h_, (0, 2, 1, 3, 4)) return x + h_ -class SpatialAttnBlock(nn.Cell): +class SpatialAttnBlockV2(nn.Cell): # rewritten to reduce transpose and reshape ops def __init__(self, in_channels): super().__init__() @@ -337,9 +435,9 @@ def __init__(self, in_channels): self.bmm = ops.BatchMatMul() self.norm = Normalize(in_channels) # TODO: load pretrained weight defined in Conv2d - self.to_q = nn.Dense(in_channels, in_channels, has_bias=True) - self.to_k = nn.Dense(in_channels, in_channels, has_bias=True) - self.to_v = nn.Dense(in_channels, in_channels, has_bias=True) + self.q = nn.Dense(in_channels, in_channels, has_bias=True) + self.k = nn.Dense(in_channels, in_channels, has_bias=True) + self.v = nn.Dense(in_channels, in_channels, has_bias=True) self.proj_out = nn.Dense(in_channels, in_channels, has_bias=True) self.scale = ms.Tensor(in_channels**(-0.5), dtype=ms.float32) # hidden_dim = in_channels @@ -350,9 +448,9 @@ def construct(self, x): h_ = self.norm(h_) # rearrange h_ to (b*t h*w c) - B, C, T, H, W = x.shape + B, C, T, H, W = h_.shape h_ = ops.transpose(h_, (0, 2, 3, 4, 1)) - ops.reshape(h_, (B*T, H*W, C)) + h_ = ops.reshape(h_, (B*T, H*W, C)) q = self.q(h_) k = self.k(h_) @@ -384,7 +482,7 @@ def __init__(self, in_channels, has_bias=True): self.bmm = ops.BatchMatMul() # TODO: instead of GroupNorm, LayerNorm is better for tiling self.norm = Normalize(in_channels) - # TODO: use mint.nn.Linear + # TODO: compare conv1d with Dense on performance self.to_q = nn.Dense(in_channels, in_channels, has_bias=has_bias) self.to_k = nn.Dense(in_channels, in_channels, has_bias=has_bias) self.to_v = nn.Dense(in_channels, in_channels, has_bias=has_bias) @@ -427,6 +525,18 @@ def construct(self, x): return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + # assert attn_type in ["vanilla", "vanilla3D"], f"attn_type {attn_type} not supported" + _logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return nn.SequentialCell( + SpatialAttnBlock(in_channels), + TemporalAttnBlock(in_channels), + ) + else: + raise NotImplementedError + # used in vae class Encoder(nn.Cell): # @ms.lazy_inline() @@ -483,9 +593,11 @@ def __init__( down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) + down.downsample_spat = SpatialDownsample(block_in, resamp_with_conv) + down.downsample_temp = TemporalDownsample(block_in) else: - down.downsample = nn.Identity() + down.downsample_spat = nn.Identity() + down.downsample_temp = nn.Identity() curr_res = curr_res // 2 down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") self.down.append(down) @@ -506,7 +618,7 @@ def __init__( # end self.norm_out = Normalize(block_in) - self.conv_out = nn.Conv2d( + self.conv_out = Conv2_5d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, @@ -533,7 +645,8 @@ def construct(self, x): h = self.down[i_level].attn[i_block](h) hs = h if i_level != self.num_resolutions - 1: - hs = self.down[i_level].downsample(hs) + hs = self.down[i_level].downsample_spat(hs) + hs = self.down[i_level].downsample_temp(hs) # middle h = hs diff --git a/examples/movie_gen/mg/models/tae/vae.py b/examples/movie_gen/mg/models/tae/vae.py new file mode 100644 index 0000000000..d846d2fdac --- /dev/null +++ b/examples/movie_gen/mg/models/tae/vae.py @@ -0,0 +1,485 @@ +import logging +import os + +from transformers import PretrainedConfig + +import mindspore as ms +from mindspore import mint, nn, ops +from mindspore.communication import get_group_size + +from ...acceleration.communications import GatherFowardSplitBackward, SplitFowardGatherBackward +from ...acceleration.parallel_states import get_sequence_parallel_group +from ..layers.operation_selector import get_split_op +from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD +from .vae_temporal import VAE_Temporal_SD # noqa: F401 + +__all__ = ["AutoencoderKL"] + + +_logger = logging.getLogger(__name__) +SD_CONFIG = { + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, +} +SDXL_CONFIG = SD_CONFIG.copy() +SDXL_CONFIG.update({"resolution": 512}) + + +class AutoencoderKL(AutoencoderKL_SD): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.split = get_split_op() + + def init_from_ckpt(self, path, ignore_keys=list()): + if not os.path.exists(path): + raise ValueError( + "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" + ) + param_dict = ms.load_checkpoint(path) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: + _logger.warning( + f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" + ) + + def encode_with_moments_output(self, x): + """For latent caching usage""" + h = self.encoder(x) + moments = self.quant_conv(h) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + + return mean, std + + +class VideoAutoencoderKL(nn.Cell): + """ + Spatial VAE + """ + + def __init__( + self, + config=SDXL_CONFIG, + ckpt_path=None, + micro_batch_size=None, + scale_factor=0.18215, + use_recompute=False, + micro_batch_parallel=False, + sample_deterministic=False, + ): + super().__init__() + + self.module = AutoencoderKL_SD( + ddconfig=config, + embed_dim=config["z_channels"], + ckpt_path=ckpt_path, + use_recompute=use_recompute, + sample_deterministic=sample_deterministic, + ) + + self.out_channels = config["z_channels"] # self.module.config.latent_channels + self.patch_size = (1, 8, 8) + self.micro_batch_size = micro_batch_size + self.micro_batch_parallel = micro_batch_parallel + if self.micro_batch_parallel: + sp_group = get_sequence_parallel_group() + _logger.info(f"Initialize Spatial VAE model with parallel group `{sp_group}`.") + self.sp_size = get_group_size(sp_group) + self.split_forward_gather_backward = SplitFowardGatherBackward(dim=0, grad_scale="down", group=sp_group) + self.gather_forward_split_backward = GatherFowardSplitBackward(dim=0, grad_scale="up", group=sp_group) + # TODO: drop the assertion once conv3d support fp32, test with test suites + assert self.micro_batch_size == 1 + + # FIXME: "scaling_factor": 0.13025 is set in + # https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/blob/main/vae/config.json. + # This is a mistake made during the training of OpenSora v1.2. + # To re-use the trained model, we need to keep this mistake. + # For training, we should refine to 0.13025. + self.scale_factor = scale_factor + self.split = get_split_op() + self.scale_factor = 0.18215 + + @staticmethod + def rearrange_in(x): + B, C, T, H, W = x.shape + # (b c t h w) -> (b t c h w) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = ops.reshape(x, (B * T, C, H, W)) + + return x + + @staticmethod + def rearrange_out(x, B): + # x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + BT, C, H, W = x.shape + T = BT // B + x = ops.reshape(x, (B, T, C, H, W)) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + + return x + + def encode(self, x): + """ + Args: + x: (B, C, T, H, W) + Return: + (B C T H W) + + NOTE: remind to use stop gradient when invoke it + """ + # is_video = (x.ndim == 5) + + B = x.shape[0] + # B C T H W -> (B T) C H W + x = self.rearrange_in(x) + + pad_num = None + if self.micro_batch_parallel: + # select part of x for micro_batch + pad_num = self.get_pad_num(x.shape[0]) + if pad_num > 0: + x = mint.nn.functional.pad(x, (0, 0, 0, 0, 0, 0, 0, pad_num)) + x = self.split_forward_gather_backward(x) + + if self.micro_batch_size is None: + x_out = self.module.encode(x) * self.scale_factor + else: + bs = self.micro_batch_size + x_out = self.module.encode(x[:bs]) * self.scale_factor + for i in range(bs, x.shape[0], bs): + x_cur = self.module.encode(x[i : i + bs]) * self.scale_factor + x_out = ops.cat((x_out, x_cur), axis=0) + + if self.micro_batch_parallel: + x_out = self.gather_forward_split_backward(x_out) + if pad_num > 0: + x_out = x_out.narrow(0, 0, x_out.shape[0] - pad_num) + + # (B T) C H W -> B C T H W + x_out = self.rearrange_out(x_out, B=B) + + return x_out + + def decode(self, x, **kwargs): + # is_video = (x.ndim == 5) + + B = x.shape[0] + # x: (B, Z, T, H, W) + # B Z T H W -> (B T) Z H W + x = self.rearrange_in(x) + + if self.micro_batch_size is None: + x_out = self.module.decode(x / self.scale_factor) + else: + mbs = self.micro_batch_size + + x_out = self.module.decode(x[:mbs] / self.scale_factor) + for i in range(mbs, x.shape[0], mbs): + x_cur = self.module.decode(x[i : i + mbs] / self.scale_factor) + x_out = ops.cat((x_out, x_cur), axis=0) + + # (B T) Z H W -> B Z T H W + x_out = self.rearrange_out(x_out, B=B) + + return x_out + + def get_latent_size(self, input_size): + latent_size = [] + for i in range(3): + # assert ( + # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 + # ), "Input size must be divisible by patch size" + latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) + return latent_size + + def get_pad_num(self, dim_size: int) -> int: + pad = (self.sp_size - (dim_size % self.sp_size)) % self.sp_size + return pad + + +class VideoAutoencoderPipelineConfig(PretrainedConfig): + model_type = "VideoAutoencoderPipeline" + + def __init__( + self, + vae_2d=None, + vae_temporal=None, + from_pretrained=None, + freeze_vae_2d=False, + cal_loss=False, + micro_frame_size=None, + concat_posterior=False, + shift=0.0, + scale=1.0, + micro_frame_parallel=False, + sample_deterministic=False, + **kwargs, + ): + self.vae_2d = vae_2d + self.vae_temporal = vae_temporal + self.from_pretrained = from_pretrained + self.freeze_vae_2d = freeze_vae_2d + self.cal_loss = cal_loss + self.micro_frame_size = micro_frame_size + self.shift = shift + self.scale = scale + self.concat_posterior = (concat_posterior,) + self.micro_frame_parallel = micro_frame_parallel + self.sample_deterministic = sample_deterministic + super().__init__(**kwargs) + + +def build_module_from_config(config): + """ + config dict format: + - type: model class name + - others: model init args + """ + cfg = config.copy() + name = cfg.pop("type") + kwargs = cfg + + # FIXME: use importlib with path + module = eval(name)(**kwargs) + return module + + +class VideoAutoencoderPipeline(nn.Cell): + """ + Main model for spatial vae + tempral vae + """ + + # config_class = VideoAutoencoderPipelineConfig + def __init__(self, config: VideoAutoencoderPipelineConfig): + super().__init__() + self.spatial_vae = build_module_from_config(config.vae_2d) + self.temporal_vae = build_module_from_config(config.vae_temporal) + + self.cal_loss = config.cal_loss + self.micro_frame_size = config.micro_frame_size + self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] + print(f"micro_frame_size: {self.micro_frame_size}, micro_z_frame_size: {self.micro_z_frame_size}") + self.micro_frame_parallel = config.micro_frame_parallel + self.sample_deterministic = config.sample_deterministic + + if config.freeze_vae_2d: + for param in self.spatial_vae.get_parameters(): + param.requires_grad = False + + self.out_channels = self.temporal_vae.out_channels + self.split = get_split_op() + + # normalization parameters + scale = ms.Tensor(config.scale) + shift = ms.Tensor(config.shift) + if len(scale.shape) > 0: + scale = scale[None, :, None, None, None] + if len(shift.shape) > 0: + shift = shift[None, :, None, None, None] + self.scale = ms.Parameter(scale, requires_grad=False) + self.shift = ms.Parameter(shift, requires_grad=False) + self.freeze_vae_2d = config.freeze_vae_2d + self.concat_posterior = config.concat_posterior + + if self.micro_frame_parallel: + sp_group = get_sequence_parallel_group() + _logger.info(f"Initialize Temporal VAE model with parallel group `{sp_group}`.") + self.sp_size = get_group_size(sp_group) + self.split_forward_gather_backward = SplitFowardGatherBackward(dim=2, grad_scale="down", group=sp_group) + self.gather_forward_split_backward = GatherFowardSplitBackward(dim=2, grad_scale="up", group=sp_group) + if self.cal_loss: + raise NotImplementedError("Not Supported yet.") + + def encode(self, x): + if self.freeze_vae_2d: + x_z = ops.stop_gradient(self.spatial_vae.encode(x)) + else: + x_z = self.spatial_vae.encode(x) + + if self.micro_frame_parallel: + # TODO: drop assertion and add padding + assert x_z.shape[2] % self.sp_size == 0 + if self.micro_frame_size is not None: + assert x_z.shape[2] % self.micro_frame_size == 0 + x_z = self.split_forward_gather_backward(x_z) + + if self.micro_frame_size is None: + posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z) + if self.sample_deterministic: + z_out = posterior_mean + else: + z_out = self.temporal_vae.sample(posterior_mean, posterior_logvar) + + if self.cal_loss: + return z_out, posterior_mean, posterior_logvar, x_z + else: + if self.micro_frame_parallel: + z_out = self.gather_forward_split_backward(z_out) + return (z_out - self.shift) / self.scale + else: + # x_z: (b z t h w) + mfs = self.micro_frame_size + if self.cal_loss: + # TODO: fix the bug in torch, output concat of the splitted posteriors instead of the last split + posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z[:, :, :mfs]) + if self.sample_deterministic: + z_out = posterior_mean + else: + z_out = self.temporal_vae.sample(posterior_mean, posterior_logvar) + for i in range(mfs, x_z.shape[2], mfs): + posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z[:, :, i : i + mfs]) + if self.sample_deterministic: + z_cur = posterior_mean + else: + z_cur = self.temporal_vae.sample(posterior_mean, posterior_logvar) + z_out = ops.cat((z_out, z_cur), axis=2) + + return z_out, posterior_mean, posterior_logvar, x_z + else: + # no posterior cache to reduce memory in inference + z_out = self.temporal_vae.encode(x_z[:, :, :mfs]) + for i in range(mfs, x_z.shape[2], mfs): + z_cur = self.temporal_vae.encode(x_z[:, :, i : i + mfs]) + z_out = ops.cat((z_out, z_cur), axis=2) + + if self.micro_frame_parallel: + z_out = self.gather_forward_split_backward(z_out) + + return (z_out - self.shift) / self.scale + + def decode(self, z, num_frames=None): + if not self.cal_loss: + z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) + + if self.micro_frame_size is None: + x_z_out = self.temporal_vae.decode(z, num_frames=num_frames) + x = self.spatial_vae.decode(x_z_out) + if self.cal_loss: + return x, x_z_out + else: + return x + else: + mz = self.micro_z_frame_size + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=remain_frames) + num_frames -= self.micro_frame_size + + for i in range(mz, z.shape[2], mz): + remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size + x_z_cur = self.temporal_vae.decode(z[:, :, i : i + mz], num_frames=remain_frames) + x_z_out = ops.cat((x_z_out, x_z_cur), axis=2) + num_frames -= self.micro_frame_size + + x = self.spatial_vae.decode(x_z_out) + + if self.cal_loss: + return x, x_z_out + else: + return x + + def construct(self, x): + # assert self.cal_loss, "This method is only available when cal_loss is True" + z, posterior_mean, posterior_logvar, x_z = self.encode(x) + x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) + return x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z + + def get_latent_size(self, input_size): + if self.micro_frame_size is None or input_size[0] is None: + return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + else: + sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] + sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) + sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) + remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] + if remain_temporal_size[0] > 0: + remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) + sub_latent_size[0] += remain_size[0] + return sub_latent_size + + def get_temporal_last_layer(self): + return self.temporal_vae.decoder.conv_out.conv.weight + + +def OpenSoraVAE_V1_2( + micro_batch_size=4, + micro_frame_size=17, + micro_batch_parallel=False, + micro_frame_parallel=False, + ckpt_path=None, + vae2d_ckpt_path=None, + freeze_vae_2d=False, + cal_loss=False, + use_recompute=False, + sample_deterministic=False, +): + """ + ckpt_path: path to the checkpoint of the overall model (vae2d + temporal vae) + vae_2d_ckpt_path: path to the checkpoint of the vae 2d model. It will only be loaded when `ckpt_path` not provided. + """ + + if isinstance(micro_batch_size, int): + if micro_batch_size <= 0: + micro_batch_size = None + if isinstance(micro_frame_size, int): + if micro_frame_size <= 0: + micro_frame_size = None + + vae_2d = dict( + type="VideoAutoencoderKL", + config=SDXL_CONFIG, + micro_batch_size=micro_batch_size, + micro_batch_parallel=micro_batch_parallel, + use_recompute=use_recompute, + sample_deterministic=sample_deterministic, + ) + vae_temporal = dict( + type="VAE_Temporal_SD", + from_pretrained=None, + use_recompute=use_recompute, + sample_deterministic=sample_deterministic, + ) + shift = (-0.10, 0.34, 0.27, 0.98) + scale = (3.85, 2.32, 2.33, 3.06) + kwargs = dict( + vae_2d=vae_2d, + vae_temporal=vae_temporal, + freeze_vae_2d=freeze_vae_2d, + cal_loss=cal_loss, + micro_frame_size=micro_frame_size, + shift=shift, + scale=scale, + micro_frame_parallel=micro_frame_parallel, + sample_deterministic=sample_deterministic, + ) + + config = VideoAutoencoderPipelineConfig(**kwargs) + model = VideoAutoencoderPipeline(config) + + # load model weights + if (ckpt_path is not None) and (os.path.exists(ckpt_path)): + sd = ms.load_checkpoint(ckpt_path) + + # remove the added prefix in the trained checkpoint + pnames = list(sd.keys()) + for pn in pnames: + new_pn = pn.replace("autoencoder.", "").replace("_backbone.", "") + sd[new_pn] = sd.pop(pn) + + pu, cu = ms.load_param_into_net(model, sd, strict_load=False) + print(f"Net param not loaded : {pu}") + print(f"Checkpoint param not loaded : {cu}") + elif (vae2d_ckpt_path is not None) and (os.path.exists(vae2d_ckpt_path)): + sd = ms.load_checkpoint(vae2d_ckpt_path) + # TODO: add spatial_vae prefix to the param name + pu, cu = ms.load_param_into_net(model.spatial_vae, sd, strict_load=False) + + return model diff --git a/examples/movie_gen/tests/test_gn.py b/examples/movie_gen/tests/test_gn.py new file mode 100644 index 0000000000..8d55ef9a13 --- /dev/null +++ b/examples/movie_gen/tests/test_gn.py @@ -0,0 +1,29 @@ +import numpy as np +import torch +import torch.nn as nn + +# 定义输入形状 +B, C, T, H, W = 2, 3, 16, 256, 256 +x = np.random.normal(size=(B, C, T, H, W)).astype(np.float32) +x_tensor = torch.tensor(x) + +# 定义 GroupNorm 层 +group_norm = nn.GroupNorm(num_groups=3, num_channels=C) + +# 第一次 GroupNorm 操作 +y1 = group_norm(x_tensor) + +# 重新排列形状 +x_rearranged = x_tensor.permute(0, 3, 4, 1, 2).contiguous().view(B * H * W, C, T) + +# 第二次 GroupNorm 操作 +y2 = group_norm(x_rearranged) + +# 恢复形状 +# y1 = y1.view(B, C, T, H, W).permute(0, 2, 1, 3, 4).contiguous() +y2 = y2.view(B, H, W, C, T).permute(0, 3, 4, 1, 2).contiguous() + +# 比较 y1 和 y2 +print(y1.sum()) +print(y2.sum()) +print(torch.allclose(y1, y2)) diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py new file mode 100644 index 0000000000..ffe166e65d --- /dev/null +++ b/examples/movie_gen/tests/test_tae.py @@ -0,0 +1,137 @@ +import numpy as np +import mindspore as ms +from mg.models.tae.modules import Conv2_5d, ResnetBlock, SpatialAttnBlock, SpatialAttnBlockV2, TemporalAttnBlock, TemporalUpsample, TemporalDownsample, SpatialDownsample, Encoder + +from mg.models.tae.tae import SDXL_CONFIG + +def test_conv25d(): + in_shape = (B, C, T, H, W) = (2, 3, 16, 256, 256) + cout = 128 + x = np.random.normal(size=in_shape).astype(np.float32) + + ms.set_context(mode=0) + x = ms.Tensor(x) + conv2d = Conv2_5d(C, cout, 3) + + y = conv2d(x) + + print(y.shape) + +def test_resnetblock(): + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + cout = C + x = np.random.normal(size=in_shape).astype(np.float32) + + rb = ResnetBlock( + in_channels=C, + out_channels=cout, + dropout=0., + ) + + ms.set_context(mode=0) + x = ms.Tensor(x) + y = rb(x) + + print(y.shape) + print(y.mean(), y.std()) + + +def test_spatial_attn(): + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + cout = C + x = np.random.normal(size=in_shape).astype(np.float32) + + # TODO: compare time cost for v1 and v2 + # sa = SpatialAttnBlock(C) + sa = SpatialAttnBlockV2(C) + + ms.set_context(mode=0) + + x = ms.Tensor(x) + y = sa(x) + + print(y.shape) + print(y.mean(), y.std()) + + +def test_temporal_attn(): + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + cout = C + x = np.random.normal(size=in_shape).astype(np.float32) + + # TODO: compare time cost for v1 and v2 + ta = TemporalAttnBlock(C) + + ms.set_context(mode=0) + + x = ms.Tensor(x) + y = ta(x) + + print(y.shape) + print(y.mean(), y.std()) + +def test_spatial_downsample(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + x = np.random.normal(size=in_shape).astype(np.float32) + sd = SpatialDownsample(C, True) + + x = ms.Tensor(x) + y = sd(x) + + print(y.shape) + + +def test_temporal_downsample(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + x = np.random.normal(size=in_shape).astype(np.float32) + td = TemporalDownsample(C) + + print(x[0, 0, :, 0, 0]) + x = ms.Tensor(x) + y = td(x) + + print(y[0, 0, :, 0, 0]) + print(y.shape) + + + +def test_temporal_upsample(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + x = np.random.normal(size=in_shape).astype(np.float32) + tu = TemporalUpsample(C) + + print(x[0, 0, :, 0, 0]) + x = ms.Tensor(x) + y = tu(x) + + print(y[0, 0, :, 0, 0]) + print(y.shape) + + +def test_encoder(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 3, 1, 64, 64) + x = np.random.normal(size=in_shape).astype(np.float32) + enc = Encoder(**SDXL_CONFIG) + + x = ms.Tensor(x) + y = enc(x) + + print(y.shape) + + + +if __name__ == "__main__": + # test_conv25d() + # test_resnetblock() + # test_spatial_attn() + # test_temporal_attn() + # test_spatial_downsample() + # test_temporal_downsample() + # test_temporal_upsample() + test_encoder() + + From 3676a9f9051b3b2583f139d6680d2c621b2e19c6 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 25 Oct 2024 10:26:11 +0800 Subject: [PATCH 007/122] useless change --- .../animatediff/ad/models/diffusion/ddpm.py | 5 +- .../opensora_hpcai/opensora/models/vae/vae.py | 60 ++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/examples/animatediff/ad/models/diffusion/ddpm.py b/examples/animatediff/ad/models/diffusion/ddpm.py index 8010e347d4..ae23a9d6ca 100644 --- a/examples/animatediff/ad/models/diffusion/ddpm.py +++ b/examples/animatediff/ad/models/diffusion/ddpm.py @@ -365,7 +365,10 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs """ # 1. get image/video latents z using vae - z = self.get_latents(x) + if self.emb_cache: + z = x + else: + z = self.get_latents(x) # 2. sample timestep and add noise to latents t = self.uniform_int( diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index d846d2fdac..d26b89d4e5 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -19,7 +19,7 @@ _logger = logging.getLogger(__name__) SD_CONFIG = { "double_z": True, - "z_channels": 4, + "z_channels": 4, # TODO: set 16 "resolution": 256, "in_channels": 3, "out_ch": 3, @@ -33,34 +33,6 @@ SDXL_CONFIG.update({"resolution": 512}) -class AutoencoderKL(AutoencoderKL_SD): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.split = get_split_op() - - def init_from_ckpt(self, path, ignore_keys=list()): - if not os.path.exists(path): - raise ValueError( - "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" - ) - param_dict = ms.load_checkpoint(path) - param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) - if param_not_load or ckpt_not_load: - _logger.warning( - f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" - ) - - def encode_with_moments_output(self, x): - """For latent caching usage""" - h = self.encoder(x) - moments = self.quant_conv(h) - mean, logvar = self.split(moments, moments.shape[1] // 2, 1) - logvar = ops.clip_by_value(logvar, -30.0, 20.0) - std = self.exp(0.5 * logvar) - - return mean, std - - class VideoAutoencoderKL(nn.Cell): """ Spatial VAE @@ -483,3 +455,33 @@ def OpenSoraVAE_V1_2( pu, cu = ms.load_param_into_net(model.spatial_vae, sd, strict_load=False) return model + + +class AutoencoderKL(AutoencoderKL_SD): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.split = get_split_op() + + def init_from_ckpt(self, path, ignore_keys=list()): + if not os.path.exists(path): + raise ValueError( + "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" + ) + param_dict = ms.load_checkpoint(path) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: + _logger.warning( + f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" + ) + + def encode_with_moments_output(self, x): + """For latent caching usage""" + h = self.encoder(x) + moments = self.quant_conv(h) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + + return mean, std + + From a5655fbc3a0b9e0306d79f7bcdebe3b585245685 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 29 Oct 2024 10:53:41 +0800 Subject: [PATCH 008/122] continue work from llama3_movie_pr_20241029 --- examples/moviegen/moviegen/__init__.py | 1 + examples/moviegen/moviegen/models/__init__.py | 2 + .../moviegen/models/llama/__init__.py | 1 + .../moviegen/models/llama/activation.py | 28 + .../moviegen/moviegen/models/llama/block.py | 512 +++++++++++++++ .../moviegen/moviegen/models/llama/network.py | 593 ++++++++++++++++++ .../moviegen/models/text_encoders/__init__.py | 1 + .../models/text_encoders/text_projector.py | 43 ++ .../moviegen/moviegen/parallel/__init__.py | 2 + examples/moviegen/moviegen/parallel/layers.py | 398 ++++++++++++ .../moviegen/parallel/parallel_states.py | 39 ++ .../moviegen/moviegen/pipelines/__init__.py | 1 + .../moviegen/pipelines/train_pipeline.py | 59 ++ .../moviegen/moviegen/schedulers/__init__.py | 1 + .../moviegen/schedulers/rectified_flow.py | 147 +++++ .../parallel/run_test_llama3_parallel.sh | 13 + .../run_test_llama3_parallel_block.sh | 13 + .../run_test_llama3_parallel_layer.sh | 13 + .../tests/parallel/test_llama3_parallel.py | 107 ++++ .../parallel/test_llama3_parallel_block.py | 107 ++++ .../parallel/test_llama3_parallel_layer.py | 125 ++++ examples/moviegen/tests/parallel/utils.py | 32 + .../moviegen/tests/ut/test_byt5_pynative.py | 85 +++ .../moviegen/tests/ut/test_llama3_forward.py | 16 + examples/moviegen/tests/ut/test_rflow.py | 27 + .../moviegen/tests/ut/test_ul2_pynative.py | 83 +++ .../moviegen/tools/download_convert_st.py | 323 ++++++++++ .../opensora/acceleration/parallel_states.py | 42 +- .../datasets/video_dataset_refactored.py | 26 +- .../opensora/models/stdit/__init__.py | 1 + .../opensora/models/stdit/stdit_llama3.py | 87 +++ .../opensora/models/text_encoder/t5.py | 15 +- .../opensora/pipelines/infer_pipeline.py | 2 +- .../opensora/pipelines/train_pipeline.py | 4 + .../opensora/schedulers/rectified_flow.py | 2 + examples/opensora_hpcai/scripts/args_train.py | 22 +- examples/opensora_hpcai/scripts/infer_t5.py | 56 +- examples/opensora_hpcai/scripts/inference.py | 95 ++- examples/opensora_hpcai/scripts/train.py | 61 +- mindone/trainers/zero.py | 14 +- mindone/transformers/modeling_utils.py | 2 +- mindone/transformers/models/t5/modeling_t5.py | 3 +- 42 files changed, 3133 insertions(+), 71 deletions(-) create mode 100644 examples/moviegen/moviegen/__init__.py create mode 100644 examples/moviegen/moviegen/models/__init__.py create mode 100644 examples/moviegen/moviegen/models/llama/__init__.py create mode 100644 examples/moviegen/moviegen/models/llama/activation.py create mode 100644 examples/moviegen/moviegen/models/llama/block.py create mode 100644 examples/moviegen/moviegen/models/llama/network.py create mode 100644 examples/moviegen/moviegen/models/text_encoders/__init__.py create mode 100644 examples/moviegen/moviegen/models/text_encoders/text_projector.py create mode 100644 examples/moviegen/moviegen/parallel/__init__.py create mode 100644 examples/moviegen/moviegen/parallel/layers.py create mode 100644 examples/moviegen/moviegen/parallel/parallel_states.py create mode 100644 examples/moviegen/moviegen/pipelines/__init__.py create mode 100644 examples/moviegen/moviegen/pipelines/train_pipeline.py create mode 100644 examples/moviegen/moviegen/schedulers/__init__.py create mode 100644 examples/moviegen/moviegen/schedulers/rectified_flow.py create mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel.sh create mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh create mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel.py create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_block.py create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_layer.py create mode 100644 examples/moviegen/tests/parallel/utils.py create mode 100644 examples/moviegen/tests/ut/test_byt5_pynative.py create mode 100644 examples/moviegen/tests/ut/test_llama3_forward.py create mode 100644 examples/moviegen/tests/ut/test_rflow.py create mode 100644 examples/moviegen/tests/ut/test_ul2_pynative.py create mode 100644 examples/moviegen/tools/download_convert_st.py create mode 100644 examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py diff --git a/examples/moviegen/moviegen/__init__.py b/examples/moviegen/moviegen/__init__.py new file mode 100644 index 0000000000..aed4fa323c --- /dev/null +++ b/examples/moviegen/moviegen/__init__.py @@ -0,0 +1 @@ +from .models import * diff --git a/examples/moviegen/moviegen/models/__init__.py b/examples/moviegen/moviegen/models/__init__.py new file mode 100644 index 0000000000..00a5d9eab3 --- /dev/null +++ b/examples/moviegen/moviegen/models/__init__.py @@ -0,0 +1,2 @@ +from .llama import * +from .text_encoders import * diff --git a/examples/moviegen/moviegen/models/llama/__init__.py b/examples/moviegen/moviegen/models/llama/__init__.py new file mode 100644 index 0000000000..6cf34ce83b --- /dev/null +++ b/examples/moviegen/moviegen/models/llama/__init__.py @@ -0,0 +1 @@ +from .network import * diff --git a/examples/moviegen/moviegen/models/llama/activation.py b/examples/moviegen/moviegen/models/llama/activation.py new file mode 100644 index 0000000000..7b54d885a1 --- /dev/null +++ b/examples/moviegen/moviegen/models/llama/activation.py @@ -0,0 +1,28 @@ +import logging +from collections import OrderedDict + +import mindspore.mint as mint +import mindspore.nn as nn +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +class QuickGELU(nn.Cell): + def construct(self, x: Tensor): + return x * mint.sigmoid(1.702 * x) + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "quick_gelu": QuickGELU, + "gelu": nn.GELU, + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py new file mode 100644 index 0000000000..7f10826401 --- /dev/null +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -0,0 +1,512 @@ +import logging +from typing import Optional, Tuple + +from moviegen.parallel import ( + ColumnParallelLinear, + FusedColumnParallelLinear, + FusedRowParallelLinear, + GatherForwardReduceScatterBackward, + RowParallelLinear, +) + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.communication import GlobalComm +from mindspore.ops.operations.nn_ops import FlashAttentionScore + +from .activation import ACT2FN + +logger = logging.getLogger(__name__) + + +class LlamaRMSNorm(nn.Cell): + def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.weight = Parameter(mint.ones(hidden_size, dtype=dtype)) + self.variance_epsilon = eps + + def construct(self, hidden_states: Tensor) -> Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = mint.pow(hidden_states, 2) + variance = mint.mean(variance, dim=-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 8192, + hidden_size: int = 3072, + hidden_act: str = "silu", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False, dtype=dtype) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class TensorParallelLlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 8192, + hidden_size: int = 3072, + hidden_act: str = "silu", + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False, input_is_parallel=True, group=group, dtype=dtype + ) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + def load_weight_from_non_parallel_cell(self, target: LlamaMLP): + self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) + self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) + self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) + + +class FusedTensorParallelLlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 8192, + hidden_size: int = 3072, + hidden_act: str = "silu", + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = FusedColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype + ) + self.up_proj = FusedColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype + ) + self.down_proj = FusedRowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dim=dim, + group=group, + dtype=dtype, + ) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + def load_weight_from_non_parallel_cell(self, target: LlamaMLP): + self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) + self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) + self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) + + +def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :] + hidden_states = mint.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) + hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) + return hidden_states + + +class LlamaAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype) + self.k_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.v_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = mint.permute(key_states, (0, 1, 3, 2)) + attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim)) + + # upcast attention to fp32 + attn_weights = attn_weights.to(ms.float32) + attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) + attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = mint.matmul(attn_weights, value_states) + + attn_output = mint.permute(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class ContextParallelLlamaAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype) + self.k_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.v_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) + + self.gather_forward_reduce_scatter_backward = GatherForwardReduceScatterBackward(dim=1, group=group) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + key_states = self.gather_forward_reduce_scatter_backward(key_states) + value_states = self.gather_forward_reduce_scatter_backward(value_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = mint.permute(key_states, (0, 1, 3, 2)) + attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim)) + + # upcast attention to fp32 + attn_weights = attn_weights.to(ms.float32) + attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) + attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = mint.matmul(attn_weights, value_states) + + attn_output = mint.permute(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaFlashAttention(LlamaAttention): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + self.flash_attention = FlashAttentionScore( + self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + ) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Reshape to the expected shape and dtype for Flash Attention + query_states = mint.permute(query_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class ContextParallelLlamaFlashAttention(ContextParallelLlamaAttention): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + group=group, + dtype=dtype, + ) + self.flash_attention = FlashAttentionScore( + self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + ) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + key_states = self.gather_forward_reduce_scatter_backward(key_states) + value_states = self.gather_forward_reduce_scatter_backward(value_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Reshape to the expected shape and dtype for Flash Attention + query_states = mint.permute(query_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class PatchEmbed3D(nn.Cell): + def __init__( + self, + patch_size: Tuple[int, int, int] = (1, 2, 2), + in_channels: int = 8, + hidden_size: int = 4096, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + pad_mode="pad", + has_bias=False, + dtype=dtype, + ) + + def construct(self, x: Tensor) -> Tensor: + _, t, _, h, w = x.shape + assert t % self.patch_size[0] == 0 + assert h % self.patch_size[1] == 0 + assert w % self.patch_size[2] == 0 + + x = mint.permute(x, (0, 2, 1, 3, 4)) + x = self.proj(x) # (B C T H W) + x = mint.flatten(x, start_dim=2) + x = mint.permute(x, (0, 2, 1)) + return x + + +class LinearPatchEmbed3D(nn.Cell): + def __init__( + self, + patch_size: Tuple[int, int, int] = (1, 2, 2), + in_channels: int = 8, + hidden_size: int = 4096, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.proj = mint.nn.Linear( + patch_size[0] * patch_size[1] * patch_size[2] * in_channels, hidden_size, bias=False, dtype=dtype + ) + + def construct(self, x: Tensor) -> Tensor: + b, t, c, h, w = x.shape + assert t % self.patch_size[0] == 0 + assert h % self.patch_size[1] == 0 + assert w % self.patch_size[2] == 0 + + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + x = ops.reshape(x, (b, nt, p0, c, nh, p1, nw, p2)) + x = mint.permute(x, (0, 1, 4, 6, 3, 2, 5, 7)) # (B, nt, nh, nw, c, p0, p1, p2) + x = ops.reshape(x, (b, nt * nh * nw, -1)) + x = self.proj(x) + return x + + +class TimestepEmbedder(nn.Cell): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + hidden_act: str = "silu", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.mlp = nn.SequentialCell( + mint.nn.Linear(frequency_embedding_size, hidden_size, bias=False, dtype=dtype), + ACT2FN[hidden_act], + mint.nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + self.dtype = dtype + + @staticmethod + def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor: + half = dim // 2 + freqs = mint.exp(-mint.log(Tensor(max_period)) * mint.arange(start=0, end=half, dtype=ms.float32) / half) + args = ops.unsqueeze(t, 1).to(ms.float32) * ops.unsqueeze(freqs, 0) + embedding = mint.cat([mint.cos(args), mint.sin(args)], dim=-1) + if dim % 2: + embedding = mint.cat([embedding, mint.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def construct(self, t: Tensor) -> Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.dtype)) + return t_emb + + +class CaptionEmbedder(nn.Cell): + def __init__( + self, + in_channels: int, + hidden_size: int, + eps: float = 1e-6, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.proj = nn.SequentialCell( + mint.nn.Linear(in_channels, hidden_size, bias=False, dtype=dtype), + LlamaRMSNorm((hidden_size,), eps=eps, dtype=dtype), + ) + + def construct(self, caption: Tensor) -> Tensor: + caption_emb = self.proj(caption) + return caption_emb diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py new file mode 100644 index 0000000000..f1a3daf485 --- /dev/null +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from typing import Literal, Optional, Tuple, Union + +import numpy as np +from moviegen.parallel import GatherForwardSplitBackward, SplitForwardGatherBackward +from moviegen.parallel.parallel_states import get_model_parallel_group + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor, load_checkpoint, load_param_into_net +from mindspore.communication import GlobalComm, get_group_size + +from mindone.models.utils import normal_, zeros_ + +from .activation import ACT2FN +from .block import ( + CaptionEmbedder, + ContextParallelLlamaAttention, + ContextParallelLlamaFlashAttention, + FusedTensorParallelLlamaMLP, + LinearPatchEmbed3D, + LlamaAttention, + LlamaFlashAttention, + LlamaMLP, + LlamaRMSNorm, + PatchEmbed3D, + TensorParallelLlamaMLP, + TimestepEmbedder, +) + +__all__ = ["LlamaModel", "llama3_1B", "llama3_5B", "llama3_30B"] + +Llama_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention": LlamaFlashAttention, +} + +CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES = { + "eager": ContextParallelLlamaAttention, + "flash_attention": ContextParallelLlamaFlashAttention, +} + + +def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + return x * (1 + scale) + shift + + +class LlamaDecoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + attention_bias: bool = False, + hidden_act: str = "silu", + attn_implementation: Literal["eager", "flash_attention"] = "eager", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + + self.self_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + self.mlp = LlamaMLP( + intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype + ) + + self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size), dtype=dtype) / hidden_size**0.5) + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + def construct( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + modulation_parameters: Tensor, + position_embedding: Tensor, + ) -> Tensor: + B = hidden_states.shape[0] + + # 3.1.3 Positional Embedding + hidden_states = hidden_states + position_embedding + + # 3.1.3 Adaptive Layer Norm + modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + modulation_parameters.reshape(B, 6, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) + + # Self Attention (Bi-Directional Attention) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) + hidden_states = self.self_attn(hidden_states) + hidden_states = gate_msa * hidden_states + hidden_states = residual + hidden_states + + # 3.1.3 Cross Attention + residual = hidden_states + hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp) + hidden_states = self.mlp(hidden_states) + hidden_states = gate_mlp * hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class ModelParallelLlamaDecoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + attention_bias: bool = False, + hidden_act: str = "silu", + attn_implementation: Literal["eager", "flash_attention"] = "eager", + fused_tensor_parallel: bool = True, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + + # 3.1.6 Context Parallelism + self.self_attn = CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + # 3.1.6 Tensor Parallelism + if fused_tensor_parallel: + self.mlp = FusedTensorParallelLlamaMLP( + intermediate_size=intermediate_size, + hidden_size=hidden_size, + hidden_act=hidden_act, + dim=1, + group=group, + dtype=dtype, + ) + else: + self.mlp = TensorParallelLlamaMLP( + intermediate_size=intermediate_size, + hidden_size=hidden_size, + hidden_act=hidden_act, + group=group, + dtype=dtype, + ) + + self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size), dtype=dtype) / hidden_size**0.5) + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + if not fused_tensor_parallel: + self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=group) + self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=group) + else: + self.split_forward_gather_backward = nn.Identity() + self.gather_forward_split_backward = nn.Identity() + + def construct( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + modulation_parameters: Tensor, + position_embedding: Tensor, + ) -> Tensor: + B = hidden_states.shape[0] + + # 3.1.3 Positional Embedding + hidden_states = hidden_states + position_embedding + + # 3.1.3 Adaptive Layer Norm + modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + modulation_parameters.reshape(B, 6, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) + + # Self Attention (Bi-Directional Attention) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) + hidden_states = self.self_attn(hidden_states) + hidden_states = gate_msa * hidden_states + hidden_states = residual + hidden_states + + # 3.1.3 Cross Attention + residual = hidden_states + hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp) + hidden_states = self.gather_forward_split_backward(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.split_forward_gather_backward(hidden_states) + hidden_states = gate_mlp * hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +class LlamaFinalLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + patch_size: Tuple[int, int, int] = (1, 2, 2), + out_channels: int = 8, + rms_norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.proj = nn.Dense( + hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype + ) + self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size), dtype=dtype) / hidden_size**0.5) + + def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): + shift, scale = mint.chunk( + ops.unsqueeze(self.scale_shift_table, 0) + ops.unsqueeze(timestep_embedding, 1), 2, dim=1 + ) + hidden_states = t2i_modulate(self.input_layernorm(hidden_states), shift, scale) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class LlamaModel(nn.Cell): + def __init__( + self, + in_channels: int = 8, + out_channels: Optional[int] = None, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_hidden_layers: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + attention_bias: bool = False, + hidden_act: str = "silu", + initializer_range: float = 0.02, + patch_size: Tuple[int, int, int] = (1, 2, 2), + max_length: Tuple[int, int, int] = (128, 64, 64), + caption_channels: int = 4096, + attn_implementation: Literal["eager", "flash_attention"] = "eager", + gradient_checkpointing: bool = False, + use_linear_patch_embedder: bool = True, + model_parallelism: bool = False, + fused_tensor_parallel: bool = True, + post_init_weight: bool = True, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.max_length = max_length + self.model_parallelism = model_parallelism + self._dtype = dtype + mp_group = get_model_parallel_group() + + if self.model_parallelism: + self.layers = nn.CellList( + [ + ModelParallelLlamaDecoderLayer( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + fused_tensor_parallel=fused_tensor_parallel, + group=mp_group, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) + else: + self.layers = nn.CellList( + [ + LlamaDecoderLayer( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.final_layer = LlamaFinalLayer( + hidden_size=self.hidden_size, + patch_size=self.patch_size, + out_channels=self.out_channels, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + ) + + self.pos_embedding_table_t = nn.Embedding(max_length[0], self.hidden_size, dtype=dtype) + self.pos_embedding_table_h = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) + self.pos_embedding_table_w = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) + + if use_linear_patch_embedder: + self.latent_embedder = LinearPatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) + else: + self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) + + self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) + self.adaLN_modulation = nn.SequentialCell( + ACT2FN[hidden_act], mint.nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=False, dtype=dtype) + ) + + # TODO: drop this + self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, eps=rms_norm_eps, dtype=dtype) + + if self.model_parallelism: + self.group_size = get_group_size(mp_group) + self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=mp_group) + self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=mp_group) + + # post-init + if post_init_weight: + self.initializer_range = initializer_range + self.init_weights() + + # recompute + if gradient_checkpointing: + self.layers.recompute() + + @property + def dtype(self): + return self._dtype + + def init_weights(self): + std = self.initializer_range + + def _init_weights(module): + if isinstance(module, mint.nn.Linear): + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.weight) + elif isinstance(module, nn.Embedding): + normal_(module.embedding_table, mean=0.0, std=std) + + self.apply(_init_weights) + + # Initialize patch_embed like nn.Dense (instead of nn.Conv3d): + normal_(self.latent_embedder.proj.weight, mean=0.0, std=std) + if self.latent_embedder.proj.bias is not None: + zeros_(self.latent_embedder.proj.bias) + + # Zero-out adaLN modulation block: + zeros_(self.adaLN_modulation[-1].weight) + if self.adaLN_modulation[-1].bias is not None: + zeros_(self.adaLN_modulation[-1].bias) + + # Zero-out final block as DiT does + zeros_(self.final_layer.proj.weight) + if self.final_layer.proj.bias is not None: + zeros_(self.final_layer.proj.bias) + + def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: + # 3.1.3 + _, t, _, h, w = latent_embedding.shape + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + + assert nt < self.max_length[0] + assert nh < self.max_length[1] + assert nw < self.max_length[2] + + t_inds = mint.arange(nt, dtype=ms.int64) + h_inds = mint.arange(nh, dtype=ms.int64) + w_inds = mint.arange(nw, dtype=ms.int64) + + position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") + position_ids = ops.stack(position_ids, axis=-1) + position_ids = ops.reshape(position_ids, (-1, 3)) + + t_inds, h_inds, w_inds = ops.unbind(position_ids, dim=-1) + pos_embed_t = self.pos_embedding_table_t(t_inds) + pos_embed_h = self.pos_embedding_table_h(h_inds) + pos_embed_w = self.pos_embedding_table_w(w_inds) + pos_embed = pos_embed_t + pos_embed_h + pos_embed_w + pos_embed = ops.unsqueeze(pos_embed, 0) + return pos_embed + + def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor: + """ + hidden_states: (N, T, patch_size[0] * patch_size[1] * patch_size[2] * C) + """ + bs = hidden_states.shape[0] + c = self.out_channels + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + + hidden_states = ops.reshape(hidden_states, (bs, nt, nh, nw, p0, p1, p2, c)) + # bs, nt, p0, c, nh, p1, nw, p2, c + hidden_states = mint.permute(hidden_states, (0, 1, 4, 7, 2, 5, 3, 6)) + output = ops.reshape(hidden_states, (bs, nt * p0, c, nh * p1, nw * p2)) + return output + + def construct( + self, + latent_embedding: Tensor, + timestep: Tensor, + text_embedding: Tensor, + ) -> Tensor: + """ + latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video) + timestep: (N,) tensor to indicate denoising step + text_embedding: (N, L, C') tensor of the text embedding + """ + _, t, _, h, w = latent_embedding.shape + + # create position embedding to be shared across the decoder layers + position_embedding = self.learnable_position_embedding(latent_embedding) + position_embedding = position_embedding.to(latent_embedding.dtype) + + # patchify and embed latent in transformer hidden dim. + latent_embedding = self.latent_embedder(latent_embedding) + + # 6.1.2 shared timestep embedding & modulation. It does not mention the detail structure, we follow PixArt-Alpha here + timestep_embedding = self.timestep_embedder(timestep) + modulation_parameters = self.adaLN_modulation(timestep_embedding) + + # 3.1.4 text embedding + text_embedding = self.caption_embedder(text_embedding) + + # main blocks + hidden_states = latent_embedding + + # 3.1.6 Sequence Parallelism Start + if self.model_parallelism: + assert hidden_states.shape[1] % self.group_size == 0 + hidden_states = self.split_forward_gather_backward(hidden_states) + position_embedding = self.split_forward_gather_backward(position_embedding) + + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding) + + # 3.1.6 Sequence Parallelism End + if self.model_parallelism: + hidden_states = self.gather_forward_split_backward(hidden_states) + + # final block + hidden_states = self.final_layer(hidden_states, timestep_embedding) + + # unpatchify + output = self.unpatchify(hidden_states, t, h, w) + return output + + def construct_with_cfg( + self, + latent_embedding: Tensor, + timestep: Tensor, + text_embedding: Tensor, + cfg_scale: Union[Tensor, float] = 7.5, + ) -> Tensor: + """ + latent_embedding: (2N, T, C, H, W) tensor of inputs (latent representations of video) + timestep: (2N,) tensor to indicate denoising step + text_embedding: (2N, L, C') tensor of the text embedding + cfg_scale: CFG scale + """ + model_out = self(latent_embedding, timestep, text_embedding) + cond_model_out, uncond_model_out = mint.chunk(model_out, 2, dim=0) + model_out = uncond_model_out + cfg_scale * (cond_model_out - uncond_model_out) + model_out = mint.tile(model_out, (2, 1, 1, 1, 1)) + return model_out + + def load_weight_from_non_parallel_cell(self, target: LlamaModel): + param_dict = target.parameters_dict() + + # filter tensor-parallel block + names = ["gate_proj", "up_proj", "down_proj"] + param_dict = {k: v for k, v in param_dict.items() if not any([name in k for name in names])} + load_param_into_net(self, param_dict) + + # load tensor-parallel block + for layer, target_layer in zip(self.layers, target.layers): + layer.mlp.load_weight_from_non_parallel_cell(target_layer.mlp) + + +def llama3_1B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=1536, + initializer_range=0.02, + intermediate_size=4096, + num_attention_heads=16, + num_hidden_layers=24, + num_key_value_heads=16, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model + + +def llama3_5B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=3072, + initializer_range=0.02, + intermediate_size=8192, + num_attention_heads=24, + num_hidden_layers=32, + num_key_value_heads=24, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model + + +def llama3_30B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=6144, + initializer_range=0.02, + intermediate_size=16384, + num_attention_heads=48, + num_hidden_layers=48, + num_key_value_heads=48, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model diff --git a/examples/moviegen/moviegen/models/text_encoders/__init__.py b/examples/moviegen/moviegen/models/text_encoders/__init__.py new file mode 100644 index 0000000000..c26604e0d0 --- /dev/null +++ b/examples/moviegen/moviegen/models/text_encoders/__init__.py @@ -0,0 +1 @@ +from .text_projector import TextProjector diff --git a/examples/moviegen/moviegen/models/text_encoders/text_projector.py b/examples/moviegen/moviegen/models/text_encoders/text_projector.py new file mode 100644 index 0000000000..2b6b5b1945 --- /dev/null +++ b/examples/moviegen/moviegen/models/text_encoders/text_projector.py @@ -0,0 +1,43 @@ +from typing import Type + +import mindspore as ms +from mindspore import Tensor, mint, nn + + +class TextProjector(nn.Cell): + def __init__( + self, + ul2_in_features: int = 4096, + metaclip_in_features: int = 1280, + byt5_in_features: int = 1472, + out_features: int = 6144, + layer_norm: Type[nn.Cell] = mint.nn.LayerNorm, + norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ): + super().__init__() + self.ul2_projector = nn.SequentialCell( + [ + mint.nn.Linear(ul2_in_features, out_features, bias=False, dtype=dtype), + layer_norm((out_features,), eps=norm_eps, dtype=dtype), + ] + ) + self.metaclip_projector = nn.SequentialCell( + [ + mint.nn.Linear(metaclip_in_features, out_features, bias=False, dtype=dtype), + layer_norm((out_features,), eps=norm_eps, dtype=dtype), + ] + ) + self.byt5_projector = nn.SequentialCell( + [ + mint.nn.Linear(byt5_in_features, out_features, bias=False, dtype=dtype), + layer_norm((out_features,), eps=norm_eps, dtype=dtype), + ] + ) + + def construct(self, ul2_text: Tensor, metaclip_text: Tensor, byt5_text: Tensor) -> Tensor: + ul2_hidden_states = self.ul2_projector(ul2_text) + metaclip_hidden_states = self.metaclip_projector(metaclip_text) + byt5_hidden_states = self.byt5_projector(byt5_text) + + return mint.cat((ul2_hidden_states, metaclip_hidden_states, byt5_hidden_states), dim=1) diff --git a/examples/moviegen/moviegen/parallel/__init__.py b/examples/moviegen/moviegen/parallel/__init__.py new file mode 100644 index 0000000000..de133abd08 --- /dev/null +++ b/examples/moviegen/moviegen/parallel/__init__.py @@ -0,0 +1,2 @@ +from .layers import * +from .parallel_states import * diff --git a/examples/moviegen/moviegen/parallel/layers.py b/examples/moviegen/moviegen/parallel/layers.py new file mode 100644 index 0000000000..d238d47391 --- /dev/null +++ b/examples/moviegen/moviegen/parallel/layers.py @@ -0,0 +1,398 @@ +import numbers +from typing import Callable, Literal, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.initializer import Initializer +from mindspore.communication import GlobalComm, get_group_size, get_rank + +__all__ = [ + "SplitForwardGatherBackward", + "GatherForwardSplitBackward", + "GatherForwardReduceScatterBackward", + "ColumnParallelLinear", + "RowParallelLinear", + "FusedColumnParallelLinear", + "FusedRowParallelLinear", +] + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: + dim_size = x.shape[dim] + tensor_list = x.split(dim_size // world_size, axis=dim) + x = tensor_list[rank] + return x + + +class _CopyToModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = self.reduce(dout) + return (dout,) + + +class _ReduceFromModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + return self.reduce(x) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + return (dout,) + + +class _ScatterToModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.gather = ops.AllGather(group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + + def construct(self, x: Tensor) -> Tensor: + return _split(x, -1, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _communicate_along_dim(dout, -1, self.gather) + return (dout,) + + +class _GatherFromModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.gather = ops.AllGather(group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, -1, self.gather) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _split(dout, -1, self.rank, self.world_size) + return (dout,) + + +class _GatherToModelParallelRegion(nn.Cell): + def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.dim = dim + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.scale = self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, self.dim, self.gather) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) + return (dout,) + + +class _ReduceScatterFromModelParallelRegion(nn.Cell): + def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.dim = dim + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, self.dim, self.reduce_scatter) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class SplitForwardGatherBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _split(x, self.dim, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class GatherForwardSplitBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _split(dout, self.dim, self.rank, self.world_size) + return (dout,) + + +class GatherForwardReduceScatterBackward(nn.Cell): + def __init__(self, dim: int = 0, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) + return (dout,) + + +class ColumnParallelLinear(nn.Cell): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + gather_output: bool = True, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert out_features % self.world_size == 0 + self.out_features_per_partition = out_features // self.world_size + self.gather_output = gather_output + + self.copy_to_tensor_parallel_region = _CopyToModelParallelRegion(group=group) + self.linear = mint.nn.Linear( + in_features, + self.out_features_per_partition, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if self.gather_output: + self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + x = self.copy_to_tensor_parallel_region(x) + x = self.linear(x) + if self.gather_output: + x = self.gather_from_tensor_parallel_region(x) + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] + self.linear.bias.set_data(bias) + + +class FusedColumnParallelLinear(nn.Cell): + """For tensor parallel using sequence parallel input + It is a fused operation of gather_forward_split_backward & allreduce backward + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + gather_output: bool = True, + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert out_features % self.world_size == 0 + self.out_features_per_partition = out_features // self.world_size + self.gather_output = gather_output + + self.gather_to_tensor_parallel_region = _GatherToModelParallelRegion(dim=dim, group=group) + self.linear = mint.nn.Linear( + in_features, + self.out_features_per_partition, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if self.gather_output: + self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + x = self.gather_to_tensor_parallel_region(x) + x = self.linear(x) + if self.gather_output: + x = self.gather_from_tensor_parallel_region(x) + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] + self.linear.bias.set_data(bias) + + +class RowParallelLinear(nn.Cell): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + input_is_parallel: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert in_features % self.world_size == 0 + self.in_features_per_partition = in_features // self.world_size + self.input_is_parallel = input_is_parallel + + self.reduce_from_tensor_parallel_region = _ReduceFromModelParallelRegion(group=group) + self.linear = mint.nn.Linear( + self.in_features_per_partition, + out_features, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if not self.input_is_parallel: + self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + if not self.input_is_parallel: + x = self.scatter_to_tensor_parallel_region(x) + x = self.linear.dense(x, self.linear.weight) + x = self.reduce_from_tensor_parallel_region(x) + if self.linear.bias is not None: + x = x + self.linear.bias + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + self.linear.bias.set_data(target.bias) + + +class FusedRowParallelLinear(nn.Cell): + """For tensor parallel to sequence parallel output + It is a fused operation of split_forward_gather_backward & allreduce forward + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + input_is_parallel: bool = False, + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert in_features % self.world_size == 0 + self.in_features_per_partition = in_features // self.world_size + self.input_is_parallel = input_is_parallel + + self.reduce_from_tensor_parallel_region = _ReduceScatterFromModelParallelRegion(dim=dim, group=group) + self.linear = mint.nn.Linear( + self.in_features_per_partition, + out_features, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if not self.input_is_parallel: + self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + if not self.input_is_parallel: + x = self.scatter_to_tensor_parallel_region(x) + x = self.linear.dense(x, self.linear.weight) + x = self.reduce_from_tensor_parallel_region(x) + if self.linear.bias is not None: + x = x + self.linear.bias + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + self.linear.bias.set_data(target.bias) diff --git a/examples/moviegen/moviegen/parallel/parallel_states.py b/examples/moviegen/moviegen/parallel/parallel_states.py new file mode 100644 index 0000000000..3effa239c3 --- /dev/null +++ b/examples/moviegen/moviegen/parallel/parallel_states.py @@ -0,0 +1,39 @@ +from typing import Optional + +from mindspore.communication import GlobalComm, create_group, get_group_size, get_rank + +__all__ = ["set_model_parallel_group", "get_model_parallel_group", "create_parallel_group"] + + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_model_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["model"] = group + + +def get_model_parallel_group() -> Optional[str]: + # TODO: change the default value to be None + return _GLOBAL_PARALLEL_GROUPS.get("model", GlobalComm.WORLD_COMM_GROUP) + + +def create_parallel_group(model_parallel_shards: int = 1) -> None: + if model_parallel_shards <= 1: + raise ValueError( + f"`model_parallel_shards` must be larger than 1 to enable model parallel, but get `{model_parallel_shards}`." + ) + + device_num = get_group_size() + if device_num % model_parallel_shards != 0: + raise ValueError( + f"Total number of devices ({device_num}) must be divisible by the number of model parallel shards ({model_parallel_shards})." + ) + + rank_id = get_rank() + + if model_parallel_shards > 1: + mp_group_id = rank_id // model_parallel_shards + mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) + mp_group_name = f"mp_group_{mp_group_id}" + create_group(mp_group_name, mp_group_rank_ids) + set_model_parallel_group(mp_group_name) diff --git a/examples/moviegen/moviegen/pipelines/__init__.py b/examples/moviegen/moviegen/pipelines/__init__.py new file mode 100644 index 0000000000..8cf855d610 --- /dev/null +++ b/examples/moviegen/moviegen/pipelines/__init__.py @@ -0,0 +1 @@ +from .train_pipeline import DiffusionWithLoss diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py new file mode 100644 index 0000000000..7327039581 --- /dev/null +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -0,0 +1,59 @@ +from typing import Optional + +from mindspore import Tensor, nn, ops + +from ..schedulers.rectified_flow import RFlowScheduler + +__all__ = ["DiffusionWithLoss"] + + +class DiffusionWithLoss(nn.Cell): + def __init__( + self, + network: nn.Cell, + scheduler: RFlowScheduler, + vae: Optional[nn.Cell] = None, + text_encoder: Optional[nn.Cell] = None, + scale_factor: float = 0.18215, + text_emb_cached: bool = True, + video_emb_cached: bool = False, + ): + super().__init__() + + if not text_emb_cached and text_encoder is None: + raise ValueError("`text_encoder` must be provided when `text_emb_cached=False`.") + if not video_emb_cached and vae is None: + raise ValueError("`vae` must be provided when `video_emb_cached=False`.") + + self.network = network + self.vae = vae + self.scheduler = scheduler + self.text_encoder = text_encoder + self.scale_factor = scale_factor + self.text_emb_cached = text_emb_cached + self.video_emb_cached = video_emb_cached + + if self.vae is not None: + for param in self.vae.trainable_params(): + param.requires_grad = False + + if self.text_encoder is not None: + for param in self.text_encoder.trainable_params(): + param.requires_grad = False + + def get_condition_embeddings(self, text_tokens: Tensor) -> Tensor: + if self.text_emb_cached: + return text_tokens + text_emb = ops.stop_gradient(self.text_encoder(text_tokens)) + return text_emb + + def get_latents(self, video_tokens: Tensor) -> Tensor: + if self.video_emb_cached: + return video_tokens + video_emb = ops.stop_gradient(self.vae.encode(video_tokens)) + return video_emb + + def construct(self, video_tokens: Tensor, text_tokens: Tensor) -> Tensor: + latent_embedding = self.get_latents(video_tokens) + text_embedding = self.get_condition_embeddings(text_tokens) + return self.scheduler.training_loss(self.network, latent_embedding, text_embedding) diff --git a/examples/moviegen/moviegen/schedulers/__init__.py b/examples/moviegen/moviegen/schedulers/__init__.py new file mode 100644 index 0000000000..d030f82972 --- /dev/null +++ b/examples/moviegen/moviegen/schedulers/__init__.py @@ -0,0 +1 @@ +from .rectified_flow import * diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py new file mode 100644 index 0000000000..b466487608 --- /dev/null +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -0,0 +1,147 @@ +import logging +from typing import Literal, Optional, Tuple + +import numpy as np +from tqdm import tqdm + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import Tensor, mint, nn, ops + +from ..models import LlamaModel + +logger = logging.getLogger(__name__) + +__all__ = ["RFLOW", "RFlowLossWrapper"] + + +class LogisticNormal(nn.Cell): + def __init__(self, loc: float = 0.0, scale: float = 1.0) -> None: + super().__init__() + self.mean = loc + self.std = scale + self._min = Tensor(np.finfo(np.float32).tiny, dtype=ms.float32) + self._max = Tensor(1.0 - np.finfo(np.float32).eps, dtype=ms.float32) + + def construct(self, shape: Tuple[int, ...]) -> Tensor: + assert shape[-1] == 1 + x = mint.normal(mean=self.mean, std=self.std, size=shape) + offset = x.shape[-1] + 1 - mint.cumsum(mint.ones(x.shape[-1]), dim=-1) + z = self._clipped_sigmoid(x - mint.log(offset)) + z_cumprod = ops.cumprod((1 - z), dim=-1) + y = F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) + return y[:, 0] + + def _clipped_sigmoid(self, x: Tensor) -> Tensor: + x = mint.clamp(mint.sigmoid(x), min=self._min, max=self._max) + return x + + +class RFLOW: + def __init__( + self, + num_sampling_steps: int = 50, + num_timesteps: int = 1000, + sample_method: Literal["linear", "linear-quadratic"] = "linear", + ) -> None: + self.num_sampling_steps = num_sampling_steps + self.num_timesteps = num_timesteps + self.sample_method = sample_method + + def __call__(self, model: nn.Cell, x: Tensor, text_embedding: Tensor) -> Tensor: + """ + x: (N, T, C, H, W) tensor of inputs (latent representations of video) + text_embedding: (N, L, C') tensor of the text embedding + """ + # prepare timesteps + if self.sample_method == "linear": + timesteps = (1.0 - np.arange(self.num_sampling_steps) / self.num_sampling_steps) * self.num_timesteps + else: + raise NotImplementedError("Not supported yet.") + + timesteps = np.tile(timesteps[None, ...], (x.shape[0], 1)) + timesteps = Tensor(timesteps, dtype=ms.int64) + + for i, timestep in tqdm(enumerate(timesteps), total=self.num_sampling_steps): + pred = model(x, timestep, text_embedding) + + # update z + dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] + dt = dt / self.num_timesteps + x = x + pred * dt[:, None, None, None, None] + + return x + + +class RFlowLossWrapper(nn.Cell): + """Wrapper for calculating the training loss""" + + def __init__( + self, + model: LlamaModel, + num_timesteps: int = 1000, + sample_method: Literal["discrete-uniform", "uniform", "logit-normal"] = "logit-normal", + loc: float = 0.0, + scale: float = 1.0, + eps: float = 1e-5, + ) -> None: + super().__init__(auto_prefix=False) + self.num_timesteps = num_timesteps + self.eps = eps + + if sample_method == "discrete-uniform": + self._sample_func = self._discrete_sample + elif sample_method == "uniform": + self._sample_func = self._uniform_sample + elif sample_method == "logit-normal": + self.distribution = LogisticNormal(loc=loc, scale=scale) + self._sample_func = self._logit_normal_sample + else: + raise ValueError(f"Unknown sample method: {sample_method}") + + self.model = model + self.criteria = nn.MSELoss() + + def _discrete_sample(self, size: int) -> Tensor: + return ops.randint(0, self.num_timesteps, (size,), dtype=ms.int64) + + def _uniform_sample(self, size: int) -> Tensor: + return mint.rand((size,), dtype=ms.float32) * self.num_timesteps + + def _logit_normal_sample(self, size: int) -> Tensor: + return self.distribution((size, 1)) * self.num_timesteps + + def construct(self, x: Tensor, text_embedding: Tensor, timestep: Optional[Tensor] = None) -> Tensor: + """Calculate the training loss for the corresponding timestep. + x: (N, T, C, H, W) tensor of inputs (latent representations of video) + text_embedding: (N, L, C') tensor of the text embedding + timestep: (N,) tensor to indicate denoising step + """ + x = x.to(ms.float32) + + if timestep is None: + timestep = self._sample_func(x.shape[0]) + + noise = mint.normal(size=x.shape) + x_t = self.add_noise(x, noise, timestep) + + model_output = self.model(x_t.to(self.model.dtype), timestep, text_embedding.to(self.model.dtype)).to( + ms.float32 + ) + v_t = x - (1 - self.eps) * noise + + # 3.1.2 Eqa (2) + loss = self.criteria(model_output, v_t) + return loss + + def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: + """ + x: (N, T, C, H, W) tensor of ground truth + noise: (N, T, C, H, W) tensor of white noise + timesteps: (N,) tensor of timestamps with range [0, num_timesteps) + """ + timesteps = 1 - timesteps.to(ms.float32) / self.num_timesteps + timesteps = timesteps[:, None, None, None, None] + + # 3.1.2 First Eqa. + return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh new file mode 100755 index 0000000000..b532dad534 --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh new file mode 100755 index 0000000000..603aac9fce --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_block_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_block.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh new file mode 100755 index 0000000000..ecf23ff9a8 --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_layer_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_layer.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py new file mode 100644 index 0000000000..59b47dd951 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel.py @@ -0,0 +1,107 @@ +import argparse +from typing import Tuple + +import numpy as np +from moviegen.models.llama.network import LlamaModel +from moviegen.parallel import create_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, Tensor, Tensor]: + latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) + timestep = ms.Tensor([35], dtype=ms.int64) + text_embedding = ops.rand([1, 64, 4096], dtype=dtype) + return latent_embedding, timestep, text_embedding + + +def get_network_config(model_parallelism=False, fused_tensor_parallel=False): + config = dict( + num_hidden_layers=2, + attn_implementation="eager", + model_parallelism=model_parallelism, + fused_tensor_parallel=fused_tensor_parallel, + post_init_weight=False, + ) + return config + + +def run_network(mode: int = 0, fused_tensor_parallel: bool = False, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + # non parallel network + set_random_seed(1024) + non_parallel_network_cfg = get_network_config(model_parallelism=False, fused_tensor_parallel=fused_tensor_parallel) + non_parallel_network = LlamaModel(**non_parallel_network_cfg, dtype=dtype) + + # parallel netowrk + parallel_network_cfg = get_network_config(model_parallelism=True, fused_tensor_parallel=fused_tensor_parallel) + parallel_network = LlamaModel(**parallel_network_cfg, dtype=dtype) + + # load weight + parallel_network.load_weight_from_non_parallel_cell(non_parallel_network) + + # test forward + non_parallel_out = non_parallel_network(*data).asnumpy() + parallel_out = parallel_network(*data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.", flush=True) + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_network) + parallel_mean_net = MeanNet(parallel_network) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(*data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(*data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=2e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + print("Non-fused tensor parallel:", flush=True) + run_network(mode=args.mode, fused_tensor_parallel=False) + + print("Fused tensor parallel:", flush=True) + run_network(mode=args.mode, fused_tensor_parallel=True) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py new file mode 100644 index 0000000000..f9b3a765a8 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py @@ -0,0 +1,107 @@ +import argparse + +import numpy as np +from moviegen.models.llama.block import LlamaMLP, TensorParallelLlamaMLP +from moviegen.parallel import create_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: + x = ops.rand([4, 64, 3072], dtype=dtype) # (N, T, H) + return x + + +def get_block_config(): + config = dict(intermediate_size=8192, hidden_size=3072, hidden_act="silu") + return config + + +def run_block(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + # non parallel block + set_random_seed(1024) + non_parallel_block_cfg = get_block_config() + non_parallel_block = LlamaMLP(**non_parallel_block_cfg, dtype=dtype) + + # parallel block + parallel_block_cfg = get_block_config() + parallel_block = TensorParallelLlamaMLP(**parallel_block_cfg, dtype=dtype) + + # load weight + parallel_block.load_weight_from_non_parallel_cell(non_parallel_block) + + # test forward + non_parallel_out = non_parallel_block(data).asnumpy() + parallel_out = parallel_block(data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.") + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_block) + parallel_mean_net = MeanNet(parallel_block) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.") + + # check the input gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=0) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 3 (Backward: Input Gradient): Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_block(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py new file mode 100644 index 0000000000..a2e35a0576 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py @@ -0,0 +1,125 @@ +import argparse +from typing import Literal + +import numpy as np +from moviegen.parallel import ColumnParallelLinear, RowParallelLinear, create_parallel_group, get_model_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: + x = ops.rand([4, 64, 256], dtype=dtype) # (N, T, H) + return x + + +def get_layer_config(bias: bool = False): + config = dict(in_features=256, out_features=32, bias=bias) + return config + + +def run_layer(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + print("Column Parallel Linear (Bias=True):") + run_parallel_linear(data, type="column_parallel", bias=True, dtype=dtype) + print("Column Parallel Linear (Bias=False):") + run_parallel_linear(data, type="column_parallel", bias=False, dtype=dtype) + print("Row Parallel Linear (Bias=True):") + run_parallel_linear(data, type="row_parallel", bias=True, dtype=dtype) + print("Row Parallel Linear (Bias=False):") + run_parallel_linear(data, type="row_parallel", bias=False, dtype=dtype) + + +def run_parallel_linear( + data: Tensor, type: Literal["column_parallel", "row_parallel"], bias: bool = False, dtype: ms.Type = ms.float32 +): + # non parallel layer + set_random_seed(1024) + non_parallel_layer_cfg = get_layer_config(bias=bias) + non_parallel_layer = mint.nn.Linear(**non_parallel_layer_cfg, dtype=dtype) + + # parallel layer + group = get_model_parallel_group() + parallel_layer_cfg = get_layer_config(bias=bias) + if type == "column_parallel": + parallel_layer = ColumnParallelLinear(**parallel_layer_cfg, gather_output=True, group=group, dtype=dtype) + else: + parallel_layer = RowParallelLinear(**parallel_layer_cfg, input_is_parallel=False, group=group, dtype=dtype) + + # load weight + parallel_layer.load_weight_from_non_parallel_cell(non_parallel_layer) + + # test forward + non_parallel_out = non_parallel_layer(data).asnumpy() + parallel_out = parallel_layer(data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.") + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_layer) + parallel_mean_net = MeanNet(parallel_layer) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.") + + # check the input gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=0) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 3 (Backward: Input Gradient): Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_layer(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/utils.py b/examples/moviegen/tests/parallel/utils.py new file mode 100644 index 0000000000..2f8d19e2d5 --- /dev/null +++ b/examples/moviegen/tests/parallel/utils.py @@ -0,0 +1,32 @@ +from typing import Callable, Tuple + +import numpy as np + +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import GlobalComm, get_group_size + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +def gather_or_reduce_parallel_gradient( + parallel_gradient: Tensor, non_parallel_gradient_shape: Tuple[int, ...], group: str = GlobalComm.WORLD_COMM_GROUP +) -> Tensor: + if parallel_gradient.shape == non_parallel_gradient_shape: + # Sequence Parallel / Context Parallel + allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + parallel_gradient = allreduce(parallel_gradient) / get_group_size(group) + return parallel_gradient + + scales = np.array(non_parallel_gradient_shape) / np.array(parallel_gradient.shape) + assert np.count_nonzero(scales - 1) == 1 + assert np.prod(scales) == get_group_size(group) + dim = np.argmax(scales).item() + allgather = ops.AllGather(group=group) + parallel_gradient = _communicate_along_dim(parallel_gradient, dim, allgather) + return parallel_gradient diff --git a/examples/moviegen/tests/ut/test_byt5_pynative.py b/examples/moviegen/tests/ut/test_byt5_pynative.py new file mode 100644 index 0000000000..79cca8d623 --- /dev/null +++ b/examples/moviegen/tests/ut/test_byt5_pynative.py @@ -0,0 +1,85 @@ +import os +import sys + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer +from transformers import T5EncoderModel as T5EncoderModel_PyTorch + +import mindspore as ms + +# FIXME: remove in future when mindone is ready for install +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) +from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel, T5LayerNorm + +ms.set_context(mode=ms.PYNATIVE_MODE) + +fp32_tolerance = 1e-4 +fp16_tolerance = 2e-2 +bf16_tolerance = 2e-1 + +test_samples = [ + "Life is like a box of chocolates.", + "La vie est comme une boîte de chocolat.", + "Today is Monday.", + "Aujourd'hui c'est lundi.", +] + +tokenizer = AutoTokenizer.from_pretrained("google/byt5-small", local_files_only=True) +test_samples = tokenizer(test_samples, padding="longest", return_tensors="np") + + +@pytest.fixture(scope="function") +def byt5_pt(): + return T5EncoderModel_PyTorch.from_pretrained("google/byt5-small", local_files_only=True) + + +@pytest.fixture(scope="function") +def byt5_ms(): + return T5EncoderModel.from_pretrained("google/byt5-small", local_files_only=True) + + +def test_fp32(byt5_ms, byt5_pt): + # set models precision + byt5_pt.to(torch.float32) + + ms_enc = byt5_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].asnumpy().astype(np.float32) + pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().numpy().astype(np.float32) + assert np.allclose(ms_enc, pt_enc, atol=fp32_tolerance, rtol=0) + + +def test_fp16(byt5_ms, byt5_pt): + # set models precision + byt5_ms = ms.amp.custom_mixed_precision( + byt5_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.float16 + ) + byt5_pt.to(torch.float16) + + ms_enc = byt5_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].asnumpy().astype(np.float32) + pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().numpy().astype(np.float32) + assert np.allclose(ms_enc, pt_enc, atol=fp16_tolerance, rtol=0) + + +def test_bf16(byt5_ms, byt5_pt): + # set models precision + byt5_ms = ms.amp.custom_mixed_precision( + byt5_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.bfloat16 + ) + byt5_pt.to(torch.bfloat16) + + ms_enc = byt5_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].astype(ms.float32).asnumpy() + pt_enc = byt5_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().to(torch.float32).numpy() + assert np.allclose(ms_enc, pt_enc, atol=bf16_tolerance, rtol=0) diff --git a/examples/moviegen/tests/ut/test_llama3_forward.py b/examples/moviegen/tests/ut/test_llama3_forward.py new file mode 100644 index 0000000000..f582962559 --- /dev/null +++ b/examples/moviegen/tests/ut/test_llama3_forward.py @@ -0,0 +1,16 @@ +import numpy as np +from moviegen import llama3_1B + +import mindspore as ms + + +def test_llama3_forward_graph(): + ms.set_context(mode=ms.GRAPH_MODE) + network = llama3_1B(attn_implementation="flash_attention", dtype=ms.bfloat16) + + latent_embedding = ms.Tensor(np.ones((1, 16, 8, 24, 44)), dtype=ms.bfloat16) + timestep = ms.Tensor([35], dtype=ms.int64) + text_embedding = ms.Tensor(np.ones((1, 64, 4096)), dtype=ms.bfloat16) + outputs = network(latent_embedding, timestep, text_embedding) + + assert outputs.shape == (1, 16, 8, 24, 44) diff --git a/examples/moviegen/tests/ut/test_rflow.py b/examples/moviegen/tests/ut/test_rflow.py new file mode 100644 index 0000000000..a5bc12da89 --- /dev/null +++ b/examples/moviegen/tests/ut/test_rflow.py @@ -0,0 +1,27 @@ +import numpy as np +from moviegen.schedulers import RFlowLossWrapper + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +class SimpleBF16Net(nn.Cell): + def construct(self, x: Tensor, timestamp: Tensor, text_embedding: Tensor): + return x.to(ms.bfloat16) + + @property + def dtype(self): + return ms.bfloat16 + + +def test_rflow_loss(): + ms.set_context(mode=ms.GRAPH_MODE) + network = RFlowLossWrapper( + SimpleBF16Net(), num_timesteps=1000, sample_method="logit-normal", loc=0.0, scale=1.0, eps=1e-5 + ) + + latent_embedding = ms.Tensor(np.ones((2, 16, 8, 24, 44)), dtype=ms.bfloat16) + text_embedding = ms.Tensor(np.ones((2, 64, 4096)), dtype=ms.bfloat16) + loss = network(latent_embedding, text_embedding).item() + assert loss > 0 diff --git a/examples/moviegen/tests/ut/test_ul2_pynative.py b/examples/moviegen/tests/ut/test_ul2_pynative.py new file mode 100644 index 0000000000..6e03e87942 --- /dev/null +++ b/examples/moviegen/tests/ut/test_ul2_pynative.py @@ -0,0 +1,83 @@ +import os +import sys + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer +from transformers import T5EncoderModel as T5EncoderModel_PyTorch + +import mindspore as ms + +# FIXME: remove in future when mindone is ready for install +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) +from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel, T5LayerNorm + +ms.set_context(mode=ms.PYNATIVE_MODE) + +fp32_tolerance = 1e-4 +fp16_tolerance = 2e-2 +bf16_tolerance = 2e-1 + +test_samples = [ + "Life is like a box of chocolates.", + "La vie est comme une boîte de chocolat.", + "Today is Monday.", + "Aujourd'hui c'est lundi.", +] + +tokenizer = AutoTokenizer.from_pretrained("google/ul2", local_files_only=True) +test_samples = tokenizer(test_samples, padding="max_length", return_tensors="np") + + +@pytest.fixture(scope="function") +def ul2_pt(): + return T5EncoderModel_PyTorch.from_pretrained("google/ul2", local_files_only=True) + + +@pytest.fixture(scope="function") +def ul2_ms(): + return T5EncoderModel.from_pretrained("google/ul2", local_files_only=True) + + +def test_fp32(ul2_ms, ul2_pt): + # set models precision + ul2_pt.to(torch.float32) + + ms_enc = ul2_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].asnumpy().astype(np.float32) + pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().numpy().astype(np.float32) + assert np.allclose(ms_enc, pt_enc, atol=fp32_tolerance, rtol=0) + + +def test_fp16(ul2_ms, ul2_pt): + # set models precision + ul2_ms = ms.amp.custom_mixed_precision(ul2_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.float16) + ul2_pt.to(torch.float16) + + ms_enc = ul2_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].asnumpy().astype(np.float32) + pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().numpy().astype(np.float32) + assert np.allclose(ms_enc, pt_enc, atol=fp16_tolerance, rtol=0) + + +def test_bf16(ul2_ms, ul2_pt): + # set models precision + ul2_ms = ms.amp.custom_mixed_precision( + ul2_ms, black_list=ms.amp.get_black_list() + [T5LayerNorm], dtype=ms.bfloat16 + ) + ul2_pt.to(torch.bfloat16) + + ms_enc = ul2_ms( + ms.Tensor(test_samples.input_ids, dtype=ms.int32), ms.Tensor(test_samples.attention_mask, dtype=ms.uint8) + ) + ms_enc = ms_enc[0].astype(ms.float32).asnumpy() + pt_enc = ul2_pt(torch.tensor(test_samples.input_ids), torch.tensor(test_samples.attention_mask), return_dict=False) + pt_enc = pt_enc[0].detach().to(torch.float32).numpy() + assert np.allclose(ms_enc, pt_enc, atol=bf16_tolerance, rtol=0) diff --git a/examples/moviegen/tools/download_convert_st.py b/examples/moviegen/tools/download_convert_st.py new file mode 100644 index 0000000000..47684496e6 --- /dev/null +++ b/examples/moviegen/tools/download_convert_st.py @@ -0,0 +1,323 @@ +""" +Modified from +https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py +""" +import argparse +import json +import os +from collections import defaultdict +from typing import Dict, List, Optional, Set + +import requests +import torch +from huggingface_hub import HfApi, configure_http_backend, hf_hub_download +from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file + + +def backend_factory() -> requests.Session: + session = requests.Session() + session.verify = False + return session + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model" + " since you could be storing much more memory than needed." + " Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def get_discard_names( + model_id: str, revision: Optional[str], folder: str, token: Optional[str], endpoint: str +) -> List[str]: + try: + import json + + import transformers + + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json", token=token, cache_dir=folder, endpoint=endpoint + ) + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] + + class_ = getattr(transformers, architecture) + + # Name for this variable depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + + except Exception: + discard_names = [] + return discard_names + + +class AlreadyExists(Exception): + pass + + +def check_file_size(sf_filename: str, pt_filename: str): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """ + ) + + +def rename(pt_filename: str) -> str: + filename, ext = os.path.splitext(pt_filename) + local = f"{filename}.safetensors" + local = local.replace("pytorch_model", "model") + return local + + +def convert_multi( + model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str], endpoint: str +) -> str: + filename = hf_hub_download( + repo_id=model_id, + revision=revision, + filename="pytorch_model.bin.index.json", + token=token, + cache_dir=folder, + endpoint=endpoint, + ) + save_path = os.path.dirname(filename) + with open(filename, "r") as f: + data = json.load(f) + + filenames = set(data["weight_map"].values()) + for filename in filenames: + pt_filename = hf_hub_download( + repo_id=model_id, filename=filename, token=token, cache_dir=folder, endpoint=endpoint + ) + sf_filename = rename(pt_filename) + sf_filename = os.path.join(save_path, sf_filename) + convert_file(pt_filename, sf_filename, discard_names=discard_names) + + index = os.path.join(save_path, "model.safetensors.index.json") + with open(index, "w") as f: + newdata = {k: v for k, v in data.items()} + newmap = {k: rename(v) for k, v in data["weight_map"].items()} + newdata["weight_map"] = newmap + json.dump(newdata, f, indent=4) + + return save_path + + +def convert_single( + model_id: str, + *, + revision: Optional[str], + folder: str, + token: Optional[str], + discard_names: List[str], + endpoint: str, +) -> str: + pt_filename = hf_hub_download( + repo_id=model_id, + revision=revision, + filename="pytorch_model.bin", + token=token, + cache_dir=folder, + endpoint=endpoint, + ) + save_path = os.path.dirname(pt_filename) + sf_name = "model.safetensors" + sf_filename = os.path.join(save_path, sf_name) + convert_file(pt_filename, sf_filename, discard_names) + return save_path + + +def convert_file( + pt_filename: str, + sf_filename: str, + discard_names: List[str], +): + loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata=metadata) + check_file_size(sf_filename, pt_filename) + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_generic( + model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str], endpoint: str +) -> str: + save_path = "" + extensions = {".bin", ".ckpt"} + for filename in filenames: + prefix, ext = os.path.splitext(filename) + if ext in extensions: + pt_filename = hf_hub_download( + model_id, revision=revision, filename=filename, token=token, cache_dir=folder, endpoint=endpoint + ) + save_path = os.path.dirname(pt_filename) + + dirname, raw_filename = os.path.split(filename) + if raw_filename == "pytorch_model.bin": + # XXX: This is a special case to handle `transformers` and the + # `transformers` part of the model which is actually loaded by `transformers`. + sf_in_repo = os.path.join(dirname, "model.safetensors") + else: + sf_in_repo = f"{prefix}.safetensors" + sf_filename = os.path.join(save_path, sf_in_repo) + convert_file(pt_filename, sf_filename, discard_names=[]) + return save_path + + +def convert( + model_id: str, + revision: Optional[str] = None, + folder: str = None, + force: bool = False, + endpoint: str = "https://hf-mirror.com", +) -> str: + api = HfApi(endpoint=endpoint) + info = api.model_info(model_id, revision=revision) + filenames = set(s.rfilename for s in info.siblings) + + library_name = getattr(info, "library_name", None) + if any(filename.endswith(".safetensors") for filename in filenames) and not force: + raise AlreadyExists(f"Model {model_id} is already converted, skipping..") + elif library_name == "transformers": + discard_names = get_discard_names( + model_id, revision=revision, folder=folder, token=api.token, endpoint=endpoint + ) + if "pytorch_model.bin" in filenames: + save_path = convert_single( + model_id, + revision=revision, + folder=folder, + token=api.token, + discard_names=discard_names, + endpoint=endpoint, + ) + elif "pytorch_model.bin.index.json" in filenames: + save_path = convert_multi( + model_id, + revision=revision, + folder=folder, + token=api.token, + discard_names=discard_names, + endpoint=endpoint, + ) + else: + raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") + else: + save_path = convert_generic( + model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, endpoint=endpoint + ) + return save_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Downloads and converts weights to `safetensors` format.") + parser.add_argument( + "model_id", + type=str, + help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", + ) + parser.add_argument( + "--revision", + type=str, + help="The revision to convert", + ) + parser.add_argument( + "--output_dir", + type=str, + help="The output directory to download and save the converted model", + ) + parser.add_argument( + "--endpoint", + type=str, + default="https://hf-mirror.com", + help="The Huggingface endpoint to use. Defaults to `https://hf-mirror.com`.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force weights re-conversion.", + ) + parser.add_argument( + "--disable_ssl_verify", + action="store_true", + help="Disable SSL verification when downloading the model weights.", + ) + + args = parser.parse_args() + if args.disable_ssl_verify: + configure_http_backend(backend_factory=backend_factory) + + path = convert( + args.model_id, revision=args.revision, folder=args.output_dir, force=args.force, endpoint=args.endpoint + ) + print(f"Converted weights saved to {os.path.dirname(os.path.dirname(path))}") diff --git a/examples/opensora_hpcai/opensora/acceleration/parallel_states.py b/examples/opensora_hpcai/opensora/acceleration/parallel_states.py index c60b9c3932..b7e5f1e7f3 100644 --- a/examples/opensora_hpcai/opensora/acceleration/parallel_states.py +++ b/examples/opensora_hpcai/opensora/acceleration/parallel_states.py @@ -13,23 +13,41 @@ def get_sequence_parallel_group() -> Optional[str]: return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) -def create_parallel_group(sequence_parallel_shards: int) -> None: - if sequence_parallel_shards <= 1: +def set_model_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["model"] = group + + +def get_model_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("model", None) + + +def create_parallel_group(sequence_parallel_shards: int = 1, model_parallel_shards: int = 1) -> None: + if sequence_parallel_shards <= 1 and model_parallel_shards <= 1: raise ValueError( - f"`sequence_parallel_shards` must be larger than 1 to enable sequence parallel, but get `{sequence_parallel_shards}`." + f"`sequence_parallel_shards`/`model_parallel_shards` must be larger than 1 " + f"to enable sequence/model parallel, but get `{sequence_parallel_shards}` and `{model_parallel_shards}`." ) device_num = get_group_size() - if device_num % sequence_parallel_shards != 0: + if device_num % sequence_parallel_shards != 0 or device_num % model_parallel_shards != 0: raise ValueError( - f"Total number of devices ({device_num}) must be devisible by the number of sequence parallel shards ({sequence_parallel_shards})." + f"Total number of devices ({device_num}) must be divisible by the number of " + f"sequence parallel shards ({sequence_parallel_shards}) and model parallel shards ({model_parallel_shards})." ) rank_id = get_rank() - sp_group_id = rank_id // sequence_parallel_shards - sp_group_rank_ids = list( - range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) - ) - sp_group_name = f"sp_group_{sp_group_id}" - create_group(sp_group_name, sp_group_rank_ids) - set_sequence_parallel_group(sp_group_name) + + if sequence_parallel_shards > 1: + sp_group_id = rank_id // sequence_parallel_shards + sp_group_rank_ids = list( + range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) + ) + sp_group_name = f"sp_group_{sp_group_id}" + create_group(sp_group_name, sp_group_rank_ids) + set_sequence_parallel_group(sp_group_name) + elif model_parallel_shards > 1: # not compatible with SP currently + mp_group_id = rank_id // model_parallel_shards + mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) + mp_group_name = f"mp_group_{mp_group_id}" + create_group(mp_group_name, mp_group_rank_ids) + set_model_parallel_group(mp_group_name) diff --git a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py index 8c91f917c0..0cf7de0dcf 100644 --- a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py +++ b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py @@ -6,7 +6,7 @@ import sys from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cv2 import numpy as np @@ -62,7 +62,7 @@ def __init__( self, csv_path: str, video_folder: str, - text_emb_folder: Optional[str] = None, + text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, vae_latent_folder: Optional[str] = None, vae_downsample_rate: float = 8.0, vae_scale_factor: float = 0.18215, @@ -156,7 +156,7 @@ def __init__( def _read_data( data_dir: str, csv_path: str, - text_emb_folder: Optional[str] = None, + text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, vae_latent_folder: Optional[str] = None, filter_data: bool = False, ) -> List[dict]: @@ -178,7 +178,13 @@ def _filter_data(sample_): for item in csv.DictReader(csv_file): sample = {**item, "video": os.path.join(data_dir, item["video"])} if text_emb_folder: - sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz")) + if isinstance(text_emb_folder, str): + sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz")) + else: + sample["text_emb"] = { + name: os.path.join(path, Path(item["video"]).with_suffix(".npz")) + for name, path in text_emb_folder.items() + } if vae_latent_folder: sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npz")) data.append(sample) @@ -217,9 +223,15 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: num_frames = self._frames if self._text_emb_folder: - with np.load(text_emb_path) as td: - data["caption"] = td["text_emb"] - data["mask"] = td["mask"].astype(np.uint8) + if isinstance(self._text_emb_folder, str): + with np.load(text_emb_path) as td: + data["caption"] = td["text_emb"] + data["mask"] = td["mask"].astype(np.uint8) + else: + for enc_name, path in text_emb_path.items(): + with np.load(path) as td: + data[enc_name + "_caption"] = td["text_emb"] + data[enc_name + "_mask"] = td["mask"].astype(np.uint8) if self._vae_latent_folder: # pick a resolution randomly if there are multi-resolution latents in vae folder diff --git a/examples/opensora_hpcai/opensora/models/stdit/__init__.py b/examples/opensora_hpcai/opensora/models/stdit/__init__.py index 7957e9000f..dc2c63cb06 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/__init__.py +++ b/examples/opensora_hpcai/opensora/models/stdit/__init__.py @@ -1,3 +1,4 @@ from .stdit import STDiT_XL_2 from .stdit2 import STDiT2_XL_2 from .stdit3 import STDiT3_3B_2, STDiT3_XL_2 +from .stdit_llama3 import STDiTLlama3Wrapper diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py new file mode 100644 index 0000000000..0eeb1883b3 --- /dev/null +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py @@ -0,0 +1,87 @@ +import os +from typing import Literal, Optional + +from moviegen import llama3_1B, llama3_5B, llama3_30B +from moviegen.models import TextProjector +from moviegen.models.llama.block import LlamaRMSNorm + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor, load_checkpoint, load_param_into_net + + +class STDiTLlama3Wrapper(nn.Cell): + def __init__(self, model_size: Literal["1B", "5B", "30B"] = "1B", **kwargs): + super().__init__(auto_prefix=False) + + attn_implementation = "flash_attention" if kwargs.get("enable_flashattn", False) else "eager" + gradient_checkpointing = kwargs.get("use_recompute", False) + model_parallelism = kwargs.get("enable_model_parallelism", False) + + model_kwargs = dict( + in_channels=4, + out_channels=8, + attn_implementation=attn_implementation, + gradient_checkpointing=gradient_checkpointing, + model_parallelism=model_parallelism, + ) + + if model_size == "1B": + self.llama = llama3_1B(**model_kwargs) + elif model_size == "5B": + self.llama = llama3_5B(**model_kwargs) + else: + self.llama = llama3_30B(**model_kwargs) + + self.text_projector = TextProjector( + out_features=self.llama.hidden_size, + layer_norm=LlamaRMSNorm, + norm_eps=self.llama.rms_norm_eps, + dtype=self.llama.dtype, + ) + + self.patch_size = self.llama.patch_size + self.hidden_size = self.llama.hidden_size + self.num_heads = self.llama.num_attention_heads + self.input_sq_size = None + self.in_channels = self.llama.in_channels + + def construct( + self, + x: Tensor, + timestep: Tensor, + y: Tensor, + mask: Optional[Tensor] = None, + frames_mask: Optional[Tensor] = None, + fps: Optional[Tensor] = None, + height: Optional[Tensor] = None, + width: Optional[Tensor] = None, + extra_text_embed1: Optional[Tensor] = None, + extra_mask1: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + x = ops.transpose(x, (0, 2, 1, 3, 4)) + + if extra_text_embed1 is not None: + y = ops.squeeze(y, axis=1) + # FIXME: placeholder for MetaCLIP + metaclip_text_embed = ops.ones((extra_text_embed1.shape[0], 100, 1280), dtype=extra_text_embed1.dtype) + text_embedding = self.text_projector(y, metaclip_text_embed, extra_text_embed1) + else: + text_embedding = ops.squeeze(y, axis=1) + + latent_embedding = x + output = self.llama(latent_embedding, timestep, text_embedding) + output = ops.transpose(output, (0, 2, 1, 3, 4)) + return output + + def load_from_checkpoint(self, ckpt_path): + if not os.path.exists(ckpt_path): + print(f"WARNING: {ckpt_path} not found. No checkpoint loaded!!") + else: + sd = load_checkpoint(ckpt_path) + sd = {k.replace("network.llama.", "").replace("_backbone.", ""): v for k, v in sd.items()} + + m, u = load_param_into_net(self, sd, strict_load=True) + print("net param not load: ", m, len(m)) + print("ckpt param not load: ", u, len(u)) diff --git a/examples/opensora_hpcai/opensora/models/text_encoder/t5.py b/examples/opensora_hpcai/opensora/models/text_encoder/t5.py index 6c79b68352..29609a98a6 100644 --- a/examples/opensora_hpcai/opensora/models/text_encoder/t5.py +++ b/examples/opensora_hpcai/opensora/models/text_encoder/t5.py @@ -3,6 +3,7 @@ import logging import os import re +import sys import urllib.parse as ul import ftfy @@ -14,6 +15,10 @@ from .flan_t5_large.t5 import get_t5_encoder +# FIXME: remove in future when mindone is ready for install +sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) +from mindone.transformers import T5EncoderModel + logger = logging.getLogger(__name__) @@ -228,7 +233,15 @@ def get_text_encoder_and_tokenizer(name, ckpt_path, **kwargs): logger.info("T5 init") text_encoder = T5Embedder(cache_dir=ckpt_path, pretrained_ckpt=os.path.join(ckpt_path, "model.ckpt"), **kwargs) tokenizer = text_encoder.tokenizer + elif name.lower() == "ul2": + logger.info("UL2 init") + tokenizer = AutoTokenizer.from_pretrained("google/ul2", local_files_only=True, cache_dir=ckpt_path) + text_encoder = T5EncoderModel.from_pretrained("google/ul2", local_files_only=True, cache_dir=ckpt_path) + elif name.lower() == "byt5": + logger.info("ByT5 init") + tokenizer = AutoTokenizer.from_pretrained("google/byt5-small", local_files_only=True, cache_dir=ckpt_path) + text_encoder = T5EncoderModel.from_pretrained("google/byt5-small", local_files_only=True, cache_dir=ckpt_path) else: - raise NotImplementedError + raise NotImplementedError(f"Unknown text encoder: {name}") return text_encoder, tokenizer diff --git a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py index 09d08dca4d..c123e9174d 100644 --- a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py @@ -128,9 +128,9 @@ def data_prepare(self, inputs): # for token/text drop in caption embedder for condition-free guidance training. The null mask is the same as text mask. n = x.shape[0] # (n_tokens, dim_emb) -> (b n_tokens dim_emb) - null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) if self.use_cfg: + null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) y = ops.cat([text_emb, null_emb], axis=0) x_in = ops.concat([x] * 2, axis=0) assert y.shape[0] == x_in.shape[0], "shape mismatch!" diff --git a/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py index b49f025afa..6b0fe50680 100644 --- a/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py @@ -125,6 +125,8 @@ def construct( width: Optional[Tensor] = None, fps: Optional[Tensor] = None, ar: Optional[Tensor] = None, + extra_text_tokens1: Optional[Tensor] = None, + extra_mask1: Optional[Tensor] = None, ): """ Video diffusion model forward and loss computation for training @@ -166,6 +168,8 @@ def construct( width=width, fps=fps, ar=ar, + extra_text_embed1=extra_text_tokens1, + extra_mask1=extra_mask1, ) return loss diff --git a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py index 8a481cde8c..25522c0b4f 100644 --- a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py +++ b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py @@ -87,6 +87,8 @@ def __call__( noise_added = mask_t_upper pred = model(z, t, **model_kwargs) + # FIXME: a tmp solution for inference with cfg==1.0 + pred = pred[:, :4] # update z dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index 4666b84af4..1dd144848d 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -40,7 +40,9 @@ def parse_train_args(parser): "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file" ) parser.add_argument("--video_folder", required=True, type=str, help="root dir for the video data") - parser.add_argument("--text_embed_folder", type=str, help="root dir for the text embeding data") + parser.add_argument("--text_embed_folder", type=str, help="root dir for the text embedding data") + parser.add_argument("--ul2_text_embed_folder", type=str, help="root dir for the text embedding data") + parser.add_argument("--byt5_text_embed_folder", type=str, help="root dir for the text embedding data") parser.add_argument("--vae_latent_folder", type=str, help="root dir for the vae latent data") parser.add_argument("--filter_data", default=False, type=str2bool, help="Filter non-existing videos.") parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") @@ -49,7 +51,11 @@ def parse_train_args(parser): ) # model parser.add_argument( - "--model_version", default="v1", type=str, choices=["v1", "v1.1"], help="OpenSora model version." + "--model_version", + default="v1", + type=str, + choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], + help="OpenSora model version.", ) parser.add_argument( "--pretrained_model_path", @@ -330,6 +336,18 @@ def parse_train_args(parser): type=int, help="The number of shards in sequence parallel. Default is 1.", ) + parser.add_argument( + "--enable_model_parallelism", + default=False, + type=str2bool, + help="whether to enable model parallelism. Default is False. Only for LLama3 strcture,", + ) + parser.add_argument( + "--model_parallel_shards", + default=1, + type=int, + help="The number of shards in model parallel. Default is 1.", + ) parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static") parser.add_argument( diff --git a/examples/opensora_hpcai/scripts/infer_t5.py b/examples/opensora_hpcai/scripts/infer_t5.py index 7a04f5a4de..974474e16d 100644 --- a/examples/opensora_hpcai/scripts/infer_t5.py +++ b/examples/opensora_hpcai/scripts/infer_t5.py @@ -22,6 +22,7 @@ from opensora.utils.cond_data import read_captions_from_csv, read_captions_from_txt from opensora.utils.model_utils import str2bool # _check_cfgs_in_parser +from mindone.transformers.models.t5.modeling_t5 import T5LayerNorm from mindone.utils.amp import auto_mixed_precision from mindone.utils.logger import set_logger from mindone.utils.misc import to_abspath @@ -127,15 +128,19 @@ def main(args): logger.info(f"Num batches: {dataset_size}") # model initiate and weight loading - ckpt_path = args.t5_model_dir - text_encoder, tokenizer = get_text_encoder_and_tokenizer("t5", ckpt_path, model_max_length=args.model_max_length) + ckpt_path = args.model_dir + text_encoder, tokenizer = get_text_encoder_and_tokenizer( + args.model, ckpt_path, model_max_length=args.model_max_length + ) text_encoder.set_train(False) for param in text_encoder.get_parameters(): # freeze latte_model param.requires_grad = False dtype_map = {"fp16": ms.float16, "bf16": ms.bfloat16} if args.dtype in ["fp16", "bf16"]: - text_encoder = auto_mixed_precision(text_encoder, amp_level=args.amp_level, dtype=dtype_map[args.dtype]) + text_encoder = auto_mixed_precision( + text_encoder, amp_level=args.amp_level, custom_fp32_cells=[T5LayerNorm], dtype=dtype_map[args.dtype] + ) # infer if args.csv_path is not None: @@ -155,8 +160,22 @@ def main(args): captions = [str(captions[i]) for i in range(len(captions))] # print(captions) - text_tokens, mask = text_encoder.get_text_tokens_and_mask(captions, return_tensor=True) - text_emb = text_encoder(text_tokens, mask) + if args.model.lower() == "t5": + text_tokens, mask = text_encoder.get_text_tokens_and_mask(captions, return_tensor=True) + text_emb = text_encoder(text_tokens, mask) + else: + text_tokens_and_mask = tokenizer( + captions, + max_length=args.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="np", + ) + text_tokens = ms.Tensor(text_tokens_and_mask["input_ids"], dtype=ms.int32) + mask = ms.Tensor(text_tokens_and_mask["attention_mask"], dtype=ms.float32) + text_emb = text_encoder(input_ids=text_tokens, attention_mask=mask)[0] end_time = time.time() time_cost = end_time - start_time @@ -199,8 +218,22 @@ def main(args): batch_prompts = captions[i : i + args.batch_size] ns = len(batch_prompts) - batch_text_tokens, batch_mask = text_encoder.get_text_tokens_and_mask(batch_prompts, return_tensor=True) - batch_text_emb = text_encoder(batch_text_tokens, batch_mask) + if args.model.lower() == "t5": + batch_text_tokens, batch_mask = text_encoder.get_text_tokens_and_mask(batch_prompts, return_tensor=True) + batch_text_emb = text_encoder(batch_text_tokens, batch_mask) + else: + text_tokens_and_mask = tokenizer( + batch_prompts, + max_length=args.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="np", + ) + batch_text_tokens = ms.Tensor(text_tokens_and_mask["input_ids"], dtype=ms.int32) + batch_mask = ms.Tensor(text_tokens_and_mask["attention_mask"], dtype=ms.float32) + batch_text_emb = text_encoder(input_ids=batch_text_tokens, attention_mask=batch_mask)[0] # save result batch_mask = batch_mask.asnumpy().astype(np.uint8) @@ -245,8 +278,9 @@ def parse_args(): help="output dir to save the embeddings, if None, will treat the parent dir of csv_path as output dir.", ) parser.add_argument("--caption_column", type=str, default="caption", help="caption column num in csv") - parser.add_argument("--t5_model_dir", default="models/t5-v1_1-xxl", type=str, help="the T5 cache folder path") - parser.add_argument("--model_max_length", type=int, default=120, help="T5's embedded sequence length.") + parser.add_argument("--model", default="t5", type=str, choices=["t5", "ul2", "byt5"], help="Name of the model.") + parser.add_argument("--model_dir", type=str, help="the T5 cache folder path") + parser.add_argument("--model_max_length", type=int, default=120, help="Model's embedded sequence length.") # MS new args parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") parser.add_argument("--mode", type=int, default=0, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1) (default=0)") @@ -304,7 +338,7 @@ def parse_args(): parser.set_defaults( **dict( captions=cfg["captions"], - t5_model_dir=cfg["t5_model_dir"], + model_dir=cfg["model_dir"], ) ) args = parser.parse_args() @@ -312,7 +346,7 @@ def parse_args(): args.csv_path = to_abspath(abs_path, args.csv_path) args.prompt_path = to_abspath(abs_path, args.prompt_path) args.output_path = to_abspath(abs_path, args.output_path) - args.t5_model_dir = to_abspath(abs_path, args.t5_model_dir) + args.model_dir = to_abspath(abs_path, args.model_dir) return args diff --git a/examples/opensora_hpcai/scripts/inference.py b/examples/opensora_hpcai/scripts/inference.py index 3e0defe0a1..f607d48d50 100644 --- a/examples/opensora_hpcai/scripts/inference.py +++ b/examples/opensora_hpcai/scripts/inference.py @@ -17,10 +17,11 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../moviegen"))) from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets.aspect import ASPECT_RATIO_MAP, ASPECT_RATIOS, get_image_size, get_num_frames -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper from opensora.models.text_encoder.t5 import get_text_encoder_and_tokenizer from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import InferPipeline, InferPipelineFiTLike @@ -163,16 +164,24 @@ def main(args): latent_condition_frame_length = round(latent_condition_frame_length / 17 * 5) captions = process_prompts(captions, args.loop) # in v1.1 and above, each loop can have a different caption + start_idx, end_idx = 0, len(captions) + if args.text_embed_folder: + end_idx = len(glob.glob(os.path.join(args.text_embed_folder, "*.npz"))) + elif args.ul2_text_embed_folder: + end_idx = len(glob.glob(os.path.join(args.ul2_text_embed_folder, "*.npz"))) if not args.enable_sequence_parallelism: # split samples to NPUs as even as possible - start_idx, end_idx = distribute_samples(len(captions), rank_id, device_num) - captions = captions[start_idx:end_idx] + start_idx, end_idx = distribute_samples(end_idx, rank_id, device_num) + if args.reference_path is not None: + args.reference_path = args.reference_path[start_idx:end_idx] + if args.mask_strategy is not None: + args.mask_strategy = args.mask_strategy[start_idx:end_idx] base_data_idx = start_idx else: base_data_idx = 0 if args.use_parallel and not args.enable_sequence_parallelism: - print(f"Num captions for rank {rank_id}: {len(captions)}") + print(f"Num captions for rank {rank_id}: {end_idx - start_idx}") # 2. model initiate and weight loading # 2.1 vae @@ -253,6 +262,12 @@ def main(args): model_extra_args["qk_norm"] = True logger.info(f"{model_name} init") latte_model = STDiT3_XL_2(**model_extra_args) + elif args.model_version == "llama3_1b": + model_name = "Llama3-1B" + latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) + elif args.model_version == "llama3_5b": + model_name = "Llama3-5B" + latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") @@ -280,10 +295,13 @@ def main(args): logger.warning(f"{model_name} uses random initialization!") # 2.3 text encoder - if args.text_embed_folder is None: + if not args.text_embed_folder and not (args.ul2_text_embed_folder and args.byt5_text_embed_folder): + if args.model_version in ["llama3_1b", "llama3_5b"]: + raise ValueError("UL2 and ByT5 text embedding folders are required for MovieGen.") text_encoder, tokenizer = get_text_encoder_and_tokenizer( "t5", args.t5_model_dir, model_max_length=args.model_max_length ) + captions = captions[start_idx:end_idx] num_prompts = len(captions) text_tokens, mask = zip( *[text_encoder.get_text_tokens_and_mask(caption, return_tensor=False) for caption in captions] @@ -301,28 +319,44 @@ def main(args): ) logger.info(f"Num tokens: {mask.asnumpy().sum(2)}") else: - assert not args.use_parallel, "parallel inference is not supported for t5 cached sampling currently." if args.model_version != "v1": logger.warning("For embedded captions, only one prompt per video is supported at this moment.") - embed_paths = sorted(glob.glob(os.path.join(args.text_embed_folder, "*.npz"))) - prompt_prefix = [] - text_tokens, mask, text_emb = [], [], [] - for fp in embed_paths: - prompt_prefix.append(os.path.basename(fp)[:-4]) - dat = np.load(fp) - text_tokens.append(dat["tokens"]) - mask.append(dat["mask"]) - text_emb.append(dat["text_emb"]) - text_tokens = np.concatenate(text_tokens) - mask = np.concatenate(mask) - text_emb = np.concatenate(text_emb) - logger.info(f"Num tokens: {mask.sum(1)}") + extra_embed_paths1 = None + if args.text_embed_folder: + assert args.model_version not in [ + "llama3_1b", + "llama3_5b", + ], "UL2 and ByT5 text embedding folders are required for MovieGen." + main_embed_paths = sorted(glob.glob(os.path.join(args.text_embed_folder, "*.npz")))[start_idx:end_idx] + elif args.ul2_text_embed_folder and args.byt5_text_embed_folder: + main_embed_paths = sorted(glob.glob(os.path.join(args.ul2_text_embed_folder, "*.npz")))[start_idx:end_idx] + extra_embed_paths1 = sorted(glob.glob(os.path.join(args.byt5_text_embed_folder, "*.npz")))[ + start_idx:end_idx + ] + else: + raise NotImplementedError("T5 or UL2 and ByT5 text embedding should be provided.") + + def read_embeddings(embed_paths): + prefix = [] + _mask, _text_emb = [], [] + for fp in embed_paths: + prefix.append(os.path.basename(fp)[:-4]) + with np.load(fp) as dat: + _mask.append(dat["mask"]) + _text_emb.append(dat["text_emb"]) + return ( + ms.Tensor(np.concatenate(_mask), dtype=ms.uint8), + ms.Tensor(np.concatenate(_text_emb), dtype=ms.float32), + prefix, + ) + mask, text_emb, prompt_prefix = read_embeddings(main_embed_paths) + extra_mask1, extra_text_emb1, _ = ( + read_embeddings(extra_embed_paths1) if extra_embed_paths1 else (None, None, None) + ) + logger.info(f"Num tokens: {mask.sum(1)}") num_prompts = text_emb.shape[0] - text_tokens = ms.Tensor(text_tokens) - mask = ms.Tensor(mask, dtype=ms.uint8) - text_emb = ms.Tensor(text_emb, dtype=ms.float32) text_encoder = None if (args.model_version == "v1" or args.reference_path is None) and num_prompts < 1: @@ -457,6 +491,9 @@ def main(args): inputs["text_tokens"] = None inputs["text_emb"] = text_emb[i : i + ns] inputs["mask"] = mask[i : i + ns] + if extra_text_emb1 is not None: + model_args["extra_text_embed1"] = extra_text_emb1[i : i + ns] + model_args["extra_mask1"] = extra_mask1[i : i + ns] logger.info("Sampling captions:") for j in range(ns): @@ -489,13 +526,13 @@ def main(args): # save result for j in range(ns): - global_idx = base_data_idx + i + j - if args.text_embed_folder is None: + if not args.text_embed_folder and not (args.ul2_text_embed_folder and args.byt5_text_embed_folder): + global_idx = base_data_idx + i + j prompt = "-".join((batch_prompts[j][0].replace("/", "").split(" ")[:10])) save_fp = f"{save_dir}/{global_idx:03d}-{prompt}.{args.save_format}" latent_save_fp = f"{latent_dir}/{global_idx:03d}-{prompt}.npy" else: - fn = prompt_prefix[global_idx] + fn = prompt_prefix[i + j] save_fp = f"{save_dir}/{fn}.{args.save_format}" latent_save_fp = f"{latent_dir}/{fn}.npy" @@ -520,7 +557,11 @@ def parse_args(): help="path to load a config yaml file that describes the setting which will override the default arguments", ) parser.add_argument( - "--model_version", default="v1", type=str, choices=["v1", "v1.1", "v1.2"], help="OpenSora model version." + "--model_version", + default="v1", + type=str, + choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], + help="OpenSora model version.", ) parser.add_argument("--image_size", type=int, nargs="+", help="image size in [256, 512]") parser.add_argument("--resolution", type=str, help=f"Supported video resolutions: {list(ASPECT_RATIOS.keys())}") @@ -696,6 +737,8 @@ def parse_args(): parser.add_argument("--fps", type=int, default=8, help="FPS in the saved video") parser.add_argument("--batch_size", default=4, type=int, help="infer batch size") parser.add_argument("--text_embed_folder", type=str, default=None, help="path to t5 embedding") + parser.add_argument("--ul2_text_embed_folder", type=str, help="path to ul2 embedding") + parser.add_argument("--byt5_text_embed_folder", type=str, help="path to byt5 embedding") parser.add_argument( "--save_latent", type=str2bool, diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index 1b4146c666..af14e1a155 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -22,11 +22,13 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../moviegen"))) + from args_train import parse_args from opensora.acceleration.parallel_states import create_parallel_group from opensora.datasets.aspect import ASPECT_RATIOS, get_image_size from opensora.models.layers.operation_selector import set_dynamic_mode -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import ( DiffusionWithLoss, @@ -69,7 +71,9 @@ def init_env( global_bf16: bool = False, dynamic_shape: bool = False, enable_sequence_parallelism: bool = False, + enable_model_parallelism: bool = False, sequence_parallel_shards: int = 1, + model_parallel_shards: int = 1, debug: bool = False, ) -> Tuple[int, int]: """ @@ -84,12 +88,16 @@ def init_env( """ set_random_seed(seed) - if enable_sequence_parallelism: + if enable_sequence_parallelism or enable_model_parallelism: if parallel_mode != "data" or not distributed: raise ValueError( - "sequence parallel can only be used in data parallel mode, " + "sequence parallel / tensor parallel can only be used in data parallel mode, " f"but get parallel_mode=`{parallel_mode}` with distributed=`{distributed}`." ) + if enable_sequence_parallelism and enable_model_parallelism: + raise ValueError( + "Cannot turn on sequence parallel (Non-Llama structure) / model paralell (Llama structure) in the same time." + ) if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging logger.warning("Debug mode is on, switching execution mode to PyNative.") @@ -126,9 +134,12 @@ def init_env( ) if enable_sequence_parallelism: - create_parallel_group(sequence_parallel_shards) + create_parallel_group(sequence_parallel_shards=sequence_parallel_shards) ms.set_auto_parallel_context(enable_alltoall=True) + if enable_model_parallelism: + create_parallel_group(model_parallel_shards=model_parallel_shards) + var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] logger.info(dict(zip(var_info, var_value))) @@ -186,7 +197,6 @@ def initialize_dataset( args, csv_path, video_folder, - text_embed_folder, vae_latent_folder, batch_size, img_h, @@ -204,7 +214,7 @@ def initialize_dataset( ds_config = dict( csv_path=csv_path, video_folder=video_folder, - text_emb_folder=text_embed_folder, + text_emb_folder=args.text_embed_folder, return_text_emb=True, vae_latent_folder=vae_latent_folder, return_vae_latent=args.train_with_vae_latent, @@ -255,6 +265,24 @@ def initialize_dataset( if args.pre_patchify: output_columns.extend(["spatial_pos", "spatial_mask", "temporal_pos", "temporal_mask"]) + text_embed_folder = {} + if args.text_embed_folder: + text_embed_folder["t5"] = args.text_embed_folder + if args.ul2_text_embed_folder: + text_embed_folder["ul2"] = args.ul2_text_embed_folder + if args.byt5_text_embed_folder: + text_embed_folder["byt5"] = args.byt5_text_embed_folder + + if not len(text_embed_folder): + text_embed_folder = None + elif len(text_embed_folder) == 1: + text_embed_folder = list(text_embed_folder.values())[0] + else: + # FIXME: hardcoding + output_columns[1] = "ul2_caption" + output_columns[2] = "ul2_mask" + output_columns.extend(["byt5_caption", "byt5_mask"]) + datasets = [ VideoDatasetRefactored( csv_path=csv_path, @@ -359,7 +387,9 @@ def main(args): global_bf16=args.global_bf16, dynamic_shape=(args.bucket_config is not None), enable_sequence_parallelism=args.enable_sequence_parallelism, + enable_model_parallelism=args.enable_model_parallelism, sequence_parallel_shards=args.sequence_parallel_shards, + model_parallel_shards=args.model_parallel_shards, debug=args.debug, ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) @@ -430,6 +460,7 @@ def main(args): manual_pad=args.manual_pad, enable_flashattn=args.enable_flash_attention, enable_sequence_parallelism=args.enable_sequence_parallelism, + enable_model_parallelism=args.enable_model_parallelism, use_recompute=args.use_recompute, num_recompute_blocks=args.num_recompute_blocks, ) @@ -455,6 +486,12 @@ def main(args): model_extra_args["qk_norm"] = True model_extra_args["freeze_y_embedder"] = args.freeze_y_embedder latte_model = STDiT3_XL_2(**model_extra_args) + elif args.model_version == "llama3_1b": + model_name = "Llama3-1B" + latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) + elif args.model_version == "llama3_5b": + model_name = "Llama3-5B" + latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") logger.info(f"{model_name} input size: {latent_size if args.bucket_config is None else 'Variable'}") @@ -545,6 +582,10 @@ def main(args): data_device_num = device_num // args.sequence_parallel_shards data_rank_id = rank_id // args.sequence_parallel_shards logger.info(f"Creating dataloader: ID={rank_id}, group={data_rank_id}, num_groups={data_device_num}") + elif args.enable_model_parallelism: + data_device_num = device_num // args.model_parallel_shards + data_rank_id = rank_id // args.model_parallel_shards + logger.info(f"Creating dataloader: ID={rank_id}, group={data_rank_id}, num_groups={data_device_num}") else: data_device_num = device_num data_rank_id = rank_id @@ -553,7 +594,6 @@ def main(args): args, args.csv_path, args.video_folder, - args.text_embed_folder, args.vae_latent_folder, args.batch_size, img_h, @@ -747,7 +787,10 @@ def main(args): logger.info( "As steps per epoch are inaccurate with bucket config, TimeMonitor is disabled. See result.log for the actual step time" ) - if rank_id == 0: + if rank_id == 0 or args.enable_model_parallelism: + if args.enable_model_parallelism: + ckpt_dir = os.path.join(ckpt_dir, f"rank_{rank_id}") + save_cb = EvalSaveCallback( network=latent_diffusion_with_loss.network, rank_id=rank_id, @@ -766,6 +809,8 @@ def main(args): record_lr=False, train_steps=args.train_steps, ) + + if rank_id == 0: rec_cb = PerfRecorderCallback( save_dir=args.output_path, file_name="result_val.log", diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 478904a149..8d5e1c5353 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -587,7 +587,19 @@ def prepare_train_network( is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel and zero_stage == 0: _logger.info("No need prepare train_network with zero.") - return network, optimizer + train_network = TrainOneStepWrapper( + network, + optimizer, + scale_sense=scale_sense, + ema=ema, + updates=updates, + drop_overflow_update=drop_overflow_update, + gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + clip_norm=clip_norm, + verbose=verbose, + ) + return train_network if zero_stage not in [0, 1, 2, 3]: raise ValueError("Not support zero_stage {zero_stage}") diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index b327d8709b..7d413cf204 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -1301,7 +1301,7 @@ def from_pretrained( state_dict = kwargs.pop("state_dict", None) from_tf = kwargs.pop("from_tf", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) + resume_download = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) use_auth_token = kwargs.pop("use_auth_token", None) diff --git a/mindone/transformers/models/t5/modeling_t5.py b/mindone/transformers/models/t5/modeling_t5.py index 0a72dc6374..2326c3446b 100644 --- a/mindone/transformers/models/t5/modeling_t5.py +++ b/mindone/transformers/models/t5/modeling_t5.py @@ -1072,7 +1072,7 @@ def construct( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = False, + return_dict: Optional[bool] = None, ) -> Union[Tuple[ms.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1099,6 +1099,7 @@ def construct( >>> logits = outputs[1] ```""" use_cache = use_cache if use_cache is not None else self.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: From 5b77353bc7f438ef854b0cb7f2704f819266bc80 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 29 Oct 2024 11:39:51 +0800 Subject: [PATCH 009/122] add parallel test case for scheduler and fix some minor bug --- .../moviegen/parallel/parallel_states.py | 5 +- .../moviegen/pipelines/train_pipeline.py | 10 ++-- .../moviegen/schedulers/rectified_flow.py | 18 +++++- .../tests/parallel/run_test_rflow_parallel.sh | 13 ++++ .../tests/parallel/test_rflow_parallel.py | 59 +++++++++++++++++++ 5 files changed, 94 insertions(+), 11 deletions(-) create mode 100755 examples/moviegen/tests/parallel/run_test_rflow_parallel.sh create mode 100644 examples/moviegen/tests/parallel/test_rflow_parallel.py diff --git a/examples/moviegen/moviegen/parallel/parallel_states.py b/examples/moviegen/moviegen/parallel/parallel_states.py index 3effa239c3..2a8d9c0a0c 100644 --- a/examples/moviegen/moviegen/parallel/parallel_states.py +++ b/examples/moviegen/moviegen/parallel/parallel_states.py @@ -1,6 +1,6 @@ from typing import Optional -from mindspore.communication import GlobalComm, create_group, get_group_size, get_rank +from mindspore.communication import create_group, get_group_size, get_rank __all__ = ["set_model_parallel_group", "get_model_parallel_group", "create_parallel_group"] @@ -13,8 +13,7 @@ def set_model_parallel_group(group: str) -> None: def get_model_parallel_group() -> Optional[str]: - # TODO: change the default value to be None - return _GLOBAL_PARALLEL_GROUPS.get("model", GlobalComm.WORLD_COMM_GROUP) + return _GLOBAL_PARALLEL_GROUPS.get("model", None) def create_parallel_group(model_parallel_shards: int = 1) -> None: diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index 7327039581..d1c2d83260 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -2,7 +2,7 @@ from mindspore import Tensor, nn, ops -from ..schedulers.rectified_flow import RFlowScheduler +from ..schedulers import RFlowLossWrapper __all__ = ["DiffusionWithLoss"] @@ -10,8 +10,7 @@ class DiffusionWithLoss(nn.Cell): def __init__( self, - network: nn.Cell, - scheduler: RFlowScheduler, + network: RFlowLossWrapper, vae: Optional[nn.Cell] = None, text_encoder: Optional[nn.Cell] = None, scale_factor: float = 0.18215, @@ -27,7 +26,6 @@ def __init__( self.network = network self.vae = vae - self.scheduler = scheduler self.text_encoder = text_encoder self.scale_factor = scale_factor self.text_emb_cached = text_emb_cached @@ -50,10 +48,10 @@ def get_condition_embeddings(self, text_tokens: Tensor) -> Tensor: def get_latents(self, video_tokens: Tensor) -> Tensor: if self.video_emb_cached: return video_tokens - video_emb = ops.stop_gradient(self.vae.encode(video_tokens)) + video_emb = ops.stop_gradient(self.vae.encode(video_tokens) * self.scale_factor) return video_emb def construct(self, video_tokens: Tensor, text_tokens: Tensor) -> Tensor: latent_embedding = self.get_latents(video_tokens) text_embedding = self.get_condition_embeddings(text_tokens) - return self.scheduler.training_loss(self.network, latent_embedding, text_embedding) + return self.network(latent_embedding, text_embedding) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index b466487608..234a86f63e 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -7,8 +7,10 @@ import mindspore as ms import mindspore.mint.nn.functional as F from mindspore import Tensor, mint, nn, ops +from mindspore.communication import get_rank from ..models import LlamaModel +from ..parallel import get_model_parallel_group logger = logging.getLogger(__name__) @@ -102,6 +104,13 @@ def __init__( self.model = model self.criteria = nn.MSELoss() + self.mp_group = get_model_parallel_group() + if self.mp_group is not None: + logging.info( + f"Broadcasting all random variables from rank (0) to current rank ({get_rank(self.mp_group)}) in group `{self.mp_group}`." + ) + self.broadcast = ops.Broadcast(0, group=self.mp_group) + def _discrete_sample(self, size: int) -> Tensor: return ops.randint(0, self.num_timesteps, (size,), dtype=ms.int64) @@ -111,6 +120,11 @@ def _uniform_sample(self, size: int) -> Tensor: def _logit_normal_sample(self, size: int) -> Tensor: return self.distribution((size, 1)) * self.num_timesteps + def _broadcast(self, x: Tensor) -> Tensor: + if self.mp_group is None: + return x + return self.broadcast((x,))[0] + def construct(self, x: Tensor, text_embedding: Tensor, timestep: Optional[Tensor] = None) -> Tensor: """Calculate the training loss for the corresponding timestep. x: (N, T, C, H, W) tensor of inputs (latent representations of video) @@ -120,9 +134,9 @@ def construct(self, x: Tensor, text_embedding: Tensor, timestep: Optional[Tensor x = x.to(ms.float32) if timestep is None: - timestep = self._sample_func(x.shape[0]) + timestep = self._broadcast(self._sample_func(x.shape[0])) - noise = mint.normal(size=x.shape) + noise = self._broadcast(mint.normal(size=x.shape)) x_t = self.add_noise(x, noise, timestep) model_output = self.model(x_t.to(self.model.dtype), timestep, text_embedding.to(self.model.dtype)).to( diff --git a/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh b/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh new file mode 100755 index 0000000000..88ad571cac --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_rflow_parallel_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_rflow_parallel.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/test_rflow_parallel.py b/examples/moviegen/tests/parallel/test_rflow_parallel.py new file mode 100644 index 0000000000..6ffd4b254b --- /dev/null +++ b/examples/moviegen/tests/parallel/test_rflow_parallel.py @@ -0,0 +1,59 @@ +import argparse +from typing import Tuple + +from moviegen.parallel import create_parallel_group +from moviegen.schedulers import RFlowLossWrapper + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class SimpleNet(nn.Cell): + def construct(self, x: Tensor, timestamp: Tensor, text_embedding: Tensor): + return x.to(ms.float32) + + @property + def dtype(self): + return ms.float32 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, Tensor]: + latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) + text_embedding = ops.rand([1, 64, 4096], dtype=dtype) + return latent_embedding, text_embedding + + +def run_network(mode: int = 0): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data() + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + model = SimpleNet() + + # parallel netowrk + network = RFlowLossWrapper(model) + + loss = network(*data) + loss = ops.AllGather()(ops.unsqueeze(loss, 0)).asnumpy() + assert loss[0] == loss[1], f"expected two elements to be same, but get `{loss}`." + print("Test 1: Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_network(mode=args.mode) From c9bb319de8e24a2249b6d5128d67397ca8259071 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 29 Oct 2024 14:16:39 +0800 Subject: [PATCH 010/122] add train script --- .../configs/train/moviegen-256x256-t2i.yaml | 27 ++ .../moviegen/moviegen/dataset/__init__.py | 2 + examples/moviegen/moviegen/dataset/image.py | 69 ++++ examples/moviegen/moviegen/dataset/video.py | 18 ++ .../moviegen/pipelines/train_pipeline.py | 2 +- examples/moviegen/moviegen/utils/__init__.py | 3 + examples/moviegen/moviegen/utils/callback.py | 137 ++++++++ examples/moviegen/moviegen/utils/misc.py | 56 ++++ .../moviegen/moviegen/utils/model_utils.py | 32 ++ examples/moviegen/train.py | 299 ++++++++++++++++++ 10 files changed, 644 insertions(+), 1 deletion(-) create mode 100644 examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml create mode 100644 examples/moviegen/moviegen/dataset/__init__.py create mode 100644 examples/moviegen/moviegen/dataset/image.py create mode 100644 examples/moviegen/moviegen/dataset/video.py create mode 100644 examples/moviegen/moviegen/utils/__init__.py create mode 100644 examples/moviegen/moviegen/utils/callback.py create mode 100644 examples/moviegen/moviegen/utils/misc.py create mode 100644 examples/moviegen/moviegen/utils/model_utils.py create mode 100644 examples/moviegen/train.py diff --git a/examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml b/examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml new file mode 100644 index 0000000000..219d29453b --- /dev/null +++ b/examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml @@ -0,0 +1,27 @@ +# model +model_version: llama-1B +batch_size: 64 +checkpoint: "models/PixArt-Sigma-XL-2-256x256.ckpt" +vae_root: "models/vae" +text_encoder_root: "models/text_encoder" +tokenizer_root: "models/tokenizer" +scale_factor: 0.13025 +enable_flash_attention: True +dtype: "bf16" + +# training hyper-parameters +epochs: 100 +scheduler: "constant" +start_learning_rate: 1.0e-4 +optim: "adamw" +weight_decay: 0.1 +loss_scaler_type: "static" +init_loss_scale: 1.0 +gradient_accumulation_steps: 1 +clip_grad: True +max_grad_norm: 1.0 +ckpt_save_interval: 1 +log_loss_interval: 1 +recompute: True +text_drop_prob: 0.2 +warmup_steps: 2000 diff --git a/examples/moviegen/moviegen/dataset/__init__.py b/examples/moviegen/moviegen/dataset/__init__.py new file mode 100644 index 0000000000..31b6c8daed --- /dev/null +++ b/examples/moviegen/moviegen/dataset/__init__.py @@ -0,0 +1,2 @@ +from .image import ImageDataset +from .video import VideoDataset diff --git a/examples/moviegen/moviegen/dataset/image.py b/examples/moviegen/moviegen/dataset/image.py new file mode 100644 index 0000000000..88b4610e99 --- /dev/null +++ b/examples/moviegen/moviegen/dataset/image.py @@ -0,0 +1,69 @@ +import json +import logging +import os +import random +from typing import Tuple + +import numpy as np +from PIL import Image +from transformers import AutoTokenizer + +from mindspore.dataset.transforms import Compose, vision + +logger = logging.getLogger(__name__) + + +class ImageDataset: + def __init__( + self, + json_path: str, + image_dir: str, + image_size: int, + tokenizer: AutoTokenizer, + text_drop_prob: float = 0.2, + ) -> None: + logger.info(f"loading annotations from `{json_path}`.") + with open(json_path, "r") as f: + self.dataset = json.load(f) + + self.length = len(self.dataset) + + self.image_dir = image_dir + self.tokenizer = tokenizer + self.text_drop_prob = text_drop_prob + self.interpolation_mode = vision.Inter.BILINEAR + self.transform = self.create_transform(image_size, self.interpolation_mode) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + record = self.dataset[idx] + image_path = os.path.join(self.image_dir, record["path"]) + + if random.random() < self.text_drop_prob: + text = "" + else: + text = record["prompt"] + + # process text + encoding = self.tokenizer(text, padding="max_length", truncation=True, return_tensors="np") + text_ids = encoding.input_ids[0] + + # process image + image = Image.open(image_path).convert("RGB") + + image = self.transform(image)[0] + image = np.expand_dims(image, axis=0) # 1, C, H, W + return image, text_ids + + @staticmethod + def create_transform(image_size: int, interpolation: vision.Inter) -> Compose: + return Compose( + [ + vision.Resize(image_size, interpolation=interpolation), + vision.CenterCrop(image_size), + vision.ToTensor(), + vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_hwc=False), + ] + ) diff --git a/examples/moviegen/moviegen/dataset/video.py b/examples/moviegen/moviegen/dataset/video.py new file mode 100644 index 0000000000..50e49d9128 --- /dev/null +++ b/examples/moviegen/moviegen/dataset/video.py @@ -0,0 +1,18 @@ +from typing import Tuple + +import numpy as np + + +class VideoDataset: + def __len__(self) -> int: + return NotImplementedError() + + def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Returns: + video/video caching + text embedding 1 + text embedding 1 + text embedding 1 + """ + raise NotImplementedError() diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index d1c2d83260..17ca52b490 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -13,7 +13,7 @@ def __init__( network: RFlowLossWrapper, vae: Optional[nn.Cell] = None, text_encoder: Optional[nn.Cell] = None, - scale_factor: float = 0.18215, + scale_factor: float = 0.13025, text_emb_cached: bool = True, video_emb_cached: bool = False, ): diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py new file mode 100644 index 0000000000..d980a5b445 --- /dev/null +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -0,0 +1,3 @@ +from .callback import * +from .misc import * +from .model_utils import * diff --git a/examples/moviegen/moviegen/utils/callback.py b/examples/moviegen/moviegen/utils/callback.py new file mode 100644 index 0000000000..0578deb7b7 --- /dev/null +++ b/examples/moviegen/moviegen/utils/callback.py @@ -0,0 +1,137 @@ +import logging +import os +import time +from typing import List, Optional + +import numpy as np + +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.train import Callback, RunContext + +from mindone.trainers.checkpoint import CheckpointManager + +__all__ = ["LossMonitor", "SaveCkptCallback", "TimeMonitor"] + +logger = logging.getLogger(__name__) + + +class LossMonitor(Callback): + def __init__(self, log_interval: int = 1, log_overflow: bool = True) -> None: + self.log_interval = log_interval + self.log_overflow = log_overflow + self.step_num = 0 + + def on_train_step_begin(self, run_context: RunContext) -> None: + self.step_num += 1 + + def on_train_epoch_end(self, run_context: RunContext) -> None: + self.step_num = 0 + + def on_train_step_end(self, run_context: RunContext) -> None: + cb_params = run_context.original_args() + cur_step = cb_params.cur_step_num + + if cur_step % self.log_interval == 0: + cur_lr = self._fetch_optimizer_lr(cb_params) + cur_loss = self._fetch_loss(cb_params) + cur_loss_scale = self._fetch_loss_scale(cb_params) + + logger.info( + "epoch: %d step: %d, lr: %.7f, loss: %.6f, loss scale: %d.", + cb_params.cur_epoch_num, + self.step_num, + cur_lr.item(), + cur_loss.item(), + cur_loss_scale.item(), + ) + + if self.log_overflow: + overflow = cb_params.net_outputs[1] + if overflow: + logger.warning(f"overflow detected in epoch {cb_params.cur_epoch_num} step {self.step_num}.") + + def _get_optimizer_from_cbp(self, cb_params): + if cb_params.optimizer is not None: + optimizer = cb_params.optimizer + elif cb_params.dataset_sink_mode: + optimizer = cb_params.train_network.network.optimizer + else: + optimizer = cb_params.train_network.optimizer + return optimizer + + def _fetch_loss_scale(self, cb_params) -> Tensor: + if cb_params.dataset_sink_mode: + return cb_params.train_network.network.scale_sense + else: + return cb_params.train_network.scale_sense + + def _fetch_optimizer_lr(self, cb_params) -> Tensor: + opt = self._get_optimizer_from_cbp(cb_params) + lr = opt.learning_rate + if opt.dynamic_lr: + lr = opt.learning_rate(ops.clip(opt.global_step - 1, min=0))[0] + return lr + + def _fetch_loss(self, cb_params) -> Tensor: + loss = cb_params.net_outputs[0] + return loss + + +class SaveCkptCallback(Callback): + def __init__( + self, + output_dir: str = "./output", + ckpt_max_keep: int = 5, + ckpt_save_interval: int = 1, + rank_id: Optional[int] = None, + ) -> None: + self.rank_id = 0 if rank_id is None else rank_id + if self.rank_id != 0: + return + + self.ckpt_save_interval = ckpt_save_interval + + ckpt_save_dir = os.path.join(output_dir, f"rank_{rank_id}") + if not os.path.isdir(ckpt_save_dir): + os.makedirs(ckpt_save_dir) + self.ckpt_manager = CheckpointManager(ckpt_save_dir, ckpt_save_policy="latest_k", k=ckpt_max_keep) + + def on_train_epoch_end(self, run_context: RunContext) -> None: + if self.rank_id != 0: + return + + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + epoch_num = cb_params.epoch_num + + if cur_epoch % self.ckpt_save_interval != 0 and cur_epoch != epoch_num: + return + + ckpt_name = f"epoch_{cur_epoch}.ckpt" + network = cb_params.train_network.network + self.ckpt_manager.save(network=network.trainable_params(), ckpt_name=ckpt_name) + + +class TimeMonitor(Callback): + def __init__(self) -> None: + self.epoch_start_time = 0 + self.step_start_time = 0 + self.durations: List[int] = list() + + def on_train_epoch_begin(self, run_context: RunContext) -> None: + self.epoch_start_time = time.time() + + def on_train_step_begin(self, run_context: RunContext) -> None: + self.step_start_time = time.time() + + def on_train_step_end(self, run_context: RunContext) -> None: + duration = time.time() - self.step_start_time + self.durations.append(duration) + + def on_train_epoch_end(self, run_context: RunContext) -> None: + epoch_duration = time.time() - self.epoch_start_time + avg_time = np.mean(self.durations) + self.durations = list() + logger.info(f"Total training time for single epoch: {epoch_duration:.3f} seconds") + logger.info(f"Average step time: {avg_time:.3f} seconds") diff --git a/examples/moviegen/moviegen/utils/misc.py b/examples/moviegen/moviegen/utils/misc.py new file mode 100644 index 0000000000..9d1f56984e --- /dev/null +++ b/examples/moviegen/moviegen/utils/misc.py @@ -0,0 +1,56 @@ +import argparse +import logging +from typing import Tuple + +from moviegen.models import llama3_1B, llama3_5B, llama3_30B + +import mindspore as ms +from mindspore.communication import get_group_size, get_rank, init + +from mindone.utils.seed import set_random_seed + +__all__ = ["MODEL_SPEC", "MODEL_DTYPE", "str2bool", "check_cfgs_in_parser", "init_env"] + + +logger = logging.getLogger(__name__) + + +MODEL_SPEC = {"llama-1B": llama3_1B, "llama-5B": llama3_5B, "llama-30B": llama3_30B} + +MODEL_DTYPE = { + "fp32": ms.float32, + "fp16": ms.float16, + "bf16": ms.bfloat16, +} + + +def str2bool(b: str) -> bool: + if b.lower() not in ["false", "true"]: + raise Exception("Invalid Bool Value") + if b.lower() in ["false"]: + return False + return True + + +def check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser) -> None: + actions_dest = [action.dest for action in parser._actions] + defaults_key = parser._defaults.keys() + for k in cfgs.keys(): + if k not in actions_dest and k not in defaults_key: + raise KeyError(f"{k} does not exist in ArgumentParser!") + + +def init_env(args) -> Tuple[int, int]: + set_random_seed(args.seed) + ms.set_context(mode=args.mode, device_target=args.device_target, jit_config=dict(jit_level=args.jit_level)) + if args.use_parallel: + init() + device_num = get_group_size() + rank_id = get_rank() + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num + ) + else: + device_num, rank_id = 1, 0 + + return device_num, rank_id diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/moviegen/utils/model_utils.py new file mode 100644 index 0000000000..d70d940b5d --- /dev/null +++ b/examples/moviegen/moviegen/utils/model_utils.py @@ -0,0 +1,32 @@ +import logging +from typing import Dict, Tuple, Union + +import mindspore as ms +import mindspore.nn as nn + +__all__ = ["load_ckpt_params", "count_params"] + +logger = logging.getLogger(__name__) + + +def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: + if isinstance(ckpt, str): + logger.info(f"Loading {ckpt} params into network...") + param_dict = ms.load_checkpoint(ckpt) + else: + param_dict = ckpt + + param_not_load, ckpt_not_load = ms.load_param_into_net(model, param_dict) + if not (len(param_not_load) == len(ckpt_not_load) == 0): + logger.warning( + "Exist ckpt params not loaded: {} (total: {}), or net params not loaded: {} (total: {})".format( + ckpt_not_load, len(ckpt_not_load), param_not_load, len(param_not_load) + ) + ) + return model + + +def count_params(model: nn.Cell) -> Tuple[int, int]: + total_params = sum([param.size for param in model.get_parameters()]) + trainable_params = sum([param.size for param in model.trainable_params()]) + return total_params, trainable_params diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py new file mode 100644 index 0000000000..582ca635f6 --- /dev/null +++ b/examples/moviegen/train.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python +import argparse +import logging +import os +import sys + +import yaml + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Model +from mindspore.dataset import GeneratorDataset + +# TODO: remove in future when mindone is ready for install +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.insert(0, mindone_lib_path) + +from moviegen.dataset import ImageDataset +from moviegen.pipelines import DiffusionWithLoss +from moviegen.schedulers import RFlowLossWrapper +from moviegen.utils import ( + MODEL_DTYPE, + MODEL_SPEC, + LossMonitor, + SaveCkptCallback, + TimeMonitor, + check_cfgs_in_parser, + count_params, + init_env, + load_ckpt_params, + str2bool, +) +from transformers import AutoTokenizer + +from mindone.diffusers import AutoencoderKL +from mindone.trainers.optim import create_optimizer +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.transformers import T5EncoderModel +from mindone.utils.logger import set_logger + +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Movie-Gen Training script", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-c", + "--config", + help="Path to load a config yaml file that describes the setting which will override the default arguments.", + ) + parser.add_argument("--json_path", required=True, help="path to json annotation file.") + parser.add_argument("--model_version", default="llama-1B", choices=["llama-1B", "llama-5B", "llama-30B"]) + parser.add_argument("--image_dir", required=True, help="Directory storing the image directory.") + parser.add_argument("--output_path", default="./output", help="Output directory to save the training result.") + + parser.add_argument("--batch_size", default=64, type=int, help="Training batch size.") + parser.add_argument("--num_parallel_workers", default=4, type=int, help="Number of workers for data loading.") + parser.add_argument("--checkpoint", default="", help="The path to the PixArt checkpoint.") + parser.add_argument("--vae_root", default="models/vae", help="Path storing the VAE checkpoint and configure file.") + parser.add_argument( + "--tokenizer_root", default="models/tokenizer", help="Path storing the T5 checkpoint and configure file." + ) + parser.add_argument( + "--text_encoder_root", default="models/text_encoder", help="Path storing the T5 tokenizer and configure file." + ) + parser.add_argument("--t5_max_length", default=300, type=int, help="T5's embedded sequence length.") + parser.add_argument( + "--scale_factor", default=0.13025, type=float, help="VAE scale factor of Stable Diffusion network." + ) + parser.add_argument( + "--text_drop_prob", + default=0.2, + type=float, + help="The probability of using drop text label", + ) + + parser.add_argument("--device_target", default="Ascend", choices=["Ascend"], help="Device target.") + parser.add_argument("--mode", default=0, type=int, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).") + parser.add_argument("--jit_level", default="O0", choices=["O0", "O1"], help="Jit Level") + parser.add_argument("--seed", default=42, type=int, help="Training seed.") + + parser.add_argument( + "--enable_flash_attention", default=True, type=str2bool, help="whether to enable flash attention." + ) + parser.add_argument( + "--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="what data type to use for network." + ) + parser.add_argument("--scheduler", default="constant", choices=["constant"], help="LR scheduler.") + parser.add_argument("--start_learning_rate", default=1e-4, type=float, help="The learning rate.") + parser.add_argument("--warmup_steps", default=1000, type=int, help="Warmup steps.") + parser.add_argument("--epochs", default=200, type=int, help="Number of total training epochs.") + parser.add_argument("--optim", default="adamw", type=str, choices=["adamw"], help="Optimizer name.") + parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight decay.") + parser.add_argument( + "--loss_scaler_type", + default="static", + choices=["static", "dynamic"], + help="Use dynamic or static loss scaler.", + ) + parser.add_argument("--init_loss_scale", default=1.0, type=float, help="Loss scale.") + parser.add_argument("--scale_window", default=1000, type=int, help="Loss scale window.") + parser.add_argument("--loss_scale_factor", default=2.0, type=float, help="Loss scale factor.") + parser.add_argument("--use_ema", default=False, type=str2bool, help="Whether to use EMA") + parser.add_argument("--ema_rate", default=0.9999, type=float, help="EMA Rate.") + parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="Drop overflow update.") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Gradient accumulation steps.") + parser.add_argument("--clip_grad", default=True, type=str2bool, help="Whether apply gradient clipping.") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm for clipping, effective when `clip_grad` enabled.", + ) + parser.add_argument("--ckpt_max_keep", default=3, type=int, help="Maximum number of checkpoints to keep") + parser.add_argument("--ckpt_save_interval", default=1, type=int, help="Save checkpoint every this epochs or steps.") + parser.add_argument("--log_loss_interval", default=1, type=int, help="Log interval of loss value.") + parser.add_argument("--recompute", default=False, type=str2bool, help="Use recompute during training.") + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel training.") + default_args = parser.parse_args() + abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) + if default_args.config: + logger.info(f"Overwrite default arguments with configuration file {default_args.config}") + default_args.config = os.path.join(abs_path, default_args.config) + with open(default_args.config, "r") as f: + cfg = yaml.safe_load(f) + check_cfgs_in_parser(cfg, parser) + parser.set_defaults(**cfg) + args = parser.parse_args() + return args + + +def main(args): + if not os.path.isdir(args.output_path): + os.makedirs(args.output_path) + + # 1. init env + device_num, rank_id = init_env(args) + set_logger(output_dir=os.path.join(args.output_path, "logs"), rank=rank_id) + + # 2. model initialize and weight loading + # 2.1 PixArt + image_size = args.sample_size * 8 + logger.info(f"{image_size}x{image_size} init") + + attn_implementation = "flash_attention" if args.enable_flash_attention else "eager" + + network = MODEL_SPEC[args.model_version]( + gradient_checkpointing=args.recompute, attn_implementation=attn_implementation, dtype=MODEL_DTYPE[args.dtype] + ) + + if args.checkpoint: + network = load_ckpt_params(network, args.checkpoint) + else: + logger.info("Initialize network randomly.") + + # 2.2 VAE + logger.info("vae init") + vae = AutoencoderKL.from_pretrained(args.vae_root, mindspore_dtype=MODEL_DTYPE[args.dtype]) + + # 2.3 T5 + logger.info("text encoder init") + text_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_root, model_max_length=args.t5_max_length) + text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_root, mindspore_dtype=MODEL_DTYPE[args.dtype]) + + # 2.4 LossWrapper + rflow_loss_wrapper = RFlowLossWrapper(network) + + # 3. build training network + latent_diffusion_with_loss = DiffusionWithLoss( + rflow_loss_wrapper, vae, text_encoder, scale_factor=args.scale_factor + ) + + # 4. build dataset + dataset = ImageDataset( + args.json_path, + args.image_dir, + image_size, + text_tokenizer, + text_drop_prob=args.text_drop_prob, + ) + data_generator = GeneratorDataset( + dataset, + column_names=["image", "text"], + column_types=[ms.float32, ms.int64], + shuffle=True, + num_parallel_workers=args.num_parallel_workers, + num_shards=device_num, + shard_id=rank_id, + max_rowsize=-1, + ) + data_generator = data_generator.batch(args.batch_size, drop_remainder=True) + + # 5. build training utils: lr, optim, callbacks, trainer + # 5.1 LR + lr = nn.WarmUpLR(learning_rate=args.start_learning_rate, warmup_steps=args.warmup_steps) + + # 5.2 optimizer + optim = "adamw_re" if args.optim == "adamw" else args.optim + eps = args.adamw_eps if args.optim == "adamw" else args.came_eps + betas = None if args.optim == "adamw" else args.came_betas + optimizer = create_optimizer( + latent_diffusion_with_loss.trainable_params(), + name=optim, + lr=lr, + weight_decay=args.weight_decay, + betas=betas, + eps=eps, + ) + + if args.loss_scaler_type == "dynamic": + loss_scaler = nn.DynamicLossScaleUpdateCell( + loss_scale_value=args.init_loss_scale, scale_factor=args.loss_scale_factor, scale_window=args.scale_window + ) + else: + loss_scaler = nn.FixedLossScaleUpdateCell(args.init_loss_scale) + + # 5.3 trainer (standalone and distributed) + if args.use_ema: + raise NotImplementedError("`EMA` does not support yet.") + # ema = EMA(latent_diffusion_with_loss.network, ema_decay=args.ema_rate) + else: + ema = None + + net_with_grads = TrainOneStepWrapper( + latent_diffusion_with_loss, + optimizer=optimizer, + scale_sense=loss_scaler, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=ema, + ) + + model = Model(net_with_grads) + + # 5.4 callbacks + callbacks = [ + TimeMonitor(), + LossMonitor(log_interval=args.log_loss_interval), + SaveCkptCallback( + output_dir=os.path.join(args.output_path, "ckpt"), + ckpt_max_keep=args.ckpt_max_keep, + ckpt_save_interval=args.ckpt_save_interval, + save_ema=args.use_ema, + rank_id=rank_id, + ), + ] + + if rank_id == 0: + num_params_vae, num_params_trainable_vae = count_params(vae) + num_params_network, num_params_trainable_network = count_params(network) + num_params_text_encoder, num_params_trainable_text_encoder = count_params(text_encoder) + num_params = num_params_vae + num_params_network + num_params_text_encoder + num_params_trainable = ( + num_params_trainable_vae + num_params_trainable_network + num_params_trainable_text_encoder + ) + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"JIT level: {args.jit_level}", + f"Distributed mode: {args.use_parallel}", + f"Data path: {args.json_path}", + f"Number of samples: {len(dataset)}", + f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,}, text_encoder: {num_params_text_encoder:,})", + f"Num trainable params: {num_params_trainable:,}", + f"Model type: {args.dtype}", + f"Learning rate: {args.start_learning_rate:.7f}", + f"Batch size: {args.batch_size}", + f"Image size: {image_size}", + f"Weight decay: {args.weight_decay}", + f"Grad accumulation steps: {args.gradient_accumulation_steps}", + f"Num epochs: {args.epochs}", + f"Loss scaler: {args.loss_scaler_type}", + f"Init loss scale: {args.init_loss_scale}", + f"Grad clipping: {args.clip_grad}", + f"Max grad norm: {args.max_grad_norm}", + f"EMA: {args.use_ema}", + f"Enable flash attention: {args.enable_flash_attention}", + ] + ) + key_info += "\n" + "=" * 50 + print(key_info) + + with open(os.path.join(args.output_path, "args.yaml"), "w") as f: + yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + + # 6. train + logger.info("Start training...") + model.train(args.epochs, data_generator, callbacks=callbacks) + + +if __name__ == "__main__": + args = parse_args() + main(args) From bb01342a0efe97189119c1a1f06fe769cfd2fde2 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 29 Oct 2024 18:45:42 +0800 Subject: [PATCH 011/122] move config file outside the folder --- .../{moviegen => }/configs/train/moviegen-256x256-t2i.yaml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/moviegen/{moviegen => }/configs/train/moviegen-256x256-t2i.yaml (100%) diff --git a/examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml b/examples/moviegen/configs/train/moviegen-256x256-t2i.yaml similarity index 100% rename from examples/moviegen/moviegen/configs/train/moviegen-256x256-t2i.yaml rename to examples/moviegen/configs/train/moviegen-256x256-t2i.yaml From 3d279be894880590895176c1315861be58544cd0 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 30 Oct 2024 14:48:53 +0800 Subject: [PATCH 012/122] temp save --- examples/movie_gen/mg/models/tae/modules.py | 49 ++++++++++++---- examples/movie_gen/mg/models/tae/tae.py | 17 +++--- examples/movie_gen/tests/test_tae.py | 58 +++++++++++++++---- .../opensora_hpcai/opensora/models/vae/vae.py | 2 - .../opensora_hpcai/tools/mem_monitor/plot.py | 17 +++--- 5 files changed, 102 insertions(+), 41 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index b59cf0c0d1..c1e2a4a8ad 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -1,7 +1,7 @@ import logging -from packaging import version import numpy as np +from packaging import version import mindspore as ms from mindspore import nn, ops @@ -188,7 +188,7 @@ def construct(self, x): return x -class Upsample(nn.Cell): +class SpatialUpsample(nn.Cell): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -198,12 +198,27 @@ def __init__(self, in_channels, with_conv): ) def construct(self, x): + ''' + x: (b c t h w) + return: (b c t h w) + ''' + B, Ci, T, Hi, Wi = x.shape + # (b c t h w) -> (b t c h w) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + # (b t c h w) -> (b*t c h w) + x = ops.reshape(x, (B*T, Ci, Hi, Wi)) + in_shape = x.shape[-2:] out_shape = tuple(2 * x for x in in_shape) x = ops.ResizeNearestNeighbor(out_shape)(x) if self.with_conv: x = self.conv(x) + + _, Co, Ho, Wo = x.shape + x = ops.reshape(x, (B, T, Co, Ho, Wo)) + x = ops.transpose(x, (0, 2, 1, 3, 4)) + return x @@ -593,11 +608,14 @@ def __init__( down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - down.downsample_spat = SpatialDownsample(block_in, resamp_with_conv) - down.downsample_temp = TemporalDownsample(block_in) + # down.downsample_spat = SpatialDownsample(block_in, resamp_with_conv) + # down.downsample_temp = TemporalDownsample(block_in) + down.downsample = nn.SequentialCell( + SpatialDownsample(block_in, resamp_with_conv), + TemporalDownsample(block_in), + ) else: - down.downsample_spat = nn.Identity() - down.downsample_temp = nn.Identity() + down.downsample = nn.Identity() curr_res = curr_res // 2 down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") self.down.append(down) @@ -645,8 +663,7 @@ def construct(self, x): h = self.down[i_level].attn[i_block](h) hs = h if i_level != self.num_resolutions - 1: - hs = self.down[i_level].downsample_spat(hs) - hs = self.down[i_level].downsample_temp(hs) + hs = self.down[i_level].downsample(hs) # middle h = hs @@ -701,7 +718,7 @@ def __init__( _logger.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = nn.Conv2d( + self.conv_in = Conv2_5d( z_channels, block_in, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ) @@ -736,7 +753,10 @@ def __init__( up.block = block up.attn = attn if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) + up.upsample = nn.SequentialCell( + SpatialUpsample(block_in, resamp_with_conv), + TemporalUpsample(block_in), + ) else: up.upsample = nn.Identity() curr_res = curr_res * 2 @@ -748,9 +768,16 @@ def __init__( # end self.norm_out = Normalize(block_in) - self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) + self.conv_out = Conv2_5d(block_in, out_ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) def construct(self, z): + ''' + Args: + x: (b c t h w) + Returns: + (b c t h w) + ''' + # z to block_in h = self.conv_in(z) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 8288c66229..1467359753 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -1,7 +1,6 @@ import mindspore as ms from mindspore import nn, ops - SDXL_CONFIG = { "double_z": True, "z_channels": 4, @@ -24,29 +23,28 @@ class VideoAutoencoder(nn.Cell): config (`dict`): config dict pretrained (`str`): checkpoint path """ + def __init__( - self, - config: dict=SDXL_CONFIG, - pretrained: str=None, - ): + self, + config: dict = SDXL_CONFIG, + pretrained: str = None, + ): super().__init__() # encoder self.encoder = Encoder(**config) # quant and post quant - self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) - self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1, pad_mode="valid", has_bias=True) + self.quant_conv = Conv2_5d(2 * config["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) + self.post_quant_conv = Conv2_5d(embed_dim, config["z_channels"], 1, pad_mode="valid", has_bias=True) # decoder self.decoder = Decoder(**config) def encode(self, x: ms.Tensor) -> ms.Tensor: - return x def decode(self, x: ms.Tensor) -> ms.Tensor: - return x def construct(self, x: ms.Tensor) -> ms.Tensor: @@ -57,4 +55,3 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: """ return x - diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index ffe166e65d..e16c0ba27a 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -1,8 +1,21 @@ import numpy as np +from mg.models.tae.modules import ( + Conv2_5d, + Decoder, + Encoder, + ResnetBlock, + SpatialAttnBlock, + SpatialAttnBlockV2, + SpatialDownsample, + SpatialUpsample, + TemporalAttnBlock, + TemporalDownsample, + TemporalUpsample, +) +from mg.models.tae.tae import SDXL_CONFIG + import mindspore as ms -from mg.models.tae.modules import Conv2_5d, ResnetBlock, SpatialAttnBlock, SpatialAttnBlockV2, TemporalAttnBlock, TemporalUpsample, TemporalDownsample, SpatialDownsample, Encoder -from mg.models.tae.tae import SDXL_CONFIG def test_conv25d(): in_shape = (B, C, T, H, W) = (2, 3, 16, 256, 256) @@ -17,16 +30,17 @@ def test_conv25d(): print(y.shape) + def test_resnetblock(): in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) cout = C x = np.random.normal(size=in_shape).astype(np.float32) rb = ResnetBlock( - in_channels=C, - out_channels=cout, - dropout=0., - ) + in_channels=C, + out_channels=cout, + dropout=0.0, + ) ms.set_context(mode=0) x = ms.Tensor(x) @@ -70,6 +84,7 @@ def test_temporal_attn(): print(y.shape) print(y.mean(), y.std()) + def test_spatial_downsample(): # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) @@ -82,6 +97,18 @@ def test_spatial_downsample(): print(y.shape) +def test_spatial_upsample(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) + x = np.random.normal(size=in_shape).astype(np.float32) + su = SpatialUpsample(C, True) + + x = ms.Tensor(x) + y = su(x) + + print(y.shape) + + def test_temporal_downsample(): # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) @@ -96,7 +123,6 @@ def test_temporal_downsample(): print(y.shape) - def test_temporal_upsample(): # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) @@ -123,6 +149,17 @@ def test_encoder(): print(y.shape) +def test_decoder(): + # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 4, 4, 16, 16) + x = np.random.normal(size=in_shape).astype(np.float32) + dec = Decoder(**SDXL_CONFIG) + + x = ms.Tensor(x) + y = dec(x) + + print(y.shape) + if __name__ == "__main__": # test_conv25d() @@ -131,7 +168,8 @@ def test_encoder(): # test_temporal_attn() # test_spatial_downsample() # test_temporal_downsample() - # test_temporal_upsample() - test_encoder() - + # test_encoder() + # test_temporal_upsample() + # test_spatial_upsample() + test_decoder() diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index d26b89d4e5..195f0dc36e 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -483,5 +483,3 @@ def encode_with_moments_output(self, x): std = self.exp(0.5 * logvar) return mean, std - - diff --git a/examples/opensora_hpcai/tools/mem_monitor/plot.py b/examples/opensora_hpcai/tools/mem_monitor/plot.py index fa76dffe47..bb5d4588a6 100644 --- a/examples/opensora_hpcai/tools/mem_monitor/plot.py +++ b/examples/opensora_hpcai/tools/mem_monitor/plot.py @@ -1,20 +1,21 @@ -import pandas as pd -import matplotlib.pyplot as plt import sys +import matplotlib.pyplot as plt +import pandas as pd + # Read the log file -data = pd.read_csv("memory_usage.log", parse_dates=['Timestamp']) +data = pd.read_csv("memory_usage.log", parse_dates=["Timestamp"]) # Plotting the memory usage plt.figure(figsize=(10, 5)) -plt.plot(data['Timestamp'], data['Memory_Usage_Percentage'], label='Memory Usage (%)', color='blue') -plt.title('Memory Usage Percentage Over Time') -plt.xlabel('Time') -plt.ylabel('Memory Usage (%)') +plt.plot(data["Timestamp"], data["Memory_Usage_Percentage"], label="Memory Usage (%)", color="blue") +plt.title("Memory Usage Percentage Over Time") +plt.xlabel("Time") +plt.ylabel("Memory Usage (%)") plt.xticks(rotation=45) plt.ylim(0, 100) # Set y-axis limits from 0 to 100% plt.grid() plt.legend() plt.tight_layout() -plt.savefig('memory_usage_plot.png') +plt.savefig("memory_usage_plot.png") plt.show() From 2cc907caffca304d26fcd71c0d9119e9b83f3b0b Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 30 Oct 2024 18:32:16 +0800 Subject: [PATCH 013/122] change some ops to mint --- examples/moviegen/moviegen/models/llama/block.py | 14 +++++++++----- examples/moviegen/moviegen/models/llama/network.py | 8 ++++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py index 7f10826401..8d01673a32 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -197,8 +197,8 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) - attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = mint.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = mint.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = mint.matmul(attn_weights, value_states) attn_output = mint.permute(attn_output, (0, 2, 1, 3)) @@ -272,8 +272,8 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) - attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = mint.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = mint.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = mint.matmul(attn_weights, value_states) attn_output = mint.permute(attn_output, (0, 2, 1, 3)) @@ -475,7 +475,11 @@ def __init__( mint.nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype), ) self.frequency_embedding_size = frequency_embedding_size - self.dtype = dtype + self._dtype = dtype + + @property + def dtype(self): + return self._dtype @staticmethod def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor: diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index f1a3daf485..8fe8fac923 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -103,7 +103,9 @@ def construct( hidden_states = hidden_states + position_embedding # 3.1.3 Adaptive Layer Norm - modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + modulation_parameters.reshape(B, 6, -1) + modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + ops.reshape( + modulation_parameters, (B, 6, -1) + ) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) # Self Attention (Bi-Directional Attention) @@ -210,7 +212,9 @@ def construct( hidden_states = hidden_states + position_embedding # 3.1.3 Adaptive Layer Norm - modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + modulation_parameters.reshape(B, 6, -1) + modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + ops.reshape( + modulation_parameters, (B, 6, -1) + ) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) # Self Attention (Bi-Directional Attention) From 888a2134b1bf3c1ecc49c7f3c848760318c8013b Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 30 Oct 2024 18:40:26 +0800 Subject: [PATCH 014/122] add init for text projector --- .../models/text_encoders/text_projector.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/examples/moviegen/moviegen/models/text_encoders/text_projector.py b/examples/moviegen/moviegen/models/text_encoders/text_projector.py index 2b6b5b1945..3cf4d7dadd 100644 --- a/examples/moviegen/moviegen/models/text_encoders/text_projector.py +++ b/examples/moviegen/moviegen/models/text_encoders/text_projector.py @@ -3,6 +3,8 @@ import mindspore as ms from mindspore import Tensor, mint, nn +from mindone.models.utils import normal_, zeros_ + class TextProjector(nn.Cell): def __init__( @@ -13,6 +15,8 @@ def __init__( out_features: int = 6144, layer_norm: Type[nn.Cell] = mint.nn.LayerNorm, norm_eps: float = 1e-5, + initializer_range: float = 0.02, + post_init_weight: bool = True, dtype: ms.Type = ms.float32, ): super().__init__() @@ -34,6 +38,23 @@ def __init__( layer_norm((out_features,), eps=norm_eps, dtype=dtype), ] ) + self.initializer_range = initializer_range + + # post-init + if post_init_weight: + self.initializer_range = initializer_range + self.init_weights() + + def init_weights(self): + std = self.initializer_range + + def _init_weights(module): + if isinstance(module, mint.nn.Linear): + normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + zeros_(module.weight) + + self.apply(_init_weights) def construct(self, ul2_text: Tensor, metaclip_text: Tensor, byt5_text: Tensor) -> Tensor: ul2_hidden_states = self.ul2_projector(ul2_text) From d5a6406a9ae6d410bada86c5d822bd5f39aa1082 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 31 Oct 2024 09:56:42 +0800 Subject: [PATCH 015/122] fix mint --- examples/moviegen/moviegen/models/llama/block.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py index 8d01673a32..e867c2f1cf 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -11,6 +11,7 @@ import mindspore as ms import mindspore.mint as mint +import mindspore.mint.nn.functional as F import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor @@ -197,8 +198,8 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - attn_weights = mint.softmax(attn_weights, dim=-1).to(query_states.dtype) - attn_weights = mint.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = mint.matmul(attn_weights, value_states) attn_output = mint.permute(attn_output, (0, 2, 1, 3)) @@ -272,8 +273,8 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - attn_weights = mint.softmax(attn_weights, dim=-1).to(query_states.dtype) - attn_weights = mint.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = mint.matmul(attn_weights, value_states) attn_output = mint.permute(attn_output, (0, 2, 1, 3)) From 14f0a34a765ee406e43a8891ab0e497fd1a268e1 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 31 Oct 2024 10:39:48 +0800 Subject: [PATCH 016/122] fix type --- examples/moviegen/moviegen/models/llama/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py index e867c2f1cf..607b9b37aa 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Sequence, Tuple, Union from moviegen.parallel import ( ColumnParallelLinear, @@ -24,7 +24,7 @@ class LlamaRMSNorm(nn.Cell): - def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: + def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: super().__init__() self.weight = Parameter(mint.ones(hidden_size, dtype=dtype)) self.variance_epsilon = eps From 231be441e5c12222bd8dac703e0fbe35c136e0d8 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Mon, 4 Nov 2024 10:30:05 +0800 Subject: [PATCH 017/122] encoder ok --- examples/movie_gen/mg/models/tae/tae.py | 35 +++++++++++++++++++++++-- examples/movie_gen/tests/test_tae.py | 16 +++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 1467359753..208ef428ea 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -1,5 +1,6 @@ import mindspore as ms from mindspore import nn, ops +from .modules import Conv2_5d, Encoder, Decoder SDXL_CONFIG = { "double_z": True, @@ -35,17 +36,47 @@ def __init__( self.encoder = Encoder(**config) # quant and post quant + embed_dim = config['z_channels'] self.quant_conv = Conv2_5d(2 * config["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) self.post_quant_conv = Conv2_5d(embed_dim, config["z_channels"], 1, pad_mode="valid", has_bias=True) # decoder self.decoder = Decoder(**config) + self.exp = ops.Exp() + self.stdnormal = ops.StandardNormal() + self.split = ms.ops.split + self.sample_deterministic=False + + def _encode(self, x): + # return latent distribution, N(mean, logvar) + h = self.encoder(x) + moments = self.quant_conv(h) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + + return mean, logvar + + def sample(self, mean, logvar): + # sample z from latent distribution + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + z = mean + std * self.stdnormal(mean.shape) + + return z + def encode(self, x: ms.Tensor) -> ms.Tensor: - return x + # embedding, get latent representation z + posterior_mean, posterior_logvar = self._encode(x) + if self.sample_deterministic: + return posterior_mean + z = self.sample(posterior_mean, posterior_logvar) + + return z def decode(self, x: ms.Tensor) -> ms.Tensor: - return x + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec def construct(self, x: ms.Tensor) -> ms.Tensor: """ diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index e16c0ba27a..c1e3b3764c 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -12,7 +12,7 @@ TemporalDownsample, TemporalUpsample, ) -from mg.models.tae.tae import SDXL_CONFIG +from mg.models.tae.tae import SDXL_CONFIG, VideoAutoencoder import mindspore as ms @@ -161,6 +161,17 @@ def test_decoder(): print(y.shape) +def test_tae_encode(): + in_shape = (B, C, T, H, W) = (1, 3, 1, 64, 64) + x = np.random.normal(size=in_shape).astype(np.float32) + x = ms.Tensor(x) + + tae = VideoAutoencoder(config=SDXL_CONFIG) + y = tae.encode(x) + + print(y.shape) + + if __name__ == "__main__": # test_conv25d() # test_resnetblock() @@ -172,4 +183,5 @@ def test_decoder(): # test_temporal_upsample() # test_spatial_upsample() - test_decoder() + # test_decoder() + test_tae_encode() From 539437a5002b5f9bc827a9f89eff7d602486d531 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:36:18 +0800 Subject: [PATCH 018/122] add image support to OS data loader --- .../opensora/datasets/transforms.py | 110 +++--------------- .../datasets/video_dataset_refactored.py | 93 +++++++-------- 2 files changed, 55 insertions(+), 148 deletions(-) diff --git a/examples/opensora_hpcai/opensora/datasets/transforms.py b/examples/opensora_hpcai/opensora/datasets/transforms.py index c37fc37435..d69a01f2f2 100644 --- a/examples/opensora_hpcai/opensora/datasets/transforms.py +++ b/examples/opensora_hpcai/opensora/datasets/transforms.py @@ -1,102 +1,24 @@ -from typing import Tuple +from typing import Optional, Tuple import cv2 import numpy as np -from mindspore.dataset.transforms import Compose -from mindspore.dataset.vision import CenterCrop, Inter -from mindspore.dataset.vision import Resize as MSResize -from .bucket import Bucket - - -class Resize: - def __init__(self, size: Tuple[int, int], interpolation=Inter.BILINEAR): - self._h, self._w = size +class ResizeCrop: + def __init__(self, size: Optional[Tuple[int, int]] = None, interpolation=cv2.INTER_LINEAR): + self._size = size self._inter = interpolation - def __call__(self, x: np.ndarray) -> np.ndarray: - img_h, img_w = x.shape[-3:-1] # support images and videos - scale = max(self._h / img_h, self._w / img_w) - if scale != 1: - x = MSResize((round(img_h * scale), round(img_w * scale)), self._inter)(x) + def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np.ndarray: + h, w = x.shape[-3:-1] # support images and videos + th, tw = size or self._size + scale = max(th / h, tw / w) + if scale != 1: # resize + if x.ndim == 3: # if image + x = cv2.resize(x, None, fx=scale, fy=scale, interpolation=self._inter) + else: # if video + x = np.array([cv2.resize(i, None, fx=scale, fy=scale, interpolation=self._inter) for i in x]) + if x.shape[-3:-1] != (th, tw): # crop + i, j = round((x.shape[-3] - th) / 2.0), round((x.shape[-2] - tw) / 2.0) + x = x[..., i : i + th, j : j + tw, :] return x - - -class BucketResizeCrop: - def __init__(self, buckets: Bucket): - self._transforms = {} # is this reasonable? There are 350+ buckets - for name, lengths in buckets.ar_criteria.items(): - self._transforms[name] = {} - for length, ars in lengths.items(): - self._transforms[name][str(length)] = {} - for ar, hw in ars.items(): - self._transforms[name][str(length)][ar] = Compose( - [MSResize(min(hw), interpolation=Inter.BILINEAR), CenterCrop(hw)] - ) - - def __call__(self, x, bucket_id): - return self._transforms[bucket_id[0]][bucket_id[1]][bucket_id[2]](x) - - -class ResizeAndCrop: - """Resize an RGB image to a target size while preserving the aspect ratio and cropping it. - Align to resize_crop_to_fill in torch. Ensure no black surrounding produced. - """ - - def __init__(self, target_height, target_width): - super(ResizeAndCrop, self).__init__() - self.tar_h = target_height - self.tar_w = target_width - - def __call__(self, img): - # Ensure the image is in RGB format - if img.shape[2] != 3: - raise ValueError("Input image must be in RGB format with 3 channels.") - - h, w = img.shape[:2] - th, tw = self.tar_h, self.tar_w # target - rh, rw = th / h, tw / w # ratio - - if rh > rw: - # target image is thinner than the original image - new_h, new_w = th, round(w * rh) - start_y = 0 - start_x = int(round(new_w - tw) / 2.0) - else: - # target image is fatter than the original image - new_h, new_w = round(h * rw), tw - start_y = int(round(new_h - th) / 2.0) - start_x = 0 - - if rh > rw: - new_h, new_w = th, round(w * rh) - start_y = 0 - start_x = int(round(new_w - tw)) - - # Resize the image - # NOTE: for opensora v1.2, HD videos are mainly downsampled according to buckets. The best choice for down-sample interpolation is INTER_AREA. - resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) - - # Crop the image to the target size - cropped_img = resized_img[start_y : start_y + self.tar_h, start_x : start_x + self.tar_w] - - return cropped_img - - -class BucketResizeAndCrop(object): - """According to bucket config, resize an RGB image to a target size while preserving the aspect ratio and cropping it.""" - - def __init__(self, buckets): - super().__init__() - self._transforms = {} # is this reasonable? There are 350+ buckets - for name, lengths in buckets.ar_criteria.items(): - self._transforms[name] = {} - for length, ars in lengths.items(): - self._transforms[name][str(length)] = {} - for ar, hw in ars.items(): - self._transforms[name][str(length)][ar] = ResizeAndCrop(hw[0], hw[1]) - - def __call__(self, image, bucket_id=None): - resized_img = self._transforms[bucket_id[0]][str(bucket_id[1])][bucket_id[2]](image) - return resized_img diff --git a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py index 0cf7de0dcf..d35653dc0e 100644 --- a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py +++ b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py @@ -14,12 +14,11 @@ import mindspore as ms from mindspore.dataset.transforms import Compose -from mindspore.dataset.vision import CenterCrop, Inter, Normalize from mindone.data.video_reader import VideoReader as VideoReader_CV2 from .bucket import Bucket -from .transforms import BucketResizeAndCrop, BucketResizeCrop, Resize, ResizeAndCrop +from .transforms import ResizeCrop # FIXME: remove in future when mindone is ready for install sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) @@ -30,33 +29,20 @@ _logger = logging.getLogger(__name__) +IMAGE_EXT = (".jpg", ".jpeg", ".png", ".gif", ".webp") -def create_infer_transforms(target_size: Tuple[int, int], interpolation=Inter.BILINEAR): + +def create_infer_transforms(target_size: Tuple[int, int], interpolation=cv2.INTER_LINEAR): return Compose( [ - Resize(target_size, interpolation=interpolation), - CenterCrop(target_size), - lambda x: (x / 255.0).astype(np.float32), # ms.ToTensor() doesn't support 4D data - Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ResizeCrop(target_size, interpolation=interpolation), + lambda x: x.astype(np.float32) / 127.5 - 1, lambda x: x[None, ...] if x.ndim == 3 else x, # if image - lambda x: np.transpose(x, (0, 3, 1, 2)), # ms.HWC2CHW() doesn't support 4D data + lambda x: np.transpose(x, (0, 3, 1, 2)), ] ) -def create_train_transforms(target_size, buckets=None): - """ - expect rgb image in range 0-255, shape (h w c) - """ - - if buckets is None: - transforms = ResizeAndCrop(target_size[0], target_size[1]) - else: - transforms = BucketResizeAndCrop(buckets) - - return transforms - - class VideoDatasetRefactored(BaseDataset): def __init__( self, @@ -143,7 +129,7 @@ def __init__( self.apply_train_transforms = apply_train_transforms if self.apply_train_transforms: - self.pixel_transforms = create_train_transforms(target_size, buckets=buckets) + self.pixel_transforms = ResizeCrop(target_size, interpolation=cv2.INTER_AREA) if "bucket_id" in self.output_columns: self.output_columns.remove("bucket_id") assert not pre_patchify, "transforms for prepatchify not implemented yet" @@ -300,28 +286,33 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: ) # / self._stride # FIXME: OS v1.1 incorrect del reader elif self.video_backend == "cv2": - with VideoReader_CV2(video_path) as reader: - min_length = self._min_length - if self._buckets: - data["bucket_id"] = self._buckets.get_bucket_id( - T=len(reader), - H=reader.shape[1], - W=reader.shape[0], - frame_interval=self._stride, - ) - if data["bucket_id"] is None: - raise ValueError( - f"Couldn't assign a bucket to {data['video']}" - f" (T={len(reader)}, H={reader.shape[1]}, W={reader.shape[0]})." + if video_path.lower().endswith(IMAGE_EXT): + num_frames = 1 + data["fps"] = np.array(120, dtype=np.float32) # FIXME: extract as IMG_FPS + video = cv2.cvtColor(cv2.imread(data["video"]), cv2.COLOR_BGR2RGB) + else: + with VideoReader_CV2(video_path) as reader: + min_length = self._min_length + if self._buckets: + data["bucket_id"] = self._buckets.get_bucket_id( + T=len(reader), + H=reader.shape[1], + W=reader.shape[0], + frame_interval=self._stride, ) - num_frames, *_ = self._buckets.get_thw(data["bucket_id"]) - min_length = (num_frames - 1) * self._stride + 1 - - if len(reader) < min_length: - raise ValueError(f"Video is too short: {video_path}") - start_pos = random.randint(0, len(reader) - min_length) - video = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride) - data["fps"] = np.array(reader.fps, dtype=np.float32) + if data["bucket_id"] is None: + raise ValueError( + f"Couldn't assign a bucket to {data['video']}" + f" (T={len(reader)}, H={reader.shape[1]}, W={reader.shape[0]})." + ) + num_frames, *_ = self._buckets.get_thw(data["bucket_id"]) + min_length = (num_frames - 1) * self._stride + 1 + + if len(reader) < min_length: + raise ValueError(f"Video is too short: {video_path}") + start_pos = random.randint(0, len(reader) - min_length) + video = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride) + data["fps"] = np.array(reader.fps, dtype=np.float32) else: # TODO: add pyav backend and test raise NotImplementedError @@ -337,14 +328,9 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: # apply transforms on video frames here if self.apply_train_transforms: # variable resize and crop, frame-wise - clip = [] - for i in range(num_frames): - if self._buckets: - resized_img = self.pixel_transforms(video[i], bucket_id=data["bucket_id"]) - else: - resized_img = self.pixel_transforms(video[i]) - clip.append(resized_img) - clip = np.stack(clip, axis=0) + clip = self.pixel_transforms(video) + if clip.ndim == 3: + clip = np.expand_dims(clip, 0) # transpose and norm, clip-wise clip = np.transpose(clip, (0, 3, 1, 2)) @@ -424,7 +410,7 @@ def train_transforms( transforms.extend( [ { - "operations": BucketResizeCrop(self._buckets), + "operations": ResizeCrop(interpolation=cv2.INTER_AREA), "input_columns": ["video", "bucket_id"], "output_columns": ["video"], # drop `bucket_id` column }, @@ -444,8 +430,7 @@ def train_transforms( transforms.append( { "operations": [ - Resize(target_size, interpolation=Inter.BILINEAR), - CenterCrop(target_size), + ResizeCrop(target_size, interpolation=cv2.INTER_AREA), lambda x: np.divide(x, 127.5, dtype=np.float32), lambda x: np.subtract(x, 1.0, dtype=np.float32), lambda x: np.transpose(x, (0, 3, 1, 2)), From da5576bedc39c2b7a6333987f2296127cc5640fe Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:23:26 +0800 Subject: [PATCH 019/122] update convert script --- .../moviegen/tools/download_convert_st.py | 73 ++++++++++++------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/examples/moviegen/tools/download_convert_st.py b/examples/moviegen/tools/download_convert_st.py index 47684496e6..6e32186e1e 100644 --- a/examples/moviegen/tools/download_convert_st.py +++ b/examples/moviegen/tools/download_convert_st.py @@ -10,7 +10,7 @@ import requests import torch -from huggingface_hub import HfApi, configure_http_backend, hf_hub_download +from huggingface_hub import HfApi, configure_http_backend, hf_hub_download, snapshot_download from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file @@ -138,7 +138,7 @@ def convert_multi( filenames = set(data["weight_map"].values()) for filename in filenames: pt_filename = hf_hub_download( - repo_id=model_id, filename=filename, token=token, cache_dir=folder, endpoint=endpoint + model_id, revision=revision, filename=filename, token=token, cache_dir=folder, endpoint=endpoint ) sf_filename = rename(pt_filename) sf_filename = os.path.join(save_path, sf_filename) @@ -247,35 +247,52 @@ def convert( library_name = getattr(info, "library_name", None) if any(filename.endswith(".safetensors") for filename in filenames) and not force: - raise AlreadyExists(f"Model {model_id} is already converted, skipping..") - elif library_name == "transformers": - discard_names = get_discard_names( - model_id, revision=revision, folder=folder, token=api.token, endpoint=endpoint + print(f"Model {model_id} is already converted. Downloading safetensors...") + save_path = snapshot_download( # Download an entire directory, including the tokenizer config + model_id, + revision=revision, + allow_patterns=["*.safetensors", "*.json", "*.model"], + token=api.token, + cache_dir=folder, + endpoint=endpoint, ) - if "pytorch_model.bin" in filenames: - save_path = convert_single( - model_id, - revision=revision, - folder=folder, - token=api.token, - discard_names=discard_names, - endpoint=endpoint, - ) - elif "pytorch_model.bin.index.json" in filenames: - save_path = convert_multi( - model_id, - revision=revision, - folder=folder, - token=api.token, - discard_names=discard_names, - endpoint=endpoint, - ) - else: - raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") else: - save_path = convert_generic( - model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, endpoint=endpoint + snapshot_download( # Download an entire directory, including the tokenizer config + model_id, + revision=revision, + allow_patterns=["*.bin", "*.json", "*.model"], + token=api.token, + cache_dir=folder, + endpoint=endpoint, ) + if library_name == "transformers": + discard_names = get_discard_names( + model_id, revision=revision, folder=folder, token=api.token, endpoint=endpoint + ) + if "pytorch_model.bin" in filenames: + save_path = convert_single( + model_id, + revision=revision, + folder=folder, + token=api.token, + discard_names=discard_names, + endpoint=endpoint, + ) + elif "pytorch_model.bin.index.json" in filenames: + save_path = convert_multi( + model_id, + revision=revision, + folder=folder, + token=api.token, + discard_names=discard_names, + endpoint=endpoint, + ) + else: + raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") + else: + save_path = convert_generic( + model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, endpoint=endpoint + ) return save_path From ee2acef9f93ece7e12a69eb9e73463ca28ba0da8 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:06:46 +0800 Subject: [PATCH 020/122] add recompute support in PyNative --- .../moviegen/moviegen/models/llama/network.py | 3 ++- examples/moviegen/tools/patch_pynative.sh | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 examples/moviegen/tools/patch_pynative.sh diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index 8fe8fac923..293015d9a7 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -383,7 +383,8 @@ def __init__( # recompute if gradient_checkpointing: - self.layers.recompute() + for layer in self.layers: # Explicitly recompute each block for PyNative + layer.recompute() @property def dtype(self): diff --git a/examples/moviegen/tools/patch_pynative.sh b/examples/moviegen/tools/patch_pynative.sh new file mode 100644 index 0000000000..516685b0b5 --- /dev/null +++ b/examples/moviegen/tools/patch_pynative.sh @@ -0,0 +1,24 @@ +# Patch MindSpore to add support for recompute in PyNative mode + +# Find the site-packages path +SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") + +# Define the file path and the line to insert +FILE_PATH="$SITE_PACKAGES/mindspore/common/recompute.py" +LINE_AFTER=" self.wrap_cell = _WrapCell(block)" +LINE_TO_INSERT=" self.wrap_cell.set_inputs()" + +# Check if the file has already been modified +if grep -qF "$LINE_TO_INSERT" "$FILE_PATH"; then + echo "File $FILE_PATH has already been patched. No changes made." + exit 0 +fi + +# Use sed to insert the line after the specified pattern +if sed -i "/$LINE_AFTER/a \\$LINE_TO_INSERT" "$FILE_PATH" +then + echo "Successfully patched $FILE_PATH" +else + echo "Error: Failed to patch $FILE_PATH" + exit 1 +fi From 59ed0d56dd3512a213fe81ba11d7ca257c4c8c0c Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:19:31 +0800 Subject: [PATCH 021/122] add dataloader --- .../moviegen/moviegen/dataset/__init__.py | 3 +- examples/moviegen/moviegen/dataset/dataset.py | 289 ++++++++++++++++++ examples/moviegen/moviegen/dataset/image.py | 69 ----- .../moviegen/moviegen/dataset/transforms.py | 24 ++ examples/moviegen/moviegen/dataset/video.py | 18 -- 5 files changed, 314 insertions(+), 89 deletions(-) create mode 100644 examples/moviegen/moviegen/dataset/dataset.py delete mode 100644 examples/moviegen/moviegen/dataset/image.py create mode 100644 examples/moviegen/moviegen/dataset/transforms.py delete mode 100644 examples/moviegen/moviegen/dataset/video.py diff --git a/examples/moviegen/moviegen/dataset/__init__.py b/examples/moviegen/moviegen/dataset/__init__.py index 31b6c8daed..54fc7d4725 100644 --- a/examples/moviegen/moviegen/dataset/__init__.py +++ b/examples/moviegen/moviegen/dataset/__init__.py @@ -1,2 +1 @@ -from .image import ImageDataset -from .video import VideoDataset +from .dataset import ImageVideoDataset diff --git a/examples/moviegen/moviegen/dataset/dataset.py b/examples/moviegen/moviegen/dataset/dataset.py new file mode 100644 index 0000000000..9e8ed8eb64 --- /dev/null +++ b/examples/moviegen/moviegen/dataset/dataset.py @@ -0,0 +1,289 @@ +import csv +import logging +import os +import random +import sys +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +from tqdm import tqdm + +from mindspore.dataset.transforms import Compose + +from mindone.data.video_reader import VideoReader + +from .transforms import ResizeCrop + +# FIXME: remove in future when mindone is ready for install +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) +from mindone.data import BaseDataset + +_logger = logging.getLogger(__name__) + + +IMAGE_EXT = (".jpg", ".jpeg", ".png", ".gif", ".webp") + + +def create_infer_transforms(target_size: Tuple[int, int], interpolation=cv2.INTER_LINEAR): + return Compose( + [ + ResizeCrop(target_size, interpolation=interpolation), + lambda x: x.astype(np.float32) / 127.5 - 1, + lambda x: x[None, ...] if x.ndim == 3 else x, # if image + lambda x: np.transpose(x, (0, 3, 1, 2)), + ] + ) + + +class ImageVideoDataset(BaseDataset): + def __init__( + self, + csv_path: str, + video_folder: str, + text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, + vae_latent_folder: Optional[str] = None, + vae_downsample_rate: float = 8.0, + vae_scale_factor: float = 0.18215, + target_size: Optional[Tuple[int, int]] = None, + sample_n_frames: int = 17, + sample_stride: int = 1, + frames_mask_generator: Optional[Callable[[int], np.ndarray]] = None, + t_compress_func: Optional[Callable[[int], int]] = None, + filter_data: bool = False, + apply_transforms_dataset: bool = False, + *, + output_columns: List[str], + ): + if text_emb_folder is None: + raise NotImplementedError( + "Text embedding during training is not supported, please provide `text_emb_folder`." + ) + + self._data = self._read_data(video_folder, csv_path, text_emb_folder, vae_latent_folder, filter_data) + self._frames = sample_n_frames + self._stride = sample_stride + self._min_length = (self._frames - 1) * self._stride + 1 + self._text_emb_folder = text_emb_folder + self._vae_latent_folder = vae_latent_folder + self._vae_downsample_rate = vae_downsample_rate + self._vae_scale_factor = vae_scale_factor + self._fmask_gen = frames_mask_generator + self._t_compress_func = t_compress_func or (lambda x: x) + + self.output_columns = output_columns + + self._transforms = ( + self.train_transforms(target_size, interpolation=cv2.INTER_AREA) if apply_transforms_dataset else None + ) + + # prepare replacement data in case the loading of a sample fails + self._prev_ok_sample = self._get_replacement() + self._require_update_prev = False + + @staticmethod + def _read_data( + data_dir: str, + csv_path: str, + text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, + vae_latent_folder: Optional[str] = None, + filter_data: bool = False, + ) -> List[dict]: + def _filter_data(sample_): + if not os.path.isfile(sample_["video"]): + _logger.warning(f"Video not found: {sample_['video']}") + return None + if "text_emb" in sample_: + if isinstance(sample_["text_emb"], str) and not os.path.isfile(sample_["text_emb"]): + _logger.warning(f"Text embedding not found: {sample_['text_emb']}") + return None + else: + for name, path in sample_["text_emb"].items(): + if not os.path.isfile(sample_["text_emb"][name]): + _logger.warning(f"Text embedding not found: {sample_['text_emb'][name]}") + return None + if "vae_latent" in sample_ and not os.path.isfile(sample_["vae_latent"]): + _logger.warning(f"Text embedding not found: {sample_['vae_latent']}") + return None + return sample_ + + with open(csv_path, "r") as csv_file: + try: + data = [] + for item in csv.DictReader(csv_file): + sample = {**item, "video": os.path.join(data_dir, item["video"])} + if text_emb_folder: + if isinstance(text_emb_folder, str): + sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz")) + else: + sample["text_emb"] = { + name: os.path.join(path, Path(item["video"]).with_suffix(".npz")) + for name, path in text_emb_folder.items() + } + if vae_latent_folder: + sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npz")) + data.append(sample) + except KeyError as e: + _logger.error(f"CSV file requires `video` (file paths) column, but got {list(item.keys())}") + raise e + + if filter_data: + with ThreadPoolExecutor(max_workers=10) as executor: + data = [ + item + for item in tqdm(executor.map(_filter_data, data), total=len(data), desc="Filtering data") + if item is not None + ] + + _logger.info(f"Number of data samples: {len(data)}") + return data + + def _get_replacement(self, max_attempts: int = 100) -> Tuple[Any, ...]: + attempts, error = min(max_attempts, len(self)), None + for idx in range(attempts): + try: + return self._get_item(idx) + except Exception as e: + error = e + _logger.debug(f"Failed to load a replacement sample: {repr(e)}") + + raise RuntimeError(f"Fail to load a replacement sample in {attempts} attempts. Error: {repr(error)}") + + def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tuple[Any, ...]: + data = self._data[idx].copy() + num_frames = self._frames + + if self._text_emb_folder: + if isinstance(data["text_emb"], str): + with np.load(data["text_emb"]) as td: + data.update({"caption": td["text_emb"], "mask": td["mask"]}) + else: + for enc_name, path in data["text_emb"].items(): + with np.load(path) as td: + data.update({enc_name + "_caption": td["text_emb"], enc_name + "_mask": td["mask"]}) + + if self._vae_latent_folder: + # TODO: add support for images + vae_latent_data = np.load(data["vae_latent"]) + latent_mean, latent_std = vae_latent_data["latent_mean"], vae_latent_data["latent_std"] + if len(latent_mean) < self._min_length: + raise ValueError(f"Video is too short: {data['video']}") + + if "fps" not in data: + if "fps" in vae_latent_data: + data["fps"] = vae_latent_data["fps"] + else: + with VideoReader(data["video"]) as reader: + data["fps"] = reader.fps + data["fps"] = np.array(data["fps"] / self._stride, dtype=np.float32) + + start_pos = random.randint(0, len(latent_mean) - self._min_length) + batch_index = np.linspace(start_pos, start_pos + self._min_length - 1, num_frames, dtype=int) + + latent_mean, latent_std = latent_mean[batch_index], latent_std[batch_index] + vae_latent = latent_mean + latent_std * np.random.standard_normal(latent_mean.shape) + data["video"] = vae_latent * self._vae_scale_factor + + else: + if data["video"].lower().endswith(IMAGE_EXT): + num_frames = 1 + data["fps"] = np.array(120, dtype=np.float32) # FIXME: extract as IMG_FPS + data["video"] = cv2.cvtColor(cv2.imread(data["video"]), cv2.COLOR_BGR2RGB) + else: + with VideoReader(data["video"]) as reader: + min_length = self._min_length + if thw is not None: + num_frames, *data["size"] = thw + min_length = (num_frames - 1) * self._stride + 1 + if len(reader) < min_length: + raise ValueError(f"Video is too short: {data['video']}") + start_pos = random.randint(0, len(reader) - min_length) + data["video"] = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride) + data["fps"] = np.array(reader.fps / self._stride, dtype=np.float32) + + data["num_frames"] = np.array(num_frames, dtype=np.float32) + + if self._fmask_gen is not None: + # return frames mask with respect to the VAE's latent temporal compression + data["frames_mask"] = self._fmask_gen(self._t_compress_func(num_frames)) + + if self._transforms: + data = self._apply_transforms(data) + + return tuple(data[c] for c in self.output_columns) + + def get_bucket(self, thw: Tuple[int, int, int], sample_ids: List[int]) -> Tuple[Any, ...]: + batch = [self._get_item(sample_id, thw) for sample_id in sample_ids] + return tuple(np.stack(item) for item in map(list, zip(*batch))) + + def __getitem__(self, idx: int) -> Tuple[Any, ...]: + try: + sample = self._get_item(idx) + if self._require_update_prev: + self._prev_ok_sample = sample + self._require_update_prev = False + except Exception as e: + _logger.warning(f"Failed to fetch sample #{idx}, the video will be replaced. Error: {e}") + sample = self._prev_ok_sample + self._require_update_prev = True + + return sample + + def __len__(self): + return len(self._data) + + def _apply_transforms(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + for transform in self._transforms: + input_data = tuple(data[column] for column in transform["input_columns"]) + for op in transform["operations"]: + input_data = op(*input_data) + if not isinstance(input_data, tuple): # wrap numpy array in a tuple + input_data = (input_data,) + data.update(zip(transform.get("output_columns", transform["input_columns"]), input_data)) + return data + + def train_transforms( + self, + target_size: Tuple[int, int], + interpolation: int = cv2.INTER_LINEAR, + tokenizer: Optional[Callable[[str], np.ndarray]] = None, + ) -> List[dict]: + transforms = [] + vae_downsample_rate = self._vae_downsample_rate + + if not self._vae_latent_folder: + vae_downsample_rate = 1 + transforms.append( + { + "operations": [ + ResizeCrop(target_size, interpolation=interpolation), + lambda x: x.astype(np.float32) / 127.5 - 1, + lambda x: np.transpose(x, (0, 3, 1, 2)), + ], + "input_columns": ["video"], + } + ) + # the followings are not transformation for video frames, can be excluded + transforms.append( + { + "operations": [ + lambda video: ( + video, # need to return the video itself to preserve the column + np.array(video.shape[-2] * vae_downsample_rate, dtype=np.float32), + np.array(video.shape[-1] * vae_downsample_rate, dtype=np.float32), + np.array(video.shape[-2] / video.shape[-1], dtype=np.float32), + ) + ], + "input_columns": ["video"], + "output_columns": ["video", "height", "width", "ar"], + } + ) + + if "caption" in self.output_columns and not self._text_emb_folder: + if tokenizer is None: + raise RuntimeError("Please provide a tokenizer for text data in `train_transforms()`.") + transforms.append({"operations": [tokenizer], "input_columns": ["caption"]}) + + return transforms diff --git a/examples/moviegen/moviegen/dataset/image.py b/examples/moviegen/moviegen/dataset/image.py deleted file mode 100644 index 88b4610e99..0000000000 --- a/examples/moviegen/moviegen/dataset/image.py +++ /dev/null @@ -1,69 +0,0 @@ -import json -import logging -import os -import random -from typing import Tuple - -import numpy as np -from PIL import Image -from transformers import AutoTokenizer - -from mindspore.dataset.transforms import Compose, vision - -logger = logging.getLogger(__name__) - - -class ImageDataset: - def __init__( - self, - json_path: str, - image_dir: str, - image_size: int, - tokenizer: AutoTokenizer, - text_drop_prob: float = 0.2, - ) -> None: - logger.info(f"loading annotations from `{json_path}`.") - with open(json_path, "r") as f: - self.dataset = json.load(f) - - self.length = len(self.dataset) - - self.image_dir = image_dir - self.tokenizer = tokenizer - self.text_drop_prob = text_drop_prob - self.interpolation_mode = vision.Inter.BILINEAR - self.transform = self.create_transform(image_size, self.interpolation_mode) - - def __len__(self) -> int: - return self.length - - def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - record = self.dataset[idx] - image_path = os.path.join(self.image_dir, record["path"]) - - if random.random() < self.text_drop_prob: - text = "" - else: - text = record["prompt"] - - # process text - encoding = self.tokenizer(text, padding="max_length", truncation=True, return_tensors="np") - text_ids = encoding.input_ids[0] - - # process image - image = Image.open(image_path).convert("RGB") - - image = self.transform(image)[0] - image = np.expand_dims(image, axis=0) # 1, C, H, W - return image, text_ids - - @staticmethod - def create_transform(image_size: int, interpolation: vision.Inter) -> Compose: - return Compose( - [ - vision.Resize(image_size, interpolation=interpolation), - vision.CenterCrop(image_size), - vision.ToTensor(), - vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], is_hwc=False), - ] - ) diff --git a/examples/moviegen/moviegen/dataset/transforms.py b/examples/moviegen/moviegen/dataset/transforms.py new file mode 100644 index 0000000000..d69a01f2f2 --- /dev/null +++ b/examples/moviegen/moviegen/dataset/transforms.py @@ -0,0 +1,24 @@ +from typing import Optional, Tuple + +import cv2 +import numpy as np + + +class ResizeCrop: + def __init__(self, size: Optional[Tuple[int, int]] = None, interpolation=cv2.INTER_LINEAR): + self._size = size + self._inter = interpolation + + def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np.ndarray: + h, w = x.shape[-3:-1] # support images and videos + th, tw = size or self._size + scale = max(th / h, tw / w) + if scale != 1: # resize + if x.ndim == 3: # if image + x = cv2.resize(x, None, fx=scale, fy=scale, interpolation=self._inter) + else: # if video + x = np.array([cv2.resize(i, None, fx=scale, fy=scale, interpolation=self._inter) for i in x]) + if x.shape[-3:-1] != (th, tw): # crop + i, j = round((x.shape[-3] - th) / 2.0), round((x.shape[-2] - tw) / 2.0) + x = x[..., i : i + th, j : j + tw, :] + return x diff --git a/examples/moviegen/moviegen/dataset/video.py b/examples/moviegen/moviegen/dataset/video.py deleted file mode 100644 index 50e49d9128..0000000000 --- a/examples/moviegen/moviegen/dataset/video.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Tuple - -import numpy as np - - -class VideoDataset: - def __len__(self) -> int: - return NotImplementedError() - - def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Returns: - video/video caching - text embedding 1 - text embedding 1 - text embedding 1 - """ - raise NotImplementedError() From 2fe103b077c8577208e26fadccd44816c1b5f8a2 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:50:45 +0800 Subject: [PATCH 022/122] update train script --- .../configs/train/moviegen-256x256-t2i.yaml | 27 -- .../configs/train/moviegen_t2i_256x256.yaml | 70 ++++ .../moviegen/moviegen/models/llama/block.py | 22 +- .../moviegen/moviegen/models/llama/network.py | 20 +- .../models/text_encoders/text_projector.py | 8 +- .../moviegen/pipelines/train_pipeline.py | 19 +- .../moviegen/schedulers/rectified_flow.py | 17 +- examples/moviegen/moviegen/utils/__init__.py | 1 + examples/moviegen/moviegen/utils/ema.py | 24 ++ examples/moviegen/moviegen/utils/misc.py | 45 +-- .../moviegen/moviegen/utils/model_utils.py | 27 +- examples/moviegen/requirements.txt | 1 + .../moviegen/scripts/train_t2i_256x256.sh | 26 ++ examples/moviegen/train.py | 375 +++++++----------- .../opensora/models/stdit/stdit_llama3.py | 21 +- mindone/data/loader.py | 3 +- mindone/trainers/ema.py | 11 +- mindone/trainers/train_step.py | 5 +- mindone/trainers/zero.py | 3 +- mindone/utils/env.py | 18 +- 20 files changed, 366 insertions(+), 377 deletions(-) delete mode 100644 examples/moviegen/configs/train/moviegen-256x256-t2i.yaml create mode 100644 examples/moviegen/configs/train/moviegen_t2i_256x256.yaml create mode 100644 examples/moviegen/moviegen/utils/ema.py create mode 100644 examples/moviegen/requirements.txt create mode 100644 examples/moviegen/scripts/train_t2i_256x256.sh diff --git a/examples/moviegen/configs/train/moviegen-256x256-t2i.yaml b/examples/moviegen/configs/train/moviegen-256x256-t2i.yaml deleted file mode 100644 index 219d29453b..0000000000 --- a/examples/moviegen/configs/train/moviegen-256x256-t2i.yaml +++ /dev/null @@ -1,27 +0,0 @@ -# model -model_version: llama-1B -batch_size: 64 -checkpoint: "models/PixArt-Sigma-XL-2-256x256.ckpt" -vae_root: "models/vae" -text_encoder_root: "models/text_encoder" -tokenizer_root: "models/tokenizer" -scale_factor: 0.13025 -enable_flash_attention: True -dtype: "bf16" - -# training hyper-parameters -epochs: 100 -scheduler: "constant" -start_learning_rate: 1.0e-4 -optim: "adamw" -weight_decay: 0.1 -loss_scaler_type: "static" -init_loss_scale: 1.0 -gradient_accumulation_steps: 1 -clip_grad: True -max_grad_norm: 1.0 -ckpt_save_interval: 1 -log_loss_interval: 1 -recompute: True -text_drop_prob: 0.2 -warmup_steps: 2000 diff --git a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml new file mode 100644 index 0000000000..3e334cec84 --- /dev/null +++ b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml @@ -0,0 +1,70 @@ +env: + mode: 0 + jit_level: O0 + seed: 42 + distributed: False + debug: False + +model: + name: llama-1B + pretrained_model_path: + enable_flash_attention: True + recompute: True + dtype: bf16 + +vae: + pretrained_model_name_or_path: stabilityai/sd-vae-ft-ema + dtype: fp16 + +dataset: + csv_path: CSV_PATH + video_folder: VIDEO_FOLDER + text_emb_folder: + ul2: UL2_FOLDER + byt5: BYT5_FOLDER + target_size: [ 256, 256 ] + apply_transforms_dataset: True + output_columns: ["video", "ul2_caption", "byt5_caption"] + +dataloader: + batch_size: 64 + shuffle: True + num_workers_dataset: 4 + +train: + epochs: 1000 + output_path: output/moviegen_t2i_256x256 + + lr_scheduler: + name: constant + lr: 1.0e-4 + warmup_steps: 1000 + + optimizer: + name: adamw_re + eps: 1e-15 + betas: [0.9, 0.999] + weight_decay: 0.1 + + loss_scaler: + class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell + init_args: + loss_scale_value: 1 + + ema: + ema_decay: 0.9999 + offloading: True + + settings: + zero_stage: 2 + gradient_accumulation_steps: 1 + clip_grad: True + clip_norm: 1.0 + + save: + ckpt_save_policy: latest_k + ckpt_max_keep: 10 + ckpt_save_interval: 50 + log_interval: 1 + save_ema_only: False + record_lr: False diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py index 607b9b37aa..de71142ec0 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -1,6 +1,7 @@ import logging from typing import Optional, Sequence, Tuple, Union +import numpy as np from moviegen.parallel import ( ColumnParallelLinear, FusedColumnParallelLinear, @@ -26,7 +27,7 @@ class LlamaRMSNorm(nn.Cell): def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: super().__init__() - self.weight = Parameter(mint.ones(hidden_size, dtype=dtype)) + self.weight = Parameter(Tensor(np.ones(hidden_size), dtype=dtype)) self.variance_epsilon = eps def construct(self, hidden_states: Tensor) -> Tensor: @@ -496,22 +497,3 @@ def construct(self, t: Tensor) -> Tensor: t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq.to(self.dtype)) return t_emb - - -class CaptionEmbedder(nn.Cell): - def __init__( - self, - in_channels: int, - hidden_size: int, - eps: float = 1e-6, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__() - self.proj = nn.SequentialCell( - mint.nn.Linear(in_channels, hidden_size, bias=False, dtype=dtype), - LlamaRMSNorm((hidden_size,), eps=eps, dtype=dtype), - ) - - def construct(self, caption: Tensor) -> Tensor: - caption_emb = self.proj(caption) - return caption_emb diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index 293015d9a7..de8c3989a5 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -15,9 +15,9 @@ from mindone.models.utils import normal_, zeros_ +from ..text_encoders import TextProjector from .activation import ACT2FN from .block import ( - CaptionEmbedder, ContextParallelLlamaAttention, ContextParallelLlamaFlashAttention, FusedTensorParallelLlamaMLP, @@ -86,7 +86,7 @@ def __init__( intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype ) - self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size), dtype=dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size) / hidden_size**0.5, dtype=dtype)) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -188,7 +188,7 @@ def __init__( dtype=dtype, ) - self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size), dtype=dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size) / hidden_size**0.5, dtype=dtype)) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -257,7 +257,7 @@ def __init__( self.proj = nn.Dense( hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype ) - self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size), dtype=dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size) / hidden_size**0.5, dtype=dtype)) def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): shift, scale = mint.chunk( @@ -368,8 +368,9 @@ def __init__( ACT2FN[hidden_act], mint.nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=False, dtype=dtype) ) - # TODO: drop this - self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, eps=rms_norm_eps, dtype=dtype) + self.text_projector = TextProjector( + out_features=self.hidden_size, layer_norm=LlamaRMSNorm, norm_eps=self.rms_norm_eps, dtype=dtype + ) if self.model_parallelism: self.group_size = get_group_size(mp_group) @@ -460,10 +461,7 @@ def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor: return output def construct( - self, - latent_embedding: Tensor, - timestep: Tensor, - text_embedding: Tensor, + self, latent_embedding: Tensor, timestep: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor ) -> Tensor: """ latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video) @@ -484,7 +482,7 @@ def construct( modulation_parameters = self.adaLN_modulation(timestep_embedding) # 3.1.4 text embedding - text_embedding = self.caption_embedder(text_embedding) + text_embedding = self.text_projector(ul2_emb, metaclip_emb, byt5_emb) # main blocks hidden_states = latent_embedding diff --git a/examples/moviegen/moviegen/models/text_encoders/text_projector.py b/examples/moviegen/moviegen/models/text_encoders/text_projector.py index 3cf4d7dadd..0920cf5de5 100644 --- a/examples/moviegen/moviegen/models/text_encoders/text_projector.py +++ b/examples/moviegen/moviegen/models/text_encoders/text_projector.py @@ -56,9 +56,9 @@ def _init_weights(module): self.apply(_init_weights) - def construct(self, ul2_text: Tensor, metaclip_text: Tensor, byt5_text: Tensor) -> Tensor: - ul2_hidden_states = self.ul2_projector(ul2_text) - metaclip_hidden_states = self.metaclip_projector(metaclip_text) - byt5_hidden_states = self.byt5_projector(byt5_text) + def construct(self, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor) -> Tensor: + ul2_hidden_states = self.ul2_projector(ul2_emb) + metaclip_hidden_states = self.metaclip_projector(metaclip_emb) + byt5_hidden_states = self.byt5_projector(byt5_emb) return mint.cat((ul2_hidden_states, metaclip_hidden_states, byt5_hidden_states), dim=1) diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index 17ca52b490..a152585e02 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -3,6 +3,7 @@ from mindspore import Tensor, nn, ops from ..schedulers import RFlowLossWrapper +from ..utils.model_utils import no_grad __all__ = ["DiffusionWithLoss"] @@ -42,16 +43,24 @@ def __init__( def get_condition_embeddings(self, text_tokens: Tensor) -> Tensor: if self.text_emb_cached: return text_tokens - text_emb = ops.stop_gradient(self.text_encoder(text_tokens)) + with no_grad(): + text_emb = ops.stop_gradient(self.text_encoder(text_tokens)) return text_emb def get_latents(self, video_tokens: Tensor) -> Tensor: if self.video_emb_cached: return video_tokens - video_emb = ops.stop_gradient(self.vae.encode(video_tokens) * self.scale_factor) + with no_grad(): + b, f, *_ = video_tokens.shape # FIXME: no VideoAutoencoderKL in mindone.differs + video_tokens = video_tokens.reshape(-1, *video_tokens.shape[2:]) + video_emb = ops.stop_gradient(self.vae.encode(video_tokens.astype(self.vae.dtype))[0] * self.scale_factor) + video_emb = video_emb.reshape(b, f, *video_emb.shape[1:]) return video_emb - def construct(self, video_tokens: Tensor, text_tokens: Tensor) -> Tensor: + def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tensor) -> Tensor: latent_embedding = self.get_latents(video_tokens) - text_embedding = self.get_condition_embeddings(text_tokens) - return self.network(latent_embedding, text_embedding) + ul2_emb = self.get_condition_embeddings(ul2_tokens) + byt5_emb = self.get_condition_embeddings(byt5_tokens) + # FIXME: add metaclip + metaclip_emb = ops.ones((byt5_emb.shape[0], 300, 1280), dtype=byt5_emb.dtype) + return self.network(latent_embedding, ul2_emb, metaclip_emb, byt5_emb) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 234a86f63e..541cb8b4e5 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -125,7 +125,9 @@ def _broadcast(self, x: Tensor) -> Tensor: return x return self.broadcast((x,))[0] - def construct(self, x: Tensor, text_embedding: Tensor, timestep: Optional[Tensor] = None) -> Tensor: + def construct( + self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, timestep: Optional[Tensor] = None + ) -> Tensor: """Calculate the training loss for the corresponding timestep. x: (N, T, C, H, W) tensor of inputs (latent representations of video) text_embedding: (N, L, C') tensor of the text embedding @@ -139,13 +141,18 @@ def construct(self, x: Tensor, text_embedding: Tensor, timestep: Optional[Tensor noise = self._broadcast(mint.normal(size=x.shape)) x_t = self.add_noise(x, noise, timestep) - model_output = self.model(x_t.to(self.model.dtype), timestep, text_embedding.to(self.model.dtype)).to( - ms.float32 - ) + model_output = self.model( + x_t.to(self.model.dtype), + timestep, + ul2_emb.to(self.model.dtype), + metaclip_emb.to(self.model.dtype), + byt5_emb.to(self.model.dtype), + ).to(ms.float32) + velocity_pred = mint.chunk(model_output, 2, dim=2)[0] v_t = x - (1 - self.eps) * noise # 3.1.2 Eqa (2) - loss = self.criteria(model_output, v_t) + loss = self.criteria(velocity_pred, v_t) return loss def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py index d980a5b445..afbf2a621d 100644 --- a/examples/moviegen/moviegen/utils/__init__.py +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -1,3 +1,4 @@ from .callback import * +from .ema import * from .misc import * from .model_utils import * diff --git a/examples/moviegen/moviegen/utils/ema.py b/examples/moviegen/moviegen/utils/ema.py new file mode 100644 index 0000000000..9ff3db69f9 --- /dev/null +++ b/examples/moviegen/moviegen/utils/ema.py @@ -0,0 +1,24 @@ +from mindspore.ops import composite as C +from mindspore.ops import functional as F + +from mindone.trainers.ema import EMA as EMA_ + +__all__ = ["EMA"] + +_ema_op = C.MultitypeFuncGraph("grad_ema_op") + + +@_ema_op.register("Number", "Tensor", "Tensor") +def _ema_weights(factor, ema_weight, weight): + return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor)) + + +class EMA(EMA_): + def ema_update(self): + """Update EMA parameters.""" + self.updates += 1 + # update trainable parameters + success = self.hyper_map(F.partial(_ema_op, self.ema_decay), self.ema_weight, self.net_weight) + self.updates = F.depend(self.updates, success) + + return self.updates diff --git a/examples/moviegen/moviegen/utils/misc.py b/examples/moviegen/moviegen/utils/misc.py index 9d1f56984e..4235d3f807 100644 --- a/examples/moviegen/moviegen/utils/misc.py +++ b/examples/moviegen/moviegen/utils/misc.py @@ -1,19 +1,8 @@ -import argparse -import logging -from typing import Tuple - from moviegen.models import llama3_1B, llama3_5B, llama3_30B import mindspore as ms -from mindspore.communication import get_group_size, get_rank, init - -from mindone.utils.seed import set_random_seed - -__all__ = ["MODEL_SPEC", "MODEL_DTYPE", "str2bool", "check_cfgs_in_parser", "init_env"] - - -logger = logging.getLogger(__name__) +__all__ = ["MODEL_SPEC", "MODEL_DTYPE"] MODEL_SPEC = {"llama-1B": llama3_1B, "llama-5B": llama3_5B, "llama-30B": llama3_30B} @@ -22,35 +11,3 @@ "fp16": ms.float16, "bf16": ms.bfloat16, } - - -def str2bool(b: str) -> bool: - if b.lower() not in ["false", "true"]: - raise Exception("Invalid Bool Value") - if b.lower() in ["false"]: - return False - return True - - -def check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser) -> None: - actions_dest = [action.dest for action in parser._actions] - defaults_key = parser._defaults.keys() - for k in cfgs.keys(): - if k not in actions_dest and k not in defaults_key: - raise KeyError(f"{k} does not exist in ArgumentParser!") - - -def init_env(args) -> Tuple[int, int]: - set_random_seed(args.seed) - ms.set_context(mode=args.mode, device_target=args.device_target, jit_config=dict(jit_level=args.jit_level)) - if args.use_parallel: - init() - device_num = get_group_size() - rank_id = get_rank() - ms.set_auto_parallel_context( - parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num - ) - else: - device_num, rank_id = 1, 0 - - return device_num, rank_id diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/moviegen/utils/model_utils.py index d70d940b5d..1852643255 100644 --- a/examples/moviegen/moviegen/utils/model_utils.py +++ b/examples/moviegen/moviegen/utils/model_utils.py @@ -1,10 +1,10 @@ import logging -from typing import Dict, Tuple, Union +from typing import Dict, Union import mindspore as ms -import mindspore.nn as nn +from mindspore import _no_grad, jit_class, nn -__all__ = ["load_ckpt_params", "count_params"] +__all__ = ["load_ckpt_params", "no_grad"] logger = logging.getLogger(__name__) @@ -26,7 +26,20 @@ def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: return model -def count_params(model: nn.Cell) -> Tuple[int, int]: - total_params = sum([param.size for param in model.get_parameters()]) - trainable_params = sum([param.size for param in model.trainable_params()]) - return total_params, trainable_params +@jit_class +class no_grad(_no_grad): + """ + A context manager that suppresses gradient memory allocation in PyNative mode. + """ + + def __init__(self): + super().__init__() + self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE + + def __enter__(self): + if self._pynative: + super().__enter__() + + def __exit__(self, *args): + if self._pynative: + super().__exit__(*args) diff --git a/examples/moviegen/requirements.txt b/examples/moviegen/requirements.txt new file mode 100644 index 0000000000..3dc0560cd1 --- /dev/null +++ b/examples/moviegen/requirements.txt @@ -0,0 +1 @@ +jsonargparse[signatures,omegaconf,urls]>=4.33.0 diff --git a/examples/moviegen/scripts/train_t2i_256x256.sh b/examples/moviegen/scripts/train_t2i_256x256.sh new file mode 100644 index 0000000000..07c74b4e18 --- /dev/null +++ b/examples/moviegen/scripts/train_t2i_256x256.sh @@ -0,0 +1,26 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# improve data loading performance for distributed training: 1 +export MS_ENABLE_NUMA=0 +# plot memory usage, feature/model: 1 +export MS_MEMORY_STATISTIC=0 +export MS_DATASET_SINK_QUEUE=4 + +# log level +export GLOG_v=2 + +output_dir=output/moviegen_t2i_256x256/$(date +"%Y.%m.%d-%H.%M.%S") + +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +python train.py \ + --config configs/train/moviegen_t2i_256x256.yaml \ + --env.mode 0 \ + --env.jit_level O0 \ + --env.max_device_memory 59GB \ + --env.distributed=True \ + --model.name llama-1B \ + --dataset.csv_path CSV_PATH \ + --dataset.video_folder VIDEO_FOLDER \ + --dataset.text_emb_folder.ul2 UL2_FOLDER \ + --dataset.text_emb_folder.byt5 BYT5_FOLDER \ + --train.output_path=$output_dir \ + --train.ema "" # turn off ema diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 582ca635f6..1c600e79b9 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -1,299 +1,212 @@ -#!/usr/bin/env python -import argparse import logging import os import sys +from typing import Literal, Optional -import yaml +from jsonargparse import ActionConfigFile, ArgumentParser +from jsonargparse.typing import Path_fr, path_type -import mindspore as ms -import mindspore.nn as nn -from mindspore import Model -from mindspore.dataset import GeneratorDataset +from mindspore import Model, nn +from mindspore.train.callback import TimeMonitor # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.insert(0, mindone_lib_path) -from moviegen.dataset import ImageDataset +from moviegen.dataset import ImageVideoDataset +from moviegen.models.llama import LlamaModel from moviegen.pipelines import DiffusionWithLoss from moviegen.schedulers import RFlowLossWrapper -from moviegen.utils import ( - MODEL_DTYPE, - MODEL_SPEC, - LossMonitor, - SaveCkptCallback, - TimeMonitor, - check_cfgs_in_parser, - count_params, - init_env, - load_ckpt_params, - str2bool, -) -from transformers import AutoTokenizer +from moviegen.utils import EMA, MODEL_DTYPE, MODEL_SPEC, load_ckpt_params +from mindone.data import create_dataloader from mindone.diffusers import AutoencoderKL -from mindone.trainers.optim import create_optimizer -from mindone.trainers.train_step import TrainOneStepWrapper -from mindone.transformers import T5EncoderModel -from mindone.utils.logger import set_logger +from mindone.trainers import create_optimizer, create_scheduler +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor +from mindone.trainers.zero import prepare_train_network +from mindone.utils import count_params, init_train_env, set_logger logger = logging.getLogger(__name__) - -def parse_args(): - parser = argparse.ArgumentParser( - description="Movie-Gen Training script", formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "-c", - "--config", - help="Path to load a config yaml file that describes the setting which will override the default arguments.", - ) - parser.add_argument("--json_path", required=True, help="path to json annotation file.") - parser.add_argument("--model_version", default="llama-1B", choices=["llama-1B", "llama-5B", "llama-30B"]) - parser.add_argument("--image_dir", required=True, help="Directory storing the image directory.") - parser.add_argument("--output_path", default="./output", help="Output directory to save the training result.") - - parser.add_argument("--batch_size", default=64, type=int, help="Training batch size.") - parser.add_argument("--num_parallel_workers", default=4, type=int, help="Number of workers for data loading.") - parser.add_argument("--checkpoint", default="", help="The path to the PixArt checkpoint.") - parser.add_argument("--vae_root", default="models/vae", help="Path storing the VAE checkpoint and configure file.") - parser.add_argument( - "--tokenizer_root", default="models/tokenizer", help="Path storing the T5 checkpoint and configure file." - ) - parser.add_argument( - "--text_encoder_root", default="models/text_encoder", help="Path storing the T5 tokenizer and configure file." - ) - parser.add_argument("--t5_max_length", default=300, type=int, help="T5's embedded sequence length.") - parser.add_argument( - "--scale_factor", default=0.13025, type=float, help="VAE scale factor of Stable Diffusion network." - ) - parser.add_argument( - "--text_drop_prob", - default=0.2, - type=float, - help="The probability of using drop text label", - ) - - parser.add_argument("--device_target", default="Ascend", choices=["Ascend"], help="Device target.") - parser.add_argument("--mode", default=0, type=int, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).") - parser.add_argument("--jit_level", default="O0", choices=["O0", "O1"], help="Jit Level") - parser.add_argument("--seed", default=42, type=int, help="Training seed.") - - parser.add_argument( - "--enable_flash_attention", default=True, type=str2bool, help="whether to enable flash attention." - ) - parser.add_argument( - "--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="what data type to use for network." - ) - parser.add_argument("--scheduler", default="constant", choices=["constant"], help="LR scheduler.") - parser.add_argument("--start_learning_rate", default=1e-4, type=float, help="The learning rate.") - parser.add_argument("--warmup_steps", default=1000, type=int, help="Warmup steps.") - parser.add_argument("--epochs", default=200, type=int, help="Number of total training epochs.") - parser.add_argument("--optim", default="adamw", type=str, choices=["adamw"], help="Optimizer name.") - parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight decay.") - parser.add_argument( - "--loss_scaler_type", - default="static", - choices=["static", "dynamic"], - help="Use dynamic or static loss scaler.", - ) - parser.add_argument("--init_loss_scale", default=1.0, type=float, help="Loss scale.") - parser.add_argument("--scale_window", default=1000, type=int, help="Loss scale window.") - parser.add_argument("--loss_scale_factor", default=2.0, type=float, help="Loss scale factor.") - parser.add_argument("--use_ema", default=False, type=str2bool, help="Whether to use EMA") - parser.add_argument("--ema_rate", default=0.9999, type=float, help="EMA Rate.") - parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="Drop overflow update.") - parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Gradient accumulation steps.") - parser.add_argument("--clip_grad", default=True, type=str2bool, help="Whether apply gradient clipping.") - parser.add_argument( - "--max_grad_norm", - default=1.0, - type=float, - help="Max gradient norm for clipping, effective when `clip_grad` enabled.", - ) - parser.add_argument("--ckpt_max_keep", default=3, type=int, help="Maximum number of checkpoints to keep") - parser.add_argument("--ckpt_save_interval", default=1, type=int, help="Save checkpoint every this epochs or steps.") - parser.add_argument("--log_loss_interval", default=1, type=int, help="Log interval of loss value.") - parser.add_argument("--recompute", default=False, type=str2bool, help="Use recompute during training.") - parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel training.") - default_args = parser.parse_args() - abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) - if default_args.config: - logger.info(f"Overwrite default arguments with configuration file {default_args.config}") - default_args.config = os.path.join(abs_path, default_args.config) - with open(default_args.config, "r") as f: - cfg = yaml.safe_load(f) - check_cfgs_in_parser(cfg, parser) - parser.set_defaults(**cfg) - args = parser.parse_args() - return args +Path_dcc = path_type("dcc") # path to a directory that can be created if it does not exist + + +def init_model( + name: Literal["llama-1B", "llama-5B", "llama-30B"], + pretrained_model_path: Optional[Path_fr] = None, + enable_flash_attention: bool = True, + recompute: bool = False, + dtype: Literal["fp32", "fp16", "bf16"] = "fp32", +) -> LlamaModel: + attn_implementation = "flash_attention" if enable_flash_attention else "eager" + model = MODEL_SPEC[name]( + in_channels=4, + out_channels=8, + attn_implementation=attn_implementation, + gradient_checkpointing=recompute, + dtype=MODEL_DTYPE[dtype], + ) + if pretrained_model_path: + model = load_ckpt_params(model, pretrained_model_path) + else: + logger.info("Initialize network randomly.") + return model def main(args): - if not os.path.isdir(args.output_path): - os.makedirs(args.output_path) - # 1. init env - device_num, rank_id = init_env(args) - set_logger(output_dir=os.path.join(args.output_path, "logs"), rank=rank_id) - - # 2. model initialize and weight loading - # 2.1 PixArt - image_size = args.sample_size * 8 - logger.info(f"{image_size}x{image_size} init") - - attn_implementation = "flash_attention" if args.enable_flash_attention else "eager" + args.train.output_path = os.path.join(__dir__, args.train.output_path.relative) + os.makedirs(args.train.output_path, exist_ok=True) + device_id, rank_id, device_num = init_train_env(**args.env) + set_logger("", output_dir=args.train.output_path, rank=rank_id) - network = MODEL_SPEC[args.model_version]( - gradient_checkpointing=args.recompute, attn_implementation=attn_implementation, dtype=MODEL_DTYPE[args.dtype] - ) + # instantiate classes only after initializing training environment + initializer = parser.instantiate_classes(cfg) - if args.checkpoint: - network = load_ckpt_params(network, args.checkpoint) - else: - logger.info("Initialize network randomly.") + # 2. model initialize and weight loading + # 2.1 Llama 3 + network = init_model(**args.model) # 2.2 VAE logger.info("vae init") - vae = AutoencoderKL.from_pretrained(args.vae_root, mindspore_dtype=MODEL_DTYPE[args.dtype]) - - # 2.3 T5 - logger.info("text encoder init") - text_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_root, model_max_length=args.t5_max_length) - text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_root, mindspore_dtype=MODEL_DTYPE[args.dtype]) + # TODO: add support of training with latents + vae_args = args.vae.as_dict() + vae_args["mindspore_dtype"] = MODEL_DTYPE[vae_args.pop("dtype")] # Replace non-standard key + vae = AutoencoderKL.from_pretrained(**vae_args, local_files_only=True) # 2.4 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) # 3. build training network - latent_diffusion_with_loss = DiffusionWithLoss( - rflow_loss_wrapper, vae, text_encoder, scale_factor=args.scale_factor - ) + latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, vae, scale_factor=vae.config.scaling_factor) # 4. build dataset - dataset = ImageDataset( - args.json_path, - args.image_dir, - image_size, - text_tokenizer, - text_drop_prob=args.text_drop_prob, + dataset = ImageVideoDataset(**args.dataset) + transforms = ( + dataset.train_transforms(args.dataset.target_size) if not args.dataset.apply_transforms_dataset else None ) - data_generator = GeneratorDataset( - dataset, - column_names=["image", "text"], - column_types=[ms.float32, ms.int64], - shuffle=True, - num_parallel_workers=args.num_parallel_workers, - num_shards=device_num, - shard_id=rank_id, - max_rowsize=-1, + dataloader = create_dataloader( + dataset, transforms=transforms, device_num=device_num, rank_id=rank_id, **args.dataloader ) - data_generator = data_generator.batch(args.batch_size, drop_remainder=True) # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR - lr = nn.WarmUpLR(learning_rate=args.start_learning_rate, warmup_steps=args.warmup_steps) + lr = create_scheduler(steps_per_epoch=dataloader.get_dataset_size(), **args.train.lr_scheduler) # 5.2 optimizer - optim = "adamw_re" if args.optim == "adamw" else args.optim - eps = args.adamw_eps if args.optim == "adamw" else args.came_eps - betas = None if args.optim == "adamw" else args.came_betas - optimizer = create_optimizer( - latent_diffusion_with_loss.trainable_params(), - name=optim, - lr=lr, - weight_decay=args.weight_decay, - betas=betas, - eps=eps, - ) + optimizer = create_optimizer(latent_diffusion_with_loss.trainable_params(), lr=lr, **args.train.optimizer) - if args.loss_scaler_type == "dynamic": - loss_scaler = nn.DynamicLossScaleUpdateCell( - loss_scale_value=args.init_loss_scale, scale_factor=args.loss_scale_factor, scale_window=args.scale_window - ) - else: - loss_scaler = nn.FixedLossScaleUpdateCell(args.init_loss_scale) + loss_scaler = initializer.train.loss_scaler # 5.3 trainer (standalone and distributed) - if args.use_ema: - raise NotImplementedError("`EMA` does not support yet.") - # ema = EMA(latent_diffusion_with_loss.network, ema_decay=args.ema_rate) - else: - ema = None - - net_with_grads = TrainOneStepWrapper( - latent_diffusion_with_loss, - optimizer=optimizer, - scale_sense=loss_scaler, - drop_overflow_update=args.drop_overflow_update, - gradient_accumulation_steps=args.gradient_accumulation_steps, - clip_grad=args.clip_grad, - clip_norm=args.max_grad_norm, - ema=ema, + ema = EMA(latent_diffusion_with_loss.network, **args.train.ema.init_args) if args.train.ema else None + net_with_grads = prepare_train_network( + latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings ) model = Model(net_with_grads) # 5.4 callbacks - callbacks = [ - TimeMonitor(), - LossMonitor(log_interval=args.log_loss_interval), - SaveCkptCallback( - output_dir=os.path.join(args.output_path, "ckpt"), - ckpt_max_keep=args.ckpt_max_keep, - ckpt_save_interval=args.ckpt_save_interval, - save_ema=args.use_ema, - rank_id=rank_id, - ), - ] + callbacks = [OverflowMonitor()] if rank_id == 0: + callbacks.extend( + [ + TimeMonitor(args.train.save.log_interval), + EvalSaveCallback( + network=latent_diffusion_with_loss.network, + model_name=args.model.name, + rank_id=rank_id, + ckpt_save_dir=os.path.join(args.train.output_path, "ckpt"), + ema=ema, + **args.train.save, + ), + ] + ) num_params_vae, num_params_trainable_vae = count_params(vae) num_params_network, num_params_trainable_network = count_params(network) - num_params_text_encoder, num_params_trainable_text_encoder = count_params(text_encoder) - num_params = num_params_vae + num_params_network + num_params_text_encoder - num_params_trainable = ( - num_params_trainable_vae + num_params_trainable_network + num_params_trainable_text_encoder - ) + num_params = num_params_vae + num_params_network + num_params_trainable = num_params_trainable_vae + num_params_trainable_network key_info = "Key Settings:\n" + "=" * 50 + "\n" key_info += "\n".join( [ - f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", - f"JIT level: {args.jit_level}", - f"Distributed mode: {args.use_parallel}", - f"Data path: {args.json_path}", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}", + f"JIT level: {args.env.jit_level}", + f"Distributed mode: {args.env.distributed}", + f"Data path: {args.dataset.csv_path}", f"Number of samples: {len(dataset)}", - f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,}, text_encoder: {num_params_text_encoder:,})", + f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,})", f"Num trainable params: {num_params_trainable:,}", - f"Model type: {args.dtype}", - f"Learning rate: {args.start_learning_rate:.7f}", - f"Batch size: {args.batch_size}", - f"Image size: {image_size}", - f"Weight decay: {args.weight_decay}", - f"Grad accumulation steps: {args.gradient_accumulation_steps}", - f"Num epochs: {args.epochs}", - f"Loss scaler: {args.loss_scaler_type}", - f"Init loss scale: {args.init_loss_scale}", - f"Grad clipping: {args.clip_grad}", - f"Max grad norm: {args.max_grad_norm}", - f"EMA: {args.use_ema}", - f"Enable flash attention: {args.enable_flash_attention}", + f"Model dtype: {args.model.dtype}", + f"VAE dtype: {args.vae.dtype}", + f"Learning rate: {args.train.lr_scheduler.lr:.0e}", + f"Batch size: {args.dataloader.batch_size}", + f"Image size: {args.dataset.target_size}", + f"Frames: {args.dataset.sample_n_frames}", + f"Weight decay: {args.train.optimizer.weight_decay}", + f"Grad accumulation steps: {args.train.settings.gradient_accumulation_steps}", + f"Num epochs: {args.train.epochs}", + f"Loss scaler: {args.train.loss_scaler.class_path}", + f"Init loss scale: {args.train.loss_scaler.init_args.loss_scale_value}", + f"Grad clipping: {args.train.settings.clip_grad}", + f"Max grad norm: {args.train.settings.clip_norm}", + f"EMA: {ema is not None}", + f"Enable flash attention: {args.model.enable_flash_attention}", ] ) key_info += "\n" + "=" * 50 print(key_info) - - with open(os.path.join(args.output_path, "args.yaml"), "w") as f: - yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + parser.save(args, args.train.output_path + "/config.yaml", format="yaml", overwrite=True) # 6. train logger.info("Start training...") - model.train(args.epochs, data_generator, callbacks=callbacks) + model.train(args.train.epochs, dataloader, callbacks=callbacks) if __name__ == "__main__": - args = parse_args() - main(args) + parser = ArgumentParser(description="Movie Gen training script.") + parser.add_argument( + "-c", + "--config", + action=ActionConfigFile, + help="Path to load a config yaml file that describes the setting which will override the default arguments.", + ) + parser.add_function_arguments(init_train_env, "env") + parser.add_function_arguments(init_model, "model") + parser.add_method_arguments(AutoencoderKL, "from_pretrained", "vae", fail_untyped=True) + parser.add_argument( + "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." + ) + parser.add_class_arguments( + ImageVideoDataset, "dataset", skip={"frames_mask_generator", "t_compress_func"}, instantiate=False + ) + parser.add_function_arguments( + create_dataloader, "dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} + ) + parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") + parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch"}) + parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"}) + parser.add_subclass_arguments( + nn.Cell, + "train.loss_scaler", + fail_untyped=False, # no typing in mindspore + help="mindspore.nn.FixedLossScaleUpdateCell or mindspore.nn.DynamicLossScaleUpdateCell", + ) + parser.add_function_arguments( + prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema"} + ) + parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False) + parser.add_argument( + "--train.output_path", default="output/", type=Path_dcc, help="Output directory to save training results." + ) + parser.add_argument("--train.epochs", default=10, type=int, help="Number of epochs to train. Default: 100.") + parser.link_arguments("train.epochs", "train.lr_scheduler.num_epochs", apply_on="parse") + parser.add_class_arguments( + EvalSaveCallback, + "train.save", + skip={"network", "rank_id", "ckpt_save_dir", "output_dir", "ema", "start_epoch", "model_name"}, + instantiate=False, + ) + cfg = parser.parse_args() + main(cfg) diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py index 0eeb1883b3..a2e9f559ee 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py @@ -2,8 +2,6 @@ from typing import Literal, Optional from moviegen import llama3_1B, llama3_5B, llama3_30B -from moviegen.models import TextProjector -from moviegen.models.llama.block import LlamaRMSNorm import mindspore.nn as nn import mindspore.ops as ops @@ -33,13 +31,6 @@ def __init__(self, model_size: Literal["1B", "5B", "30B"] = "1B", **kwargs): else: self.llama = llama3_30B(**model_kwargs) - self.text_projector = TextProjector( - out_features=self.llama.hidden_size, - layer_norm=LlamaRMSNorm, - norm_eps=self.llama.rms_norm_eps, - dtype=self.llama.dtype, - ) - self.patch_size = self.llama.patch_size self.hidden_size = self.llama.hidden_size self.num_heads = self.llama.num_attention_heads @@ -62,16 +53,12 @@ def construct( ) -> Tensor: x = ops.transpose(x, (0, 2, 1, 3, 4)) - if extra_text_embed1 is not None: - y = ops.squeeze(y, axis=1) - # FIXME: placeholder for MetaCLIP - metaclip_text_embed = ops.ones((extra_text_embed1.shape[0], 100, 1280), dtype=extra_text_embed1.dtype) - text_embedding = self.text_projector(y, metaclip_text_embed, extra_text_embed1) - else: - text_embedding = ops.squeeze(y, axis=1) + ul2_emb = ops.squeeze(y, axis=1) + metaclip_emb = ops.ones((extra_text_embed1.shape[0], 100, 1280), dtype=extra_text_embed1.dtype) + byt5_emb = extra_text_embed1 latent_embedding = x - output = self.llama(latent_embedding, timestep, text_embedding) + output = self.llama(latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb) output = ops.transpose(output, (0, 2, 1, 3, 4)) return output diff --git a/mindone/data/loader.py b/mindone/data/loader.py index dd394f5193..6032141d30 100644 --- a/mindone/data/loader.py +++ b/mindone/data/loader.py @@ -3,6 +3,7 @@ import mindspore as ms from mindspore.communication import get_local_rank, get_local_rank_size +from ..utils.version_control import MS_VERSION from .dataset import BaseDataset @@ -89,7 +90,7 @@ def create_dataloader( **transform, python_multiprocessing=python_multiprocessing, num_parallel_workers=num_workers, - max_rowsize=max_rowsize, + max_rowsize=max_rowsize if MS_VERSION < "2.3" else -1, # MS 2.3 and above: allocate memory dynamically ) if project_columns: diff --git a/mindone/trainers/ema.py b/mindone/trainers/ema.py index 45fd95c2b5..ca702e5c47 100644 --- a/mindone/trainers/ema.py +++ b/mindone/trainers/ema.py @@ -3,6 +3,8 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F +__all__ = ["EMA"] + _ema_op = C.MultitypeFuncGraph("grad_ema_op") @@ -18,7 +20,14 @@ class EMA(nn.Cell): offloading: if True, offload the assign computation to CPU to avoid OOM issue. """ - def __init__(self, network, ema_decay=0.9999, updates=0, trainable_only=True, offloading=True): + def __init__( + self, + network: nn.Cell, + ema_decay: float = 0.9999, + updates: int = 0, + trainable_only: bool = True, + offloading: bool = True, + ): super().__init__() # TODO: net.trainable_params() is more reasonable? if trainable_only: diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index 519347948c..385c26f7b9 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -1,4 +1,5 @@ """Train step wrapper supporting setting drop overflow update, ema etc""" +from typing import Optional from packaging import version @@ -13,6 +14,8 @@ from mindspore.ops import functional as F from mindspore.ops import operations as P +from .ema import EMA + _grad_scale = C.MultitypeFuncGraph("grad_scale") reciprocal = P.Reciprocal() _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") @@ -55,7 +58,7 @@ def __init__( network, optimizer, scale_sense=1.0, - ema=None, + ema: Optional[EMA] = None, updates=0, drop_overflow_update=True, gradient_accumulation_steps=1, diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 8d5e1c5353..7f5c0adaff 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -1,6 +1,7 @@ import json import logging import os +from typing import Literal import mindspore as ms from mindspore import nn, ops @@ -554,7 +555,7 @@ def prepare_train_network( clip_grad: bool = False, clip_norm: float = 1.0, verbose: bool = False, - zero_stage: int = 0, + zero_stage: Literal[0, 1, 2, 3] = 0, optimizer_offload: bool = False, op_group: str = None, dp_group: str = None, diff --git a/mindone/utils/env.py b/mindone/utils/env.py index 12c51a8204..b459898d06 100644 --- a/mindone/utils/env.py +++ b/mindone/utils/env.py @@ -10,6 +10,8 @@ import mindspore as ms from mindspore.communication import get_group_size, get_rank, init +from .version_control import MS_VERSION + _logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ def init_train_env( cache_path: str = "./cache", distributed: bool = False, ascend_config: Optional[dict] = None, + jit_level: Optional[Literal["O0", "O1", "O2"]] = None, enable_modelarts: bool = False, max_device_memory: str = None, num_workers: int = 1, @@ -41,6 +44,8 @@ def init_train_env( cache_path: The path to save or load the saved computation graph. distributed: Whether to enable distributed training. Default is False. ascend_config: Parameters specific to the Ascend hardware platform. + jit_level: The compilation optimization level. Options: "O0", "O1", "O2". + Default is None and the level selected based on the device. enable_modelarts: Whether to enable modelarts (OpenI) support. Default is False. max_device_memory (str, default: None): The maximum amount of memory that can be allocated on the Ascend device. num_workers: The number of modelarts workers. Used only when `enable_modelarts` is True. Default is 1. @@ -58,9 +63,18 @@ def init_train_env( mode = ms.PYNATIVE_MODE if max_device_memory is not None: ms.set_context(max_device_memory=max_device_memory) + if jit_level: + if MS_VERSION < "2.3": + _logger.warning("Compilation optimization (JIT Level) is supported only in MindSpore 2.3 or later.") + else: + ms.set_context(jit_config={"jit_level": jit_level}) + if distributed: - device_id = int(os.getenv("DEVICE_ID")) - ms.set_context(mode=mode, device_target=device_target, device_id=device_id, ascend_config=ascend_config or {}) + device_id, kwargs = None, {} # if no rank table + if os.getenv("DEVICE_ID"): + device_id = int(os.getenv("DEVICE_ID")) + kwargs = {"device_id": int(os.getenv("DEVICE_ID"))} + ms.set_context(mode=mode, device_target=device_target, ascend_config=ascend_config or {}, **kwargs) init() device_num = get_group_size() rank_id = get_rank() From 52822b04b9659698a99fb09413f92dff74eeff1d Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:06:03 +0800 Subject: [PATCH 023/122] add OSv1.2 VAE --- .../configs/train/moviegen_t2i_256x256.yaml | 2 +- .../moviegen/pipelines/train_pipeline.py | 12 +++++------- examples/moviegen/train.py | 19 +++++++++++++------ .../opensora_hpcai/opensora/models/vae/vae.py | 2 +- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml index 3e334cec84..6a0bb886e6 100644 --- a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml @@ -13,7 +13,7 @@ model: dtype: bf16 vae: - pretrained_model_name_or_path: stabilityai/sd-vae-ft-ema + ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt dtype: fp16 dataset: diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index a152585e02..c3477fa514 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -1,6 +1,6 @@ from typing import Optional -from mindspore import Tensor, nn, ops +from mindspore import Tensor, nn, ops, float32 from ..schedulers import RFlowLossWrapper from ..utils.model_utils import no_grad @@ -14,7 +14,6 @@ def __init__( network: RFlowLossWrapper, vae: Optional[nn.Cell] = None, text_encoder: Optional[nn.Cell] = None, - scale_factor: float = 0.13025, text_emb_cached: bool = True, video_emb_cached: bool = False, ): @@ -28,7 +27,6 @@ def __init__( self.network = network self.vae = vae self.text_encoder = text_encoder - self.scale_factor = scale_factor self.text_emb_cached = text_emb_cached self.video_emb_cached = video_emb_cached @@ -51,10 +49,10 @@ def get_latents(self, video_tokens: Tensor) -> Tensor: if self.video_emb_cached: return video_tokens with no_grad(): - b, f, *_ = video_tokens.shape # FIXME: no VideoAutoencoderKL in mindone.differs - video_tokens = video_tokens.reshape(-1, *video_tokens.shape[2:]) - video_emb = ops.stop_gradient(self.vae.encode(video_tokens.astype(self.vae.dtype))[0] * self.scale_factor) - video_emb = video_emb.reshape(b, f, *video_emb.shape[1:]) + # (b c f h w) shape is expected. FIXME: remove this redundancy + video_tokens = ops.transpose(video_tokens, (0, 2, 1, 3, 4)) + video_emb = ops.stop_gradient(self.vae.encode(video_tokens)).astype(float32) + video_emb = ops.transpose(video_emb, (0, 2, 1, 3, 4)) # FIXME return video_emb def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tensor) -> Tensor: diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 1c600e79b9..7946ebeefe 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -6,7 +6,7 @@ from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import Path_fr, path_type -from mindspore import Model, nn +from mindspore import Model, amp, nn from mindspore.train.callback import TimeMonitor # TODO: remove in future when mindone is ready for install @@ -21,12 +21,15 @@ from moviegen.utils import EMA, MODEL_DTYPE, MODEL_SPEC, load_ckpt_params from mindone.data import create_dataloader -from mindone.diffusers import AutoencoderKL from mindone.trainers import create_optimizer, create_scheduler from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger +# TODO: remove when VAE is added to the project +sys.path.append(os.path.join(__dir__, "../opensora_hpcai/")) +from opensora.models.vae.vae import OpenSoraVAE_V1_2 + logger = logging.getLogger(__name__) Path_dcc = path_type("dcc") # path to a directory that can be created if it does not exist @@ -72,14 +75,18 @@ def main(args): logger.info("vae init") # TODO: add support of training with latents vae_args = args.vae.as_dict() - vae_args["mindspore_dtype"] = MODEL_DTYPE[vae_args.pop("dtype")] # Replace non-standard key - vae = AutoencoderKL.from_pretrained(**vae_args, local_files_only=True) + vae_dtype = vae_args.pop("dtype") + vae = OpenSoraVAE_V1_2(**vae_args).set_train(False) + if vae_dtype != "fp32": + vae_dtype = MODEL_DTYPE[vae_dtype] + # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative + amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=vae_dtype) # 2.4 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) # 3. build training network - latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, vae, scale_factor=vae.config.scaling_factor) + latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, vae) # 4. build dataset dataset = ImageVideoDataset(**args.dataset) @@ -174,7 +181,7 @@ def main(args): ) parser.add_function_arguments(init_train_env, "env") parser.add_function_arguments(init_model, "model") - parser.add_method_arguments(AutoencoderKL, "from_pretrained", "vae", fail_untyped=True) + parser.add_function_arguments(OpenSoraVAE_V1_2, "vae", fail_untyped=False) parser.add_argument( "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." ) diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index d846d2fdac..676b75aa7f 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -13,7 +13,7 @@ from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD from .vae_temporal import VAE_Temporal_SD # noqa: F401 -__all__ = ["AutoencoderKL"] +__all__ = ["AutoencoderKL", "OpenSoraVAE_V1_2"] _logger = logging.getLogger(__name__) From d3ae9e3582901a8bbabc95722f75bfc9505877c6 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:10:48 +0800 Subject: [PATCH 024/122] fixes --- .../configs/train/moviegen_t2i_256x256.yaml | 7 +- examples/moviegen/moviegen/dataset/dataset.py | 29 +--- .../moviegen/moviegen/models/llama/network.py | 6 +- .../moviegen/pipelines/train_pipeline.py | 14 +- .../moviegen/schedulers/rectified_flow.py | 12 +- examples/moviegen/moviegen/utils/__init__.py | 1 - examples/moviegen/moviegen/utils/callback.py | 137 ------------------ .../moviegen/scripts/train_t2i_256x256.sh | 7 +- examples/moviegen/train.py | 29 ++-- 9 files changed, 46 insertions(+), 196 deletions(-) delete mode 100644 examples/moviegen/moviegen/utils/callback.py diff --git a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml index 6a0bb886e6..797f917dfa 100644 --- a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml @@ -36,9 +36,10 @@ train: output_path: output/moviegen_t2i_256x256 lr_scheduler: - name: constant - lr: 1.0e-4 - warmup_steps: 1000 + class_path: mindspore.nn.WarmUpLR + init_args: + learning_rate: 1.0e-4 + warmup_steps: 1000 optimizer: name: adamw_re diff --git a/examples/moviegen/moviegen/dataset/dataset.py b/examples/moviegen/moviegen/dataset/dataset.py index 9e8ed8eb64..5915acfc79 100644 --- a/examples/moviegen/moviegen/dataset/dataset.py +++ b/examples/moviegen/moviegen/dataset/dataset.py @@ -2,42 +2,25 @@ import logging import os import random -import sys from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import cv2 import numpy as np from tqdm import tqdm -from mindspore.dataset.transforms import Compose - +from mindone.data import BaseDataset from mindone.data.video_reader import VideoReader from .transforms import ResizeCrop -# FIXME: remove in future when mindone is ready for install -sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) -from mindone.data import BaseDataset - _logger = logging.getLogger(__name__) IMAGE_EXT = (".jpg", ".jpeg", ".png", ".gif", ".webp") -def create_infer_transforms(target_size: Tuple[int, int], interpolation=cv2.INTER_LINEAR): - return Compose( - [ - ResizeCrop(target_size, interpolation=interpolation), - lambda x: x.astype(np.float32) / 127.5 - 1, - lambda x: x[None, ...] if x.ndim == 3 else x, # if image - lambda x: np.transpose(x, (0, 3, 1, 2)), - ] - ) - - class ImageVideoDataset(BaseDataset): def __init__( self, @@ -140,7 +123,7 @@ def _filter_data(sample_): _logger.info(f"Number of data samples: {len(data)}") return data - def _get_replacement(self, max_attempts: int = 100) -> Tuple[Any, ...]: + def _get_replacement(self, max_attempts: int = 100) -> Tuple[np.ndarray, ...]: attempts, error = min(max_attempts, len(self)), None for idx in range(attempts): try: @@ -151,7 +134,7 @@ def _get_replacement(self, max_attempts: int = 100) -> Tuple[Any, ...]: raise RuntimeError(f"Fail to load a replacement sample in {attempts} attempts. Error: {repr(error)}") - def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tuple[Any, ...]: + def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tuple[np.ndarray, ...]: data = self._data[idx].copy() num_frames = self._frames @@ -214,11 +197,11 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup return tuple(data[c] for c in self.output_columns) - def get_bucket(self, thw: Tuple[int, int, int], sample_ids: List[int]) -> Tuple[Any, ...]: + def get_bucket(self, thw: Tuple[int, int, int], sample_ids: List[int]) -> Tuple[np.ndarray, ...]: batch = [self._get_item(sample_id, thw) for sample_id in sample_ids] return tuple(np.stack(item) for item in map(list, zip(*batch))) - def __getitem__(self, idx: int) -> Tuple[Any, ...]: + def __getitem__(self, idx: int) -> Tuple[np.ndarray, ...]: try: sample = self._get_item(idx) if self._require_update_prev: diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index de8c3989a5..ce66e554a0 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -465,8 +465,10 @@ def construct( ) -> Tensor: """ latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video) - timestep: (N,) tensor to indicate denoising step - text_embedding: (N, L, C') tensor of the text embedding + timestep: (N,) tensor to indicate a denoising step + ul2_emb: (N, L1, 4096) UL2 text embeddings + metaclip_emb: (N, L2, 1280) MetaCLIP text embeddings + byt5_emb: (N, L3, 1472) ByT5 text embeddings """ _, t, _, h, w = latent_embedding.shape diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index c3477fa514..8fe6e3f0a0 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -1,6 +1,7 @@ from typing import Optional -from mindspore import Tensor, nn, ops, float32 +import mindspore as ms +from mindspore import Tensor, mint, nn, ops from ..schedulers import RFlowLossWrapper from ..utils.model_utils import no_grad @@ -14,6 +15,7 @@ def __init__( network: RFlowLossWrapper, vae: Optional[nn.Cell] = None, text_encoder: Optional[nn.Cell] = None, + scale_factor: float = 0.13025, text_emb_cached: bool = True, video_emb_cached: bool = False, ): @@ -27,6 +29,7 @@ def __init__( self.network = network self.vae = vae self.text_encoder = text_encoder + self.scale_factor = scale_factor self.text_emb_cached = text_emb_cached self.video_emb_cached = video_emb_cached @@ -50,9 +53,10 @@ def get_latents(self, video_tokens: Tensor) -> Tensor: return video_tokens with no_grad(): # (b c f h w) shape is expected. FIXME: remove this redundancy - video_tokens = ops.transpose(video_tokens, (0, 2, 1, 3, 4)) - video_emb = ops.stop_gradient(self.vae.encode(video_tokens)).astype(float32) - video_emb = ops.transpose(video_emb, (0, 2, 1, 3, 4)) # FIXME + video_tokens = mint.permute(video_tokens, (0, 2, 1, 3, 4)) + # FIXME: extract scale_factor from VAE and use it here + video_emb = ops.stop_gradient(self.vae.encode(video_tokens)).to(ms.float32) + video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME return video_emb def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tensor) -> Tensor: @@ -60,5 +64,5 @@ def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tenso ul2_emb = self.get_condition_embeddings(ul2_tokens) byt5_emb = self.get_condition_embeddings(byt5_tokens) # FIXME: add metaclip - metaclip_emb = ops.ones((byt5_emb.shape[0], 300, 1280), dtype=byt5_emb.dtype) + metaclip_emb = mint.ones((byt5_emb.shape[0], 300, 1280), dtype=byt5_emb.dtype) return self.network(latent_embedding, ul2_emb, metaclip_emb, byt5_emb) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 541cb8b4e5..2c58196237 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -128,10 +128,13 @@ def _broadcast(self, x: Tensor) -> Tensor: def construct( self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, timestep: Optional[Tensor] = None ) -> Tensor: - """Calculate the training loss for the corresponding timestep. + """ + Calculate the training loss for the corresponding timestep. x: (N, T, C, H, W) tensor of inputs (latent representations of video) - text_embedding: (N, L, C') tensor of the text embedding - timestep: (N,) tensor to indicate denoising step + ul2_emb: (N, L1, 4096) UL2 text embeddings + metaclip_emb: (N, L2, 1280) MetaCLIP text embeddings + byt5_emb: (N, L3, 1472) ByT5 text embeddings + timestep: (N,) tensor to indicate a denoising step """ x = x.to(ms.float32) @@ -148,11 +151,10 @@ def construct( metaclip_emb.to(self.model.dtype), byt5_emb.to(self.model.dtype), ).to(ms.float32) - velocity_pred = mint.chunk(model_output, 2, dim=2)[0] v_t = x - (1 - self.eps) * noise # 3.1.2 Eqa (2) - loss = self.criteria(velocity_pred, v_t) + loss = self.criteria(model_output, v_t) return loss def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py index afbf2a621d..01392be094 100644 --- a/examples/moviegen/moviegen/utils/__init__.py +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -1,4 +1,3 @@ -from .callback import * from .ema import * from .misc import * from .model_utils import * diff --git a/examples/moviegen/moviegen/utils/callback.py b/examples/moviegen/moviegen/utils/callback.py deleted file mode 100644 index 0578deb7b7..0000000000 --- a/examples/moviegen/moviegen/utils/callback.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -import os -import time -from typing import List, Optional - -import numpy as np - -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.train import Callback, RunContext - -from mindone.trainers.checkpoint import CheckpointManager - -__all__ = ["LossMonitor", "SaveCkptCallback", "TimeMonitor"] - -logger = logging.getLogger(__name__) - - -class LossMonitor(Callback): - def __init__(self, log_interval: int = 1, log_overflow: bool = True) -> None: - self.log_interval = log_interval - self.log_overflow = log_overflow - self.step_num = 0 - - def on_train_step_begin(self, run_context: RunContext) -> None: - self.step_num += 1 - - def on_train_epoch_end(self, run_context: RunContext) -> None: - self.step_num = 0 - - def on_train_step_end(self, run_context: RunContext) -> None: - cb_params = run_context.original_args() - cur_step = cb_params.cur_step_num - - if cur_step % self.log_interval == 0: - cur_lr = self._fetch_optimizer_lr(cb_params) - cur_loss = self._fetch_loss(cb_params) - cur_loss_scale = self._fetch_loss_scale(cb_params) - - logger.info( - "epoch: %d step: %d, lr: %.7f, loss: %.6f, loss scale: %d.", - cb_params.cur_epoch_num, - self.step_num, - cur_lr.item(), - cur_loss.item(), - cur_loss_scale.item(), - ) - - if self.log_overflow: - overflow = cb_params.net_outputs[1] - if overflow: - logger.warning(f"overflow detected in epoch {cb_params.cur_epoch_num} step {self.step_num}.") - - def _get_optimizer_from_cbp(self, cb_params): - if cb_params.optimizer is not None: - optimizer = cb_params.optimizer - elif cb_params.dataset_sink_mode: - optimizer = cb_params.train_network.network.optimizer - else: - optimizer = cb_params.train_network.optimizer - return optimizer - - def _fetch_loss_scale(self, cb_params) -> Tensor: - if cb_params.dataset_sink_mode: - return cb_params.train_network.network.scale_sense - else: - return cb_params.train_network.scale_sense - - def _fetch_optimizer_lr(self, cb_params) -> Tensor: - opt = self._get_optimizer_from_cbp(cb_params) - lr = opt.learning_rate - if opt.dynamic_lr: - lr = opt.learning_rate(ops.clip(opt.global_step - 1, min=0))[0] - return lr - - def _fetch_loss(self, cb_params) -> Tensor: - loss = cb_params.net_outputs[0] - return loss - - -class SaveCkptCallback(Callback): - def __init__( - self, - output_dir: str = "./output", - ckpt_max_keep: int = 5, - ckpt_save_interval: int = 1, - rank_id: Optional[int] = None, - ) -> None: - self.rank_id = 0 if rank_id is None else rank_id - if self.rank_id != 0: - return - - self.ckpt_save_interval = ckpt_save_interval - - ckpt_save_dir = os.path.join(output_dir, f"rank_{rank_id}") - if not os.path.isdir(ckpt_save_dir): - os.makedirs(ckpt_save_dir) - self.ckpt_manager = CheckpointManager(ckpt_save_dir, ckpt_save_policy="latest_k", k=ckpt_max_keep) - - def on_train_epoch_end(self, run_context: RunContext) -> None: - if self.rank_id != 0: - return - - cb_params = run_context.original_args() - cur_epoch = cb_params.cur_epoch_num - epoch_num = cb_params.epoch_num - - if cur_epoch % self.ckpt_save_interval != 0 and cur_epoch != epoch_num: - return - - ckpt_name = f"epoch_{cur_epoch}.ckpt" - network = cb_params.train_network.network - self.ckpt_manager.save(network=network.trainable_params(), ckpt_name=ckpt_name) - - -class TimeMonitor(Callback): - def __init__(self) -> None: - self.epoch_start_time = 0 - self.step_start_time = 0 - self.durations: List[int] = list() - - def on_train_epoch_begin(self, run_context: RunContext) -> None: - self.epoch_start_time = time.time() - - def on_train_step_begin(self, run_context: RunContext) -> None: - self.step_start_time = time.time() - - def on_train_step_end(self, run_context: RunContext) -> None: - duration = time.time() - self.step_start_time - self.durations.append(duration) - - def on_train_epoch_end(self, run_context: RunContext) -> None: - epoch_duration = time.time() - self.epoch_start_time - avg_time = np.mean(self.durations) - self.durations = list() - logger.info(f"Total training time for single epoch: {epoch_duration:.3f} seconds") - logger.info(f"Average step time: {avg_time:.3f} seconds") diff --git a/examples/moviegen/scripts/train_t2i_256x256.sh b/examples/moviegen/scripts/train_t2i_256x256.sh index 07c74b4e18..5d4a2d9afe 100644 --- a/examples/moviegen/scripts/train_t2i_256x256.sh +++ b/examples/moviegen/scripts/train_t2i_256x256.sh @@ -1,9 +1,6 @@ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -# improve data loading performance for distributed training: 1 -export MS_ENABLE_NUMA=0 # plot memory usage, feature/model: 1 export MS_MEMORY_STATISTIC=0 -export MS_DATASET_SINK_QUEUE=4 # log level export GLOG_v=2 @@ -16,11 +13,11 @@ python train.py \ --env.mode 0 \ --env.jit_level O0 \ --env.max_device_memory 59GB \ - --env.distributed=True \ + --env.distributed True \ --model.name llama-1B \ --dataset.csv_path CSV_PATH \ --dataset.video_folder VIDEO_FOLDER \ --dataset.text_emb_folder.ul2 UL2_FOLDER \ --dataset.text_emb_folder.byt5 BYT5_FOLDER \ - --train.output_path=$output_dir \ + --train.output_path $output_dir \ --train.ema "" # turn off ema diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 7946ebeefe..21ab7436bb 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -21,7 +21,7 @@ from moviegen.utils import EMA, MODEL_DTYPE, MODEL_SPEC, load_ckpt_params from mindone.data import create_dataloader -from mindone.trainers import create_optimizer, create_scheduler +from mindone.trainers import create_optimizer from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger @@ -37,6 +37,7 @@ def init_model( name: Literal["llama-1B", "llama-5B", "llama-30B"], + in_channels: int = 4, pretrained_model_path: Optional[Path_fr] = None, enable_flash_attention: bool = True, recompute: bool = False, @@ -44,8 +45,7 @@ def init_model( ) -> LlamaModel: attn_implementation = "flash_attention" if enable_flash_attention else "eager" model = MODEL_SPEC[name]( - in_channels=4, - out_channels=8, + in_channels=in_channels, attn_implementation=attn_implementation, gradient_checkpointing=recompute, dtype=MODEL_DTYPE[dtype], @@ -68,10 +68,7 @@ def main(args): initializer = parser.instantiate_classes(cfg) # 2. model initialize and weight loading - # 2.1 Llama 3 - network = init_model(**args.model) - - # 2.2 VAE + # 2.1 VAE logger.info("vae init") # TODO: add support of training with latents vae_args = args.vae.as_dict() @@ -82,7 +79,9 @@ def main(args): # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=vae_dtype) - # 2.4 LossWrapper + # 2.2 Llama 3 + network = init_model(in_channels=vae.out_channels, **args.model) + # 2.3 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) # 3. build training network @@ -99,15 +98,14 @@ def main(args): # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR - lr = create_scheduler(steps_per_epoch=dataloader.get_dataset_size(), **args.train.lr_scheduler) + lr = initializer.train.lr_scheduler # 5.2 optimizer optimizer = create_optimizer(latent_diffusion_with_loss.trainable_params(), lr=lr, **args.train.optimizer) - loss_scaler = initializer.train.loss_scaler - # 5.3 trainer (standalone and distributed) ema = EMA(latent_diffusion_with_loss.network, **args.train.ema.init_args) if args.train.ema else None + loss_scaler = initializer.train.loss_scaler net_with_grads = prepare_train_network( latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings ) @@ -147,7 +145,7 @@ def main(args): f"Num trainable params: {num_params_trainable:,}", f"Model dtype: {args.model.dtype}", f"VAE dtype: {args.vae.dtype}", - f"Learning rate: {args.train.lr_scheduler.lr:.0e}", + f"Learning rate: {args.train.lr_scheduler.init_args.learning_rate:.0e}", f"Batch size: {args.dataloader.batch_size}", f"Image size: {args.dataset.target_size}", f"Frames: {args.dataset.sample_n_frames}", @@ -180,7 +178,7 @@ def main(args): help="Path to load a config yaml file that describes the setting which will override the default arguments.", ) parser.add_function_arguments(init_train_env, "env") - parser.add_function_arguments(init_model, "model") + parser.add_function_arguments(init_model, "model", skip={"in_channels"}) parser.add_function_arguments(OpenSoraVAE_V1_2, "vae", fail_untyped=False) parser.add_argument( "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." @@ -192,7 +190,9 @@ def main(args): create_dataloader, "dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") - parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch"}) + parser.add_subclass_arguments( + nn.learning_rate_schedule.LearningRateSchedule, "train.lr_scheduler", fail_untyped=False + ) parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"}) parser.add_subclass_arguments( nn.Cell, @@ -208,7 +208,6 @@ def main(args): "--train.output_path", default="output/", type=Path_dcc, help="Output directory to save training results." ) parser.add_argument("--train.epochs", default=10, type=int, help="Number of epochs to train. Default: 100.") - parser.link_arguments("train.epochs", "train.lr_scheduler.num_epochs", apply_on="parse") parser.add_class_arguments( EvalSaveCallback, "train.save", From 798698cd4dba07cda4c0a7999e7a480bc32a4556 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Tue, 5 Nov 2024 11:54:56 +0800 Subject: [PATCH 025/122] reconstruct tested --- examples/movie_gen/mg/models/tae/tae.py | 11 ++++++++-- examples/movie_gen/tests/test_tae.py | 28 +++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 208ef428ea..b2469fbc7f 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -73,7 +73,7 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: return z - def decode(self, x: ms.Tensor) -> ms.Tensor: + def decode(self, z: ms.Tensor) -> ms.Tensor: z = self.post_quant_conv(z) dec = self.decoder(z) return dec @@ -85,4 +85,11 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: x: (b c t h w) """ - return x + posterior_mean, posterior_logvar = self._encode(x) + z = self.sample(posterior_mean, posterior_logvar) + recons = self.decode(z) + + # TODO: discard supurious frames + + return recons, posterior_mean, posterior_logvar + diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index c1e3b3764c..892200d142 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -162,7 +162,7 @@ def test_decoder(): def test_tae_encode(): - in_shape = (B, C, T, H, W) = (1, 3, 1, 64, 64) + in_shape = (B, C, T, H, W) = (1, 3, 8, 64, 64) x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) @@ -171,6 +171,28 @@ def test_tae_encode(): print(y.shape) +def test_tae_decode(): + # in_shape = (B, C, T, H, W) = (1, 3, 1, 64, 64) + in_shape = (B, C, T, H, W) = (1, 4, 1, 8, 8) + x = np.random.normal(size=in_shape).astype(np.float32) + x = ms.Tensor(x) + + tae = VideoAutoencoder(config=SDXL_CONFIG) + y = tae.decode(x) + + print(y.shape) + + +def test_tae_rec(): + in_shape = (B, C, T, H, W) = (1, 3, 8, 64, 64) + x = np.random.normal(size=in_shape).astype(np.float32) + x = ms.Tensor(x) + + tae = VideoAutoencoder(config=SDXL_CONFIG) + y = tae(x) + + print(y[0].shape) + if __name__ == "__main__": # test_conv25d() @@ -184,4 +206,6 @@ def test_tae_encode(): # test_temporal_upsample() # test_spatial_upsample() # test_decoder() - test_tae_encode() + # test_tae_encode() + # test_tae_decode() + test_tae_rec() From 8648df14a0bcb35a689fa0fe2369db8bf491ff9e Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Tue, 5 Nov 2024 15:35:01 +0800 Subject: [PATCH 026/122] update readme --- README.md | 18 ++++++++++++++---- examples/README.md | 3 +++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 438aff6a82..9f40d3503f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,16 @@ This repository contains SoTA algorithms, models, and interesting projects in the area of multimodal understanding and content generation ONE is short for "ONE for all" + ## News +- [2024.11.06] MindONE v0.2.0 is released + +## Quick tour + +To install MindONE, please checkout [Installation](https://mindspore-lab.github.io/mindone/latest/diffusers/installation/#installation) + + +[mindone/diffusers](mindone/diffusers) supports state-of-the-art diffusion models for generating images, audio, and video. Let's get started using [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) as an example. **Hello MindSpore** from **Stable Diffusion 3**! @@ -11,8 +20,6 @@ ONE is short for "ONE for all" sd3 -- [mindone/diffusers](mindone/diffusers) now supports [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium). Give it a try yourself! - ```py import mindspore from mindone.diffusers import StableDiffusion3Pipeline @@ -45,15 +52,18 @@ ONE is short for "ONE for all" | [video composer](https://github.com/mindspore-lab/mindone/tree/master/examples/videocomposer) | support conditional video generation with motion transfer and etc.| | [ip adapter](https://github.com/mindspore-lab/mindone/blob/master/examples/ip_adapter) | refactoring | | [t2i-adapter](https://github.com/mindspore-lab/mindone/blob/master/examples/t2i_adapter) | refactoring | +| [dynamicrafter](https://github.com/mindspore-lab/mindone/blob/master/examples/dynamicrafter) | support image to video generation | +| [hunyuan_dit](https://github.com/mindspore-lab/mindone/blob/master/examples/hunyuan_dit) | support text to image fine tune | +| [pixart_sigma](https://github.com/mindspore-lab/mindone/blob/master/examples/pixart_sigma) | suuport text to image fine tune at different aspect ratio | ### run hf diffusers on mindspore -mindone diffusers is under active development, most tasks were tested with mindspore 2.2.10 and ascend 910 hardware. +mindone diffusers is under active development, most tasks were tested with mindspore 2.3+ and ascend 910 hardware. | component | features | :--- | :-- | [pipeline](https://github.com/mindspore-lab/mindone/tree/master/mindone/diffusers/pipelines) | support text2image,text2video,text2audio tasks 30+ | [models](https://github.com/mindspore-lab/mindone/tree/master/mindone/diffusers/models) | support audoencoder & transformers base models same as hf diffusers | [schedulers](https://github.com/mindspore-lab/mindone/tree/master/mindone/diffusers/schedulers) | support ddpm & dpm solver 10+ schedulers same as hf diffusers + #### TODO -* [ ] mindspore 2.3.0 version adaption * [ ] hf diffusers 0.30.0 version adaption diff --git a/examples/README.md b/examples/README.md index a9a238a856..838b30d2ae 100644 --- a/examples/README.md +++ b/examples/README.md @@ -21,3 +21,6 @@ | [llava](https://github.com/mindspore-lab/mindone/blob/master/examples/llava) | Haotian-Liu official | https://github.com/haotian-liu/LLaVA | [vila](https://github.com/mindspore-lab/mindone/blob/master/examples/vila) | Nvidia Lab official | https://github.com/NVlabs/VILA | [pllava](https://github.com/mindspore-lab/mindone/blob/master/examples/pllava) | Magic Research official | https://github.com/magic-research/PLLaVA +| [dynamicrafter](https://github.com/mindspore-lab/mindone/blob/master/examples/dynamicrafter) | Tencent Research official | https://github.com/Doubiiu/DynamiCrafter +| [hunyuan_dit](https://github.com/mindspore-lab/mindone/blob/master/examples/hunyuan_dit) | Tencent Research official | https://github.com/Tencent/HunyuanDiT +| [pixart_sigma](https://github.com/mindspore-lab/mindone/blob/master/examples/pixart_sigma) | Noah Lab official | https://github.com/PixArt-alpha/PixArt-sigma \ No newline at end of file From a6b5a49f4d1b2c53044ea88bf4674b98d295f35d Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 6 Nov 2024 10:51:37 +0800 Subject: [PATCH 027/122] discard spurious frames --- examples/movie_gen/mg/models/tae/tae.py | 8 ++++++-- examples/movie_gen/tests/test_tae.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index b2469fbc7f..5d040939da 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -46,7 +46,10 @@ def __init__( self.exp = ops.Exp() self.stdnormal = ops.StandardNormal() self.split = ms.ops.split - self.sample_deterministic=False + + self.sample_deterministic = False + self.discard_spurious_frames = True + def _encode(self, x): # return latent distribution, N(mean, logvar) @@ -89,7 +92,8 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: z = self.sample(posterior_mean, posterior_logvar) recons = self.decode(z) - # TODO: discard supurious frames + if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): + recons = recons[:, :, :x.shape[-3], :, :] return recons, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 892200d142..efbee65760 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -184,7 +184,7 @@ def test_tae_decode(): def test_tae_rec(): - in_shape = (B, C, T, H, W) = (1, 3, 8, 64, 64) + in_shape = (B, C, T, H, W) = (1, 3, 9, 64, 64) x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) From 3f166727098e329c46ccd43d9e49265d2dcc2211 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 6 Nov 2024 11:36:21 +0800 Subject: [PATCH 028/122] rename --- examples/movie_gen/mg/models/tae/tae.py | 3 ++- examples/movie_gen/tests/test_tae.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 5d040939da..3a11118903 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -2,6 +2,7 @@ from mindspore import nn, ops from .modules import Conv2_5d, Encoder, Decoder +# TODO: set z_channels to 16 SDXL_CONFIG = { "double_z": True, "z_channels": 4, @@ -16,7 +17,7 @@ } -class VideoAutoencoder(nn.Cell): +class TemporalAutoencoder(nn.Cell): r""" TAE diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index efbee65760..4236e0f3d0 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -12,7 +12,7 @@ TemporalDownsample, TemporalUpsample, ) -from mg.models.tae.tae import SDXL_CONFIG, VideoAutoencoder +from mg.models.tae.tae import SDXL_CONFIG, TemporalAutoencoder import mindspore as ms @@ -166,7 +166,7 @@ def test_tae_encode(): x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) - tae = VideoAutoencoder(config=SDXL_CONFIG) + tae = TemporalAutoencoder(config=SDXL_CONFIG) y = tae.encode(x) print(y.shape) @@ -177,7 +177,7 @@ def test_tae_decode(): x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) - tae = VideoAutoencoder(config=SDXL_CONFIG) + tae = TemporalAutoencoder(config=SDXL_CONFIG) y = tae.decode(x) print(y.shape) @@ -188,7 +188,7 @@ def test_tae_rec(): x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) - tae = VideoAutoencoder(config=SDXL_CONFIG) + tae = TemporalAutoencoder(config=SDXL_CONFIG) y = tae(x) print(y[0].shape) From df2f01ca533350371fe04e2984240d0790f47587 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 6 Nov 2024 17:39:30 +0800 Subject: [PATCH 029/122] add train --- .../movie_gen/mg/models/tae/autoencoder_kl.py | 120 ----- examples/movie_gen/mg/models/tae/losses.py | 222 ++++++++ examples/movie_gen/mg/models/tae/lpips.py | 139 +++++ examples/movie_gen/mg/models/tae/vae.py | 485 ------------------ examples/movie_gen/scripts/args_train_vae.py | 294 +++++++++++ examples/movie_gen/scripts/train_vae.py | 438 ++++++++++++++++ 6 files changed, 1093 insertions(+), 605 deletions(-) delete mode 100644 examples/movie_gen/mg/models/tae/autoencoder_kl.py create mode 100644 examples/movie_gen/mg/models/tae/losses.py create mode 100644 examples/movie_gen/mg/models/tae/lpips.py delete mode 100644 examples/movie_gen/mg/models/tae/vae.py create mode 100644 examples/movie_gen/scripts/args_train_vae.py create mode 100644 examples/movie_gen/scripts/train_vae.py diff --git a/examples/movie_gen/mg/models/tae/autoencoder_kl.py b/examples/movie_gen/mg/models/tae/autoencoder_kl.py deleted file mode 100644 index 779838b71b..0000000000 --- a/examples/movie_gen/mg/models/tae/autoencoder_kl.py +++ /dev/null @@ -1,120 +0,0 @@ -import mindspore as ms -from mindspore import nn, ops - -from ..layers.operation_selector import get_split_op -from .modules import Decoder, Encoder - -__all__ = ["AutoencoderKL"] - - -class AutoencoderKL(nn.Cell): - def __init__( - self, - ddconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - monitor=None, - use_recompute=False, - sample_deterministic=False, - ): - super().__init__() - self.image_key = image_key - self.sample_deterministic = sample_deterministic - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - # assert ddconfig["double_z"] - self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) - self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1, pad_mode="valid", has_bias=True) - self.embed_dim = embed_dim - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - self.exp = ops.Exp() - self.stdnormal = ops.StandardNormal() - self.split = get_split_op() - - if use_recompute: - self.recompute(self.encoder) - self.recompute(self.quant_conv) - self.recompute(self.post_quant_conv) - self.recompute(self.decoder) - - def recompute(self, b): - if not b._has_config_recompute: - b.recompute() - if isinstance(b, nn.CellList): - self.recompute(b[-1]) - else: - b.add_flags(output_no_recompute=True) - - def init_from_ckpt( - self, path, ignore_keys=list(), remove_prefix=["first_stage_model.", "autoencoder.", "spatial_vae.module."] - ): - # TODO: support auto download pretrained checkpoints - sd = ms.load_checkpoint(path) - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - vae_prefix = ["encoder.", "decoder.", "quant_conv.", "post_quant_conv."] - for pname in keys: - is_vae_param = False - for pf in remove_prefix: - if pname.startswith(pf): - sd[pname.replace(pf, "")] = sd.pop(pname) - is_vae_param = True - for pf in vae_prefix: - if pname.startswith(pf): - is_vae_param = True - if not is_vae_param: - sd.pop(pname) - pu, cu = ms.load_param_into_net(self, sd, strict_load=False) - print(f"Net param not loaded : {pu}") - print(f"Checkpoint param not loaded : {cu}") - print(f"Restored from {path}") - - def _encode(self, x): - # return latent distribution, N(mean, logvar) - h = self.encoder(x) - moments = self.quant_conv(h) - mean, logvar = self.split(moments, moments.shape[1] // 2, 1) - - return mean, logvar - - def sample(self, mean, logvar): - # sample z from latent distribution - logvar = ops.clip_by_value(logvar, -30.0, 20.0) - std = self.exp(0.5 * logvar) - z = mean + std * self.stdnormal(mean.shape) - - return z - - def encode(self, x): - # embedding, get latent representation z - posterior_mean, posterior_logvar = self._encode(x) - if self.sample_deterministic: - return posterior_mean - z = self.sample(posterior_mean, posterior_logvar) - - return z - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def construct(self, input): - # overall pass, mostly for training - posterior_mean, posterior_logvar = self._encode(input) - z = self.sample(posterior_mean, posterior_logvar) - - recons = self.decode(z) - - return recons, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/mg/models/tae/losses.py b/examples/movie_gen/mg/models/tae/losses.py new file mode 100644 index 0000000000..dfbc959492 --- /dev/null +++ b/examples/movie_gen/mg/models/tae/losses.py @@ -0,0 +1,222 @@ +import mindspore as ms +from mindspore import nn, ops + +from .lpips import LPIPS + + +def _rearrange_in(x): + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4) + x = ops.reshape(x, (b * t, c, h, w)) + + return x + + +class GeneratorWithLoss(nn.Cell): + def __init__( + self, + autoencoder, + kl_weight=1.0e-06, + perceptual_weight=1.0, + logvar_init=0.0, + use_outlier_penalty_loss=True, + opl_weight=1e5, + dtype=ms.float32, + ): + super().__init__() + + # build perceptual models for loss compute + self.autoencoder = autoencoder + # TODO: set dtype for LPIPS ? + self.perceptual_loss = LPIPS() # freeze params inside + + # self.l1 = nn.L1Loss(reduction="none") + # TODO: is self.logvar trainable? + self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=ms.float32)) + + self.kl_weight = kl_weight + self.perceptual_weight = perceptual_weight + self.use_outlier_penalty_loss = use_outlier_penalty_loss + self.opl_weight = opl_weight + + def kl(self, mean, logvar): + # cast to fp32 to avoid overflow in exp and sum ops. + mean = mean.astype(ms.float32) + logvar = logvar.astype(ms.float32) + + var = ops.exp(logvar) + kl_loss = 0.5 * ops.sum( + ops.pow(mean, 2) + var - 1.0 - logvar, + dim=[1, 2, 3], + ) + return kl_loss + + def vae_loss_fn( + self, x, recons, mean, logvar, nll_weights=None, no_perceptual=False, no_kl=False, pixelwise_mean=False + ): + """ + return: + nll_loss: weighted sum of pixel reconstruction loss and perceptual loss + weighted_nll_loss: weighted mean of nll_loss + weighted_kl_loss: KL divergence on posterior + """ + bs = x.shape[0] + # (b c t h w) -> (b*t c h w) + x = _rearrange_in(x) + recons = _rearrange_in(recons) + + # reconstruction loss in pixels + # FIXME: debugging: use pixelwise mean to reduce loss scale + if pixelwise_mean: + rec_loss = ((x - recons) ** 2).mean() + else: + rec_loss = ops.abs(x - recons) + + # perceptual loss + if (self.perceptual_weight > 0) and (not no_perceptual): + p_loss = self.perceptual_loss(x, recons) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar + if nll_weights is not None: + weighted_nll_loss = nll_weights * nll_loss + weighted_nll_loss = weighted_nll_loss.sum() / bs + else: + weighted_nll_loss = nll_loss.sum() / bs + + # kl loss + # TODO: FIXME: it may not fit for graph mode training + if (self.kl_weight > 0) and (not no_kl): + kl_loss = self.kl(mean, logvar) + kl_loss = kl_loss.sum() / bs + weighted_kl_loss = self.kl_weight * kl_loss + else: + weighted_kl_loss = 0 + + return nll_loss, weighted_nll_loss, weighted_kl_loss + + def construct(self, x: ms.Tensor, global_step: ms.Tensor = -1, weights: ms.Tensor = None, cond=None): + """ + x: input images or videos, images: (b c 1 h w), videos: (b c t h w) + weights: sample weights + global_step: global training step + """ + print("D--: x shape: ", x.shape) + x_rec, z, posterior_mean, posterior_logvar = self.autoencoder(x) + # FIXME: debugging + x_rec, z, posterior_mean, posterior_logvar = ( + x_rec.to(ms.float32), + z.to(ms.float32), + posterior_mean.to(ms.float32), + posterior_logvar.to(ms.float32), + ) + + frames = x.shape[2] + + # Loss compute + # video frames x reconstruction loss + # TODO: loss dtype setting + # x: (b 3 t h w) + _, weighted_nll_loss, weighted_kl_loss = self.vae_loss_fn( + x, x_rec, posterior_mean, posterior_logvar, no_perceptual=False + ) + loss = weighted_nll_loss + weighted_kl_loss + + if self.use_outlier_penalty_loss and self.opl_weight > 0: + # (b c t h w) -> (b*t c h w) + z = _rearrange_in(z) + z_mean = ops.mean(z, axis=(-1, -2), keep_dims=True) + z_std = ops.std(z, axis=(-1, -2), keep_dims=True) + + std_scale = 3 # r=3 + opl_loss = ops.max((ops.abs(z - z_mean) - std_scale * z_std), 0) + opl_loss = ops.mean(opl_loss) + + loss += self.opl_weight + opl_loss + + return loss + + +# Discriminator is not used in opensora v1.2 +class DiscriminatorWithLoss(nn.Cell): + """ + Training logic: + For training step i, input data x: + 1. AE generator takes input x, feedforward to get posterior/latent and reconstructed data, and compute ae loss + 2. AE optimizer updates AE trainable params + 3. D takes the same input x, feed x to AE again **again** to get + the new posterior and reconstructions (since AE params has updated), feed x and recons to D, and compute D loss + 4. D optimizer updates D trainable params + --> Go to next training step + Ref: sd-vae training + """ + + def __init__( + self, + autoencoder, + discriminator, + disc_start=50001, + disc_factor=1.0, + disc_loss="hinge", + ): + super().__init__() + self.autoencoder = autoencoder + self.discriminator = discriminator + self.disc_start = disc_start + self.disc_factor = disc_factor + + assert disc_loss in ["hinge", "vanilla"] + if disc_loss == "hinge": + self.disc_loss = self.hinge_loss + else: + self.softplus = ops.Softplus() + self.disc_loss = self.vanilla_d_loss + + def hinge_loss(self, logits_real, logits_fake): + loss_real = ops.mean(ops.relu(1.0 - logits_real)) + loss_fake = ops.mean(ops.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + def vanilla_d_loss(self, logits_real, logits_fake): + d_loss = 0.5 * (ops.mean(self.softplus(-logits_real)) + ops.mean(self.softplus(logits_fake))) + return d_loss + + def construct(self, x: ms.Tensor, global_step=-1, cond=None): + """ + Second pass + Args: + x: input image/video, (bs c h w) + weights: sample weights + """ + + # 1. AE forward, get posterior (mean, logvar) and recons + recons, mean, logvar = ops.stop_gradient(self.autoencoder(x)) + + if x.ndim >= 5: + # TODO: use 3D discriminator + # x: b c t h w -> (b*t c h w), shape for image perceptual loss + x = _rearrange_in(x) + recons = _rearrange_in(recons) + + # 2. Disc forward to get class prediction on real input and reconstrucions + if cond is None: + logits_real = self.discriminator(x) + logits_fake = self.discriminator(recons) + else: + logits_real = self.discriminator(ops.concat((x, cond), dim=1)) + logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + + if global_step >= self.disc_start: + disc_factor = self.disc_factor + else: + disc_factor = 0.0 + + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + # log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + # "{}/logits_real".format(split): logits_real.detach().mean(), + # "{}/logits_fake".format(split): logits_fake.detach().mean() + # } + + return d_loss diff --git a/examples/movie_gen/mg/models/tae/lpips.py b/examples/movie_gen/mg/models/tae/lpips.py new file mode 100644 index 0000000000..ca1fbb4442 --- /dev/null +++ b/examples/movie_gen/mg/models/tae/lpips.py @@ -0,0 +1,139 @@ +import logging +import os + +import mindcv +from opensora.utils.load_models import load_from_pretrained + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + +_logger = logging.getLogger(__name__) + + +class LPIPS(nn.Cell): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vgg16 features + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # load NetLin metric layers + self.load_lpips() + + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + self.lins = nn.CellList(self.lins) + + # create vision backbone and load pretrained weights + self.net = vgg16(pretrained=True, requires_grad=False) + + self.set_train(False) + for param in self.trainable_params(): + param.requires_grad = False + + def load_lpips(self, ckpt_path="models/lpips_vgg-426bf45c.ckpt"): + if not os.path.exists(ckpt_path): + ckpt_path = "https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt" + load_from_pretrained(self, ckpt_path) + + _logger.info("loaded pretrained LPIPS loss from {}".format(ckpt_path)) + + def construct(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + val = 0 # ms.Tensor(0, dtype=input.dtype) + for kk in range(len(self.chns)): + diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 + # res += spatial_average(lins[kk](diff), keepdim=True) + # lin_layer = lins[kk] + val += ops.mean(self.lins[kk](diff), axis=[2, 3], keep_dims=True) + return val + + +class ScalingLayer(nn.Cell): + def __init__(self): + super(ScalingLayer, self).__init__() + self.shift = ms.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + self.scale = ms.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + + def construct(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Cell): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False, dtype=ms.float32): + super(NetLinLayer, self).__init__() + # TODO: can parse dtype=dtype in ms2.3 + layers = ( + [ + nn.Dropout(p=0.5).to_float(dtype), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, has_bias=False).to_float(dtype), + ] + self.model = nn.SequentialCell(layers) + + def construct(self, x): + return self.model(x) + + +class vgg16(nn.Cell): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + # FIXME: add bias in vgg. use the same model weights in PT. + model = mindcv.create_model("vgg16", pretrained=pretrained) + model.set_train(False) + vgg_pretrained_features = model.features + self.slice1 = nn.SequentialCell() + self.slice2 = nn.SequentialCell() + self.slice3 = nn.SequentialCell() + self.slice4 = nn.SequentialCell() + self.slice5 = nn.SequentialCell() + self.N_slices = 5 + for x in range(4): + self.slice1.append(vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.append(vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.append(vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.append(vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.append(vgg_pretrained_features[x]) + if not requires_grad: + for param in self.trainable_params(): + param.requires_grad = False + for param in model.trainable_params(): + param.requires_grad = False + + def construct(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + out = (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keep_dims=keepdim) diff --git a/examples/movie_gen/mg/models/tae/vae.py b/examples/movie_gen/mg/models/tae/vae.py deleted file mode 100644 index d846d2fdac..0000000000 --- a/examples/movie_gen/mg/models/tae/vae.py +++ /dev/null @@ -1,485 +0,0 @@ -import logging -import os - -from transformers import PretrainedConfig - -import mindspore as ms -from mindspore import mint, nn, ops -from mindspore.communication import get_group_size - -from ...acceleration.communications import GatherFowardSplitBackward, SplitFowardGatherBackward -from ...acceleration.parallel_states import get_sequence_parallel_group -from ..layers.operation_selector import get_split_op -from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD -from .vae_temporal import VAE_Temporal_SD # noqa: F401 - -__all__ = ["AutoencoderKL"] - - -_logger = logging.getLogger(__name__) -SD_CONFIG = { - "double_z": True, - "z_channels": 4, - "resolution": 256, - "in_channels": 3, - "out_ch": 3, - "ch": 128, - "ch_mult": [1, 2, 4, 4], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0.0, -} -SDXL_CONFIG = SD_CONFIG.copy() -SDXL_CONFIG.update({"resolution": 512}) - - -class AutoencoderKL(AutoencoderKL_SD): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.split = get_split_op() - - def init_from_ckpt(self, path, ignore_keys=list()): - if not os.path.exists(path): - raise ValueError( - "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" - ) - param_dict = ms.load_checkpoint(path) - param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) - if param_not_load or ckpt_not_load: - _logger.warning( - f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" - ) - - def encode_with_moments_output(self, x): - """For latent caching usage""" - h = self.encoder(x) - moments = self.quant_conv(h) - mean, logvar = self.split(moments, moments.shape[1] // 2, 1) - logvar = ops.clip_by_value(logvar, -30.0, 20.0) - std = self.exp(0.5 * logvar) - - return mean, std - - -class VideoAutoencoderKL(nn.Cell): - """ - Spatial VAE - """ - - def __init__( - self, - config=SDXL_CONFIG, - ckpt_path=None, - micro_batch_size=None, - scale_factor=0.18215, - use_recompute=False, - micro_batch_parallel=False, - sample_deterministic=False, - ): - super().__init__() - - self.module = AutoencoderKL_SD( - ddconfig=config, - embed_dim=config["z_channels"], - ckpt_path=ckpt_path, - use_recompute=use_recompute, - sample_deterministic=sample_deterministic, - ) - - self.out_channels = config["z_channels"] # self.module.config.latent_channels - self.patch_size = (1, 8, 8) - self.micro_batch_size = micro_batch_size - self.micro_batch_parallel = micro_batch_parallel - if self.micro_batch_parallel: - sp_group = get_sequence_parallel_group() - _logger.info(f"Initialize Spatial VAE model with parallel group `{sp_group}`.") - self.sp_size = get_group_size(sp_group) - self.split_forward_gather_backward = SplitFowardGatherBackward(dim=0, grad_scale="down", group=sp_group) - self.gather_forward_split_backward = GatherFowardSplitBackward(dim=0, grad_scale="up", group=sp_group) - # TODO: drop the assertion once conv3d support fp32, test with test suites - assert self.micro_batch_size == 1 - - # FIXME: "scaling_factor": 0.13025 is set in - # https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/blob/main/vae/config.json. - # This is a mistake made during the training of OpenSora v1.2. - # To re-use the trained model, we need to keep this mistake. - # For training, we should refine to 0.13025. - self.scale_factor = scale_factor - self.split = get_split_op() - self.scale_factor = 0.18215 - - @staticmethod - def rearrange_in(x): - B, C, T, H, W = x.shape - # (b c t h w) -> (b t c h w) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - x = ops.reshape(x, (B * T, C, H, W)) - - return x - - @staticmethod - def rearrange_out(x, B): - # x = rearrange(x, "(B T) C H W -> B C T H W", B=B) - BT, C, H, W = x.shape - T = BT // B - x = ops.reshape(x, (B, T, C, H, W)) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - - return x - - def encode(self, x): - """ - Args: - x: (B, C, T, H, W) - Return: - (B C T H W) - - NOTE: remind to use stop gradient when invoke it - """ - # is_video = (x.ndim == 5) - - B = x.shape[0] - # B C T H W -> (B T) C H W - x = self.rearrange_in(x) - - pad_num = None - if self.micro_batch_parallel: - # select part of x for micro_batch - pad_num = self.get_pad_num(x.shape[0]) - if pad_num > 0: - x = mint.nn.functional.pad(x, (0, 0, 0, 0, 0, 0, 0, pad_num)) - x = self.split_forward_gather_backward(x) - - if self.micro_batch_size is None: - x_out = self.module.encode(x) * self.scale_factor - else: - bs = self.micro_batch_size - x_out = self.module.encode(x[:bs]) * self.scale_factor - for i in range(bs, x.shape[0], bs): - x_cur = self.module.encode(x[i : i + bs]) * self.scale_factor - x_out = ops.cat((x_out, x_cur), axis=0) - - if self.micro_batch_parallel: - x_out = self.gather_forward_split_backward(x_out) - if pad_num > 0: - x_out = x_out.narrow(0, 0, x_out.shape[0] - pad_num) - - # (B T) C H W -> B C T H W - x_out = self.rearrange_out(x_out, B=B) - - return x_out - - def decode(self, x, **kwargs): - # is_video = (x.ndim == 5) - - B = x.shape[0] - # x: (B, Z, T, H, W) - # B Z T H W -> (B T) Z H W - x = self.rearrange_in(x) - - if self.micro_batch_size is None: - x_out = self.module.decode(x / self.scale_factor) - else: - mbs = self.micro_batch_size - - x_out = self.module.decode(x[:mbs] / self.scale_factor) - for i in range(mbs, x.shape[0], mbs): - x_cur = self.module.decode(x[i : i + mbs] / self.scale_factor) - x_out = ops.cat((x_out, x_cur), axis=0) - - # (B T) Z H W -> B Z T H W - x_out = self.rearrange_out(x_out, B=B) - - return x_out - - def get_latent_size(self, input_size): - latent_size = [] - for i in range(3): - # assert ( - # input_size[i] is None or input_size[i] % self.patch_size[i] == 0 - # ), "Input size must be divisible by patch size" - latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) - return latent_size - - def get_pad_num(self, dim_size: int) -> int: - pad = (self.sp_size - (dim_size % self.sp_size)) % self.sp_size - return pad - - -class VideoAutoencoderPipelineConfig(PretrainedConfig): - model_type = "VideoAutoencoderPipeline" - - def __init__( - self, - vae_2d=None, - vae_temporal=None, - from_pretrained=None, - freeze_vae_2d=False, - cal_loss=False, - micro_frame_size=None, - concat_posterior=False, - shift=0.0, - scale=1.0, - micro_frame_parallel=False, - sample_deterministic=False, - **kwargs, - ): - self.vae_2d = vae_2d - self.vae_temporal = vae_temporal - self.from_pretrained = from_pretrained - self.freeze_vae_2d = freeze_vae_2d - self.cal_loss = cal_loss - self.micro_frame_size = micro_frame_size - self.shift = shift - self.scale = scale - self.concat_posterior = (concat_posterior,) - self.micro_frame_parallel = micro_frame_parallel - self.sample_deterministic = sample_deterministic - super().__init__(**kwargs) - - -def build_module_from_config(config): - """ - config dict format: - - type: model class name - - others: model init args - """ - cfg = config.copy() - name = cfg.pop("type") - kwargs = cfg - - # FIXME: use importlib with path - module = eval(name)(**kwargs) - return module - - -class VideoAutoencoderPipeline(nn.Cell): - """ - Main model for spatial vae + tempral vae - """ - - # config_class = VideoAutoencoderPipelineConfig - def __init__(self, config: VideoAutoencoderPipelineConfig): - super().__init__() - self.spatial_vae = build_module_from_config(config.vae_2d) - self.temporal_vae = build_module_from_config(config.vae_temporal) - - self.cal_loss = config.cal_loss - self.micro_frame_size = config.micro_frame_size - self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] - print(f"micro_frame_size: {self.micro_frame_size}, micro_z_frame_size: {self.micro_z_frame_size}") - self.micro_frame_parallel = config.micro_frame_parallel - self.sample_deterministic = config.sample_deterministic - - if config.freeze_vae_2d: - for param in self.spatial_vae.get_parameters(): - param.requires_grad = False - - self.out_channels = self.temporal_vae.out_channels - self.split = get_split_op() - - # normalization parameters - scale = ms.Tensor(config.scale) - shift = ms.Tensor(config.shift) - if len(scale.shape) > 0: - scale = scale[None, :, None, None, None] - if len(shift.shape) > 0: - shift = shift[None, :, None, None, None] - self.scale = ms.Parameter(scale, requires_grad=False) - self.shift = ms.Parameter(shift, requires_grad=False) - self.freeze_vae_2d = config.freeze_vae_2d - self.concat_posterior = config.concat_posterior - - if self.micro_frame_parallel: - sp_group = get_sequence_parallel_group() - _logger.info(f"Initialize Temporal VAE model with parallel group `{sp_group}`.") - self.sp_size = get_group_size(sp_group) - self.split_forward_gather_backward = SplitFowardGatherBackward(dim=2, grad_scale="down", group=sp_group) - self.gather_forward_split_backward = GatherFowardSplitBackward(dim=2, grad_scale="up", group=sp_group) - if self.cal_loss: - raise NotImplementedError("Not Supported yet.") - - def encode(self, x): - if self.freeze_vae_2d: - x_z = ops.stop_gradient(self.spatial_vae.encode(x)) - else: - x_z = self.spatial_vae.encode(x) - - if self.micro_frame_parallel: - # TODO: drop assertion and add padding - assert x_z.shape[2] % self.sp_size == 0 - if self.micro_frame_size is not None: - assert x_z.shape[2] % self.micro_frame_size == 0 - x_z = self.split_forward_gather_backward(x_z) - - if self.micro_frame_size is None: - posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z) - if self.sample_deterministic: - z_out = posterior_mean - else: - z_out = self.temporal_vae.sample(posterior_mean, posterior_logvar) - - if self.cal_loss: - return z_out, posterior_mean, posterior_logvar, x_z - else: - if self.micro_frame_parallel: - z_out = self.gather_forward_split_backward(z_out) - return (z_out - self.shift) / self.scale - else: - # x_z: (b z t h w) - mfs = self.micro_frame_size - if self.cal_loss: - # TODO: fix the bug in torch, output concat of the splitted posteriors instead of the last split - posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z[:, :, :mfs]) - if self.sample_deterministic: - z_out = posterior_mean - else: - z_out = self.temporal_vae.sample(posterior_mean, posterior_logvar) - for i in range(mfs, x_z.shape[2], mfs): - posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z[:, :, i : i + mfs]) - if self.sample_deterministic: - z_cur = posterior_mean - else: - z_cur = self.temporal_vae.sample(posterior_mean, posterior_logvar) - z_out = ops.cat((z_out, z_cur), axis=2) - - return z_out, posterior_mean, posterior_logvar, x_z - else: - # no posterior cache to reduce memory in inference - z_out = self.temporal_vae.encode(x_z[:, :, :mfs]) - for i in range(mfs, x_z.shape[2], mfs): - z_cur = self.temporal_vae.encode(x_z[:, :, i : i + mfs]) - z_out = ops.cat((z_out, z_cur), axis=2) - - if self.micro_frame_parallel: - z_out = self.gather_forward_split_backward(z_out) - - return (z_out - self.shift) / self.scale - - def decode(self, z, num_frames=None): - if not self.cal_loss: - z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) - - if self.micro_frame_size is None: - x_z_out = self.temporal_vae.decode(z, num_frames=num_frames) - x = self.spatial_vae.decode(x_z_out) - if self.cal_loss: - return x, x_z_out - else: - return x - else: - mz = self.micro_z_frame_size - remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size - x_z_out = self.temporal_vae.decode(z[:, :, :mz], num_frames=remain_frames) - num_frames -= self.micro_frame_size - - for i in range(mz, z.shape[2], mz): - remain_frames = num_frames if self.micro_frame_size > num_frames else self.micro_frame_size - x_z_cur = self.temporal_vae.decode(z[:, :, i : i + mz], num_frames=remain_frames) - x_z_out = ops.cat((x_z_out, x_z_cur), axis=2) - num_frames -= self.micro_frame_size - - x = self.spatial_vae.decode(x_z_out) - - if self.cal_loss: - return x, x_z_out - else: - return x - - def construct(self, x): - # assert self.cal_loss, "This method is only available when cal_loss is True" - z, posterior_mean, posterior_logvar, x_z = self.encode(x) - x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) - return x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z - - def get_latent_size(self, input_size): - if self.micro_frame_size is None or input_size[0] is None: - return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) - else: - sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] - sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) - sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) - remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] - if remain_temporal_size[0] > 0: - remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) - sub_latent_size[0] += remain_size[0] - return sub_latent_size - - def get_temporal_last_layer(self): - return self.temporal_vae.decoder.conv_out.conv.weight - - -def OpenSoraVAE_V1_2( - micro_batch_size=4, - micro_frame_size=17, - micro_batch_parallel=False, - micro_frame_parallel=False, - ckpt_path=None, - vae2d_ckpt_path=None, - freeze_vae_2d=False, - cal_loss=False, - use_recompute=False, - sample_deterministic=False, -): - """ - ckpt_path: path to the checkpoint of the overall model (vae2d + temporal vae) - vae_2d_ckpt_path: path to the checkpoint of the vae 2d model. It will only be loaded when `ckpt_path` not provided. - """ - - if isinstance(micro_batch_size, int): - if micro_batch_size <= 0: - micro_batch_size = None - if isinstance(micro_frame_size, int): - if micro_frame_size <= 0: - micro_frame_size = None - - vae_2d = dict( - type="VideoAutoencoderKL", - config=SDXL_CONFIG, - micro_batch_size=micro_batch_size, - micro_batch_parallel=micro_batch_parallel, - use_recompute=use_recompute, - sample_deterministic=sample_deterministic, - ) - vae_temporal = dict( - type="VAE_Temporal_SD", - from_pretrained=None, - use_recompute=use_recompute, - sample_deterministic=sample_deterministic, - ) - shift = (-0.10, 0.34, 0.27, 0.98) - scale = (3.85, 2.32, 2.33, 3.06) - kwargs = dict( - vae_2d=vae_2d, - vae_temporal=vae_temporal, - freeze_vae_2d=freeze_vae_2d, - cal_loss=cal_loss, - micro_frame_size=micro_frame_size, - shift=shift, - scale=scale, - micro_frame_parallel=micro_frame_parallel, - sample_deterministic=sample_deterministic, - ) - - config = VideoAutoencoderPipelineConfig(**kwargs) - model = VideoAutoencoderPipeline(config) - - # load model weights - if (ckpt_path is not None) and (os.path.exists(ckpt_path)): - sd = ms.load_checkpoint(ckpt_path) - - # remove the added prefix in the trained checkpoint - pnames = list(sd.keys()) - for pn in pnames: - new_pn = pn.replace("autoencoder.", "").replace("_backbone.", "") - sd[new_pn] = sd.pop(pn) - - pu, cu = ms.load_param_into_net(model, sd, strict_load=False) - print(f"Net param not loaded : {pu}") - print(f"Checkpoint param not loaded : {cu}") - elif (vae2d_ckpt_path is not None) and (os.path.exists(vae2d_ckpt_path)): - sd = ms.load_checkpoint(vae2d_ckpt_path) - # TODO: add spatial_vae prefix to the param name - pu, cu = ms.load_param_into_net(model.spatial_vae, sd, strict_load=False) - - return model diff --git a/examples/movie_gen/scripts/args_train_vae.py b/examples/movie_gen/scripts/args_train_vae.py new file mode 100644 index 0000000000..23d14c0408 --- /dev/null +++ b/examples/movie_gen/scripts/args_train_vae.py @@ -0,0 +1,294 @@ +import argparse +import logging +import os +import sys + +import yaml + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) + +from opensora.utils.model_utils import _check_cfgs_in_parser, str2bool + +from mindone.utils.misc import to_abspath + +logger = logging.getLogger() + + +def parse_train_args(parser): + parser.add_argument( + "--config", + "-c", + default="", + type=str, + help="path to load a config yaml file that describes the training recipes which will override the default arguments", + ) + # the following args's defualt value will be overrided if specified in config yaml + + # data + parser.add_argument("--dataset_name", default="", type=str, help="dataset name") + parser.add_argument( + "--csv_path", + default="", + type=str, + help="path to csv annotation file. columns: video, caption. \ + video indicates the relative path of video file in video_folder. caption - the text caption for video", + ) + parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") + parser.add_argument("--random_crop", default=False, type=str2bool, help="randonly crop the image") + parser.add_argument("--flip", default=False, type=str2bool, help="flip the image") + + parser.add_argument( + "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file" + ) + parser.add_argument("--video_folder", default="", type=str, help="root dir for the video data") + parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") + parser.add_argument( + "--add_datetime", default=True, type=str, help="If True, add datetime subfolder under output_path" + ) + # model + parser.add_argument("--model_type", default="OpenSora-VAE-v1.2", type=str, help="VAE model type") + parser.add_argument("--freeze_vae_2d", default=True, type=str2bool, help="Freeze 2d vae") + parser.add_argument( + "--use_discriminator", default=False, type=str2bool, help="Use discriminator for adversarial training." + ) + parser.add_argument( + "--pretrained_model_path", + default="", + type=str, + help="Specify the pretrained model path", + ) + parser.add_argument("--perceptual_loss_weight", default=0.1, type=float, help="perceptual (lpips) loss weight") + parser.add_argument("--kl_loss_weight", default=1.0e-6, type=float, help="KL loss weight") + parser.add_argument( + "--use_outlier_penalty_loss", + default=False, + type=str2bool, + help="use outlier penalty loss", + ) + # data + parser.add_argument("--mixed_strategy", type=str, default=None, help="video and image mixed strategy") + parser.add_argument( + "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" + ) + + # ms + parser.add_argument("--debug", type=str2bool, default=False, help="Execute inference in debug mode.") + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" + ) + parser.add_argument("--jit_level", default="O0", type=str, help="O0 kbk, O1 dvm, O2 ge") + + # training hyper-params + parser.add_argument( + "--resume", + default=False, + type=str, + help="It can be a string for path to resume checkpoint, or a bool False for not resuming.(default=False)", + ) + parser.add_argument("--optim", default="adamw_re", type=str, help="optimizer") + parser.add_argument( + "--betas", + type=float, + nargs="+", + default=[0.9, 0.999], + help="Specify the [beta1, beta2] parameter for the AdamW optimizer.", + ) + parser.add_argument( + "--optim_eps", type=float, default=1e-8, help="Specify the eps parameter for the AdamW optimizer." + ) + parser.add_argument( + "--group_strategy", + type=str, + default=None, + help="Grouping strategy for weight decay. If `norm_and_bias`, weight decay filter list is [beta, gamma, bias]. \ + If None, filter list is [layernorm, bias], Default: None", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.") + parser.add_argument("--seed", default=3407, type=int, help="data path") + parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps") + parser.add_argument("--batch_size", default=10, type=int, help="batch size") + parser.add_argument( + "--micro_batch_size", + type=int, + default=4, + help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation", + ) + parser.add_argument( + "--micro_frame_size", + type=int, + default=17, + help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation. Used by temporal vae", + ) + parser.add_argument("--start_learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") + parser.add_argument( + "--scale_lr", default=False, type=str2bool, help="scale base-lr by ngpu * batch_size * n_accumulate" + ) + parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") + parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.") + parser.add_argument("--pre_patchify", default=False, type=str2bool, help="Training with patchified latent.") + parser.add_argument( + "--max_image_size", default=512, type=int, help="Max image size for patchified latent training." + ) + + # dataloader params + parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode") + parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.") + parser.add_argument( + "--epochs", + default=10, + type=int, + help="epochs. If dataset_sink_mode is on, epochs is with respect to dataset sink size. Otherwise, it's w.r.t the dataset size.", + ) + parser.add_argument( + "--train_steps", default=-1, type=int, help="If not -1, limit the number of training steps to the set value" + ) + parser.add_argument("--init_loss_scale", default=65536, type=float, help="loss scale") + parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") + parser.add_argument("--scale_window", default=2000, type=float, help="scale window") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps") + # parser.add_argument("--cond_stage_trainable", default=False, type=str2bool, help="whether text encoder is trainable") + parser.add_argument("--use_ema", default=False, type=str2bool, help="whether use EMA") + parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay ratio") + parser.add_argument("--clip_grad", default=False, type=str2bool, help="whether apply gradient clipping") + parser.add_argument( + "--use_recompute", + default=False, + type=str2bool, + help="whether use recompute.", + ) + parser.add_argument( + "--num_recompute_blocks", + default=None, + type=int, + help="If None, all stdit blocks will be applied with recompute (gradient checkpointing). If int, the first N blocks will be applied with recompute", + ) + parser.add_argument( + "--dtype", + default="fp16", + type=str, + choices=["bf16", "fp16", "fp32"], + help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--vae_keep_gn_fp32", + default=True, + type=str2bool, + help="whether keep GroupNorm in fp32.", + ) + parser.add_argument( + "--global_bf16", + default=False, + type=str2bool, + help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN", + ) + parser.add_argument( + "--vae_param_dtype", + default="fp32", + type=str, + choices=["bf16", "fp16", "fp32"], + help="what param data type to use for vae. Default is `fp32`, which corresponds to ms.float32", + ) + parser.add_argument( + "--amp_level", + default="O2", + type=str, + help="mindspore amp level, O1: most fp32, only layers in whitelist compute in fp16 (dense, conv, etc); \ + O2: most fp16, only layers in blacklist compute in fp32 (batch norm etc)", + ) + parser.add_argument("--vae_amp_level", default="O2", type=str, help="O2 or O3") + parser.add_argument( + "--vae_checkpoint", + type=str, + default="models/sd-vae-ft-ema.ckpt", + help="VAE checkpoint file path which is used to load vae weight.", + ) + parser.add_argument( + "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model." + ) + parser.add_argument("--image_size", default=256, type=int, nargs="+", help="the image size used to initiate model") + parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model") + parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride") + parser.add_argument("--mask_ratios", type=dict, help="Masking ratios") + parser.add_argument("--bucket_config", type=dict, help="Multi-resolution bucketing configuration") + parser.add_argument("--num_parallel_workers", default=12, type=int, help="num workers for data loading") + parser.add_argument( + "--data_multiprocessing", + default=False, + type=str2bool, + help="If True, use multiprocessing for data processing. Default: multithreading.", + ) + parser.add_argument("--max_rowsize", default=64, type=int, help="max rowsize for data loading") + parser.add_argument( + "--enable_flash_attention", + default=None, + type=str2bool, + help="whether to enable flash attention.", + ) + parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") + parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="max gradient norm for clipping, effective when `clip_grad` enabled.", + ) + parser.add_argument("--ckpt_save_interval", default=1, type=int, help="save checkpoint every this epochs") + parser.add_argument( + "--ckpt_save_steps", + default=-1, + type=int, + help="save checkpoint every this steps. If -1, use ckpt_save_interval will be used.", + ) + parser.add_argument("--ckpt_max_keep", default=10, type=int, help="Maximum number of checkpoints to keep") + parser.add_argument( + "--step_mode", + default=False, + type=str2bool, + help="whether save ckpt by steps. If False, save ckpt by epochs.", + ) + parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + parser.add_argument( + "--log_level", + type=str, + default="logging.INFO", + help="log level, options: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR", + ) + parser.add_argument( + "--log_interval", + default=1, + type=int, + help="log interval in the unit of data sink size.. E.g. if data sink size = 10, log_inteval=2, log every 20 steps", + ) + return parser + + +def parse_args(): + parser = argparse.ArgumentParser() + parser = parse_train_args(parser) + + __dir__ = os.path.dirname(os.path.abspath(__file__)) + abs_path = os.path.abspath(os.path.join(__dir__, "..")) + default_args = parser.parse_args() + if default_args.config: + default_args.config = to_abspath(abs_path, default_args.config) + with open(default_args.config, "r") as f: + cfg = yaml.safe_load(f) + _check_cfgs_in_parser(cfg, parser) + parser.set_defaults(**cfg) + args = parser.parse_args() + # convert to absolute path, necessary for modelarts + args.csv_path = to_abspath(abs_path, args.csv_path) + args.video_folder = to_abspath(abs_path, args.video_folder) + args.output_path = to_abspath(abs_path, args.output_path) + args.pretrained_model_path = to_abspath(abs_path, args.pretrained_model_path) + args.vae_checkpoint = to_abspath(abs_path, args.vae_checkpoint) + print(args) + + return args diff --git a/examples/movie_gen/scripts/train_vae.py b/examples/movie_gen/scripts/train_vae.py new file mode 100644 index 0000000000..60565a8230 --- /dev/null +++ b/examples/movie_gen/scripts/train_vae.py @@ -0,0 +1,438 @@ +import logging +import os +import shutil +import sys +import time +from typing import Tuple + +import yaml + +import mindspore as ms +from mindspore import Model, nn +from mindspore.communication.management import get_group_size, get_rank, init +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train.callback import TimeMonitor + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from args_train_vae import parse_args +from mg.datasets.vae_dataset import create_dataloader +from mg.models.tae.losses import GeneratorWithLoss +from mg.models.tae.tae import TemporalAutoEncoder + +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback +from mindone.trainers.checkpoint import CheckpointManager, resume_train_network +from mindone.trainers.ema import EMA +from mindone.trainers.lr_schedule import create_scheduler +from mindone.trainers.optim import create_optimizer +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.logger import set_logger +from mindone.utils.params import count_params +from mindone.utils.seed import set_random_seed + +os.environ["HCCL_CONNECT_TIMEOUT"] = "6000" +os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "INFNAN_MODE" + +logger = logging.getLogger(__name__) + + +def create_loss_scaler(loss_scaler_type, init_loss_scale, loss_scale_factor=2, scale_window=1000): + if args.loss_scaler_type == "dynamic": + loss_scaler = DynamicLossScaleUpdateCell( + loss_scale_value=args.init_loss_scale, scale_factor=args.loss_scale_factor, scale_window=args.scale_window + ) + elif args.loss_scaler_type == "static": + loss_scaler = nn.FixedLossScaleUpdateCell(args.init_loss_scale) + else: + raise ValueError + + return loss_scaler + + +def init_env( + mode: int = ms.GRAPH_MODE, + seed: int = 42, + distributed: bool = False, + max_device_memory: str = None, + device_target: str = "Ascend", + parallel_mode: str = "data", + jit_level: str = "O2", + global_bf16: bool = False, + dynamic_shape: bool = False, + debug: bool = False, +) -> Tuple[int, int]: + """ + Initialize MindSpore environment. + + Args: + mode: MindSpore execution mode. Default is 0 (ms.GRAPH_MODE). + seed: The seed value for reproducibility. Default is 42. + distributed: Whether to enable distributed training. Default is False. + Returns: + A tuple containing the device ID, rank ID and number of devices. + """ + set_random_seed(seed) + + if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging + logger.warning("Debug mode is on, switching execution mode to PyNative.") + mode = ms.PYNATIVE_MODE + + if max_device_memory is not None: + ms.set_context(max_device_memory=max_device_memory) + + # ms.set_context(mempool_block_size="55GB") + # ms.set_context(pynative_synchronize=True) + if distributed: + ms.set_context( + mode=mode, + device_target=device_target, + ) + if parallel_mode == "optim": + print("use optim parallel") + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, + enable_parallel_optimizer=True, + ) + init() + device_num = get_group_size() + rank_id = get_rank() + else: + init() + device_num = get_group_size() + rank_id = get_rank() + logger.debug(f"rank_id: {rank_id}, device_num: {device_num}") + ms.reset_auto_parallel_context() + + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=device_num, + ) + + var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] + var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] + logger.info(dict(zip(var_info, var_value))) + + else: + device_num = 1 + rank_id = 0 + ms.set_context( + mode=mode, + device_target=device_target, + pynative_synchronize=debug, + ) + + if mode == 0: + ms.set_context(jit_config={"jit_level": jit_level}) + + if global_bf16: + # only effective in GE mode, i.e. jit_level: O2 + ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) + + if dynamic_shape: + print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") + set_dynamic_mode(True) + + return rank_id, device_num + + +def main(args): + # 1. init + rank_id, device_num = init_env( + args.mode, + seed=args.seed, + distributed=args.use_parallel, + device_target=args.device_target, + max_device_memory=args.max_device_memory, + parallel_mode=args.parallel_mode, + jit_level=args.jit_level, + global_bf16=args.global_bf16, + dynamic_shape=(args.mixed_strategy == "mixed_video_random"), + debug=args.debug, + ) + set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) + + # 2. build data loader + if isinstance(args.image_size, int): + image_size = args.image_size + else: + if len(args.image_size) == 2: + assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" + image_size = args.image_size[0] + + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.video_folder, + size=image_size, + crop_size=image_size, + sample_n_frames=args.num_frames, + sample_stride=args.frame_stride, + video_column=args.video_column, + random_crop=args.random_crop, + flip=args.flip, + ) + dataloader = create_dataloader( + ds_config, + args.batch_size, + mixed_strategy=args.mixed_strategy, + mixed_image_ratio=args.mixed_image_ratio, + num_parallel_workers=args.num_parallel_workers, + max_rowsize=256, + shuffle=True, + device_num=device_num, + rank_id=rank_id, + drop_remainder=True, + ) + dataset_size = dataloader.get_dataset_size() + logger.info(f"Num batches: {dataset_size}") + + # 3. build models + ae = TemporalAutoEncoder( + micro_batch_size=args.micro_batch_size, + micro_frame_size=args.micro_frame_size, + ckpt_path=args.pretrained_model_path, + freeze_vae_2d=args.freeze_vae_2d, + cal_loss=True, + use_recompute=args.use_recompute, + ) + + if args.use_discriminator: + logging.error("Discriminator is not used or supported in OpenSora v1.2") + + # mixed precision + # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. + if args.dtype in ["fp16", "bf16"]: + dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + ae = auto_mixed_precision( + ae, + args.amp_level, + dtype, + custom_fp32_cells=[nn.GroupNorm] if args.vae_keep_gn_fp32 else [], + ) + + # 4. build net with loss + ae_with_loss = GeneratorWithLoss( + ae, + kl_weight=args.kl_loss_weight, + perceptual_weight=args.perceptual_loss_weight, + use_image_identity_loss=args.use_outlier_penalty_loss, + dtype=args.dtype, + ) + + tot_params, trainable_params = count_params(ae_with_loss) + logger.info("Total params {:,}; Trainable params {:,}".format(tot_params, trainable_params)) + + # 5. build training utils + # torch scale lr by: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr + if args.scale_lr: + learning_rate = args.start_learning_rate * args.batch_size * args.gradient_accumulation_steps * device_num + logger.info(f"Learning rate is scaled to {learning_rate}") + else: + learning_rate = args.start_learning_rate + if not args.decay_steps: + args.decay_steps = max(1, args.epochs * dataset_size - args.warmup_steps) + + if args.scheduler != "constant": + assert ( + args.optim != "adamw_exp" + ), "For dynamic LR, mindspore.experimental.optim.AdamW needs to work with LRScheduler" + lr = create_scheduler( + steps_per_epoch=dataset_size, + name=args.scheduler, + lr=learning_rate, + end_lr=args.end_learning_rate, + warmup_steps=args.warmup_steps, + decay_steps=args.decay_steps, + num_epochs=args.epochs, + ) + else: + lr = learning_rate + + # build optimizer + update_logvar = False # in torch, ae_with_loss.logvar is not updated. + if update_logvar: + ae_params_to_update = [ae_with_loss.autoencoder.trainable_params(), ae_with_loss.logvar] + else: + ae_params_to_update = ae_with_loss.autoencoder.trainable_params() + optim_ae = create_optimizer( + ae_params_to_update, + name=args.optim, + betas=args.betas, + group_strategy=args.group_strategy, + weight_decay=args.weight_decay, + lr=lr, + ) + loss_scaler_ae = create_loss_scaler( + args.loss_scaler_type, args.init_loss_scale, args.loss_scale_factor, args.scale_window + ) + + ema = ( + EMA( + ae, + ema_decay=args.ema_decay, + offloading=False, + ) + if args.use_ema + else None + ) + + # resume training states + # TODO: resume Discriminator if used + ckpt_dir = os.path.join(args.output_path, "ckpt") + os.makedirs(ckpt_dir, exist_ok=True) + start_epoch = 0 + if args.resume: + resume_ckpt = os.path.join(ckpt_dir, "train_resume.ckpt") if isinstance(args.resume, bool) else args.resume + + start_epoch, loss_scale, cur_iter, last_overflow_iter = resume_train_network( + ae_with_loss, optim_ae, resume_ckpt + ) + loss_scaler_ae.loss_scale_value = loss_scale + loss_scaler_ae.cur_iter = cur_iter + loss_scaler_ae.last_overflow_iter = last_overflow_iter + logger.info(f"Resume training from {resume_ckpt}") + + # training step + training_step_ae = TrainOneStepWrapper( + ae_with_loss, + optimizer=optim_ae, + scale_sense=loss_scaler_ae, + drop_overflow_update=args.drop_overflow_update, + gradient_accumulation_steps=args.gradient_accumulation_steps, + clip_grad=args.clip_grad, + clip_norm=args.max_grad_norm, + ema=ema, + ) + + # support dynamic shape in graph mode + if args.mode == 0 and args.mixed_strategy == "mixed_video_random": + # (b c t h w), drop_remainder so bs fixed + # videos = ms.Tensor(shape=[args.batch_size, 3, None, image_size, image_size], dtype=ms.float32) + videos = ms.Tensor(shape=[None, 3, None, image_size, image_size], dtype=ms.float32) + training_step_ae.set_inputs(videos) + logger.info("Dynamic inputs are initialized for mixed_video_random training in Graph mode!") + + if rank_id == 0: + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Distributed mode: {args.use_parallel}", + f"amp level: {args.amp_level}", + f"dtype: {args.dtype}", + f"csv path: {args.csv_path}", + f"Video folder: {args.video_folder}", + f"Learning rate: {learning_rate}", + f"Batch size: {args.batch_size}", + f"Rescale size: {args.image_size}", + f"Weight decay: {args.weight_decay}", + f"Grad accumulation steps: {args.gradient_accumulation_steps}", + f"Num epochs: {args.epochs}", + f"Loss scaler: {args.loss_scaler_type}", + f"Init loss scale: {args.init_loss_scale}", + f"Grad clipping: {args.clip_grad}", + f"Max grad norm: {args.max_grad_norm}", + f"EMA: {args.use_ema}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + # 6. training process + use_flexible_train = False + if not use_flexible_train: + model = Model(training_step_ae) + + # callbacks + callback = [TimeMonitor(args.log_interval)] + ofm_cb = OverflowMonitor() + callback.append(ofm_cb) + + if rank_id == 0: + save_cb = EvalSaveCallback( + network=ae, + rank_id=rank_id, + ckpt_save_dir=ckpt_dir, + ema=ema, + ckpt_save_policy="latest_k", + ckpt_max_keep=args.ckpt_max_keep, + ckpt_save_interval=args.ckpt_save_interval, + log_interval=args.log_interval, + start_epoch=start_epoch, + model_name="vae_3d", + record_lr=False, + ) + callback.append(save_cb) + if args.profile: + callback.append(ProfilerCallback()) + + logger.info("Start training...") + # backup config files + shutil.copyfile(args.config, os.path.join(args.output_path, os.path.basename(args.config))) + + with open(os.path.join(args.output_path, "args.yaml"), "w") as f: + yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) + + model.train( + args.epochs, + dataloader, + callbacks=callback, + dataset_sink_mode=args.dataset_sink_mode, + # sink_size=args.sink_size, + initial_epoch=start_epoch, + ) + else: + if rank_id == 0: + ckpt_manager = CheckpointManager(ckpt_dir, "latest_k", k=args.ckpt_max_keep) + # output_numpy=True ? + ds_iter = dataloader.create_dict_iterator(args.epochs - start_epoch) + + for epoch in range(start_epoch, args.epochs): + start_time_e = time.time() + for step, data in enumerate(ds_iter): + start_time_s = time.time() + x = data["video"] + + global_step = epoch * dataset_size + step + global_step = ms.Tensor(global_step, dtype=ms.int64) + + # NOTE: inputs must match the order in GeneratorWithLoss.construct + loss_ae_t, overflow, scaling_sens = training_step_ae(x, global_step) + + cur_global_step = epoch * dataset_size + step + 1 # starting from 1 for logging + if overflow: + logger.warning(f"Overflow occurs in step {cur_global_step}") + + # log + step_time = time.time() - start_time_s + if step % args.log_interval == 0: + loss_ae = float(loss_ae_t.asnumpy()) + logger.info(f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, Step time: {step_time*1000:.2f}ms") + + epoch_cost = time.time() - start_time_e + per_step_time = epoch_cost / dataset_size + cur_epoch = epoch + 1 + logger.info( + f"Epoch:[{int(cur_epoch):>3d}/{int(args.epochs):>3d}], " + f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time*1000:.2f}ms, " + ) + if rank_id == 0: + if (cur_epoch % args.ckpt_save_interval == 0) or (cur_epoch == args.epochs): + ckpt_name = f"vae_kl_f8-e{cur_epoch}.ckpt" + if ema is not None: + ema.swap_before_eval() + + ckpt_manager.save(ae, None, ckpt_name=ckpt_name, append_dict=None) + if ema is not None: + ema.swap_after_eval() + + # TODO: eval while training + + +if __name__ == "__main__": + args = parse_args() + main(args) From 69865d41bd713db88e9e267cd00503976456ef01 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 6 Nov 2024 18:05:13 +0800 Subject: [PATCH 030/122] add train config --- .../movie_gen/configs/vae/train/video_ft.yaml | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 examples/movie_gen/configs/vae/train/video_ft.yaml diff --git a/examples/movie_gen/configs/vae/train/video_ft.yaml b/examples/movie_gen/configs/vae/train/video_ft.yaml new file mode 100644 index 0000000000..c6e9330f36 --- /dev/null +++ b/examples/movie_gen/configs/vae/train/video_ft.yaml @@ -0,0 +1,46 @@ +# model +freeze_vae_2d: False +pretrained_model_path: "" + +# loss +perceptual_loss_weight: 0.1 +kl_loss_weight: 1.e-6 +use_outlier_penalty_loss: True +mixed_strategy: "mixed_video_image" +mixed_image_ratio: 0.2 + +# data +dataset_name: "video" +csv_path: "../videocomposer/datasets/webvid5_copy.csv" +video_folder: "../videocomposer/datasets/webvid5" +frame_stride: 1 +num_frames: 16 +image_size: 256 + +# micro_frame_size: 17 +# micro_batch_size: 4 +# flip: True + +# training recipe +seed: 42 +use_discriminator: False +dtype: "bf16" +batch_size: 1 +clip_grad: True +max_grad_norm: 1.0 +start_learning_rate: 1.e-5 +scale_lr: False +weight_decay: 0. +use_recompute: True + +epochs: 400 +ckpt_save_interval: 100 +init_loss_scale: 1. + +scheduler: "constant" +use_ema: False + +output_path: "outputs/tae_train" + +# ms settting +jit_level: O1 From 7bfba9006bae876979479909b89b5c522f5c8478 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 6 Nov 2024 18:08:00 +0800 Subject: [PATCH 031/122] rename --- examples/movie_gen/configs/{vae => tae}/train/video_ft.yaml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/movie_gen/configs/{vae => tae}/train/video_ft.yaml (100%) diff --git a/examples/movie_gen/configs/vae/train/video_ft.yaml b/examples/movie_gen/configs/tae/train/video_ft.yaml similarity index 100% rename from examples/movie_gen/configs/vae/train/video_ft.yaml rename to examples/movie_gen/configs/tae/train/video_ft.yaml From 3db54f08db79ac96de307fd04535da6deb658c89 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 6 Nov 2024 18:09:48 +0800 Subject: [PATCH 032/122] rename --- .../movie_gen/scripts/{args_train_vae.py => args_train_tae.py} | 0 examples/movie_gen/scripts/{train_vae.py => train_tae.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename examples/movie_gen/scripts/{args_train_vae.py => args_train_tae.py} (100%) rename examples/movie_gen/scripts/{train_vae.py => train_tae.py} (100%) diff --git a/examples/movie_gen/scripts/args_train_vae.py b/examples/movie_gen/scripts/args_train_tae.py similarity index 100% rename from examples/movie_gen/scripts/args_train_vae.py rename to examples/movie_gen/scripts/args_train_tae.py diff --git a/examples/movie_gen/scripts/train_vae.py b/examples/movie_gen/scripts/train_tae.py similarity index 100% rename from examples/movie_gen/scripts/train_vae.py rename to examples/movie_gen/scripts/train_tae.py From 5c2913fcc75e3fefc0a484d1c942e834b0a1da59 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 6 Nov 2024 18:10:08 +0800 Subject: [PATCH 033/122] add dataset --- examples/movie_gen/mg/datasets/tae_dataset.py | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 examples/movie_gen/mg/datasets/tae_dataset.py diff --git a/examples/movie_gen/mg/datasets/tae_dataset.py b/examples/movie_gen/mg/datasets/tae_dataset.py new file mode 100644 index 0000000000..e09310460a --- /dev/null +++ b/examples/movie_gen/mg/datasets/tae_dataset.py @@ -0,0 +1,358 @@ +import copy +import csv +import glob +import logging +import os +import random + +import albumentations +import cv2 +import imageio +import numpy as np +from decord import VideoReader + +import mindspore as ms + +logger = logging.getLogger() + + +def create_video_transforms( + size=384, crop_size=256, interpolation="bicubic", backend="al", random_crop=False, flip=False, num_frames=None +): + if backend == "al": + # expect rgb image in range 0-255, shape (h w c) + from albumentations import CenterCrop, HorizontalFlip, RandomCrop, SmallestMaxSize + + # NOTE: to ensure augment all frames in a video in the same way. + assert num_frames is not None, "num_frames must be parsed" + targets = {"image{}".format(i): "image" for i in range(num_frames)} + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + transforms = [ + SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]), + CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size), + ] + if flip: + transforms += [HorizontalFlip(p=0.5)] + + pixel_transforms = albumentations.Compose( + transforms, + additional_targets=targets, + ) + else: + raise NotImplementedError + + return pixel_transforms + + +def get_video_path_list(folder): + # TODO: find recursively + fmts = ["avi", "mp4", "gif"] + out = [] + for fmt in fmts: + out += glob.glob(os.path.join(folder, f"*.{fmt}")) + return sorted(out) + + +class VideoDataset: + def __init__( + self, + csv_path=None, + data_folder=None, + size=384, + crop_size=256, + random_crop=False, + flip=False, + sample_stride=4, + sample_n_frames=16, + return_image=False, + transform_backend="al", + video_column="video", + ): + """ + size: image resize size + crop_size: crop size after resize operation + """ + logger.info(f"loading annotations from {csv_path} ...") + + if csv_path is not None: + with open(csv_path, "r") as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.read_from_csv = True + else: + self.dataset = get_video_path_list(data_folder) + self.read_from_csv = False + + self.length = len(self.dataset) + logger.info(f"Num data samples: {self.length}") + logger.info(f"sample_n_frames: {sample_n_frames}") + + self.data_folder = data_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.return_image = return_image + + self.pixel_transforms = create_video_transforms( + size=size, + crop_size=crop_size, + random_crop=random_crop, + flip=flip, + num_frames=sample_n_frames, + ) + self.transform_backend = transform_backend + self.video_column = video_column + + # prepare replacement data + max_attempts = 100 + self.prev_ok_sample = self.get_replace_data(max_attempts) + self.require_update_prev = False + + def get_replace_data(self, max_attempts=100): + replace_data = None + attempts = min(max_attempts, self.length) + for idx in range(attempts): + try: + pixel_values = self.get_batch(idx) + replace_data = copy.deepcopy(pixel_values) + break + except Exception as e: + print("\tError msg: {}".format(e)) + + assert replace_data is not None, f"Fail to preload sample in {attempts} attempts." + + return replace_data + + def get_batch(self, idx): + # get video raw pixels (batch of frame) and its caption + if self.read_from_csv: + video_dict = self.dataset[idx] + video_fn = video_dict[list(video_dict.keys())[0]] + video_path = os.path.join(self.data_folder, video_fn) + else: + video_path = self.dataset[idx] + + video_reader = VideoReader(video_path) + + video_length = len(video_reader) + + if not self.return_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if video_path.endswith(".gif"): + pixel_values = video_reader[batch_index] # shape: (f, h, w, c) + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c) + + del video_reader + + return pixel_values + + def __len__(self): + return self.length + + def __getitem__(self, idx): + """ + Returns: + video: preprocessed video frames in shape (f, c, h, w), normalized to [-1, 1] + """ + try: + pixel_values = self.get_batch(idx) + if (self.prev_ok_sample is None) or (self.require_update_prev): + self.prev_ok_sample = copy.deepcopy(pixel_values) + self.require_update_prev = False + except Exception as e: + logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.") + print("\tError msg: {}".format(e), flush=True) + assert self.prev_ok_sample is not None + pixel_values = self.prev_ok_sample # unless the first sample is already not ok + self.require_update_prev = True + + if idx >= self.length: + raise IndexError # needed for checking the end of dataset iteration + + num_frames = len(pixel_values) + # pixel value: (f, h, w, 3) -> transforms -> (f 3 h' w') + if self.transform_backend == "al": + # NOTE:it's to ensure augment all frames in a video in the same way. + # ref: https://albumentations.ai/docs/examples/example_multi_target/ + + inputs = {"image": pixel_values[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = pixel_values[i + 1] + + output = self.pixel_transforms(**inputs) + + pixel_values = np.stack(list(output.values()), axis=0) + # (t h w c) -> (c t h w) + pixel_values = np.transpose(pixel_values, (3, 0, 1, 2)) + else: + raise NotImplementedError + + if self.return_image: + pixel_values = pixel_values[1] + + pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) + + return pixel_values + + +# TODO: parse in config dict +def check_sanity(x, save_fp="./tmp.gif"): + # reverse normalization and visulaize the transformed video + # (c, t, h, w) -> (t, h, w, c) + if len(x.shape) == 3: + x = np.expand_dims(x, axis=0) + x = np.transpose(x, (1, 2, 3, 0)) + + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).astype(np.uint8) + + imageio.mimsave(save_fp, x, duration=1 / 8.0, loop=1) + + +class BatchTransform: + def __init__(self, mixed_strategy, mixed_image_ratio=0.2): + self.mixed_strategy = mixed_strategy + self.mixed_image_ratio = mixed_image_ratio + + def __call__(self, x): + # x: (bs, c, t, h, w) + if self.mixed_strategy == "mixed_video_image": + if random.random() < self.mixed_image_ratio: + x = x[:, :, :1, :, :] + elif self.mixed_strategy == "mixed_video_random": + # TODO: somehow it's slow. consider do it with tensor in NetWithLoss + length = random.randint(1, x.shape[2]) + x = x[:, :, :length, :, :] + elif self.mixed_strategy == "image_only": + x = x[:, :, :1, :, :] + else: + raise ValueError + return x + + +def create_dataloader( + ds_config, + batch_size, + mixed_strategy=None, + mixed_image_ratio=0.0, + num_parallel_workers=12, + max_rowsize=32, + shuffle=True, + device_num=1, + rank_id=0, + drop_remainder=True, +): + """ + Args: + mixed_strategy: + None - all output batches are videoes [bs, c, T, h, w] + mixed_video_image - with prob of mixed_image_ratio, output batch are images [b, c, 1, h, w] + mixed_video_random - output batch has a random number of frames [bs, c, t, h, w], t is the same of samples in a batch + mixed_image_ratio: + ds_config, dataset config, args for ImageDataset or VideoDataset + ds_name: dataset name, image or video + """ + dataset = VideoDataset(**ds_config) + print("Total number of samples: ", len(dataset)) + + # Larger value leads to more memory consumption. Default: 16 + # prefetch_size = config.get("prefetch_size", 16) + # ms.dataset.config.set_prefetch_size(prefetch_size) + + dataloader = ms.dataset.GeneratorDataset( + source=dataset, + column_names=["video"], + num_shards=device_num, + shard_id=rank_id, + python_multiprocessing=True, + shuffle=shuffle, + num_parallel_workers=num_parallel_workers, + max_rowsize=max_rowsize, + ) + + dl = dataloader.batch( + batch_size, + drop_remainder=drop_remainder, + ) + + if mixed_strategy is not None: + batch_map_fn = BatchTransform(mixed_strategy, mixed_image_ratio) + dl = dl.map( + operations=batch_map_fn, + input_columns=["video"], + num_parallel_workers=1, + ) + + return dl + + +if __name__ == "__main__": + test = "dl" + if test == "dataset": + ds_config = dict( + data_folder="../videocomposer/datasets/webvid5", + random_crop=True, + flip=True, + ) + # test source dataset + ds = VideoDataset(**ds_config) + sample = ds.__getitem__(0) + print(sample.shape) + + check_sanity(sample) + else: + import math + import time + + from tqdm import tqdm + + ds_config = dict( + csv_path="../videocomposer/datasets/webvid5_copy.csv", + data_folder="../videocomposer/datasets/webvid5", + sample_n_frames=17, + size=128, + crop_size=128, + ) + + # test loader + dl = create_dataloader( + ds_config, + 4, + mixed_strategy="mixed_video_random", + mixed_image_ratio=0.2, + ) + + num_batches = dl.get_dataset_size() + # ms.set_context(mode=0) + print(num_batches) + + steps = 50 + iterator = dl.create_dict_iterator(100) # create 100 repeats + tot = 0 + + progress_bar = tqdm(range(steps)) + progress_bar.set_description("Steps") + + start = time.time() + for epoch in range(math.ceil(steps / num_batches)): + for i, batch in enumerate(iterator): + print("epoch", epoch, "step", i) + dur = time.time() - start + tot += dur + + if epoch * num_batches + i < 50: + for k in batch: + print(k, batch[k].shape, batch[k].dtype) # , batch[k].min(), batch[k].max()) + print(f"time cost: {dur * 1000} ms") + + progress_bar.update(1) + if i + 1 > steps: # in case the data size is too large + break + start = time.time() + + mean = tot / steps + print("Avg batch loading time: ", mean) From 4b706fb70d06c8d90a0832f4bfddfdfa13efa3e3 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 7 Nov 2024 14:44:04 +0800 Subject: [PATCH 034/122] trainable --- examples/movie_gen/mg/models/tae/lpips.py | 2 +- examples/movie_gen/mg/models/tae/tae.py | 2 +- examples/movie_gen/scripts/args_train_tae.py | 2 +- examples/movie_gen/scripts/train_tae.py | 17 ++++++----------- .../opensora_hpcai/scripts/inference_vae.py | 3 +++ 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/lpips.py b/examples/movie_gen/mg/models/tae/lpips.py index ca1fbb4442..be34fb4b6b 100644 --- a/examples/movie_gen/mg/models/tae/lpips.py +++ b/examples/movie_gen/mg/models/tae/lpips.py @@ -2,7 +2,7 @@ import os import mindcv -from opensora.utils.load_models import load_from_pretrained +from mg.utils.load_models import load_from_pretrained import mindspore as ms import mindspore.nn as nn diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 3a11118903..26a0a6943e 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -96,5 +96,5 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): recons = recons[:, :, :x.shape[-3], :, :] - return recons, posterior_mean, posterior_logvar + return recons, z, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/scripts/args_train_tae.py b/examples/movie_gen/scripts/args_train_tae.py index 23d14c0408..801286d8b0 100644 --- a/examples/movie_gen/scripts/args_train_tae.py +++ b/examples/movie_gen/scripts/args_train_tae.py @@ -9,7 +9,7 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_lib_path) -from opensora.utils.model_utils import _check_cfgs_in_parser, str2bool +from mg.utils.parser import _check_cfgs_in_parser, str2bool from mindone.utils.misc import to_abspath diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index 60565a8230..a5f1aa0b2f 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -18,10 +18,10 @@ sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) -from args_train_vae import parse_args -from mg.datasets.vae_dataset import create_dataloader +from args_train_tae import parse_args +from mg.datasets.tae_dataset import create_dataloader from mg.models.tae.losses import GeneratorWithLoss -from mg.models.tae.tae import TemporalAutoEncoder +from mg.models.tae.tae import TemporalAutoencoder from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network @@ -191,13 +191,8 @@ def main(args): logger.info(f"Num batches: {dataset_size}") # 3. build models - ae = TemporalAutoEncoder( - micro_batch_size=args.micro_batch_size, - micro_frame_size=args.micro_frame_size, - ckpt_path=args.pretrained_model_path, - freeze_vae_2d=args.freeze_vae_2d, - cal_loss=True, - use_recompute=args.use_recompute, + ae = TemporalAutoencoder( + pretrained=args.pretrained_model_path, ) if args.use_discriminator: @@ -219,7 +214,7 @@ def main(args): ae, kl_weight=args.kl_loss_weight, perceptual_weight=args.perceptual_loss_weight, - use_image_identity_loss=args.use_outlier_penalty_loss, + use_outlier_penalty_loss=args.use_outlier_penalty_loss, dtype=args.dtype, ) diff --git a/examples/opensora_hpcai/scripts/inference_vae.py b/examples/opensora_hpcai/scripts/inference_vae.py index 264982d9ef..8fc76cad14 100644 --- a/examples/opensora_hpcai/scripts/inference_vae.py +++ b/examples/opensora_hpcai/scripts/inference_vae.py @@ -190,6 +190,9 @@ def main(args): mean_ssim += sum(ssim_cur) num_samples += x_rgb.shape[0] + logger.info(f"cur psnr: {psnr_cur[-1]:.4f}, mean psnr:{mean_psnr/num_samples:.4f}") + logger.info(f"cur ssim: {ssim_cur[-1]:.4f}, mean ssim:{mean_ssim/num_samples:.4f}") + if args.eval_loss: recon_loss = np.abs((x - recons).asnumpy()) lpips_loss = lpips_loss_fn(x, recons).asnumpy() From 9557e5970305c4aadf6c25257fe6647b4d60e008 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:23:39 +0800 Subject: [PATCH 035/122] add inference --- .../inference/moviegen_t2i_256x256.yaml | 33 +++ examples/moviegen/inference.py | 215 ++++++++++++++++++ .../moviegen/moviegen/pipelines/__init__.py | 1 + .../moviegen/pipelines/infer_pipeline.py | 92 ++++++++ .../moviegen/schedulers/rectified_flow.py | 16 +- examples/moviegen/moviegen/utils/__init__.py | 1 - examples/moviegen/moviegen/utils/misc.py | 13 -- .../moviegen/moviegen/utils/model_utils.py | 37 ++- examples/moviegen/train.py | 42 +--- .../opensora/pipelines/infer_pipeline.py | 5 +- 10 files changed, 398 insertions(+), 57 deletions(-) create mode 100644 examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml create mode 100644 examples/moviegen/inference.py create mode 100644 examples/moviegen/moviegen/pipelines/infer_pipeline.py delete mode 100644 examples/moviegen/moviegen/utils/misc.py diff --git a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml new file mode 100644 index 0000000000..18e80a066e --- /dev/null +++ b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml @@ -0,0 +1,33 @@ +env: + mode: 0 + jit_level: O0 + seed: 42 + distributed: False + debug: False + +model: + name: llama-1B + pretrained_model_path: + enable_flash_attention: True + dtype: bf16 + +vae: + ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt + dtype: fp16 + +# Inference parameters +num_sampling_steps: 50 +sample_method: linear-quadratic +image_size: [ 256, 256 ] +num_frames: 1 # image +text_emb: + ul2_dir: + metaclip_dir: + byt5_dir: +batch_size: 10 + +# Saving options +output_path: samples +append_timestamp: True +save_format: png +save_latent: False diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py new file mode 100644 index 0000000000..808d2d85b2 --- /dev/null +++ b/examples/moviegen/inference.py @@ -0,0 +1,215 @@ +import datetime +import glob +import logging +import os +import sys +import time +from typing import List, Tuple + +import numpy as np +from jsonargparse import ActionConfigFile, ArgumentParser +from jsonargparse.typing import path_type + +import mindspore as ms +from mindspore import Tensor, amp, nn + +# TODO: remove in future when mindone is ready for install +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.append(mindone_lib_path) + +from moviegen.pipelines import InferPipeline +from moviegen.utils.model_utils import MODEL_DTYPE, init_model + +from mindone.utils import init_train_env, set_logger +from mindone.visualize.videos import save_videos + +# TODO: remove when VAE is added to the project +sys.path.append(os.path.join(__dir__, "../opensora_hpcai/")) +from opensora.models.vae.vae import OpenSoraVAE_V1_2 + +logger = logging.getLogger(__name__) + +Path_dr = path_type("dr", docstring="path to a directory that exists and is readable") + + +def to_numpy(x: Tensor) -> np.ndarray: + if x.dtype == ms.bfloat16: + x = x.astype(ms.float32) + return x.asnumpy() + + +def prepare_captions( + ul2_dir: Path_dr, metaclip_dir: Path_dr, byt5_dir: Path_dr, rank_id: int, device_num: int +) -> Tuple[List[str], List[str], List[str]]: + ul2_emb = sorted(glob.glob(os.path.join(ul2_dir, "*.npz"))) + metaclip_emb = sorted(glob.glob(os.path.join(metaclip_dir, "*.npz"))) + byt5_emb = sorted(glob.glob(os.path.join(byt5_dir, "*.npz"))) + if len(ul2_emb) != len(metaclip_emb) or len(ul2_emb) != len(byt5_emb): + raise ValueError( + f"ul2_dir ({len(ul2_emb)}), metaclip_dir ({len(metaclip_emb)})," + f" and byt5_dir ({len(byt5_emb)}) must contain the same number of files" + ) + ul2_emb = ul2_emb[rank_id::device_num] + logger.info(f"Number of captions for rank {rank_id}: {len(ul2_emb)}") + return ul2_emb, metaclip_emb[rank_id::device_num], byt5_emb[rank_id::device_num] + + +def main(args): + # TODO: CFG error + save_dir = os.path.join(__dir__, args.output_path.relative) + if args.append_timestamp: + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + save_dir = os.path.join(save_dir, time_str) + os.makedirs(save_dir, exist_ok=True) + set_logger(name="", output_dir=save_dir) + + latent_dir = os.path.join(save_dir, "denoised_latents") + if args.save_latent: + os.makedirs(latent_dir, exist_ok=True) + + # 1. init env + _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical? + + # 1.1 read caption embeddings + ul2_emb, metaclip_emb, byt5_emb = prepare_captions(**args.text_emb, rank_id=rank_id, device_num=device_num) + + # 2. model initiate and weight loading + # 2.1 vae + logger.info("vae init") + vae_args = args.vae.as_dict() + vae_dtype = vae_args.pop("dtype") + vae = OpenSoraVAE_V1_2(**vae_args).set_train(False) + if vae_dtype != "fp32": + # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative + amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=MODEL_DTYPE[vae_dtype]) + + img_h, img_w = args.image_size if isinstance(args.image_size, list) else (args.image_size, args.image_size) + num_frames = args.num_frames + latent_size = vae.get_latent_size((num_frames, img_h, img_w)) + + # 2.2 Llama 3 + model = init_model(in_channels=vae.out_channels, **args.model).set_train(False) + + # 2.3 text embeddings + prompt_prefix = [os.path.basename(emb)[:-4] for emb in ul2_emb] + ul2_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in ul2_emb], dtype=ms.float32) + metaclip_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in metaclip_emb], dtype=ms.float32) + byt5_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in byt5_emb], dtype=ms.float32) + num_prompts = ul2_emb.shape[0] + + # 3. build inference pipeline + pipeline = InferPipeline( + model, + vae, + latent_size, + scale_factor=args.scale_factor, # FIXME: refactor + guidance_scale=args.guidance_scale, + num_sampling_steps=args.num_sampling_steps, + sample_method=args.sample_method, + micro_batch_size=args.micro_batch_size, + ) + + # 4. print key info + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}", + f"Num of captions: {num_prompts}", + f"Model dtype: {args.model.dtype}", + f"VAE dtype: {vae_dtype}", + f"Image size: {(img_h, img_w)}", + f"Num frames: {num_frames}", + f"Sampling steps {args.num_sampling_steps}", + f"CFG guidance scale: {args.guidance_scale}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + for i in range(0, num_prompts, args.batch_size): + end_i = min(i + args.batch_size, num_prompts) + logger.info("Sampling captions:") + for j in range(i, end_i): + logger.info(prompt_prefix[j]) + + # infer + start_time = time.perf_counter() + sample, latent = pipeline( + ul2_emb=ul2_emb[i:end_i], + metaclip_emb=metaclip_emb[i:end_i], + byt5_emb=byt5_emb[i:end_i], + num_frames=num_frames, + ) + batch_time = time.perf_counter() - start_time + logger.info( + f"Batch time cost: {batch_time:.3f}s," + f" sampling speed: {args.num_sampling_steps * (end_i - i) / batch_time:.2f} step/s" + ) + + # save result + for j in range(0, end_i - i): + fn = prompt_prefix[i + j] + save_fp = f"{save_dir}/{fn}.{args.save_format}" + latent_save_fp = f"{latent_dir}/{fn}.npy" + + # save videos + if sample is not None: + save_videos(to_numpy(sample[j]), save_fp, fps=args.fps) + logger.info(f"Video saved in {save_fp}") + # save decoded latents + if args.save_latent: + np.save(latent_save_fp, to_numpy(latent[j : j + 1])) + logger.info(f"Denoised latents saved in {latent_save_fp}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Movie Gen inference script.") + parser.add_argument( + "-c", + "--config", + action=ActionConfigFile, + help="Path to load a config yaml file that describes the setting which will override the default arguments.", + ) + parser.add_function_arguments(init_train_env, "env") + parser.add_function_arguments(init_model, "model", skip={"in_channels"}) + vae_group = parser.add_argument_group("VAE parameters") + vae_group.add_function_arguments(OpenSoraVAE_V1_2, "vae", fail_untyped=False) + vae_group.add_argument( + "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." + ) + infer_group = parser.add_argument_group("Inference parameters") + infer_group.add_class_arguments(InferPipeline, skip={"model", "vae", "latent_size"}, instantiate=False) + infer_group.add_argument("--image_size", type=int, nargs="+", help="Output video size") + infer_group.add_argument("--num_frames", type=int, default=17, help="number of frames") + infer_group.add_argument("--fps", type=int, default=16, help="FPS in the saved video") + infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num"}) + infer_group.add_argument("--batch_size", type=int, default=1) + save_group = parser.add_argument_group("Saving options") + save_group.add_argument( + "--save_format", + default="mp4", + choices=["gif", "mp4", "png"], + type=str, + help="video format for saving the sampling output: gif, mp4 or png", + ) + save_group.add_argument( + "--output_path", + default="output/", + type=path_type("dcc"), # path to a directory that can be created if it does not exist + help="Output directory to save training results.", + ) + save_group.add_argument( + "--append_timestamp", + type=bool, + default=True, + help="If true, a subfolder named with timestamp under output_path will be created to save the sampling results", + ) + save_group.add_argument( + "--save_latent", + type=bool, + default=False, + help="Save denoised video latent. If True, the denoised latents will be saved in $output_path/denoised_latents", + ) + cfg = parser.parse_args() + main(cfg) diff --git a/examples/moviegen/moviegen/pipelines/__init__.py b/examples/moviegen/moviegen/pipelines/__init__.py index 8cf855d610..93ba177d16 100644 --- a/examples/moviegen/moviegen/pipelines/__init__.py +++ b/examples/moviegen/moviegen/pipelines/__init__.py @@ -1 +1,2 @@ +from .infer_pipeline import InferPipeline from .train_pipeline import DiffusionWithLoss diff --git a/examples/moviegen/moviegen/pipelines/infer_pipeline.py b/examples/moviegen/moviegen/pipelines/infer_pipeline.py new file mode 100644 index 0000000000..e18e09c0e2 --- /dev/null +++ b/examples/moviegen/moviegen/pipelines/infer_pipeline.py @@ -0,0 +1,92 @@ +from typing import Literal, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +from mindspore import Tensor, mint, nn, ops + +from ..schedulers.rectified_flow import RFLOW + +__all__ = ["InferPipeline"] + + +class InferPipeline: + """An Inference pipeline for diffusion model + + Args: + model (nn.Cell): A noise prediction model to denoise the encoded image latents. + vae (nn.Cell): Variational Auto-Encoder (VAE) Model to encode and decode images or videos to and from latent representations. + scale_factor (float): scale_factor for vae. + guidance_scale (float): A higher guidance scale value for noise rescale. + num_sampling_steps: (int): The number of denoising steps. + """ + + def __init__( + self, + model: nn.Cell, + vae: nn.Cell, + latent_size: Tuple[int, int, int] = (1, 64, 64), + scale_factor: float = 1.0, + guidance_scale: float = 1.0, + num_sampling_steps: int = 50, + sample_method: Literal["linear", "linear-quadratic"] = "linear", + micro_batch_size: Optional[int] = None, + ): + super().__init__() + self.model = model + self.vae = vae + self.latent_size = latent_size + self.micro_batch_size = micro_batch_size + self.scale_factor = scale_factor + self.guidance_rescale = guidance_scale + self.use_cfg = guidance_scale > 1.0 + self.rflow = RFLOW(num_sampling_steps, sample_method=sample_method) + + def vae_decode_video(self, x, num_frames=None): + """ + Args: + x: (b t c h w), denoised latent + Return: + y: (b f H W 3), batch of images, normalized to [0, 1] + """ + x = mint.permute(x, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy + y = self.vae.decode(x, num_frames=num_frames) # FIXME: extract scale_factor from VAE and use it here + y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) + # (b 3 t h w) -> (b t h w 3) + y = mint.permute(y, (0, 2, 3, 4, 1)) + return y + + def __call__( + self, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, num_frames: int = None + ) -> Tuple[Union[Tensor, None], Tensor]: + """ + args: + inputs: dict + + return: + images (b H W 3) + """ + z = ms.Tensor( + np.random.randn( + ul2_emb.shape[0], self.latent_size[0], self.vae.out_channels, self.latent_size[1], self.latent_size[2] + ).astype(np.float32), + dtype=self.model.dtype, + ) + if self.use_cfg: + raise NotImplementedError("Condition-free guidance is not supported yet.") + + latents = self.rflow( + self.model, + z, + ul2_emb.to(self.model.dtype), + metaclip_emb.to(self.model.dtype), + byt5_emb.to(self.model.dtype), + ).to(ms.float32) + + if self.vae is not None: + # latents: (b t c h w) + # out: (b T H W C) + images = self.vae_decode_video(latents, num_frames=num_frames) + return images, latents + else: + return None, latents diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 2c58196237..7768f75022 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -1,4 +1,5 @@ import logging +from math import ceil from typing import Literal, Optional, Tuple import numpy as np @@ -50,7 +51,7 @@ def __init__( self.num_timesteps = num_timesteps self.sample_method = sample_method - def __call__(self, model: nn.Cell, x: Tensor, text_embedding: Tensor) -> Tensor: + def __call__(self, model: nn.Cell, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor) -> Tensor: """ x: (N, T, C, H, W) tensor of inputs (latent representations of video) text_embedding: (N, L, C') tensor of the text embedding @@ -59,13 +60,18 @@ def __call__(self, model: nn.Cell, x: Tensor, text_embedding: Tensor) -> Tensor: if self.sample_method == "linear": timesteps = (1.0 - np.arange(self.num_sampling_steps) / self.num_sampling_steps) * self.num_timesteps else: - raise NotImplementedError("Not supported yet.") + first_half = ceil(self.num_sampling_steps / 2) + second_half = self.num_sampling_steps - first_half # in the case of an odd number of sampling steps + linear = np.arange(first_half, 0, -1) + quadratic = (np.arange(second_half, 0, -1) ** 2) / (second_half**2) + quadratic = (self.num_timesteps - first_half) * quadratic + first_half # scale and shift + timesteps = np.concatenate([quadratic, linear]) - timesteps = np.tile(timesteps[None, ...], (x.shape[0], 1)) - timesteps = Tensor(timesteps, dtype=ms.int64) + timesteps = np.tile(timesteps[..., None], (1, x.shape[0])) + timesteps = Tensor(timesteps, dtype=model.dtype) # FIXME: avoid calculations on tensors outside `construct` for i, timestep in tqdm(enumerate(timesteps), total=self.num_sampling_steps): - pred = model(x, timestep, text_embedding) + pred = model(x, timestep, ul2_emb, metaclip_emb, byt5_emb) # update z dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py index 01392be094..33a0e314cd 100644 --- a/examples/moviegen/moviegen/utils/__init__.py +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -1,3 +1,2 @@ from .ema import * -from .misc import * from .model_utils import * diff --git a/examples/moviegen/moviegen/utils/misc.py b/examples/moviegen/moviegen/utils/misc.py deleted file mode 100644 index 4235d3f807..0000000000 --- a/examples/moviegen/moviegen/utils/misc.py +++ /dev/null @@ -1,13 +0,0 @@ -from moviegen.models import llama3_1B, llama3_5B, llama3_30B - -import mindspore as ms - -__all__ = ["MODEL_SPEC", "MODEL_DTYPE"] - -MODEL_SPEC = {"llama-1B": llama3_1B, "llama-5B": llama3_5B, "llama-30B": llama3_30B} - -MODEL_DTYPE = { - "fp32": ms.float32, - "fp16": ms.float16, - "bf16": ms.bfloat16, -} diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/moviegen/utils/model_utils.py index 1852643255..6f687790c3 100644 --- a/examples/moviegen/moviegen/utils/model_utils.py +++ b/examples/moviegen/moviegen/utils/model_utils.py @@ -1,13 +1,24 @@ import logging -from typing import Dict, Union +from typing import Dict, Literal, Optional, Union + +from jsonargparse.typing import Path_fr +from moviegen import LlamaModel, llama3_1B, llama3_5B, llama3_30B import mindspore as ms from mindspore import _no_grad, jit_class, nn -__all__ = ["load_ckpt_params", "no_grad"] +__all__ = ["MODEL_DTYPE", "load_ckpt_params", "no_grad", "init_model"] logger = logging.getLogger(__name__) +MODEL_SPEC = {"llama-1B": llama3_1B, "llama-5B": llama3_5B, "llama-30B": llama3_30B} + +MODEL_DTYPE = { + "fp32": ms.float32, + "fp16": ms.float16, + "bf16": ms.bfloat16, +} + def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: if isinstance(ckpt, str): @@ -43,3 +54,25 @@ def __enter__(self): def __exit__(self, *args): if self._pynative: super().__exit__(*args) + + +def init_model( + name: Literal["llama-1B", "llama-5B", "llama-30B"], + in_channels: int = 4, + pretrained_model_path: Optional[Path_fr] = None, + enable_flash_attention: bool = True, + recompute: bool = False, + dtype: Literal["fp32", "fp16", "bf16"] = "fp32", +) -> LlamaModel: + attn_implementation = "flash_attention" if enable_flash_attention else "eager" + model = MODEL_SPEC[name]( + in_channels=in_channels, + attn_implementation=attn_implementation, + gradient_checkpointing=recompute, + dtype=MODEL_DTYPE[dtype], + ) + if pretrained_model_path: + model = load_ckpt_params(model, pretrained_model_path.absolute) + else: + logger.info(f"Initialize {name} model randomly.") + return model diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 21ab7436bb..2c1639de69 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -1,10 +1,9 @@ import logging import os import sys -from typing import Literal, Optional from jsonargparse import ActionConfigFile, ArgumentParser -from jsonargparse.typing import Path_fr, path_type +from jsonargparse.typing import path_type from mindspore import Model, amp, nn from mindspore.train.callback import TimeMonitor @@ -12,13 +11,13 @@ # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) -sys.path.insert(0, mindone_lib_path) +sys.path.append(mindone_lib_path) from moviegen.dataset import ImageVideoDataset -from moviegen.models.llama import LlamaModel from moviegen.pipelines import DiffusionWithLoss from moviegen.schedulers import RFlowLossWrapper -from moviegen.utils import EMA, MODEL_DTYPE, MODEL_SPEC, load_ckpt_params +from moviegen.utils import EMA +from moviegen.utils.model_utils import MODEL_DTYPE, init_model from mindone.data import create_dataloader from mindone.trainers import create_optimizer @@ -32,30 +31,6 @@ logger = logging.getLogger(__name__) -Path_dcc = path_type("dcc") # path to a directory that can be created if it does not exist - - -def init_model( - name: Literal["llama-1B", "llama-5B", "llama-30B"], - in_channels: int = 4, - pretrained_model_path: Optional[Path_fr] = None, - enable_flash_attention: bool = True, - recompute: bool = False, - dtype: Literal["fp32", "fp16", "bf16"] = "fp32", -) -> LlamaModel: - attn_implementation = "flash_attention" if enable_flash_attention else "eager" - model = MODEL_SPEC[name]( - in_channels=in_channels, - attn_implementation=attn_implementation, - gradient_checkpointing=recompute, - dtype=MODEL_DTYPE[dtype], - ) - if pretrained_model_path: - model = load_ckpt_params(model, pretrained_model_path) - else: - logger.info("Initialize network randomly.") - return model - def main(args): # 1. init env @@ -75,9 +50,8 @@ def main(args): vae_dtype = vae_args.pop("dtype") vae = OpenSoraVAE_V1_2(**vae_args).set_train(False) if vae_dtype != "fp32": - vae_dtype = MODEL_DTYPE[vae_dtype] # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative - amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=vae_dtype) + amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=MODEL_DTYPE[vae_dtype]) # 2.2 Llama 3 network = init_model(in_channels=vae.out_channels, **args.model) @@ -114,7 +88,6 @@ def main(args): # 5.4 callbacks callbacks = [OverflowMonitor()] - if rank_id == 0: callbacks.extend( [ @@ -205,7 +178,10 @@ def main(args): ) parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False) parser.add_argument( - "--train.output_path", default="output/", type=Path_dcc, help="Output directory to save training results." + "--train.output_path", + default="output/", + type=path_type("dcc"), # path to a directory that can be created if it does not exist + help="Output directory to save training results.", ) parser.add_argument("--train.epochs", default=10, type=int, help="Number of epochs to train. Default: 100.") parser.add_class_arguments( diff --git a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py index c123e9174d..fc2c02e9b2 100644 --- a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py @@ -58,12 +58,11 @@ def __init__( self.use_cfg = False self.text_encoder = text_encoder - self.diffusion = create_diffusion(str(num_inference_steps)) if sampling.lower() == "ddim": - self.sampling_func = self.diffusion.ddim_sample_loop + self.sampling_func = create_diffusion(str(num_inference_steps)).ddim_sample_loop elif sampling.lower() == "ddpm": - self.sampling_func = self.diffusion.p_sample_loop + self.sampling_func = create_diffusion(str(num_inference_steps)).p_sample_loop elif sampling.lower() == "rflow": self.sampling_func = RFLOW(num_inference_steps, cfg_scale=guidance_rescale, use_timestep_transform=True) else: From 602d00ee4db538f182680473dbf2d49a59f0717e Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 7 Nov 2024 18:40:27 +0800 Subject: [PATCH 036/122] fix opl loss --- examples/movie_gen/mg/models/tae/losses.py | 11 +++--- examples/movie_gen/mg/models/tae/modules.py | 40 +++++++++++++++++++-- examples/movie_gen/mg/models/tae/tae.py | 16 +++++++++ examples/movie_gen/scripts/train_tae.py | 1 + examples/movie_gen/tests/test_tae.py | 11 +++--- 5 files changed, 67 insertions(+), 12 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/losses.py b/examples/movie_gen/mg/models/tae/losses.py index dfbc959492..0a97e3e9b3 100644 --- a/examples/movie_gen/mg/models/tae/losses.py +++ b/examples/movie_gen/mg/models/tae/losses.py @@ -124,15 +124,18 @@ def construct(self, x: ms.Tensor, global_step: ms.Tensor = -1, weights: ms.Tenso if self.use_outlier_penalty_loss and self.opl_weight > 0: # (b c t h w) -> (b*t c h w) + # import pdb; pdb.set_trace() z = _rearrange_in(z) z_mean = ops.mean(z, axis=(-1, -2), keep_dims=True) - z_std = ops.std(z, axis=(-1, -2), keep_dims=True) + z_std = ops.std(z, axis=(-1, -2), keepdims=True) std_scale = 3 # r=3 - opl_loss = ops.max((ops.abs(z - z_mean) - std_scale * z_std), 0) - opl_loss = ops.mean(opl_loss) + # opl_loss = ops.max((ops.abs(z - z_mean) - std_scale * z_std), 0) + outlier_penalty = ops.abs(z - z_mean) - std_scale * z_std + outlier_penalty = ops.where(outlier_penalty > 0, outlier_penalty, 0) + opl_loss = ops.mean(outlier_penalty) - loss += self.opl_weight + opl_loss + loss += self.opl_weight * opl_loss return loss diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index c1e2a4a8ad..5faabf8efd 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -1,4 +1,5 @@ import logging +import functools import numpy as np from packaging import version @@ -117,6 +118,7 @@ def construct(self, x): return x + class Conv2_5d(nn.Cell): r""" Conv2.5d, a 2D spatial convolution followed by 1D temporal convolution @@ -138,13 +140,44 @@ def __init__( assert dilation==1 # spatial conv self.conv_spat = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias) + + # temp_pad_mode = 'zero' + # temp_pad = 'mint_rep' + temp_pad = 'manual' + # temporal conv if kernel_size > 1: - self.pad = nn.Pad(paddings=((0, 0), (0, 0), ((kernel_size-1)//2, (kernel_size-1)//2)), mode='SYMMETRIC') - self.use_pad = True + # FIXME: debugging to see how symmetric padding influence performance + if temp_pad == 'zero': + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="same", has_bias=has_bias) + self.use_pad = False + self.pad = nn.Identity() + elif temp_pad == 'mint_rep': + assert kernel_size == 3 + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + self.pad = nn.ReplicationPad1d(((kernel_size-1)//2, (kernel_size-1)//2)) + self.use_pad = True + elif temp_pad == 'manual': + assert kernel_size == 3, 'symmetric padding currently only support kernel size 3' + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + self.pad = self.symmetric_pad1d + self.use_pad = True + elif temp_pad == 'nn_pad': + self.pad = nn.Pad(paddings=((0, 0), (0, 0), ((kernel_size-1)//2, (kernel_size-1)//2)), mode='SYMMETRIC') + self.use_pad = True + else: self.use_pad = False - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) + + @staticmethod + def symmetric_pad1d(x): + first_frame = x[:, :, :1] + last_frame = x[:, :, -1:] + # last_frame_pad = ops.cat([last_frame] * self.time_pad, axis=2) + x = ops.concat((first_frame, x, last_frame), axis=2) + + return x def construct(self, x): ''' @@ -174,6 +207,7 @@ def construct(self, x): x = ops.reshape(x, (B*Ho*Wo, Co, T)) if self.use_pad: + # import pdb; pdb.set_trace() x = self.pad(x) x = self.conv_temp(x) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 26a0a6943e..2915676481 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -30,6 +30,7 @@ def __init__( self, config: dict = SDXL_CONFIG, pretrained: str = None, + use_recompute: bool=False, ): super().__init__() @@ -51,6 +52,21 @@ def __init__( self.sample_deterministic = False self.discard_spurious_frames = True + if use_recompute: + # self.recompute(self.encoder) + # self.recompute(self.quant_conv) + # self.recompute(self.post_quant_conv) + self.recompute(self.decoder) + + + def recompute(self, b): + if not b._has_config_recompute: + b.recompute() + if isinstance(b, nn.CellList): + self.recompute(b[-1]) + else: + b.add_flags(output_no_recompute=True) + def _encode(self, x): # return latent distribution, N(mean, logvar) diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index a5f1aa0b2f..39faef5996 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -193,6 +193,7 @@ def main(args): # 3. build models ae = TemporalAutoencoder( pretrained=args.pretrained_model_path, + use_recompute=args.use_recompute, ) if args.use_discriminator: diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 4236e0f3d0..49fc206c11 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -1,4 +1,7 @@ import numpy as np +import sys +sys.path.insert(0, '.') + from mg.models.tae.modules import ( Conv2_5d, Decoder, @@ -22,7 +25,6 @@ def test_conv25d(): cout = 128 x = np.random.normal(size=in_shape).astype(np.float32) - ms.set_context(mode=0) x = ms.Tensor(x) conv2d = Conv2_5d(C, cout, 3) @@ -42,7 +44,6 @@ def test_resnetblock(): dropout=0.0, ) - ms.set_context(mode=0) x = ms.Tensor(x) y = rb(x) @@ -59,7 +60,6 @@ def test_spatial_attn(): # sa = SpatialAttnBlock(C) sa = SpatialAttnBlockV2(C) - ms.set_context(mode=0) x = ms.Tensor(x) y = sa(x) @@ -76,7 +76,6 @@ def test_temporal_attn(): # TODO: compare time cost for v1 and v2 ta = TemporalAttnBlock(C) - ms.set_context(mode=0) x = ms.Tensor(x) y = ta(x) @@ -184,7 +183,7 @@ def test_tae_decode(): def test_tae_rec(): - in_shape = (B, C, T, H, W) = (1, 3, 9, 64, 64) + in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) @@ -195,6 +194,8 @@ def test_tae_rec(): if __name__ == "__main__": + ms.set_context(mode=1) + # test_conv25d() # test_resnetblock() # test_spatial_attn() From 612efba88cf7d3fd8811d6a9e5d65c3256551974 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 8 Nov 2024 15:20:31 +0800 Subject: [PATCH 037/122] z 16 --- examples/movie_gen/mg/models/tae/tae.py | 2 +- examples/movie_gen/scripts/train_tae.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 2915676481..1eae28b18c 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -5,7 +5,7 @@ # TODO: set z_channels to 16 SDXL_CONFIG = { "double_z": True, - "z_channels": 4, + "z_channels": 16, "resolution": 256, "in_channels": 3, "out_ch": 3, diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index 39faef5996..87c506f5ff 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -359,7 +359,7 @@ def main(args): ckpt_save_interval=args.ckpt_save_interval, log_interval=args.log_interval, start_epoch=start_epoch, - model_name="vae_3d", + model_name="tae", record_lr=False, ) callback.append(save_cb) From 5414b0e397e9004203280c7c340e236988a1f60b Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:49:32 +0800 Subject: [PATCH 038/122] fix linear-quadratic sampling --- examples/moviegen/moviegen/schedulers/rectified_flow.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 7768f75022..18a4f6cc50 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -62,10 +62,11 @@ def __call__(self, model: nn.Cell, x: Tensor, ul2_emb: Tensor, metaclip_emb: Ten else: first_half = ceil(self.num_sampling_steps / 2) second_half = self.num_sampling_steps - first_half # in the case of an odd number of sampling steps - linear = np.arange(first_half, 0, -1) - quadratic = (np.arange(second_half, 0, -1) ** 2) / (second_half**2) - quadratic = (self.num_timesteps - first_half) * quadratic + first_half # scale and shift - timesteps = np.concatenate([quadratic, linear]) + linear = self.num_timesteps - np.arange(first_half) + quadratic = (np.arange(1, second_half + 1) ** 2) / ((second_half + 1) ** 2) + quadratic = (self.num_timesteps - (first_half - 1)) * quadratic + (first_half - 1) # scale and shift + quadratic = self.num_timesteps - quadratic + timesteps = np.concatenate([linear, quadratic]) timesteps = np.tile(timesteps[..., None], (1, x.shape[0])) timesteps = Tensor(timesteps, dtype=model.dtype) # FIXME: avoid calculations on tensors outside `construct` From 89c3dbc575c2537b91d51bb2cdeb18df0994cca1 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:15:10 +0800 Subject: [PATCH 039/122] text encoders inference --- examples/moviegen/inference.py | 10 +- examples/moviegen/inference_text_enc.py | 128 +++++++++++++++++++ examples/moviegen/moviegen/utils/__init__.py | 1 + examples/moviegen/moviegen/utils/utils.py | 11 ++ 4 files changed, 142 insertions(+), 8 deletions(-) create mode 100644 examples/moviegen/inference_text_enc.py create mode 100644 examples/moviegen/moviegen/utils/utils.py diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 808d2d85b2..ce25de023d 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -11,7 +11,7 @@ from jsonargparse.typing import path_type import mindspore as ms -from mindspore import Tensor, amp, nn +from mindspore import amp, nn # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -19,7 +19,7 @@ sys.path.append(mindone_lib_path) from moviegen.pipelines import InferPipeline -from moviegen.utils.model_utils import MODEL_DTYPE, init_model +from moviegen.utils import MODEL_DTYPE, init_model, to_numpy from mindone.utils import init_train_env, set_logger from mindone.visualize.videos import save_videos @@ -33,12 +33,6 @@ Path_dr = path_type("dr", docstring="path to a directory that exists and is readable") -def to_numpy(x: Tensor) -> np.ndarray: - if x.dtype == ms.bfloat16: - x = x.astype(ms.float32) - return x.asnumpy() - - def prepare_captions( ul2_dir: Path_dr, metaclip_dir: Path_dr, byt5_dir: Path_dr, rank_id: int, device_num: int ) -> Tuple[List[str], List[str], List[str]]: diff --git a/examples/moviegen/inference_text_enc.py b/examples/moviegen/inference_text_enc.py new file mode 100644 index 0000000000..6b3a68c1ad --- /dev/null +++ b/examples/moviegen/inference_text_enc.py @@ -0,0 +1,128 @@ +import logging +import os +import sys +from csv import DictReader +from pathlib import Path +from typing import List, Tuple + +import numpy as np +from jsonargparse import ArgumentParser +from jsonargparse.typing import Path_fr, path_type +from tqdm import trange +from transformers import AutoTokenizer + +import mindspore as ms + +# TODO: remove in future when mindone is ready for install +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.append(mindone_lib_path) + +from moviegen.utils import MODEL_DTYPE, to_numpy + +from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel +from mindone.utils import init_train_env, set_logger + +logger = logging.getLogger(__name__) + +Path_dcc = path_type("dcc") # path to a directory that can be created if it does not exist + + +def prepare_captions( + prompts_file: Path_fr, + output_path: Path_dcc, + column_names: Tuple[str, str] = ("video", "caption"), + rank_id: int = 0, + device_num: int = 1, +) -> Tuple[List[Path], List[str]]: + """ + Reads prompts from a file and returns a list of saving paths and a list of captions. + + Args: + prompts_file: Path to the prompt file. Can be a csv file or a txt file. + output_path: Path to the output directory where the embeddings will be saved. + column_names: [CSV only] Tuple of column names for video paths and captions. + rank_id: Current rank id for distributed inference. + device_num: Number of devices used for distributed inference. + + Returns: + A tuple containing a list of saving paths and a list of captions. + """ + prompts_file = prompts_file.absolute + output_path = Path(output_path.absolute) + with open(prompts_file, "r", encoding="utf-8") as file: + if prompts_file.endswith(".csv"): + paths, captions = zip( + *[ + (output_path / Path(row[column_names[0]]).with_suffix(".npz"), row[column_names[1]]) + for row in DictReader(file) + ] + ) + return paths[rank_id::device_num], captions[rank_id::device_num] + else: + captions = [line.strip() for line in file] # preserve empty lines + paths = [ + output_path / (f"{i:03d}-" + "-".join(Path(cap).stem.split(" ")[:10]) + ".npz") + for i, cap in enumerate(captions) + ] + return paths[rank_id::device_num], captions[rank_id::device_num] + + +def main(args): + save_dir = args.output_path.absolute + os.makedirs(save_dir, exist_ok=True) + set_logger(name="", output_dir=save_dir) + + _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical? + + paths, captions = prepare_captions(args.prompts_file, args.output_path, args.column_names, rank_id, device_num) + + # model initiate and weight loading + tokenizer = AutoTokenizer.from_pretrained( + args.model_name, local_files_only=True, clean_up_tokenization_spaces=False + ) + model = T5EncoderModel.from_pretrained( + args.model_name, mindspore_dtype=MODEL_DTYPE[args.dtype.lower()], local_files_only=True + ).set_train(False) + + logger.info(f"Number of devices: {device_num} | Rank ID: {rank_id} | Number of captions: {len(captions)}") + logger.info( + f"Model name: {args.model_name} | Precision: {args.dtype} | Embedded sequence length: {args.model_max_length}" + ) + + for i in trange(0, len(captions), args.batch_size): + batch = captions[i : i + args.batch_size] + inputs = tokenizer( + batch, + max_length=args.model_max_length, + padding="max_length", + return_attention_mask=True, + truncation=True, + return_tensors="np", + ) + tokens = inputs.input_ids + masks = inputs.attention_mask + output = model(ms.Tensor(inputs.input_ids, dtype=ms.int32), ms.Tensor(inputs.attention_mask, dtype=ms.uint8))[0] + output = to_numpy(output).astype(np.float32) + + for j in range(len(output)): + paths[i + j].parent.mkdir(parents=True, exist_ok=True) + with open(os.path.join(save_dir, paths[i + j]), "wb") as f: + np.savez(f, mask=masks[j], text_emb=output[j], tokens=tokens[j]) + + logger.info(f"Finished. Embeddings saved to {save_dir}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Text embeddings generation script.") + parser.add_function_arguments(init_train_env, "env") + parser.add_argument("--model_name", type=str, default="google/byt5-small", help="Text encoder model name.") + parser.add_argument( + "--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Text encoder model precision." + ) + parser.add_function_arguments(prepare_captions, as_group=False, skip={"rank_id", "device_num"}) + parser.add_argument("--batch_size", default=10, type=int, help="Inference batch size.") + parser.add_argument("--model_max_length", type=int, default=300, help="Model's embedded sequence length.") + cfg = parser.parse_args() + main(cfg) diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py index 33a0e314cd..62934d58bc 100644 --- a/examples/moviegen/moviegen/utils/__init__.py +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -1,2 +1,3 @@ from .ema import * from .model_utils import * +from .utils import * diff --git a/examples/moviegen/moviegen/utils/utils.py b/examples/moviegen/moviegen/utils/utils.py new file mode 100644 index 0000000000..70bf86c777 --- /dev/null +++ b/examples/moviegen/moviegen/utils/utils.py @@ -0,0 +1,11 @@ +import numpy as np + +import mindspore as ms + +__all__ = ["to_numpy"] + + +def to_numpy(x: ms.Tensor) -> np.ndarray: + if x.dtype == ms.bfloat16: + x = x.astype(ms.float32) + return x.asnumpy() From 90359e9842f125bc75163841b437455abc84fe6a Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 13 Nov 2024 16:34:00 +0800 Subject: [PATCH 040/122] allow loading sd3.5 vae pretrained weights --- examples/movie_gen/mg/models/tae/modules.py | 142 ++++++---- examples/movie_gen/mg/models/tae/tae.py | 66 ++++- examples/movie_gen/tests/test_tae.py | 64 ++++- examples/movie_gen/tools/inflate_sd3.5_vae.py | 87 +++++++ .../movie_gen/tools/ms_pnames_sd3.5_vae.txt | 244 ++++++++++++++++++ .../movie_gen/tools/ms_pnames_tae_vae.txt | 244 ++++++++++++++++++ .../movie_gen/tools/pt_pnames_sd3.5_vae.txt | 244 ++++++++++++++++++ 7 files changed, 1032 insertions(+), 59 deletions(-) create mode 100644 examples/movie_gen/tools/inflate_sd3.5_vae.py create mode 100644 examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt create mode 100644 examples/movie_gen/tools/ms_pnames_tae_vae.txt create mode 100644 examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index 5faabf8efd..d0637a27fb 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -140,36 +140,24 @@ def __init__( assert dilation==1 # spatial conv self.conv_spat = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias) - + # temp_pad_mode = 'zero' # temp_pad = 'mint_rep' temp_pad = 'manual' # temporal conv if kernel_size > 1: - # FIXME: debugging to see how symmetric padding influence performance - if temp_pad == 'zero': - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="same", has_bias=has_bias) - self.use_pad = False - self.pad = nn.Identity() - elif temp_pad == 'mint_rep': - assert kernel_size == 3 - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) - self.pad = nn.ReplicationPad1d(((kernel_size-1)//2, (kernel_size-1)//2)) - self.use_pad = True - elif temp_pad == 'manual': - assert kernel_size == 3, 'symmetric padding currently only support kernel size 3' - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) - self.pad = self.symmetric_pad1d - self.use_pad = True - elif temp_pad == 'nn_pad': - self.pad = nn.Pad(paddings=((0, 0), (0, 0), ((kernel_size-1)//2, (kernel_size-1)//2)), mode='SYMMETRIC') - self.use_pad = True - + # symmetric padding + conv1d + assert kernel_size == 3, 'symmetric padding currently only support kernel size 3' + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias, bias_init='zeros') + self.pad = self.symmetric_pad1d + self.use_pad = True else: self.use_pad = False - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias) - + self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias, bias_init='zeros') + + self.init_temporal_weight() + @staticmethod def symmetric_pad1d(x): first_frame = x[:, :, :1] @@ -210,6 +198,7 @@ def construct(self, x): # import pdb; pdb.set_trace() x = self.pad(x) + # import pdb; pdb.set_trace() x = self.conv_temp(x) # (b*h*w c t) -> (b t c h w) @@ -221,6 +210,22 @@ def construct(self, x): return x + def init_temporal_weight(self): + # temporal conv kernel: (cout, cin, 1, ks) + # ks=1 or 3, cin == cout + # import pdb; pdb.set_trace() + w = self.conv_temp.weight + ch = int(w.shape[0]) + ks = int(w.shape[-1]) + value = np.zeros(tuple(w.shape)) + + # only the middle element of the kernel is 1 so that the output is the same input in initialization + for i in range(ch): + value[i, i, 0, ks//2] = 1 + w.set_data(ms.Tensor(value, dtype=ms.float32)) + + # bias is initialized to zero in layer def + class SpatialUpsample(nn.Cell): def __init__(self, in_channels, with_conv): @@ -299,14 +304,27 @@ def __init__(self, in_channels): ) # tail padding, pad with last frame self.time_pad = self.ks - 1 - self.init_weight() + self.init_weight("mean") - def init_weight(self): + def init_weight(self, method='mean'): + if method == 'normal': + # default conv init + return + + # no way to reserve complete input since stride 2 w = self.conv.weight value = np.zeros(tuple(w.shape)) - # TODO: ablate with normal init - for i in range(self.ch): - value[i, i, 0, :] = 1/self.ks # (cout, cin, 1, ks) + if method == 'mean': + # initially, it's a mean filter for temporal downsampling + for i in range(self.ch): + value[i, i, 0, :] = 1/self.ks # (cout, cin, 1, ks) + elif method == 'median': + # a median filter for temporal downsampling + for i in range(self.ch): + value[i, i, 0, self.ks//2] = 1 # (cout, cin, 1, ks) + else: + raise NotImplementedError + w.set_data(ms.Tensor(value, dtype=ms.float32)) @@ -341,14 +359,20 @@ def __init__(self, in_channels): self.ch = in_channels self.init_weight() - def init_weight(self): + def init_weight(self, method='median'): + if method == 'normal': + return + + # init so that the output is the same as vae2d for image input w = self.conv.weight value = np.zeros(tuple(w.shape)) - # TODO: ablate with normal init - # consider image input, make sure it's the same - for i in range(self.ch): - value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) - w.set_data(ms.Tensor(value, dtype=ms.float32)) + if method == 'median': + # consider image input, make sure it's the same + for i in range(self.ch): + value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) + w.set_data(ms.Tensor(value, dtype=ms.float32)) + else: + raise NotImplementedError def construct(self, x): # x (b c t h w) @@ -421,6 +445,7 @@ def construct(self, x): return x + h + class SpatialAttnBlock(nn.Cell): def __init__(self, in_channels): super().__init__() @@ -577,12 +602,18 @@ def construct(self, x): def make_attn(in_channels, attn_type="vanilla"): # assert attn_type in ["vanilla", "vanilla3D"], f"attn_type {attn_type} not supported" - _logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") + # _logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return nn.SequentialCell( SpatialAttnBlock(in_channels), TemporalAttnBlock(in_channels), ) + elif attn_type == 'spat_only': + # to ensure naming consistency + return nn.SequentialCell( + SpatialAttnBlock(in_channels), + ) else: raise NotImplementedError @@ -604,6 +635,8 @@ def __init__( double_z=True, use_linear_attn=False, attn_type="vanilla", + temporal_downsample_level=(0, 1, 2), # same as spatial + **kwargs, ): super().__init__() # if use_linear_attn: attn_type = "linear" @@ -612,6 +645,7 @@ def __init__( self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.temporal_downsample_level = temporal_downsample_level # downsampling self.conv_in = Conv2_5d( @@ -642,14 +676,15 @@ def __init__( down.block = block down.attn = attn if i_level != self.num_resolutions - 1: - # down.downsample_spat = SpatialDownsample(block_in, resamp_with_conv) - # down.downsample_temp = TemporalDownsample(block_in) - down.downsample = nn.SequentialCell( - SpatialDownsample(block_in, resamp_with_conv), - TemporalDownsample(block_in), - ) + down.downsample_spat = SpatialDownsample(block_in, resamp_with_conv) + else: + down.downsample_spat = nn.Identity() + + if i_level in self.temporal_downsample_level: + down.downsample_temp = TemporalDownsample(block_in) else: - down.downsample = nn.Identity() + down.downsample_temp = nn.Identity() + curr_res = curr_res // 2 down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") self.down.append(down) @@ -696,8 +731,9 @@ def construct(self, x): if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs = h - if i_level != self.num_resolutions - 1: - hs = self.down[i_level].downsample(hs) + # if i_level != self.num_resolutions - 1: + hs = self.down[i_level].downsample_spat(hs) + hs = self.down[i_level].downsample_temp(hs) # middle h = hs @@ -732,6 +768,7 @@ def __init__( tanh_out=False, use_linear_attn=False, attn_type="vanilla", + temporal_upsample_level=(1,2,3), # same as spatial **ignorekwargs, ): super().__init__() @@ -743,6 +780,7 @@ def __init__( self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.temporal_upsample_level = temporal_upsample_level # compute in_ch_mult, block_in and curr_res at lowest res # in_ch_mult = (1,) + tuple(ch_mult) @@ -787,12 +825,15 @@ def __init__( up.block = block up.attn = attn if i_level != 0: - up.upsample = nn.SequentialCell( - SpatialUpsample(block_in, resamp_with_conv), - TemporalUpsample(block_in), - ) + up.upsample_spat = SpatialUpsample(block_in, resamp_with_conv) + else: + up.upsample_spat = nn.Identity() + + if i_level in self.temporal_upsample_level: + up.upsample_temp = TemporalUpsample(block_in) else: - up.upsample = nn.Identity() + up.upsample_temp = nn.Identity() + curr_res = curr_res * 2 up.update_parameters_name(prefix=self.param_prefix + f"up.{i_level}.") if len(self.up) != 0: @@ -828,8 +869,9 @@ def construct(self, z): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) + + h = self.up[i_level].upsample_spat(h) + h = self.up[i_level].upsample_temp(h) # end if self.give_pre_end: diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 2915676481..cb3601dd38 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -3,6 +3,7 @@ from .modules import Conv2_5d, Encoder, Decoder # TODO: set z_channels to 16 + SDXL_CONFIG = { "double_z": True, "z_channels": 4, @@ -14,6 +15,30 @@ "num_res_blocks": 2, "attn_resolutions": [], "dropout": 0.0, + "use_post_quant_conv": True, + "use_quant_conv": True +} + +# modify based on SD3d5_CONFIG +TAE_CONFIG = { + "double_z": True, + "z_channels": 16, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "scaling_factor": 1.5305, + "shift_factor": 0.0609, + "use_post_quant_conv": False, + "use_quant_conv": False, + "attn_type": "vanilla", + "temporal_downsample_level": [0, 1, 2], + "temporal_upsample_level": [3, 2, 1], + } @@ -28,9 +53,10 @@ class TemporalAutoencoder(nn.Cell): def __init__( self, - config: dict = SDXL_CONFIG, + config: dict = TAE_CONFIG, pretrained: str = None, use_recompute: bool=False, + sample_deterministic: bool=False, ): super().__init__() @@ -39,8 +65,13 @@ def __init__( # quant and post quant embed_dim = config['z_channels'] - self.quant_conv = Conv2_5d(2 * config["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True) - self.post_quant_conv = Conv2_5d(embed_dim, config["z_channels"], 1, pad_mode="valid", has_bias=True) + if config['use_quant_conv']: + self.quant_conv = Conv2_5d(2 * embed_dim, 2 * embed_dim, 1, pad_mode="valid", has_bias=True) + if config['use_post_quant_conv']: + self.post_quant_conv = Conv2_5d(embed_dim, embed_dim, 1, pad_mode="valid", has_bias=True) + + self.use_quant_conv = config['use_quant_conv'] + self.use_post_quant_conv = config['use_post_quant_conv'] # decoder self.decoder = Decoder(**config) @@ -49,7 +80,7 @@ def __init__( self.stdnormal = ops.StandardNormal() self.split = ms.ops.split - self.sample_deterministic = False + self.sample_deterministic = sample_deterministic self.discard_spurious_frames = True if use_recompute: @@ -71,7 +102,10 @@ def recompute(self, b): def _encode(self, x): # return latent distribution, N(mean, logvar) h = self.encoder(x) - moments = self.quant_conv(h) + if self.use_quant_conv: + moments = self.quant_conv(h) + else: + moments = h mean, logvar = self.split(moments, moments.shape[1] // 2, 1) return mean, logvar @@ -94,7 +128,8 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: return z def decode(self, z: ms.Tensor) -> ms.Tensor: - z = self.post_quant_conv(z) + if self.use_post_quant_conv: + z = self.post_quant_conv(z) dec = self.decoder(z) return dec @@ -114,3 +149,22 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: return recons, z, posterior_mean, posterior_logvar + + def load_pretrained(self, ckpt_path:str): + if ckpt_path.endswith('safetensors'): + # load vae parameters from safetensors into my mindspore model + import safetensors + ckpt = safetensors.safe_open(ckpt_path, framework="pt") + state_dict = {} + for key in ckpt.keys(): + state_dict[key] = ckpt.get_tensor(key) + raise NotImplementedError + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: + print(f"{param_not_load} in network is not loaded") + print(f"{ckpt_not_load} in checkpoint is not loaded!") + print('tae checkpoint loaded') + + diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 49fc206c11..3eafaf2203 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -1,5 +1,6 @@ import numpy as np import sys +from PIL import Image sys.path.insert(0, '.') from mg.models.tae.modules import ( @@ -15,11 +16,38 @@ TemporalDownsample, TemporalUpsample, ) -from mg.models.tae.tae import SDXL_CONFIG, TemporalAutoencoder +from mg.models.tae.tae import SDXL_CONFIG, TAE_CONFIG, TemporalAutoencoder +from mg.models.tae.sd3_vae import SD3d5_CONFIG, SD3d5_VAE import mindspore as ms +def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", + W=128, + H=128): + target_size = (H, W) + + # read image using PIL and preprocess + image = Image.open(img_path).convert('RGB') + image = image.resize(target_size, Image.ANTIALIAS) + pixel_values = np.array(image, dtype=np.float32) + pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) + + pixel_values = pixel_values.transpose(2, 0, 1) + + return pixel_values + +def save_output_image(image_array, output_path='tests/tmp_output.png'): + image_array = image_array.transpose((1, 2, 0)) + image_array = ((image_array + 1) * 127.5).astype(np.uint8) + image_array = np.clip(image_array, 0, 255) + + image = Image.fromarray(image_array) + + image.save(output_path) + print(f'image saved in {output_path}') + + def test_conv25d(): in_shape = (B, C, T, H, W) = (2, 3, 16, 256, 256) cout = 128 @@ -183,14 +211,42 @@ def test_tae_decode(): def test_tae_rec(): - in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) + TAE_CONFIG['attn_type'] = 'spat_only' + tae = TemporalAutoencoder(config=TAE_CONFIG) + tae.load_pretrained("models/tae_vae2d.ckpt") + + # in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) + in_shape = (B, C, T, H, W) = (1, 3, 1, 128, 128) x = np.random.normal(size=in_shape).astype(np.float32) + img = get_input_image(H=H, W=W) + x[0, :, 0, :, :] = img x = ms.Tensor(x) - tae = TemporalAutoencoder(config=SDXL_CONFIG) y = tae(x) print(y[0].shape) + save_output_image(y[0].numpy()[0, :, 0, :, :], 'tests/tmp_tae_output.png') + +def test_sd3d5_vae(): + vae = SD3d5_VAE(sample_deterministic=True) + vae.load_pretrained("models/sd3.5_vae.ckpt") + + in_shape = (BT, C, H, W) = (1, 3, 128, 128) + x = np.random.normal(size=in_shape).astype(np.float32) + img = get_input_image(H=H, W=W) + x[0] = img + + x = ms.Tensor(x) + + outputs = vae(x) + recons = outputs[0] + print(recons.shape) + + # save to image + # TODO: there are some noise here + save_output_image(recons.numpy()[0]) + + print(recons.sum()) if __name__ == "__main__": @@ -210,3 +266,5 @@ def test_tae_rec(): # test_tae_encode() # test_tae_decode() test_tae_rec() + + # test_sd3d5_vae() diff --git a/examples/movie_gen/tools/inflate_sd3.5_vae.py b/examples/movie_gen/tools/inflate_sd3.5_vae.py new file mode 100644 index 0000000000..2da89b1498 --- /dev/null +++ b/examples/movie_gen/tools/inflate_sd3.5_vae.py @@ -0,0 +1,87 @@ +from safetensors import safe_open +import os +import numpy as np +import mindspore as ms + + +def get_shape_from_str(shape): + shape = shape.replace("(", "").replace(")", "").split(",") + shape = [int(s) for s in shape if len(s) > 0] + + return shape + +def get_pname_shape(ckpt_path): + with safe_open(ckpt_path, framework="pt", device='cpu') as fp: + for key in fp.keys(): + val = fp.get_tensor(key) + shape = tuple(val.shape) + dtype = val.dtype + print(f"{key}#{shape}#{dtype}") + +def load_torch_ckpt(ckpt_path): + pt_state_dict = {} + with safe_open(ckpt_path, framework="pt", device='cpu') as fp: + for key in fp.keys(): + pt_state_dict[key] = fp.get_tensor(key) + # print(key) + return pt_state_dict + +def plot_ms_vae2d5(): + from mg.models.tae.tae import SD3d5_CONFIG, TemporalAutoencoder + tae = TemporalAutoencoder(config=SD3d5_CONFIG) + + sd = tae.parameters_dict() + pnames = list(sd.keys()) + for pname in pnames: + shape = tuple(sd[pname].shape) + print(f"{pname}#{shape}") + + +def convert_vae2d(source_fp, target_fp, target_model='vae2d'): + # read param mapping files + ms_pnames_file = "tools/ms_pnames_sd3.5_vae.txt" if target_model == 'vae2d' else "tools/ms_pnames_tae_vae.txt" + print('target ms pnames is annotated in ', ms_pnames_file) + with open(ms_pnames_file) as file_ms: + lines_ms = list(file_ms.readlines()) + with open("tools/pt_pnames_sd3.5_vae.txt") as file_pt: + lines_pt = list(file_pt.readlines()) + + # if "from_vae2d": + # lines_ms = [line for line in lines_ms if line.startswith("spatial_vae")] + # lines_pt = [line for line in lines_pt if line.startswith("spatial_vae")] + + assert len(lines_ms) == len(lines_pt) + + # convert and save + sd_pt = load_torch_ckpt(source_fp) # state dict + num_params_pt = len(list(sd_pt.keys())) + print("Total params in pt ckpt: ", num_params_pt) + target_data = [] + for i in range(len(lines_pt)): + name_pt, shape_pt = lines_pt[i].strip().split("#") + shape_pt = get_shape_from_str(shape_pt) + name_ms, shape_ms = lines_ms[i].strip().split("#") + shape_ms = get_shape_from_str(shape_ms) + assert np.prod(shape_pt) == np.prod( + shape_ms + ), f"Mismatch param: PT: {name_pt}, {shape_pt} vs MS: {name_ms}, {shape_ms}" + + # if "from_vae2d": + # name_pt = name_pt.replace("spatial_vae.module.", "") + # param can be saved in bf16 + data = sd_pt[name_pt].cpu().detach().float().numpy().reshape(shape_ms) + + data = ms.Tensor(input_data=data.astype(np.float32), dtype=ms.float32) + target_data.append({"name": name_ms, "data": data}) # ms.Tensor(data, dtype=ms.float32)}) + + print("Total params converted: ", len(target_data)) + ms.save_checkpoint(target_data, target_fp) + + +if __name__ == "__main__": + ckpt_path = "/Users/Samit/Downloads/sd3.5_vae/diffusion_pytorch_model.safetensors" + # get_pname_shape(ckpt_path) + # convert_vae2d(ckpt_path, "models/sd3.5_vae.ckpt") + convert_vae2d(ckpt_path, "models/tae_vae2d.ckpt", target_model='tae') + + # plot_ms_vae2d5() diff --git a/examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt b/examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt new file mode 100644 index 0000000000..82d5f1917f --- /dev/null +++ b/examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt @@ -0,0 +1,244 @@ +encoder.conv_in.weight#(128, 3, 3, 3) +encoder.conv_in.bias#(128,) +encoder.down.0.block.0.norm1.gamma#(128,) +encoder.down.0.block.0.norm1.beta#(128,) +encoder.down.0.block.0.conv1.weight#(128, 128, 3, 3) +encoder.down.0.block.0.conv1.bias#(128,) +encoder.down.0.block.0.norm2.gamma#(128,) +encoder.down.0.block.0.norm2.beta#(128,) +encoder.down.0.block.0.conv2.weight#(128, 128, 3, 3) +encoder.down.0.block.0.conv2.bias#(128,) +encoder.down.0.block.1.norm1.gamma#(128,) +encoder.down.0.block.1.norm1.beta#(128,) +encoder.down.0.block.1.conv1.weight#(128, 128, 3, 3) +encoder.down.0.block.1.conv1.bias#(128,) +encoder.down.0.block.1.norm2.gamma#(128,) +encoder.down.0.block.1.norm2.beta#(128,) +encoder.down.0.block.1.conv2.weight#(128, 128, 3, 3) +encoder.down.0.block.1.conv2.bias#(128,) +encoder.down.0.downsample.conv.weight#(128, 128, 3, 3) +encoder.down.0.downsample.conv.bias#(128,) +encoder.down.1.block.0.norm1.gamma#(128,) +encoder.down.1.block.0.norm1.beta#(128,) +encoder.down.1.block.0.conv1.weight#(256, 128, 3, 3) +encoder.down.1.block.0.conv1.bias#(256,) +encoder.down.1.block.0.norm2.gamma#(256,) +encoder.down.1.block.0.norm2.beta#(256,) +encoder.down.1.block.0.conv2.weight#(256, 256, 3, 3) +encoder.down.1.block.0.conv2.bias#(256,) +encoder.down.1.block.0.nin_shortcut.weight#(256, 128, 1, 1) +encoder.down.1.block.0.nin_shortcut.bias#(256,) +encoder.down.1.block.1.norm1.gamma#(256,) +encoder.down.1.block.1.norm1.beta#(256,) +encoder.down.1.block.1.conv1.weight#(256, 256, 3, 3) +encoder.down.1.block.1.conv1.bias#(256,) +encoder.down.1.block.1.norm2.gamma#(256,) +encoder.down.1.block.1.norm2.beta#(256,) +encoder.down.1.block.1.conv2.weight#(256, 256, 3, 3) +encoder.down.1.block.1.conv2.bias#(256,) +encoder.down.1.downsample.conv.weight#(256, 256, 3, 3) +encoder.down.1.downsample.conv.bias#(256,) +encoder.down.2.block.0.norm1.gamma#(256,) +encoder.down.2.block.0.norm1.beta#(256,) +encoder.down.2.block.0.conv1.weight#(512, 256, 3, 3) +encoder.down.2.block.0.conv1.bias#(512,) +encoder.down.2.block.0.norm2.gamma#(512,) +encoder.down.2.block.0.norm2.beta#(512,) +encoder.down.2.block.0.conv2.weight#(512, 512, 3, 3) +encoder.down.2.block.0.conv2.bias#(512,) +encoder.down.2.block.0.nin_shortcut.weight#(512, 256, 1, 1) +encoder.down.2.block.0.nin_shortcut.bias#(512,) +encoder.down.2.block.1.norm1.gamma#(512,) +encoder.down.2.block.1.norm1.beta#(512,) +encoder.down.2.block.1.conv1.weight#(512, 512, 3, 3) +encoder.down.2.block.1.conv1.bias#(512,) +encoder.down.2.block.1.norm2.gamma#(512,) +encoder.down.2.block.1.norm2.beta#(512,) +encoder.down.2.block.1.conv2.weight#(512, 512, 3, 3) +encoder.down.2.block.1.conv2.bias#(512,) +encoder.down.2.downsample.conv.weight#(512, 512, 3, 3) +encoder.down.2.downsample.conv.bias#(512,) +encoder.down.3.block.0.norm1.gamma#(512,) +encoder.down.3.block.0.norm1.beta#(512,) +encoder.down.3.block.0.conv1.weight#(512, 512, 3, 3) +encoder.down.3.block.0.conv1.bias#(512,) +encoder.down.3.block.0.norm2.gamma#(512,) +encoder.down.3.block.0.norm2.beta#(512,) +encoder.down.3.block.0.conv2.weight#(512, 512, 3, 3) +encoder.down.3.block.0.conv2.bias#(512,) +encoder.down.3.block.1.norm1.gamma#(512,) +encoder.down.3.block.1.norm1.beta#(512,) +encoder.down.3.block.1.conv1.weight#(512, 512, 3, 3) +encoder.down.3.block.1.conv1.bias#(512,) +encoder.down.3.block.1.norm2.gamma#(512,) +encoder.down.3.block.1.norm2.beta#(512,) +encoder.down.3.block.1.conv2.weight#(512, 512, 3, 3) +encoder.down.3.block.1.conv2.bias#(512,) +encoder.mid.block_1.norm1.gamma#(512,) +encoder.mid.block_1.norm1.beta#(512,) +encoder.mid.block_1.conv1.weight#(512, 512, 3, 3) +encoder.mid.block_1.conv1.bias#(512,) +encoder.mid.block_1.norm2.gamma#(512,) +encoder.mid.block_1.norm2.beta#(512,) +encoder.mid.block_1.conv2.weight#(512, 512, 3, 3) +encoder.mid.block_1.conv2.bias#(512,) +encoder.mid.attn_1.norm.gamma#(512,) +encoder.mid.attn_1.norm.beta#(512,) +encoder.mid.attn_1.q.weight#(512, 512, 1, 1) +encoder.mid.attn_1.q.bias#(512,) +encoder.mid.attn_1.k.weight#(512, 512, 1, 1) +encoder.mid.attn_1.k.bias#(512,) +encoder.mid.attn_1.v.weight#(512, 512, 1, 1) +encoder.mid.attn_1.v.bias#(512,) +encoder.mid.attn_1.proj_out.weight#(512, 512, 1, 1) +encoder.mid.attn_1.proj_out.bias#(512,) +encoder.mid.block_2.norm1.gamma#(512,) +encoder.mid.block_2.norm1.beta#(512,) +encoder.mid.block_2.conv1.weight#(512, 512, 3, 3) +encoder.mid.block_2.conv1.bias#(512,) +encoder.mid.block_2.norm2.gamma#(512,) +encoder.mid.block_2.norm2.beta#(512,) +encoder.mid.block_2.conv2.weight#(512, 512, 3, 3) +encoder.mid.block_2.conv2.bias#(512,) +encoder.norm_out.gamma#(512,) +encoder.norm_out.beta#(512,) +encoder.conv_out.weight#(32, 512, 3, 3) +encoder.conv_out.bias#(32,) +decoder.conv_in.weight#(512, 16, 3, 3) +decoder.conv_in.bias#(512,) +decoder.mid.block_1.norm1.gamma#(512,) +decoder.mid.block_1.norm1.beta#(512,) +decoder.mid.block_1.conv1.weight#(512, 512, 3, 3) +decoder.mid.block_1.conv1.bias#(512,) +decoder.mid.block_1.norm2.gamma#(512,) +decoder.mid.block_1.norm2.beta#(512,) +decoder.mid.block_1.conv2.weight#(512, 512, 3, 3) +decoder.mid.block_1.conv2.bias#(512,) +decoder.mid.attn_1.norm.gamma#(512,) +decoder.mid.attn_1.norm.beta#(512,) +decoder.mid.attn_1.q.weight#(512, 512, 1, 1) +decoder.mid.attn_1.q.bias#(512,) +decoder.mid.attn_1.k.weight#(512, 512, 1, 1) +decoder.mid.attn_1.k.bias#(512,) +decoder.mid.attn_1.v.weight#(512, 512, 1, 1) +decoder.mid.attn_1.v.bias#(512,) +decoder.mid.attn_1.proj_out.weight#(512, 512, 1, 1) +decoder.mid.attn_1.proj_out.bias#(512,) +decoder.mid.block_2.norm1.gamma#(512,) +decoder.mid.block_2.norm1.beta#(512,) +decoder.mid.block_2.conv1.weight#(512, 512, 3, 3) +decoder.mid.block_2.conv1.bias#(512,) +decoder.mid.block_2.norm2.gamma#(512,) +decoder.mid.block_2.norm2.beta#(512,) +decoder.mid.block_2.conv2.weight#(512, 512, 3, 3) +decoder.mid.block_2.conv2.bias#(512,) +decoder.up.0.block.0.norm1.gamma#(256,) +decoder.up.0.block.0.norm1.beta#(256,) +decoder.up.0.block.0.conv1.weight#(128, 256, 3, 3) +decoder.up.0.block.0.conv1.bias#(128,) +decoder.up.0.block.0.norm2.gamma#(128,) +decoder.up.0.block.0.norm2.beta#(128,) +decoder.up.0.block.0.conv2.weight#(128, 128, 3, 3) +decoder.up.0.block.0.conv2.bias#(128,) +decoder.up.0.block.0.nin_shortcut.weight#(128, 256, 1, 1) +decoder.up.0.block.0.nin_shortcut.bias#(128,) +decoder.up.0.block.1.norm1.gamma#(128,) +decoder.up.0.block.1.norm1.beta#(128,) +decoder.up.0.block.1.conv1.weight#(128, 128, 3, 3) +decoder.up.0.block.1.conv1.bias#(128,) +decoder.up.0.block.1.norm2.gamma#(128,) +decoder.up.0.block.1.norm2.beta#(128,) +decoder.up.0.block.1.conv2.weight#(128, 128, 3, 3) +decoder.up.0.block.1.conv2.bias#(128,) +decoder.up.0.block.2.norm1.gamma#(128,) +decoder.up.0.block.2.norm1.beta#(128,) +decoder.up.0.block.2.conv1.weight#(128, 128, 3, 3) +decoder.up.0.block.2.conv1.bias#(128,) +decoder.up.0.block.2.norm2.gamma#(128,) +decoder.up.0.block.2.norm2.beta#(128,) +decoder.up.0.block.2.conv2.weight#(128, 128, 3, 3) +decoder.up.0.block.2.conv2.bias#(128,) +decoder.up.1.block.0.norm1.gamma#(512,) +decoder.up.1.block.0.norm1.beta#(512,) +decoder.up.1.block.0.conv1.weight#(256, 512, 3, 3) +decoder.up.1.block.0.conv1.bias#(256,) +decoder.up.1.block.0.norm2.gamma#(256,) +decoder.up.1.block.0.norm2.beta#(256,) +decoder.up.1.block.0.conv2.weight#(256, 256, 3, 3) +decoder.up.1.block.0.conv2.bias#(256,) +decoder.up.1.block.0.nin_shortcut.weight#(256, 512, 1, 1) +decoder.up.1.block.0.nin_shortcut.bias#(256,) +decoder.up.1.block.1.norm1.gamma#(256,) +decoder.up.1.block.1.norm1.beta#(256,) +decoder.up.1.block.1.conv1.weight#(256, 256, 3, 3) +decoder.up.1.block.1.conv1.bias#(256,) +decoder.up.1.block.1.norm2.gamma#(256,) +decoder.up.1.block.1.norm2.beta#(256,) +decoder.up.1.block.1.conv2.weight#(256, 256, 3, 3) +decoder.up.1.block.1.conv2.bias#(256,) +decoder.up.1.block.2.norm1.gamma#(256,) +decoder.up.1.block.2.norm1.beta#(256,) +decoder.up.1.block.2.conv1.weight#(256, 256, 3, 3) +decoder.up.1.block.2.conv1.bias#(256,) +decoder.up.1.block.2.norm2.gamma#(256,) +decoder.up.1.block.2.norm2.beta#(256,) +decoder.up.1.block.2.conv2.weight#(256, 256, 3, 3) +decoder.up.1.block.2.conv2.bias#(256,) +decoder.up.1.upsample.conv.weight#(256, 256, 3, 3) +decoder.up.1.upsample.conv.bias#(256,) +decoder.up.2.block.0.norm1.gamma#(512,) +decoder.up.2.block.0.norm1.beta#(512,) +decoder.up.2.block.0.conv1.weight#(512, 512, 3, 3) +decoder.up.2.block.0.conv1.bias#(512,) +decoder.up.2.block.0.norm2.gamma#(512,) +decoder.up.2.block.0.norm2.beta#(512,) +decoder.up.2.block.0.conv2.weight#(512, 512, 3, 3) +decoder.up.2.block.0.conv2.bias#(512,) +decoder.up.2.block.1.norm1.gamma#(512,) +decoder.up.2.block.1.norm1.beta#(512,) +decoder.up.2.block.1.conv1.weight#(512, 512, 3, 3) +decoder.up.2.block.1.conv1.bias#(512,) +decoder.up.2.block.1.norm2.gamma#(512,) +decoder.up.2.block.1.norm2.beta#(512,) +decoder.up.2.block.1.conv2.weight#(512, 512, 3, 3) +decoder.up.2.block.1.conv2.bias#(512,) +decoder.up.2.block.2.norm1.gamma#(512,) +decoder.up.2.block.2.norm1.beta#(512,) +decoder.up.2.block.2.conv1.weight#(512, 512, 3, 3) +decoder.up.2.block.2.conv1.bias#(512,) +decoder.up.2.block.2.norm2.gamma#(512,) +decoder.up.2.block.2.norm2.beta#(512,) +decoder.up.2.block.2.conv2.weight#(512, 512, 3, 3) +decoder.up.2.block.2.conv2.bias#(512,) +decoder.up.2.upsample.conv.weight#(512, 512, 3, 3) +decoder.up.2.upsample.conv.bias#(512,) +decoder.up.3.block.0.norm1.gamma#(512,) +decoder.up.3.block.0.norm1.beta#(512,) +decoder.up.3.block.0.conv1.weight#(512, 512, 3, 3) +decoder.up.3.block.0.conv1.bias#(512,) +decoder.up.3.block.0.norm2.gamma#(512,) +decoder.up.3.block.0.norm2.beta#(512,) +decoder.up.3.block.0.conv2.weight#(512, 512, 3, 3) +decoder.up.3.block.0.conv2.bias#(512,) +decoder.up.3.block.1.norm1.gamma#(512,) +decoder.up.3.block.1.norm1.beta#(512,) +decoder.up.3.block.1.conv1.weight#(512, 512, 3, 3) +decoder.up.3.block.1.conv1.bias#(512,) +decoder.up.3.block.1.norm2.gamma#(512,) +decoder.up.3.block.1.norm2.beta#(512,) +decoder.up.3.block.1.conv2.weight#(512, 512, 3, 3) +decoder.up.3.block.1.conv2.bias#(512,) +decoder.up.3.block.2.norm1.gamma#(512,) +decoder.up.3.block.2.norm1.beta#(512,) +decoder.up.3.block.2.conv1.weight#(512, 512, 3, 3) +decoder.up.3.block.2.conv1.bias#(512,) +decoder.up.3.block.2.norm2.gamma#(512,) +decoder.up.3.block.2.norm2.beta#(512,) +decoder.up.3.block.2.conv2.weight#(512, 512, 3, 3) +decoder.up.3.block.2.conv2.bias#(512,) +decoder.up.3.upsample.conv.weight#(512, 512, 3, 3) +decoder.up.3.upsample.conv.bias#(512,) +decoder.norm_out.gamma#(128,) +decoder.norm_out.beta#(128,) +decoder.conv_out.weight#(3, 128, 3, 3) +decoder.conv_out.bias#(3,) diff --git a/examples/movie_gen/tools/ms_pnames_tae_vae.txt b/examples/movie_gen/tools/ms_pnames_tae_vae.txt new file mode 100644 index 0000000000..a9caabe583 --- /dev/null +++ b/examples/movie_gen/tools/ms_pnames_tae_vae.txt @@ -0,0 +1,244 @@ +encoder.conv_in.conv_spat.weight#(128, 3, 3, 3) +encoder.conv_in.conv_spat.bias#(128,) +encoder.down.0.block.0.norm1.gamma#(128,) +encoder.down.0.block.0.norm1.beta#(128,) +encoder.down.0.block.0.conv1.conv_spat.weight#(128, 128, 3, 3) +encoder.down.0.block.0.conv1.conv_spat.bias#(128,) +encoder.down.0.block.0.norm2.gamma#(128,) +encoder.down.0.block.0.norm2.beta#(128,) +encoder.down.0.block.0.conv2.conv_spat.weight#(128, 128, 3, 3) +encoder.down.0.block.0.conv2.conv_spat.bias#(128,) +encoder.down.0.block.1.norm1.gamma#(128,) +encoder.down.0.block.1.norm1.beta#(128,) +encoder.down.0.block.1.conv1.conv_spat.weight#(128, 128, 3, 3) +encoder.down.0.block.1.conv1.conv_spat.bias#(128,) +encoder.down.0.block.1.norm2.gamma#(128,) +encoder.down.0.block.1.norm2.beta#(128,) +encoder.down.0.block.1.conv2.conv_spat.weight#(128, 128, 3, 3) +encoder.down.0.block.1.conv2.conv_spat.bias#(128,) +encoder.down.0.downsample_spat.conv.weight#(128, 128, 3, 3) +encoder.down.0.downsample_spat.conv.bias#(128,) +encoder.down.1.block.0.norm1.gamma#(128,) +encoder.down.1.block.0.norm1.beta#(128,) +encoder.down.1.block.0.conv1.conv_spat.weight#(256, 128, 3, 3) +encoder.down.1.block.0.conv1.conv_spat.bias#(256,) +encoder.down.1.block.0.norm2.gamma#(256,) +encoder.down.1.block.0.norm2.beta#(256,) +encoder.down.1.block.0.conv2.conv_spat.weight#(256, 256, 3, 3) +encoder.down.1.block.0.conv2.conv_spat.bias#(256,) +encoder.down.1.block.0.nin_shortcut.conv_spat.weight#(256, 128, 1, 1) +encoder.down.1.block.0.nin_shortcut.conv_spat.bias#(256,) +encoder.down.1.block.1.norm1.gamma#(256,) +encoder.down.1.block.1.norm1.beta#(256,) +encoder.down.1.block.1.conv1.conv_spat.weight#(256, 256, 3, 3) +encoder.down.1.block.1.conv1.conv_spat.bias#(256,) +encoder.down.1.block.1.norm2.gamma#(256,) +encoder.down.1.block.1.norm2.beta#(256,) +encoder.down.1.block.1.conv2.conv_spat.weight#(256, 256, 3, 3) +encoder.down.1.block.1.conv2.conv_spat.bias#(256,) +encoder.down.1.downsample_spat.conv.weight#(256, 256, 3, 3) +encoder.down.1.downsample_spat.conv.bias#(256,) +encoder.down.2.block.0.norm1.gamma#(256,) +encoder.down.2.block.0.norm1.beta#(256,) +encoder.down.2.block.0.conv1.conv_spat.weight#(512, 256, 3, 3) +encoder.down.2.block.0.conv1.conv_spat.bias#(512,) +encoder.down.2.block.0.norm2.gamma#(512,) +encoder.down.2.block.0.norm2.beta#(512,) +encoder.down.2.block.0.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.down.2.block.0.conv2.conv_spat.bias#(512,) +encoder.down.2.block.0.nin_shortcut.conv_spat.weight#(512, 256, 1, 1) +encoder.down.2.block.0.nin_shortcut.conv_spat.bias#(512,) +encoder.down.2.block.1.norm1.gamma#(512,) +encoder.down.2.block.1.norm1.beta#(512,) +encoder.down.2.block.1.conv1.conv_spat.weight#(512, 512, 3, 3) +encoder.down.2.block.1.conv1.conv_spat.bias#(512,) +encoder.down.2.block.1.norm2.gamma#(512,) +encoder.down.2.block.1.norm2.beta#(512,) +encoder.down.2.block.1.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.down.2.block.1.conv2.conv_spat.bias#(512,) +encoder.down.2.downsample_spat.conv.weight#(512, 512, 3, 3) +encoder.down.2.downsample_spat.conv.bias#(512,) +encoder.down.3.block.0.norm1.gamma#(512,) +encoder.down.3.block.0.norm1.beta#(512,) +encoder.down.3.block.0.conv1.conv_spat.weight#(512, 512, 3, 3) +encoder.down.3.block.0.conv1.conv_spat.bias#(512,) +encoder.down.3.block.0.norm2.gamma#(512,) +encoder.down.3.block.0.norm2.beta#(512,) +encoder.down.3.block.0.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.down.3.block.0.conv2.conv_spat.bias#(512,) +encoder.down.3.block.1.norm1.gamma#(512,) +encoder.down.3.block.1.norm1.beta#(512,) +encoder.down.3.block.1.conv1.conv_spat.weight#(512, 512, 3, 3) +encoder.down.3.block.1.conv1.conv_spat.bias#(512,) +encoder.down.3.block.1.norm2.gamma#(512,) +encoder.down.3.block.1.norm2.beta#(512,) +encoder.down.3.block.1.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.down.3.block.1.conv2.conv_spat.bias#(512,) +encoder.mid.block_1.norm1.gamma#(512,) +encoder.mid.block_1.norm1.beta#(512,) +encoder.mid.block_1.conv1.conv_spat.weight#(512, 512, 3, 3) +encoder.mid.block_1.conv1.conv_spat.bias#(512,) +encoder.mid.block_1.norm2.gamma#(512,) +encoder.mid.block_1.norm2.beta#(512,) +encoder.mid.block_1.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.mid.block_1.conv2.conv_spat.bias#(512,) +encoder.mid.attn_1.0.norm.gamma#(512,) +encoder.mid.attn_1.0.norm.beta#(512,) +encoder.mid.attn_1.0.q.weight#(512, 512, 1, 1) +encoder.mid.attn_1.0.q.bias#(512,) +encoder.mid.attn_1.0.k.weight#(512, 512, 1, 1) +encoder.mid.attn_1.0.k.bias#(512,) +encoder.mid.attn_1.0.v.weight#(512, 512, 1, 1) +encoder.mid.attn_1.0.v.bias#(512,) +encoder.mid.attn_1.0.proj_out.weight#(512, 512, 1, 1) +encoder.mid.attn_1.0.proj_out.bias#(512,) +encoder.mid.block_2.norm1.gamma#(512,) +encoder.mid.block_2.norm1.beta#(512,) +encoder.mid.block_2.conv1.conv_spat.weight#(512, 512, 3, 3) +encoder.mid.block_2.conv1.conv_spat.bias#(512,) +encoder.mid.block_2.norm2.gamma#(512,) +encoder.mid.block_2.norm2.beta#(512,) +encoder.mid.block_2.conv2.conv_spat.weight#(512, 512, 3, 3) +encoder.mid.block_2.conv2.conv_spat.bias#(512,) +encoder.norm_out.gamma#(512,) +encoder.norm_out.beta#(512,) +encoder.conv_out.conv_spat.weight#(32, 512, 3, 3) +encoder.conv_out.conv_spat.bias#(32,) +decoder.conv_in.conv_spat.weight#(512, 16, 3, 3) +decoder.conv_in.conv_spat.bias#(512,) +decoder.mid.block_1.norm1.gamma#(512,) +decoder.mid.block_1.norm1.beta#(512,) +decoder.mid.block_1.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.mid.block_1.conv1.conv_spat.bias#(512,) +decoder.mid.block_1.norm2.gamma#(512,) +decoder.mid.block_1.norm2.beta#(512,) +decoder.mid.block_1.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.mid.block_1.conv2.conv_spat.bias#(512,) +decoder.mid.attn_1.0.norm.gamma#(512,) +decoder.mid.attn_1.0.norm.beta#(512,) +decoder.mid.attn_1.0.q.weight#(512, 512, 1, 1) +decoder.mid.attn_1.0.q.bias#(512,) +decoder.mid.attn_1.0.k.weight#(512, 512, 1, 1) +decoder.mid.attn_1.0.k.bias#(512,) +decoder.mid.attn_1.0.v.weight#(512, 512, 1, 1) +decoder.mid.attn_1.0.v.bias#(512,) +decoder.mid.attn_1.0.proj_out.weight#(512, 512, 1, 1) +decoder.mid.attn_1.0.proj_out.bias#(512,) +decoder.mid.block_2.norm1.gamma#(512,) +decoder.mid.block_2.norm1.beta#(512,) +decoder.mid.block_2.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.mid.block_2.conv1.conv_spat.bias#(512,) +decoder.mid.block_2.norm2.gamma#(512,) +decoder.mid.block_2.norm2.beta#(512,) +decoder.mid.block_2.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.mid.block_2.conv2.conv_spat.bias#(512,) +decoder.up.0.block.0.norm1.gamma#(256,) +decoder.up.0.block.0.norm1.beta#(256,) +decoder.up.0.block.0.conv1.conv_spat.weight#(128, 256, 3, 3) +decoder.up.0.block.0.conv1.conv_spat.bias#(128,) +decoder.up.0.block.0.norm2.gamma#(128,) +decoder.up.0.block.0.norm2.beta#(128,) +decoder.up.0.block.0.conv2.conv_spat.weight#(128, 128, 3, 3) +decoder.up.0.block.0.conv2.conv_spat.bias#(128,) +decoder.up.0.block.0.nin_shortcut.conv_spat.weight#(128, 256, 1, 1) +decoder.up.0.block.0.nin_shortcut.conv_spat.bias#(128,) +decoder.up.0.block.1.norm1.gamma#(128,) +decoder.up.0.block.1.norm1.beta#(128,) +decoder.up.0.block.1.conv1.conv_spat.weight#(128, 128, 3, 3) +decoder.up.0.block.1.conv1.conv_spat.bias#(128,) +decoder.up.0.block.1.norm2.gamma#(128,) +decoder.up.0.block.1.norm2.beta#(128,) +decoder.up.0.block.1.conv2.conv_spat.weight#(128, 128, 3, 3) +decoder.up.0.block.1.conv2.conv_spat.bias#(128,) +decoder.up.0.block.2.norm1.gamma#(128,) +decoder.up.0.block.2.norm1.beta#(128,) +decoder.up.0.block.2.conv1.conv_spat.weight#(128, 128, 3, 3) +decoder.up.0.block.2.conv1.conv_spat.bias#(128,) +decoder.up.0.block.2.norm2.gamma#(128,) +decoder.up.0.block.2.norm2.beta#(128,) +decoder.up.0.block.2.conv2.conv_spat.weight#(128, 128, 3, 3) +decoder.up.0.block.2.conv2.conv_spat.bias#(128,) +decoder.up.1.block.0.norm1.gamma#(512,) +decoder.up.1.block.0.norm1.beta#(512,) +decoder.up.1.block.0.conv1.conv_spat.weight#(256, 512, 3, 3) +decoder.up.1.block.0.conv1.conv_spat.bias#(256,) +decoder.up.1.block.0.norm2.gamma#(256,) +decoder.up.1.block.0.norm2.beta#(256,) +decoder.up.1.block.0.conv2.conv_spat.weight#(256, 256, 3, 3) +decoder.up.1.block.0.conv2.conv_spat.bias#(256,) +decoder.up.1.block.0.nin_shortcut.conv_spat.weight#(256, 512, 1, 1) +decoder.up.1.block.0.nin_shortcut.conv_spat.bias#(256,) +decoder.up.1.block.1.norm1.gamma#(256,) +decoder.up.1.block.1.norm1.beta#(256,) +decoder.up.1.block.1.conv1.conv_spat.weight#(256, 256, 3, 3) +decoder.up.1.block.1.conv1.conv_spat.bias#(256,) +decoder.up.1.block.1.norm2.gamma#(256,) +decoder.up.1.block.1.norm2.beta#(256,) +decoder.up.1.block.1.conv2.conv_spat.weight#(256, 256, 3, 3) +decoder.up.1.block.1.conv2.conv_spat.bias#(256,) +decoder.up.1.block.2.norm1.gamma#(256,) +decoder.up.1.block.2.norm1.beta#(256,) +decoder.up.1.block.2.conv1.conv_spat.weight#(256, 256, 3, 3) +decoder.up.1.block.2.conv1.conv_spat.bias#(256,) +decoder.up.1.block.2.norm2.gamma#(256,) +decoder.up.1.block.2.norm2.beta#(256,) +decoder.up.1.block.2.conv2.conv_spat.weight#(256, 256, 3, 3) +decoder.up.1.block.2.conv2.conv_spat.bias#(256,) +decoder.up.1.upsample_spat.conv.weight#(256, 256, 3, 3) +decoder.up.1.upsample_spat.conv.bias#(256,) +decoder.up.2.block.0.norm1.gamma#(512,) +decoder.up.2.block.0.norm1.beta#(512,) +decoder.up.2.block.0.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.0.conv1.conv_spat.bias#(512,) +decoder.up.2.block.0.norm2.gamma#(512,) +decoder.up.2.block.0.norm2.beta#(512,) +decoder.up.2.block.0.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.0.conv2.conv_spat.bias#(512,) +decoder.up.2.block.1.norm1.gamma#(512,) +decoder.up.2.block.1.norm1.beta#(512,) +decoder.up.2.block.1.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.1.conv1.conv_spat.bias#(512,) +decoder.up.2.block.1.norm2.gamma#(512,) +decoder.up.2.block.1.norm2.beta#(512,) +decoder.up.2.block.1.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.1.conv2.conv_spat.bias#(512,) +decoder.up.2.block.2.norm1.gamma#(512,) +decoder.up.2.block.2.norm1.beta#(512,) +decoder.up.2.block.2.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.2.conv1.conv_spat.bias#(512,) +decoder.up.2.block.2.norm2.gamma#(512,) +decoder.up.2.block.2.norm2.beta#(512,) +decoder.up.2.block.2.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.2.block.2.conv2.conv_spat.bias#(512,) +decoder.up.2.upsample_spat.conv.weight#(512, 512, 3, 3) +decoder.up.2.upsample_spat.conv.bias#(512,) +decoder.up.3.block.0.norm1.gamma#(512,) +decoder.up.3.block.0.norm1.beta#(512,) +decoder.up.3.block.0.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.0.conv1.conv_spat.bias#(512,) +decoder.up.3.block.0.norm2.gamma#(512,) +decoder.up.3.block.0.norm2.beta#(512,) +decoder.up.3.block.0.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.0.conv2.conv_spat.bias#(512,) +decoder.up.3.block.1.norm1.gamma#(512,) +decoder.up.3.block.1.norm1.beta#(512,) +decoder.up.3.block.1.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.1.conv1.conv_spat.bias#(512,) +decoder.up.3.block.1.norm2.gamma#(512,) +decoder.up.3.block.1.norm2.beta#(512,) +decoder.up.3.block.1.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.1.conv2.conv_spat.bias#(512,) +decoder.up.3.block.2.norm1.gamma#(512,) +decoder.up.3.block.2.norm1.beta#(512,) +decoder.up.3.block.2.conv1.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.2.conv1.conv_spat.bias#(512,) +decoder.up.3.block.2.norm2.gamma#(512,) +decoder.up.3.block.2.norm2.beta#(512,) +decoder.up.3.block.2.conv2.conv_spat.weight#(512, 512, 3, 3) +decoder.up.3.block.2.conv2.conv_spat.bias#(512,) +decoder.up.3.upsample_spat.conv.weight#(512, 512, 3, 3) +decoder.up.3.upsample_spat.conv.bias#(512,) +decoder.norm_out.gamma#(128,) +decoder.norm_out.beta#(128,) +decoder.conv_out.conv_spat.weight#(3, 128, 3, 3) +decoder.conv_out.conv_spat.bias#(3,) diff --git a/examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt b/examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt new file mode 100644 index 0000000000..56fe5e8b88 --- /dev/null +++ b/examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt @@ -0,0 +1,244 @@ +encoder.conv_in.weight#(128, 3, 3, 3) +encoder.conv_in.bias#(128,) +encoder.down_blocks.0.resnets.0.norm1.weight#(128,) +encoder.down_blocks.0.resnets.0.norm1.bias#(128,) +encoder.down_blocks.0.resnets.0.conv1.weight#(128, 128, 3, 3) +encoder.down_blocks.0.resnets.0.conv1.bias#(128,) +encoder.down_blocks.0.resnets.0.norm2.weight#(128,) +encoder.down_blocks.0.resnets.0.norm2.bias#(128,) +encoder.down_blocks.0.resnets.0.conv2.weight#(128, 128, 3, 3) +encoder.down_blocks.0.resnets.0.conv2.bias#(128,) +encoder.down_blocks.0.resnets.1.norm1.weight#(128,) +encoder.down_blocks.0.resnets.1.norm1.bias#(128,) +encoder.down_blocks.0.resnets.1.conv1.weight#(128, 128, 3, 3) +encoder.down_blocks.0.resnets.1.conv1.bias#(128,) +encoder.down_blocks.0.resnets.1.norm2.weight#(128,) +encoder.down_blocks.0.resnets.1.norm2.bias#(128,) +encoder.down_blocks.0.resnets.1.conv2.weight#(128, 128, 3, 3) +encoder.down_blocks.0.resnets.1.conv2.bias#(128,) +encoder.down_blocks.0.downsamplers.0.conv.weight#(128, 128, 3, 3) +encoder.down_blocks.0.downsamplers.0.conv.bias#(128,) +encoder.down_blocks.1.resnets.0.norm1.weight#(128,) +encoder.down_blocks.1.resnets.0.norm1.bias#(128,) +encoder.down_blocks.1.resnets.0.conv1.weight#(256, 128, 3, 3) +encoder.down_blocks.1.resnets.0.conv1.bias#(256,) +encoder.down_blocks.1.resnets.0.norm2.weight#(256,) +encoder.down_blocks.1.resnets.0.norm2.bias#(256,) +encoder.down_blocks.1.resnets.0.conv2.weight#(256, 256, 3, 3) +encoder.down_blocks.1.resnets.0.conv2.bias#(256,) +encoder.down_blocks.1.resnets.0.conv_shortcut.weight#(256, 128, 1, 1) +encoder.down_blocks.1.resnets.0.conv_shortcut.bias#(256,) +encoder.down_blocks.1.resnets.1.norm1.weight#(256,) +encoder.down_blocks.1.resnets.1.norm1.bias#(256,) +encoder.down_blocks.1.resnets.1.conv1.weight#(256, 256, 3, 3) +encoder.down_blocks.1.resnets.1.conv1.bias#(256,) +encoder.down_blocks.1.resnets.1.norm2.weight#(256,) +encoder.down_blocks.1.resnets.1.norm2.bias#(256,) +encoder.down_blocks.1.resnets.1.conv2.weight#(256, 256, 3, 3) +encoder.down_blocks.1.resnets.1.conv2.bias#(256,) +encoder.down_blocks.1.downsamplers.0.conv.weight#(256, 256, 3, 3) +encoder.down_blocks.1.downsamplers.0.conv.bias#(256,) +encoder.down_blocks.2.resnets.0.norm1.weight#(256,) +encoder.down_blocks.2.resnets.0.norm1.bias#(256,) +encoder.down_blocks.2.resnets.0.conv1.weight#(512, 256, 3, 3) +encoder.down_blocks.2.resnets.0.conv1.bias#(512,) +encoder.down_blocks.2.resnets.0.norm2.weight#(512,) +encoder.down_blocks.2.resnets.0.norm2.bias#(512,) +encoder.down_blocks.2.resnets.0.conv2.weight#(512, 512, 3, 3) +encoder.down_blocks.2.resnets.0.conv2.bias#(512,) +encoder.down_blocks.2.resnets.0.conv_shortcut.weight#(512, 256, 1, 1) +encoder.down_blocks.2.resnets.0.conv_shortcut.bias#(512,) +encoder.down_blocks.2.resnets.1.norm1.weight#(512,) +encoder.down_blocks.2.resnets.1.norm1.bias#(512,) +encoder.down_blocks.2.resnets.1.conv1.weight#(512, 512, 3, 3) +encoder.down_blocks.2.resnets.1.conv1.bias#(512,) +encoder.down_blocks.2.resnets.1.norm2.weight#(512,) +encoder.down_blocks.2.resnets.1.norm2.bias#(512,) +encoder.down_blocks.2.resnets.1.conv2.weight#(512, 512, 3, 3) +encoder.down_blocks.2.resnets.1.conv2.bias#(512,) +encoder.down_blocks.2.downsamplers.0.conv.weight#(512, 512, 3, 3) +encoder.down_blocks.2.downsamplers.0.conv.bias#(512,) +encoder.down_blocks.3.resnets.0.norm1.weight#(512,) +encoder.down_blocks.3.resnets.0.norm1.bias#(512,) +encoder.down_blocks.3.resnets.0.conv1.weight#(512, 512, 3, 3) +encoder.down_blocks.3.resnets.0.conv1.bias#(512,) +encoder.down_blocks.3.resnets.0.norm2.weight#(512,) +encoder.down_blocks.3.resnets.0.norm2.bias#(512,) +encoder.down_blocks.3.resnets.0.conv2.weight#(512, 512, 3, 3) +encoder.down_blocks.3.resnets.0.conv2.bias#(512,) +encoder.down_blocks.3.resnets.1.norm1.weight#(512,) +encoder.down_blocks.3.resnets.1.norm1.bias#(512,) +encoder.down_blocks.3.resnets.1.conv1.weight#(512, 512, 3, 3) +encoder.down_blocks.3.resnets.1.conv1.bias#(512,) +encoder.down_blocks.3.resnets.1.norm2.weight#(512,) +encoder.down_blocks.3.resnets.1.norm2.bias#(512,) +encoder.down_blocks.3.resnets.1.conv2.weight#(512, 512, 3, 3) +encoder.down_blocks.3.resnets.1.conv2.bias#(512,) +encoder.mid_block.resnets.0.norm1.weight#(512,) +encoder.mid_block.resnets.0.norm1.bias#(512,) +encoder.mid_block.resnets.0.conv1.weight#(512, 512, 3, 3) +encoder.mid_block.resnets.0.conv1.bias#(512,) +encoder.mid_block.resnets.0.norm2.weight#(512,) +encoder.mid_block.resnets.0.norm2.bias#(512,) +encoder.mid_block.resnets.0.conv2.weight#(512, 512, 3, 3) +encoder.mid_block.resnets.0.conv2.bias#(512,) +encoder.mid_block.attentions.0.group_norm.weight#(512,) +encoder.mid_block.attentions.0.group_norm.bias#(512,) +encoder.mid_block.attentions.0.to_q.weight#(512, 512) +encoder.mid_block.attentions.0.to_q.bias#(512,) +encoder.mid_block.attentions.0.to_k.weight#(512, 512) +encoder.mid_block.attentions.0.to_k.bias#(512,) +encoder.mid_block.attentions.0.to_v.weight#(512, 512) +encoder.mid_block.attentions.0.to_v.bias#(512,) +encoder.mid_block.attentions.0.to_out.0.weight#(512, 512) +encoder.mid_block.attentions.0.to_out.0.bias#(512,) +encoder.mid_block.resnets.1.norm1.weight#(512,) +encoder.mid_block.resnets.1.norm1.bias#(512,) +encoder.mid_block.resnets.1.conv1.weight#(512, 512, 3, 3) +encoder.mid_block.resnets.1.conv1.bias#(512,) +encoder.mid_block.resnets.1.norm2.weight#(512,) +encoder.mid_block.resnets.1.norm2.bias#(512,) +encoder.mid_block.resnets.1.conv2.weight#(512, 512, 3, 3) +encoder.mid_block.resnets.1.conv2.bias#(512,) +encoder.conv_norm_out.weight#(512,) +encoder.conv_norm_out.bias#(512,) +encoder.conv_out.weight#(32, 512, 3, 3) +encoder.conv_out.bias#(32,) +decoder.conv_in.weight#(512, 16, 3, 3) +decoder.conv_in.bias#(512,) +decoder.mid_block.resnets.0.norm1.weight#(512,) +decoder.mid_block.resnets.0.norm1.bias#(512,) +decoder.mid_block.resnets.0.conv1.weight#(512, 512, 3, 3) +decoder.mid_block.resnets.0.conv1.bias#(512,) +decoder.mid_block.resnets.0.norm2.weight#(512,) +decoder.mid_block.resnets.0.norm2.bias#(512,) +decoder.mid_block.resnets.0.conv2.weight#(512, 512, 3, 3) +decoder.mid_block.resnets.0.conv2.bias#(512,) +decoder.mid_block.attentions.0.group_norm.weight#(512,) +decoder.mid_block.attentions.0.group_norm.bias#(512,) +decoder.mid_block.attentions.0.to_q.weight#(512, 512) +decoder.mid_block.attentions.0.to_q.bias#(512,) +decoder.mid_block.attentions.0.to_k.weight#(512, 512) +decoder.mid_block.attentions.0.to_k.bias#(512,) +decoder.mid_block.attentions.0.to_v.weight#(512, 512) +decoder.mid_block.attentions.0.to_v.bias#(512,) +decoder.mid_block.attentions.0.to_out.0.weight#(512, 512) +decoder.mid_block.attentions.0.to_out.0.bias#(512,) +decoder.mid_block.resnets.1.norm1.weight#(512,) +decoder.mid_block.resnets.1.norm1.bias#(512,) +decoder.mid_block.resnets.1.conv1.weight#(512, 512, 3, 3) +decoder.mid_block.resnets.1.conv1.bias#(512,) +decoder.mid_block.resnets.1.norm2.weight#(512,) +decoder.mid_block.resnets.1.norm2.bias#(512,) +decoder.mid_block.resnets.1.conv2.weight#(512, 512, 3, 3) +decoder.mid_block.resnets.1.conv2.bias#(512,) +decoder.up_blocks.3.resnets.0.norm1.weight#(256,) +decoder.up_blocks.3.resnets.0.norm1.bias#(256,) +decoder.up_blocks.3.resnets.0.conv1.weight#(128, 256, 3, 3) +decoder.up_blocks.3.resnets.0.conv1.bias#(128,) +decoder.up_blocks.3.resnets.0.norm2.weight#(128,) +decoder.up_blocks.3.resnets.0.norm2.bias#(128,) +decoder.up_blocks.3.resnets.0.conv2.weight#(128, 128, 3, 3) +decoder.up_blocks.3.resnets.0.conv2.bias#(128,) +decoder.up_blocks.3.resnets.0.conv_shortcut.weight#(128, 256, 1, 1) +decoder.up_blocks.3.resnets.0.conv_shortcut.bias#(128,) +decoder.up_blocks.3.resnets.1.norm1.weight#(128,) +decoder.up_blocks.3.resnets.1.norm1.bias#(128,) +decoder.up_blocks.3.resnets.1.conv1.weight#(128, 128, 3, 3) +decoder.up_blocks.3.resnets.1.conv1.bias#(128,) +decoder.up_blocks.3.resnets.1.norm2.weight#(128,) +decoder.up_blocks.3.resnets.1.norm2.bias#(128,) +decoder.up_blocks.3.resnets.1.conv2.weight#(128, 128, 3, 3) +decoder.up_blocks.3.resnets.1.conv2.bias#(128,) +decoder.up_blocks.3.resnets.2.norm1.weight#(128,) +decoder.up_blocks.3.resnets.2.norm1.bias#(128,) +decoder.up_blocks.3.resnets.2.conv1.weight#(128, 128, 3, 3) +decoder.up_blocks.3.resnets.2.conv1.bias#(128,) +decoder.up_blocks.3.resnets.2.norm2.weight#(128,) +decoder.up_blocks.3.resnets.2.norm2.bias#(128,) +decoder.up_blocks.3.resnets.2.conv2.weight#(128, 128, 3, 3) +decoder.up_blocks.3.resnets.2.conv2.bias#(128,) +decoder.up_blocks.2.resnets.0.norm1.weight#(512,) +decoder.up_blocks.2.resnets.0.norm1.bias#(512,) +decoder.up_blocks.2.resnets.0.conv1.weight#(256, 512, 3, 3) +decoder.up_blocks.2.resnets.0.conv1.bias#(256,) +decoder.up_blocks.2.resnets.0.norm2.weight#(256,) +decoder.up_blocks.2.resnets.0.norm2.bias#(256,) +decoder.up_blocks.2.resnets.0.conv2.weight#(256, 256, 3, 3) +decoder.up_blocks.2.resnets.0.conv2.bias#(256,) +decoder.up_blocks.2.resnets.0.conv_shortcut.weight#(256, 512, 1, 1) +decoder.up_blocks.2.resnets.0.conv_shortcut.bias#(256,) +decoder.up_blocks.2.resnets.1.norm1.weight#(256,) +decoder.up_blocks.2.resnets.1.norm1.bias#(256,) +decoder.up_blocks.2.resnets.1.conv1.weight#(256, 256, 3, 3) +decoder.up_blocks.2.resnets.1.conv1.bias#(256,) +decoder.up_blocks.2.resnets.1.norm2.weight#(256,) +decoder.up_blocks.2.resnets.1.norm2.bias#(256,) +decoder.up_blocks.2.resnets.1.conv2.weight#(256, 256, 3, 3) +decoder.up_blocks.2.resnets.1.conv2.bias#(256,) +decoder.up_blocks.2.resnets.2.norm1.weight#(256,) +decoder.up_blocks.2.resnets.2.norm1.bias#(256,) +decoder.up_blocks.2.resnets.2.conv1.weight#(256, 256, 3, 3) +decoder.up_blocks.2.resnets.2.conv1.bias#(256,) +decoder.up_blocks.2.resnets.2.norm2.weight#(256,) +decoder.up_blocks.2.resnets.2.norm2.bias#(256,) +decoder.up_blocks.2.resnets.2.conv2.weight#(256, 256, 3, 3) +decoder.up_blocks.2.resnets.2.conv2.bias#(256,) +decoder.up_blocks.2.upsamplers.0.conv.weight#(256, 256, 3, 3) +decoder.up_blocks.2.upsamplers.0.conv.bias#(256,) +decoder.up_blocks.1.resnets.0.norm1.weight#(512,) +decoder.up_blocks.1.resnets.0.norm1.bias#(512,) +decoder.up_blocks.1.resnets.0.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.0.conv1.bias#(512,) +decoder.up_blocks.1.resnets.0.norm2.weight#(512,) +decoder.up_blocks.1.resnets.0.norm2.bias#(512,) +decoder.up_blocks.1.resnets.0.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.0.conv2.bias#(512,) +decoder.up_blocks.1.resnets.1.norm1.weight#(512,) +decoder.up_blocks.1.resnets.1.norm1.bias#(512,) +decoder.up_blocks.1.resnets.1.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.1.conv1.bias#(512,) +decoder.up_blocks.1.resnets.1.norm2.weight#(512,) +decoder.up_blocks.1.resnets.1.norm2.bias#(512,) +decoder.up_blocks.1.resnets.1.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.1.conv2.bias#(512,) +decoder.up_blocks.1.resnets.2.norm1.weight#(512,) +decoder.up_blocks.1.resnets.2.norm1.bias#(512,) +decoder.up_blocks.1.resnets.2.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.2.conv1.bias#(512,) +decoder.up_blocks.1.resnets.2.norm2.weight#(512,) +decoder.up_blocks.1.resnets.2.norm2.bias#(512,) +decoder.up_blocks.1.resnets.2.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.1.resnets.2.conv2.bias#(512,) +decoder.up_blocks.1.upsamplers.0.conv.weight#(512, 512, 3, 3) +decoder.up_blocks.1.upsamplers.0.conv.bias#(512,) +decoder.up_blocks.0.resnets.0.norm1.weight#(512,) +decoder.up_blocks.0.resnets.0.norm1.bias#(512,) +decoder.up_blocks.0.resnets.0.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.0.conv1.bias#(512,) +decoder.up_blocks.0.resnets.0.norm2.weight#(512,) +decoder.up_blocks.0.resnets.0.norm2.bias#(512,) +decoder.up_blocks.0.resnets.0.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.0.conv2.bias#(512,) +decoder.up_blocks.0.resnets.1.norm1.weight#(512,) +decoder.up_blocks.0.resnets.1.norm1.bias#(512,) +decoder.up_blocks.0.resnets.1.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.1.conv1.bias#(512,) +decoder.up_blocks.0.resnets.1.norm2.weight#(512,) +decoder.up_blocks.0.resnets.1.norm2.bias#(512,) +decoder.up_blocks.0.resnets.1.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.1.conv2.bias#(512,) +decoder.up_blocks.0.resnets.2.norm1.weight#(512,) +decoder.up_blocks.0.resnets.2.norm1.bias#(512,) +decoder.up_blocks.0.resnets.2.conv1.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.2.conv1.bias#(512,) +decoder.up_blocks.0.resnets.2.norm2.weight#(512,) +decoder.up_blocks.0.resnets.2.norm2.bias#(512,) +decoder.up_blocks.0.resnets.2.conv2.weight#(512, 512, 3, 3) +decoder.up_blocks.0.resnets.2.conv2.bias#(512,) +decoder.up_blocks.0.upsamplers.0.conv.weight#(512, 512, 3, 3) +decoder.up_blocks.0.upsamplers.0.conv.bias#(512,) +decoder.conv_norm_out.weight#(128,) +decoder.conv_norm_out.bias#(128,) +decoder.conv_out.weight#(3, 128, 3, 3) +decoder.conv_out.bias#(3,) From 0f75248d6cd2edf0e464f22ad827d58ff163354b Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 13 Nov 2024 16:44:38 +0800 Subject: [PATCH 041/122] update convert script --- examples/movie_gen/tools/inflate_sd3.5_vae.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/movie_gen/tools/inflate_sd3.5_vae.py b/examples/movie_gen/tools/inflate_sd3.5_vae.py index 2da89b1498..f09bf4a2a4 100644 --- a/examples/movie_gen/tools/inflate_sd3.5_vae.py +++ b/examples/movie_gen/tools/inflate_sd3.5_vae.py @@ -1,4 +1,5 @@ from safetensors import safe_open +import argparse import os import numpy as np import mindspore as ms @@ -79,9 +80,26 @@ def convert_vae2d(source_fp, target_fp, target_model='vae2d'): if __name__ == "__main__": - ckpt_path = "/Users/Samit/Downloads/sd3.5_vae/diffusion_pytorch_model.safetensors" + parser = argparse.ArgumentParser() + parser.add_argument( + "--src", + "-s", + type=str, + help="path to vae torch checkpoint", + ) + parser.add_argument( + "--target", + "-t", + type=str, + default='models/tae_vae2d.ckpt', + help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/vae.ckpt", + ) + args = parser.parse_args() + + # ckpt_path = "/Users/Samit/Downloads/sd3.5_vae/diffusion_pytorch_model.safetensors" # get_pname_shape(ckpt_path) + # plot_ms_vae2d5() + # convert_vae2d(ckpt_path, "models/sd3.5_vae.ckpt") - convert_vae2d(ckpt_path, "models/tae_vae2d.ckpt", target_model='tae') + convert_vae2d(args.src, args.target, target_model='tae') - # plot_ms_vae2d5() From 82ed8f7acacb60b007500576f2c4415459352871 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 13 Nov 2024 16:47:28 +0800 Subject: [PATCH 042/122] add sd3 vae --- examples/movie_gen/mg/models/tae/sd3_vae.py | 143 ++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 examples/movie_gen/mg/models/tae/sd3_vae.py diff --git a/examples/movie_gen/mg/models/tae/sd3_vae.py b/examples/movie_gen/mg/models/tae/sd3_vae.py new file mode 100644 index 0000000000..37e337e5c5 --- /dev/null +++ b/examples/movie_gen/mg/models/tae/sd3_vae.py @@ -0,0 +1,143 @@ +import mindspore as ms +from mindspore import nn, ops +from .modules_2d import Encoder, Decoder + +# TODO: set z_channels to 16 +SD3d5_CONFIG = { + "double_z": True, + "z_channels": 16, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "scaling_factor": 1.5305, + "shift_factor": 0.0609, + "use_post_quant_conv": False, + "use_quant_conv": False +} + + +class SD3d5_VAE(nn.Cell): + r""" + TAE + + Parameters: + config (`dict`): config dict + pretrained (`str`): checkpoint path + """ + + def __init__( + self, + config: dict = SD3d5_CONFIG, + pretrained: str = None, + use_recompute: bool=False, + sample_deterministic: bool=False, + ): + super().__init__() + + # encoder + self.encoder = Encoder(**config) + + # quant and post quant + embed_dim = config['z_channels'] + if config['use_quant_conv']: + self.quant_conv = nn.Conv2d(2 * embed_dim, 2 * embed_dim, 1, pad_mode="valid", has_bias=True) + if config['use_post_quant_conv']: + self.post_quant_conv = nn.Conv2d(embed_dim, embed_dim, 1, pad_mode="valid", has_bias=True) + + self.use_quant_conv = config['use_quant_conv'] + self.use_post_quant_conv = config['use_post_quant_conv'] + + # decoder + self.decoder = Decoder(**config) + + self.exp = ops.Exp() + self.stdnormal = ops.StandardNormal() + self.split = ms.ops.split + + self.sample_deterministic = sample_deterministic + + if use_recompute: + # self.recompute(self.encoder) + # self.recompute(self.quant_conv) + # self.recompute(self.post_quant_conv) + self.recompute(self.decoder) + + + def recompute(self, b): + if not b._has_config_recompute: + b.recompute() + if isinstance(b, nn.CellList): + self.recompute(b[-1]) + else: + b.add_flags(output_no_recompute=True) + + + def _encode(self, x): + # return latent distribution, N(mean, logvar) + h = self.encoder(x) + if self.use_quant_conv: + moments = self.quant_conv(h) + else: + moments = h + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + + return mean, logvar + + def sample(self, mean, logvar): + # sample z from latent distribution + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + z = mean + std * self.stdnormal(mean.shape) + + return z + + def encode(self, x: ms.Tensor) -> ms.Tensor: + # embedding, get latent representation z + posterior_mean, posterior_logvar = self._encode(x) + if self.sample_deterministic: + return posterior_mean + z = self.sample(posterior_mean, posterior_logvar) + + return z + + def decode(self, z: ms.Tensor) -> ms.Tensor: + if self.use_post_quant_conv: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def construct(self, x: ms.Tensor) -> ms.Tensor: + """ + video reconstruction + + x: (b c h w) + """ + + posterior_mean, posterior_logvar = self._encode(x) + z = self.sample(posterior_mean, posterior_logvar) + recons = self.decode(z) + + return recons, z, posterior_mean, posterior_logvar + + def load_pretrained(self, ckpt_path:str): + if ckpt_path.endswith('safetensors'): + # load vae parameters from safetensors into my mindspore model + import safetensors + ckpt = safetensors.safe_open(ckpt_path, framework="pt") + state_dict = {} + for key in ckpt.keys(): + state_dict[key] = ckpt.get_tensor(key) + raise NotImplementedError + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: + print(f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!") + print('vae checkpoint loaded') + + From 9c0512f583eebd5eedb7816c3dde24cb2cfcfc42 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Wed, 13 Nov 2024 16:48:51 +0800 Subject: [PATCH 043/122] add moduels for sd3 vae --- .../movie_gen/mg/models/tae/modules_2d.py | 423 ++++++++++++++++++ 1 file changed, 423 insertions(+) create mode 100644 examples/movie_gen/mg/models/tae/modules_2d.py diff --git a/examples/movie_gen/mg/models/tae/modules_2d.py b/examples/movie_gen/mg/models/tae/modules_2d.py new file mode 100644 index 0000000000..b98a7ca45a --- /dev/null +++ b/examples/movie_gen/mg/models/tae/modules_2d.py @@ -0,0 +1,423 @@ +import logging + +import numpy as np + +# import mindspore as ms +from mindspore import nn, ops + +_logger = logging.getLogger(__name__) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def nonlinearity(x): + return x * (ops.sigmoid(x)) + + +def Normalize(in_channels, num_groups=32): + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Cell): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + + def construct(self, x): + in_shape = x.shape[-2:] + out_shape = tuple(2 * x for x in in_shape) + x = ops.ResizeNearestNeighbor(out_shape)(x) + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Cell): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, pad_mode="valid", padding=0, has_bias=True + ) + + def construct(self, x): + if self.with_conv: + pad = ((0, 0), (0, 0), (0, 1), (0, 1)) + x = nn.Pad(paddings=pad)(x) + x = self.conv(x) + else: + x = ops.AvgPool(kernel_size=2, stride=2)(x) + return x + + +# used in vae +class ResnetBlock(nn.Cell): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + if temb_channels > 0: + self.temb_proj = nn.Dense(temb_channels, out_channels, bias_init="normal") + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(p=dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True + ) + + def construct(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Cell): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + self.bmm = ops.BatchMatMul() + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) + + def construct(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = ops.reshape(q, (b, c, h * w)) + q = ops.transpose(q, (0, 2, 1)) # b,hw,c + k = ops.reshape(k, (b, c, h * w)) # b,c,hw + w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + w_ = w_ * (int(c) ** (-0.5)) + w_ = ops.Softmax(axis=2)(w_) + + # attend to values + v = ops.reshape(v, (b, c, h * w)) + w_ = ops.transpose(w_, (0, 2, 1)) # b,hw,hw (first hw of k, second of q) + h_ = self.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = ops.reshape(h_, (b, c, h, w)) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + # assert attn_type in ["vanilla", "vanilla3D"], f"attn_type {attn_type} not supported" + _logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + else: + raise NotImplementedError + + +# used in vae +class Encoder(nn.Cell): + # @ms.lazy_inline() + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + # if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.CellList(auto_prefix=False) + for i_level in range(self.num_resolutions): + block = nn.CellList() + attn = nn.CellList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Cell() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + else: + down.downsample = nn.Identity() + curr_res = curr_res // 2 + down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") + self.down.append(down) + + # middle + self.mid = nn.Cell() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + def construct(self, x): + # timestep embedding + temb = None + + # downsampling + hs = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs = h + if i_level != self.num_resolutions - 1: + hs = self.down[i_level].downsample(hs) + + # middle + h = hs + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Cell): + # @ms.lazy_inline() + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + # if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + # in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + _logger.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + + # middle + self.mid = nn.Cell() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.CellList(auto_prefix=False) + for i_level in reversed(range(self.num_resolutions)): + block = nn.CellList() + attn = nn.CellList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Cell() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + else: + up.upsample = nn.Identity() + curr_res = curr_res * 2 + up.update_parameters_name(prefix=self.param_prefix + f"up.{i_level}.") + if len(self.up) != 0: + self.up.insert(0, up) + else: + self.up.append(up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) + + def construct(self, z): + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + i_level = self.num_resolutions + while i_level > 0: + i_level -= 1 + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = ops.tanh(h) + return h From a6718167ac0a53297d37695fc09cc0c5820310bc Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 13 Nov 2024 12:03:36 +0800 Subject: [PATCH 044/122] update configs --- .../inference/moviegen_t2i_256x256.yaml | 2 +- ...i_256x256.yaml => stage1_t2i_256x256.yaml} | 12 ++-- .../configs/train/stage2_t2iv_256x256.yaml | 72 +++++++++++++++++++ examples/moviegen/inference_text_enc.py | 7 +- .../moviegen/moviegen/models/llama/network.py | 2 +- .../models/text_encoders/text_projector.py | 34 ++++----- .../moviegen/schedulers/rectified_flow.py | 2 +- .../{train_t2i_256x256.sh => stage1_train.sh} | 11 ++- examples/moviegen/scripts/stage2_train.sh | 22 ++++++ 9 files changed, 125 insertions(+), 39 deletions(-) rename examples/moviegen/configs/train/{moviegen_t2i_256x256.yaml => stage1_t2i_256x256.yaml} (89%) create mode 100644 examples/moviegen/configs/train/stage2_t2iv_256x256.yaml rename examples/moviegen/scripts/{train_t2i_256x256.sh => stage1_train.sh} (67%) create mode 100644 examples/moviegen/scripts/stage2_train.sh diff --git a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml index 18e80a066e..37a380854f 100644 --- a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml @@ -6,7 +6,7 @@ env: debug: False model: - name: llama-1B + name: llama-5B pretrained_model_path: enable_flash_attention: True dtype: bf16 diff --git a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml similarity index 89% rename from examples/moviegen/configs/train/moviegen_t2i_256x256.yaml rename to examples/moviegen/configs/train/stage1_t2i_256x256.yaml index 797f917dfa..310cfd6c56 100644 --- a/examples/moviegen/configs/train/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -6,7 +6,7 @@ env: debug: False model: - name: llama-1B + name: llama-5B pretrained_model_path: enable_flash_attention: True recompute: True @@ -27,13 +27,13 @@ dataset: output_columns: ["video", "ul2_caption", "byt5_caption"] dataloader: - batch_size: 64 + batch_size: 70 shuffle: True num_workers_dataset: 4 train: - epochs: 1000 - output_path: output/moviegen_t2i_256x256 + epochs: 2 + output_path: output/stage1_t2i_256x256 lr_scheduler: class_path: mindspore.nn.WarmUpLR @@ -48,7 +48,7 @@ train: weight_decay: 0.1 loss_scaler: - class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell + class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16 init_args: loss_scale_value: 1 @@ -57,7 +57,7 @@ train: offloading: True settings: - zero_stage: 2 + zero_stage: 0 gradient_accumulation_steps: 1 clip_grad: True clip_norm: 1.0 diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml new file mode 100644 index 0000000000..3fa18470c6 --- /dev/null +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -0,0 +1,72 @@ +env: + mode: 0 + jit_level: O0 + seed: 42 + distributed: False + debug: False + +model: + name: llama-5B + pretrained_model_path: + enable_flash_attention: True + recompute: True + dtype: bf16 + +vae: + ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt + dtype: fp16 + +dataset: + csv_path: CSV_PATH + video_folder: VIDEO_FOLDER + text_emb_folder: + ul2: UL2_FOLDER + byt5: BYT5_FOLDER + target_size: [ 256, 256 ] + sample_n_frames: 272 # FIXME: add variable frames support. FIXME: 17 * 16 = 272 frames of OSv1.2 VAE + apply_transforms_dataset: True + output_columns: ["video", "ul2_caption", "byt5_caption"] + +dataloader: + batch_size: 64 + shuffle: True + num_workers_dataset: 4 + +train: + epochs: 1000 + output_path: output/stage2_t2iv_256x256 + + lr_scheduler: + class_path: mindspore.nn.WarmUpLR + init_args: + learning_rate: 6.0e-5 + warmup_steps: 1000 + + optimizer: + name: adamw_re + eps: 1e-15 + betas: [0.9, 0.999] + weight_decay: 0.1 + + loss_scaler: + class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16 + init_args: + loss_scale_value: 1 + + ema: + ema_decay: 0.9999 + offloading: True + + settings: + zero_stage: 0 + gradient_accumulation_steps: 1 + clip_grad: True + clip_norm: 1.0 + + save: + ckpt_save_policy: latest_k + ckpt_max_keep: 10 + ckpt_save_interval: 50 + log_interval: 1 + save_ema_only: False + record_lr: False diff --git a/examples/moviegen/inference_text_enc.py b/examples/moviegen/inference_text_enc.py index 6b3a68c1ad..5a500bffe4 100644 --- a/examples/moviegen/inference_text_enc.py +++ b/examples/moviegen/inference_text_enc.py @@ -86,10 +86,11 @@ def main(args): args.model_name, mindspore_dtype=MODEL_DTYPE[args.dtype.lower()], local_files_only=True ).set_train(False) - logger.info(f"Number of devices: {device_num} | Rank ID: {rank_id} | Number of captions: {len(captions)}") - logger.info( - f"Model name: {args.model_name} | Precision: {args.dtype} | Embedded sequence length: {args.model_max_length}" + info = ( + f"Model name: {args.model_name}\nPrecision: {args.dtype}\nEmbedded sequence length: {args.model_max_length}" + f"\nNumber of devices: {device_num}\nRank ID: {rank_id}\nNumber of captions: {len(captions)}" ) + logger.info(info) for i in trange(0, len(captions), args.batch_size): batch = captions[i : i + args.batch_size] diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index ce66e554a0..5b40dc1c19 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -108,7 +108,7 @@ def construct( ) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) - # Self Attention (Bi-Directional Attention) + # Self-Attention (Bi-Directional Attention) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) diff --git a/examples/moviegen/moviegen/models/text_encoders/text_projector.py b/examples/moviegen/moviegen/models/text_encoders/text_projector.py index 0920cf5de5..157f027e00 100644 --- a/examples/moviegen/moviegen/models/text_encoders/text_projector.py +++ b/examples/moviegen/moviegen/models/text_encoders/text_projector.py @@ -20,24 +20,16 @@ def __init__( dtype: ms.Type = ms.float32, ): super().__init__() - self.ul2_projector = nn.SequentialCell( - [ - mint.nn.Linear(ul2_in_features, out_features, bias=False, dtype=dtype), - layer_norm((out_features,), eps=norm_eps, dtype=dtype), - ] - ) - self.metaclip_projector = nn.SequentialCell( - [ - mint.nn.Linear(metaclip_in_features, out_features, bias=False, dtype=dtype), - layer_norm((out_features,), eps=norm_eps, dtype=dtype), - ] - ) - self.byt5_projector = nn.SequentialCell( - [ - mint.nn.Linear(byt5_in_features, out_features, bias=False, dtype=dtype), - layer_norm((out_features,), eps=norm_eps, dtype=dtype), - ] - ) + # split layers for easier exclusion from weight decay + self.ul2_linear = mint.nn.Linear(ul2_in_features, out_features, bias=False, dtype=dtype) + self.ul2_layernorm = layer_norm((out_features,), eps=norm_eps, dtype=dtype) + + self.metaclip_linear = mint.nn.Linear(metaclip_in_features, out_features, bias=False, dtype=dtype) + self.metaclip_layernorm = layer_norm((out_features,), eps=norm_eps, dtype=dtype) + + self.byt5_linear = mint.nn.Linear(byt5_in_features, out_features, bias=False, dtype=dtype) + self.byt5_layernorm = layer_norm((out_features,), eps=norm_eps, dtype=dtype) + self.initializer_range = initializer_range # post-init @@ -57,8 +49,8 @@ def _init_weights(module): self.apply(_init_weights) def construct(self, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor) -> Tensor: - ul2_hidden_states = self.ul2_projector(ul2_emb) - metaclip_hidden_states = self.metaclip_projector(metaclip_emb) - byt5_hidden_states = self.byt5_projector(byt5_emb) + ul2_hidden_states = self.ul2_layernorm(self.ul2_linear(ul2_emb)) + metaclip_hidden_states = self.metaclip_layernorm(self.metaclip_linear(metaclip_emb)) + byt5_hidden_states = self.byt5_layernorm(self.byt5_linear(byt5_emb)) return mint.cat((ul2_hidden_states, metaclip_hidden_states, byt5_hidden_states), dim=1) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 18a4f6cc50..fa4fc92ea1 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -174,4 +174,4 @@ def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: timesteps = timesteps[:, None, None, None, None] # 3.1.2 First Eqa. - return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise + return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise # TODO: check for zero SNR diff --git a/examples/moviegen/scripts/train_t2i_256x256.sh b/examples/moviegen/scripts/stage1_train.sh similarity index 67% rename from examples/moviegen/scripts/train_t2i_256x256.sh rename to examples/moviegen/scripts/stage1_train.sh index 5d4a2d9afe..7831f4b7fc 100644 --- a/examples/moviegen/scripts/train_t2i_256x256.sh +++ b/examples/moviegen/scripts/stage1_train.sh @@ -5,19 +5,18 @@ export MS_MEMORY_STATISTIC=0 # log level export GLOG_v=2 -output_dir=output/moviegen_t2i_256x256/$(date +"%Y.%m.%d-%H.%M.%S") +output_dir=output/stage1_t2i_256x256/$(date +"%Y.%m.%d-%H.%M.%S") msrun --bind_core=True --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ python train.py \ - --config configs/train/moviegen_t2i_256x256.yaml \ + --config configs/train/stage1_t2i_256x256.yaml \ --env.mode 0 \ - --env.jit_level O0 \ + --env.jit_level O1 \ --env.max_device_memory 59GB \ --env.distributed True \ - --model.name llama-1B \ + --train.settings.zero_stage 2 \ --dataset.csv_path CSV_PATH \ --dataset.video_folder VIDEO_FOLDER \ --dataset.text_emb_folder.ul2 UL2_FOLDER \ --dataset.text_emb_folder.byt5 BYT5_FOLDER \ - --train.output_path $output_dir \ - --train.ema "" # turn off ema + --train.output_path "$output_dir" diff --git a/examples/moviegen/scripts/stage2_train.sh b/examples/moviegen/scripts/stage2_train.sh new file mode 100644 index 0000000000..ac6b6855fb --- /dev/null +++ b/examples/moviegen/scripts/stage2_train.sh @@ -0,0 +1,22 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# plot memory usage, feature/model: 1 +export MS_MEMORY_STATISTIC=0 + +# log level +export GLOG_v=2 + +output_dir=output/stage2_t2iv_256x256/$(date +"%Y.%m.%d-%H.%M.%S") + +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +python train.py \ + --config configs/train/stage2_t2iv_256x256.yaml \ + --env.mode 0 \ + --env.jit_level O1 \ + --env.max_device_memory 59GB \ + --env.distributed True \ + --train.settings.zero_stage 2 \ + --dataset.csv_path CSV_PATH \ + --dataset.video_folder VIDEO_FOLDER \ + --dataset.text_emb_folder.ul2 UL2_FOLDER \ + --dataset.text_emb_folder.byt5 BYT5_FOLDER \ + --train.output_path "$output_dir" From 835f9f017306f899ec8e7eb8b9385329a1e2ecef Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 15 Nov 2024 15:29:22 +0800 Subject: [PATCH 045/122] temporal median init, 1p train psnr ok --- .../movie_gen/configs/tae/train/video_ft.yaml | 8 ++-- examples/movie_gen/mg/models/tae/modules.py | 40 +++++++++++-------- examples/movie_gen/mg/models/tae/tae.py | 11 +++++ .../movie_gen/scripts/run/run_train_tae.sh | 27 +++++++++++++ examples/movie_gen/tests/test_tae.py | 2 +- 5 files changed, 66 insertions(+), 22 deletions(-) create mode 100755 examples/movie_gen/scripts/run/run_train_tae.sh diff --git a/examples/movie_gen/configs/tae/train/video_ft.yaml b/examples/movie_gen/configs/tae/train/video_ft.yaml index c6e9330f36..ac78edb15d 100644 --- a/examples/movie_gen/configs/tae/train/video_ft.yaml +++ b/examples/movie_gen/configs/tae/train/video_ft.yaml @@ -1,9 +1,8 @@ # model -freeze_vae_2d: False -pretrained_model_path: "" +pretrained_model_path: "models/tae_vae2d.ckpt" # loss -perceptual_loss_weight: 0.1 +perceptual_loss_weight: 1.0 kl_loss_weight: 1.e-6 use_outlier_penalty_loss: True mixed_strategy: "mixed_video_image" @@ -35,7 +34,8 @@ use_recompute: True epochs: 400 ckpt_save_interval: 100 -init_loss_scale: 1. +init_loss_scale: 1024. +loss_scaler_type: dynamic scheduler: "constant" use_ema: False diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index d0637a27fb..a43a3af979 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -156,7 +156,7 @@ def __init__( self.use_pad = False self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias, bias_init='zeros') - self.init_temporal_weight() + self.init_temporal_weight('median') @staticmethod def symmetric_pad1d(x): @@ -175,6 +175,7 @@ def construct(self, x): (b c t h w) ''' + B, Ci, T, Hi, Wi = x.shape # (b c t h w) -> (b t c h w) x = ops.transpose(x, (0, 2, 1, 3, 4)) @@ -210,21 +211,26 @@ def construct(self, x): return x - def init_temporal_weight(self): - # temporal conv kernel: (cout, cin, 1, ks) - # ks=1 or 3, cin == cout - # import pdb; pdb.set_trace() - w = self.conv_temp.weight - ch = int(w.shape[0]) - ks = int(w.shape[-1]) - value = np.zeros(tuple(w.shape)) + def init_temporal_weight(self, method='median'): + if method == 'normal': + return - # only the middle element of the kernel is 1 so that the output is the same input in initialization - for i in range(ch): - value[i, i, 0, ks//2] = 1 - w.set_data(ms.Tensor(value, dtype=ms.float32)) + elif method == 'median': + # temporal conv kernel: (cout, cin, 1, ks) + # ks=1 or 3, cin == cout + w = self.conv_temp.weight + ch = int(w.shape[0]) + ks = int(w.shape[-1]) + value = np.zeros(tuple(w.shape)) + + # only the middle element of the kernel is 1 so that the output is the same input in initialization + for i in range(ch): + value[i, i, 0, ks//2] = 1 + w.set_data(ms.Tensor(value, dtype=ms.float32)) - # bias is initialized to zero in layer def + # bias is initialized to zero in layer def + else: + raise NotImplementedError class SpatialUpsample(nn.Cell): @@ -304,13 +310,13 @@ def __init__(self, in_channels): ) # tail padding, pad with last frame self.time_pad = self.ks - 1 - self.init_weight("mean") + self.init_weight("median") def init_weight(self, method='mean'): if method == 'normal': # default conv init return - + # no way to reserve complete input since stride 2 w = self.conv.weight value = np.zeros(tuple(w.shape)) @@ -357,7 +363,7 @@ def __init__(self, in_channels): self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init='zeros') # TODO: init conv weight so that it pass in image mode self.ch = in_channels - self.init_weight() + self.init_weight('median') def init_weight(self, method='median'): if method == 'normal': diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 0057a8ddfe..986d13a3e4 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -89,6 +89,9 @@ def __init__( # self.recompute(self.post_quant_conv) self.recompute(self.decoder) + if pretrained is not None: + self.load_pretrained(pretrained) + def recompute(self, b): if not b._has_config_recompute: @@ -161,7 +164,15 @@ def load_pretrained(self, ckpt_path:str): raise NotImplementedError else: param_dict = ms.load_checkpoint(ckpt_path) + + # remove the added prefix in the trained checkpoint + pnames = list(param_dict.keys()) + for pn in pnames: + new_pn = pn.replace("autoencoder.", "").replace("_backbone.", "") + param_dict[new_pn] = param_dict.pop(pn) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: print(f"{param_not_load} in network is not loaded") print(f"{ckpt_not_load} in checkpoint is not loaded!") diff --git a/examples/movie_gen/scripts/run/run_train_tae.sh b/examples/movie_gen/scripts/run/run_train_tae.sh new file mode 100755 index 0000000000..01faaed3b4 --- /dev/null +++ b/examples/movie_gen/scripts/run/run_train_tae.sh @@ -0,0 +1,27 @@ +export ASCEND_RT_VISIBLE_DEVICES=7 +# improve data loading performance for distributed training: 1 +export MS_ENABLE_NUMA=0 +# plot memory usage, feature/model: 1 +export MS_MEMORY_STATISTIC=0 +export MS_DATASET_SINK_QUEUE=8 + +# operation/graph fusion for dynamic shape +export MS_DEV_ENABLE_KERNEL_PACKET=on + +# log level +export GLOG_v=2 + +output_dir=outputs/train_tae_1p_sd3.5vaeInit_noOpl + +python scripts/train_tae.py \ +--mode=0 \ +--jit_level O0 \ +--amp_level O0 \ +--use_outlier_penalty_loss False \ +--dtype fp32 \ +--config configs/tae/train/video_ft.yaml \ +--output_path=$output_dir \ +--epochs=2000 --ckpt_save_interval=50 \ + +# --use_parallel=True \ + diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 3eafaf2203..711b02476c 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -29,7 +29,7 @@ def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", # read image using PIL and preprocess image = Image.open(img_path).convert('RGB') - image = image.resize(target_size, Image.ANTIALIAS) + image = image.resize(target_size) pixel_values = np.array(image, dtype=np.float32) pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) From 5710f424adefba1687a28ec134420e807f42724f Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 15 Nov 2024 15:31:51 +0800 Subject: [PATCH 046/122] add files --- examples/movie_gen/mg/utils/load_models.py | 69 +++++ examples/movie_gen/mg/utils/parser.py | 30 ++ examples/movie_gen/scripts/inference_vae.py | 324 ++++++++++++++++++++ 3 files changed, 423 insertions(+) create mode 100644 examples/movie_gen/mg/utils/load_models.py create mode 100644 examples/movie_gen/mg/utils/parser.py create mode 100644 examples/movie_gen/scripts/inference_vae.py diff --git a/examples/movie_gen/mg/utils/load_models.py b/examples/movie_gen/mg/utils/load_models.py new file mode 100644 index 0000000000..19af59ab56 --- /dev/null +++ b/examples/movie_gen/mg/utils/load_models.py @@ -0,0 +1,69 @@ +import logging +import os +import re +from typing import Union + +from mindcv.utils.download import DownLoad + +import mindspore as ms +from mindspore import nn + +from mindone.utils.params import load_param_into_net_with_filter + +logger = logging.getLogger() + + +def is_url(string): + # Regex to check for URL patterns + url_pattern = re.compile(r"^(http|https|ftp)://") + return bool(url_pattern.match(string)) + + +def load_from_pretrained( + net: nn.Cell, + checkpoint: Union[str, dict], + ignore_net_params_not_loaded=False, + ensure_all_ckpt_params_loaded=False, + cache_dir: str = None, +): + """load checkpoint into network. + + Args: + net: network + checkpoint: local file path to checkpoint, or url to download checkpoint, or a dict for network parameters + ignore_net_params_not_loaded: set True for inference if only a part of network needs to be loaded, the flushing net-not-loaded warnings will disappear. + ensure_all_ckpt_params_loaded : set True for inference if you want to ensure no checkpoint param is missed in loading + cache_dir: directory to cache the downloaded checkpoint, only effective when `checkpoint` is a url. + """ + if isinstance(checkpoint, str): + if is_url(checkpoint): + url = checkpoint + cache_dir = os.path.join(os.path.expanduser("~"), ".mindspore/models") if cache_dir is None else cache_dir + os.makedirs(cache_dir, exist_ok=True) + DownLoad().download_url(url, path=cache_dir) + checkpoint = os.path.join(cache_dir, os.path.basename(url)) + if os.path.exists(checkpoint): + param_dict = ms.load_checkpoint(checkpoint) + else: + raise FileNotFoundError(f"{checkpoint} doesn't exist") + elif isinstance(checkpoint, dict): + param_dict = checkpoint + else: + raise TypeError(f"unknown checkpoint type: {checkpoint}") + + if param_dict: + if ignore_net_params_not_loaded: + filter = param_dict.keys() + else: + filter = None + param_not_load, ckpt_not_load = load_param_into_net_with_filter(net, param_dict, filter=filter) + + if ensure_all_ckpt_params_loaded: + assert ( + len(ckpt_not_load) == 0 + ), f"All params in checkpoint must be loaded. but got these not loaded {ckpt_not_load}" + + if not ignore_net_params_not_loaded: + if len(param_not_load) > 0: + logger.info("Net params not loaded: {}".format([p for p in param_not_load if not p.startswith("adam")])) + logger.info("Checkpoint params not loaded: {}".format([p for p in ckpt_not_load if not p.startswith("adam")])) diff --git a/examples/movie_gen/mg/utils/parser.py b/examples/movie_gen/mg/utils/parser.py new file mode 100644 index 0000000000..96d431f8eb --- /dev/null +++ b/examples/movie_gen/mg/utils/parser.py @@ -0,0 +1,30 @@ +import argparse +import logging + + +def remove_pname_prefix(param_dict, prefix="network."): + # replace the prefix of param dict + new_param_dict = {} + for pname in param_dict: + if pname.startswith(prefix): + new_pname = pname[len(prefix) :] + else: + new_pname = pname + new_param_dict[new_pname] = param_dict[pname] + return new_param_dict + + +def str2bool(b): + if b.lower() not in ["false", "true"]: + raise Exception("Invalid Bool Value") + if b.lower() in ["false"]: + return False + return True + + +def _check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser): + actions_dest = [action.dest for action in parser._actions] + defaults_key = parser._defaults.keys() + for k in cfgs.keys(): + if k not in actions_dest and k not in defaults_key: + raise KeyError(f"{k} does not exist in ArgumentParser!") diff --git a/examples/movie_gen/scripts/inference_vae.py b/examples/movie_gen/scripts/inference_vae.py new file mode 100644 index 0000000000..2885790f09 --- /dev/null +++ b/examples/movie_gen/scripts/inference_vae.py @@ -0,0 +1,324 @@ +# flake8: noqa +""" +Infer and evaluate autoencoders +""" +import argparse +import logging +import os +import sys +import time + +import imageio +import numpy as np + +from mindspore import nn, ops + +# mindone_dir = '/home/mindocr/yx/mindone' +mindone_dir = "/home_host/yx/mindone" +sys.path.insert(0, mindone_dir) + + +from omegaconf import OmegaConf +from PIL import Image +from skimage.metrics import peak_signal_noise_ratio as calc_psnr +from skimage.metrics import structural_similarity as calc_ssim +from tqdm import tqdm + +import mindspore as ms + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) +sys.path.insert(0, mindone_lib_path) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +from mg.datasets.tae_dataset import create_dataloader +from mg.models.tae.tae import TemporalAutoencoder +from mg.models.tae.lpips import LPIPS + +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.config import instantiate_from_config, str2bool +from mindone.utils.logger import set_logger + +logger = logging.getLogger(__name__) + + +def postprocess(x, trim=True): + # postprocess for computing metrics + pixels = (x + 1) * 127.5 + pixels = np.clip(pixels, 0, 255).astype(np.uint8) + + if len(pixels.shape) == 4: + # b, c, h, w -> b h w c + return np.transpose(pixels, (0, 2, 3, 1)) + else: + # b c t h w -> b t h w c -> b*t h w c + b, c, t, h, w = pixels.shape + pixels = np.transpose(pixels, (0, 2, 3, 4, 1)) + pixels = np.reshape(pixels, (b * t, h, w, c)) + return pixels + + +def visualize_image(recons, x=None, save_fn="tmp_vae_recons"): + # x: (b h w c) + for i in range(recons.shape[0]): + if x is not None: + out = np.concatenate((x[i], recons[i]), axis=-2) + else: + out = recons[i] + Image.fromarray(out).save(f"{save_fn}-{i:02d}.png") + + +def visualize_video(recons, x=None, save_fn="tmp_vae3d_recons", fps=15): + # x: (b t h w c) + for i in range(recons.shape[0]): + if x is not None: + out = np.concatenate((x[i], recons[i]), axis=-2) + else: + out = recons[i] + save_fp = f"{save_fn}-{i:02d}.gif" + imageio.mimsave(save_fp, out, duration=1 / fps, loop=0) + + +def rearrange_in(x): + b, c, t, h, w = x.shape + x = ops.transpose(x, (0, 2, 3, 4, 1)) + x = ops.reshape(x, (b * t, h, w, c)) + return x + +def rearrange_out(x, t): + bt, c, h, w = x.shape + b = bt // t + x = ops.reshape(x, (b, t, h, w, c)) + x = ops.transpose(x, (0, 4, 1, 2, 3)) + return x + + +def main(args): + ascend_config = {"precision_mode": "must_keep_origin_dtype"} + ms.set_context(mode=args.mode, ascend_config=ascend_config) + set_logger(name="", output_dir=args.output_path, rank=0) + + # build model + model = TemporalAutoencoder( + pretrained=args.ckpt_path, + ) + + model.set_train(False) + logger.info(f"Loaded checkpoint from {args.ckpt_path}") + + if args.eval_loss: + lpips_loss_fn = LPIPS() + + if args.dtype != "fp32": + amp_level = "O2" + dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + # FIXME: due to AvgPool and ops.interpolate doesn't support bf16, we add them to fp32 cells + custom_fp32_cells = [nn.GroupNorm, nn.AvgPool2d, nn.Upsample] + model = auto_mixed_precision(model, amp_level, dtype, custom_fp32_cells) + logger.info(f"Set mixed precision to O2 with dtype={args.dtype}") + else: + amp_level = "O0" + + # build dataset + if isinstance(args.image_size, int): + image_size = args.image_size + else: + if len(args.image_size) == 2: + assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" + image_size = args.image_size[0] + + ds_config = dict( + csv_path=args.csv_path, + data_folder=args.video_folder, + size=image_size, + crop_size=image_size, + sample_n_frames=args.num_frames, + sample_stride=args.frame_stride, + video_column=args.video_column, + random_crop=False, + flip=False, + ) + dataset = create_dataloader( + ds_config, + args.batch_size, + mixed_strategy=None, + mixed_image_ratio=0.0, + num_parallel_workers=8, + max_rowsize=256, + shuffle=False, + device_num=1, + rank_id=0, + drop_remainder=False, + ) + num_batches = dataset.get_dataset_size() + + ds_iter = dataset.create_dict_iterator(1) + + logger.info("Inferene begins") + mean_infer_time = 0 + mean_psnr = 0 + mean_ssim = 0 + mean_lpips = 0 + mean_recon = 0 + num_samples = 0 + for step, data in tqdm(enumerate(ds_iter)): + x = data["video"] + start_time = time.time() + + if args.encode_only: + z = model.encode(x) + else: + # recons = model.decode(z) + recons, z, posterior_mean, posterior_logvar = model(x) + + # adapt to bf16 + recons = recons.to(ms.float32) + + infer_time = time.time() - start_time + mean_infer_time += infer_time + logger.info(f"Infer time: {infer_time}") + + if not args.encode_only: + # if args.dataset_name == 'image' and args.expand_dim_t: + # # b c t h w -> b c h w + # x = x[:,:,0,:,:] + # recons= recons[:,:,0,:,:] + is_video = len(recons.shape) == 5 and (recons.shape[-3] > 1) + t = recons.shape[-3] if is_video else 1 + + recons_rgb = postprocess(recons.asnumpy()) + x_rgb = postprocess(x.asnumpy()) + + psnr_cur = [calc_psnr(x_rgb[i], recons_rgb[i]) for i in range(x_rgb.shape[0])] + ssim_cur = [ + calc_ssim(x_rgb[i], recons_rgb[i], data_range=255, channel_axis=-1, multichannel=True) + for i in range(x_rgb.shape[0]) + ] + mean_psnr += sum(psnr_cur) + mean_ssim += sum(ssim_cur) + num_samples += x_rgb.shape[0] + + logger.info(f"cur psnr: {psnr_cur[-1]:.4f}, mean psnr:{mean_psnr/num_samples:.4f}") + logger.info(f"cur ssim: {ssim_cur[-1]:.4f}, mean ssim:{mean_ssim/num_samples:.4f}") + + if args.eval_loss: + print("D--: ", x.shape, recons.shape) + recon_loss = np.abs((x - recons).asnumpy()) + + t = x.shape[2] + x = rearrange_in(x) + # lpips_loss = lpips_loss_fn(x, recons).asnumpy() + + mean_recon += recon_loss.mean() + # mean_lpips += lpips_loss.mean() + logger.info(f"mean recon loss: {mean_recon/num_batches:.4f}") + + if args.save_vis: + save_fn = os.path.join( + args.output_path, "{}-{}".format(os.path.basename(args.video_folder), f"step{step:03d}") + ) + if not is_video: + visualize_image(recons_rgb, x_rgb, save_fn=save_fn) + else: + bt, h, w, c = recons_rgb.shape + recons_rgb_vis = np.reshape(recons_rgb, (bt // t, t, h, w, c)) + x_rgb_vis = np.reshape(x_rgb, (bt // t, t, h, w, c)) + visualize_video(recons_rgb_vis, x_rgb_vis, save_fn=save_fn) + + mean_infer_time /= num_batches + logger.info(f"Mean infer time: {mean_infer_time}") + logger.info(f"Done. Results saved in {args.output_path}") + + if not args.encode_only: + mean_psnr /= num_samples + mean_ssim /= num_samples + logger.info(f"mean psnr:{mean_psnr:.4f}") + logger.info(f"mean ssim:{mean_ssim:.4f}") + + if args.eval_loss: + mean_recon /= num_batches + # mean_lpips /= num_batches + logger.info(f"mean recon loss: {mean_recon:.4f}") + # logger.info(f"mean lpips loss: {mean_lpips:.4f}") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_config", + default="configs/autoencoder_kl_f8.yaml", + type=str, + help="model architecture config", + ) + parser.add_argument( + "--ckpt_path", default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt", type=str, help="checkpoint path" + ) + parser.add_argument( + "--csv_path", + default=None, + type=str, + help="path to csv annotation file. If None, will get videos from the folder of `data_path`", + ) + parser.add_argument("--video_folder", default=None, type=str, help="folder of videos") + parser.add_argument( + "--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results" + ) + parser.add_argument("--num_frames", default=17, type=int, help="num frames") + parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride") + parser.add_argument( + "--expand_dim_t", + default=False, + type=str2bool, + help="expand temporal axis for image data, used for vae 3d inference with image data", + ) + parser.add_argument("--image_size", default=256, type=int, help="image rescale size") + # parser.add_argument("--crop_size", default=256, type=int, help="image crop size") + + parser.add_argument("--batch_size", default=1, type=int, help="batch size") + parser.add_argument("--num_parallel_workers", default=8, type=int, help="num workers for data loading") + parser.add_argument( + "--eval_loss", + default=False, + type=str2bool, + help="whether measure loss including reconstruction, kl, perceptual loss", + ) + parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images") + parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae") + parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution") + parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") + parser.add_argument( + "--mixed_strategy", + type=str, + default=None, + choices=[None, "mixed_video_image", "image_only"], + help="video and image mixed strategy.", + ) + parser.add_argument( + "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" + ) + parser.add_argument( + "--save_z_dist", + default=False, + type=str2bool, + help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.", + ) + # ms related + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument( + "--dtype", + default="fp32", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) From 164f0c9cc68e991a048f3ae1d7f28be61e9d5ab9 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 15 Nov 2024 15:53:37 +0800 Subject: [PATCH 047/122] fix rt id --- examples/movie_gen/scripts/run/run_train_tae.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/movie_gen/scripts/run/run_train_tae.sh b/examples/movie_gen/scripts/run/run_train_tae.sh index 01faaed3b4..8396e3e4da 100755 --- a/examples/movie_gen/scripts/run/run_train_tae.sh +++ b/examples/movie_gen/scripts/run/run_train_tae.sh @@ -1,4 +1,4 @@ -export ASCEND_RT_VISIBLE_DEVICES=7 +# export ASCEND_RT_VISIBLE_DEVICES=7 # improve data loading performance for distributed training: 1 export MS_ENABLE_NUMA=0 # plot memory usage, feature/model: 1 From 4f7e3ef3c0766fbe9ecc2a6c8cb1a2d2d76a5c2b Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Sat, 16 Nov 2024 14:32:12 +0800 Subject: [PATCH 048/122] set image and crop size --- examples/movie_gen/scripts/args_train_tae.py | 6 ++---- examples/movie_gen/scripts/train_tae.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/movie_gen/scripts/args_train_tae.py b/examples/movie_gen/scripts/args_train_tae.py index 801286d8b0..d2962498c6 100644 --- a/examples/movie_gen/scripts/args_train_tae.py +++ b/examples/movie_gen/scripts/args_train_tae.py @@ -133,9 +133,6 @@ def parse_train_args(parser): parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.") parser.add_argument("--pre_patchify", default=False, type=str2bool, help="Training with patchified latent.") - parser.add_argument( - "--max_image_size", default=512, type=int, help="Max image size for patchified latent training." - ) # dataloader params parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode") @@ -212,7 +209,8 @@ def parse_train_args(parser): parser.add_argument( "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model." ) - parser.add_argument("--image_size", default=256, type=int, nargs="+", help="the image size used to initiate model") + parser.add_argument("--image_size", default=256, type=int, nargs="+", help="image size for resizing the input image") + parser.add_argument("--crop_size", default=256, type=int, help="crop size after resize") parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model") parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride") parser.add_argument("--mask_ratios", type=dict, help="Masking ratios") diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index 87c506f5ff..f3e6d47c21 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -168,7 +168,7 @@ def main(args): csv_path=args.csv_path, data_folder=args.video_folder, size=image_size, - crop_size=image_size, + crop_size=args.crop_size, sample_n_frames=args.num_frames, sample_stride=args.frame_stride, video_column=args.video_column, From d830f5af25146be114a0c940e0224fce56e57554 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:32:08 +0800 Subject: [PATCH 049/122] add train step mode --- .../configs/train/stage1_t2i_256x256.yaml | 4 +-- .../configs/train/stage2_t2iv_256x256.yaml | 4 +-- examples/moviegen/train.py | 28 +++++++++++++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index 310cfd6c56..b6db71133e 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -32,7 +32,7 @@ dataloader: num_workers_dataset: 4 train: - epochs: 2 + steps: 20000 output_path: output/stage1_t2i_256x256 lr_scheduler: @@ -64,8 +64,8 @@ train: save: ckpt_save_policy: latest_k + ckpt_save_interval: 500 ckpt_max_keep: 10 - ckpt_save_interval: 50 log_interval: 1 save_ema_only: False record_lr: False diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index 3fa18470c6..860f37e3ce 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -33,7 +33,7 @@ dataloader: num_workers_dataset: 4 train: - epochs: 1000 + steps: 20000 output_path: output/stage2_t2iv_256x256 lr_scheduler: @@ -65,8 +65,8 @@ train: save: ckpt_save_policy: latest_k + ckpt_save_interval: 500 ckpt_max_keep: 10 - ckpt_save_interval: 50 log_interval: 1 save_ema_only: False record_lr: False diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 2c1639de69..cba5ecdc64 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -1,6 +1,7 @@ import logging import os import sys +from math import ceil from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import path_type @@ -21,7 +22,7 @@ from mindone.data import create_dataloader from mindone.trainers import create_optimizer -from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor +from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, StopAtStepCallback from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger @@ -72,6 +73,7 @@ def main(args): # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR + epochs = ceil(args.train.steps / dataloader.get_dataset_size()) lr = initializer.train.lr_scheduler # 5.2 optimizer @@ -87,7 +89,7 @@ def main(args): model = Model(net_with_grads) # 5.4 callbacks - callbacks = [OverflowMonitor()] + callbacks = [OverflowMonitor(), StopAtStepCallback(train_steps=args.train.steps)] if rank_id == 0: callbacks.extend( [ @@ -98,6 +100,9 @@ def main(args): rank_id=rank_id, ckpt_save_dir=os.path.join(args.train.output_path, "ckpt"), ema=ema, + step_mode=True, + use_step_unit=True, + train_steps=args.train.steps, **args.train.save, ), ] @@ -124,7 +129,7 @@ def main(args): f"Frames: {args.dataset.sample_n_frames}", f"Weight decay: {args.train.optimizer.weight_decay}", f"Grad accumulation steps: {args.train.settings.gradient_accumulation_steps}", - f"Num epochs: {args.train.epochs}", + f"Number of training steps: {args.train.steps}", f"Loss scaler: {args.train.loss_scaler.class_path}", f"Init loss scale: {args.train.loss_scaler.init_args.loss_scale_value}", f"Grad clipping: {args.train.settings.clip_grad}", @@ -139,7 +144,7 @@ def main(args): # 6. train logger.info("Start training...") - model.train(args.train.epochs, dataloader, callbacks=callbacks) + model.train(epochs, dataloader, callbacks=callbacks) if __name__ == "__main__": @@ -183,11 +188,22 @@ def main(args): type=path_type("dcc"), # path to a directory that can be created if it does not exist help="Output directory to save training results.", ) - parser.add_argument("--train.epochs", default=10, type=int, help="Number of epochs to train. Default: 100.") + parser.add_argument("--train.steps", default=100, type=int, help="Number of steps to train. Default: 100.") parser.add_class_arguments( EvalSaveCallback, "train.save", - skip={"network", "rank_id", "ckpt_save_dir", "output_dir", "ema", "start_epoch", "model_name"}, + skip={ + "network", + "rank_id", + "ckpt_save_dir", + "output_dir", + "ema", + "start_epoch", + "model_name", + "step_mode", + "use_step_unit", + "train_steps", + }, instantiate=False, ) cfg = parser.parse_args() From 9372beb8036435242507ed2d6e2b6502332a3ee2 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Tue, 19 Nov 2024 16:26:25 +0800 Subject: [PATCH 050/122] replace interpolate for bf16 support --- examples/movie_gen/mg/models/tae/modules.py | 25 +++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index a43a3af979..426f111f27 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -380,6 +380,30 @@ def init_weight(self, method='median'): else: raise NotImplementedError + def construct(self, x): + # x (b c t h w) + B, C, T0, H, W = x.shape + x = ops.reshape(x, (B, C, T0, H*W)) + + # NOTE: bf16 only support 4D interpolate + # x = ops.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest") + out_shape = (T0 * 2, H * W) + x = ops.ResizeNearestNeighbor(out_shape)(x) + + # x (b c t hw) -> (bhw c t) + T = T0 * 2 + x = ops.transpose(x, (0, 3, 1, 2)) + x = ops.reshape(x, (B*H*W, C, T)) + + x = self.conv(x) + + # x (bhw c t) -> (b c t h w) + x = ops.reshape(x, (B, H, W, C, T)) + x = ops.transpose(x, (0, 3, 4, 1, 2)) + + return x + + ''' def construct(self, x): # x (b c t h w) x = ops.interpolate(x, scale_factor=(2.0, 1.0, 1.0), mode="nearest") @@ -396,6 +420,7 @@ def construct(self, x): x = ops.transpose(x, (0, 3, 4, 1, 2)) return x + ''' # used in vae From 3eb6cb1454e98342495b2e3114b8ebd050480e4e Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:28:49 +0800 Subject: [PATCH 051/122] add validation support --- .../inference/moviegen_t2i_256x256.yaml | 2 +- .../configs/train/stage1_t2i_256x256.yaml | 23 +++- .../configs/train/stage2_t2iv_256x256.yaml | 2 +- examples/moviegen/inference.py | 2 +- examples/moviegen/moviegen/dataset/dataset.py | 19 +-- .../moviegen/pipelines/train_pipeline.py | 4 + .../moviegen/schedulers/rectified_flow.py | 21 ++- examples/moviegen/moviegen/utils/__init__.py | 1 + examples/moviegen/moviegen/utils/callbacks.py | 122 ++++++++++++++++++ examples/moviegen/train.py | 66 +++++++++- mindone/trainers/train_step.py | 4 + 11 files changed, 234 insertions(+), 32 deletions(-) create mode 100644 examples/moviegen/moviegen/utils/callbacks.py diff --git a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml index 37a380854f..803e28da75 100644 --- a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml @@ -27,7 +27,7 @@ text_emb: batch_size: 10 # Saving options -output_path: samples +output_path: ../../samples # the path is relative to this config append_timestamp: True save_format: png save_latent: False diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index b6db71133e..eac569e618 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -33,7 +33,7 @@ dataloader: train: steps: 20000 - output_path: output/stage1_t2i_256x256 + output_path: ../../output/stage1_t2i_256x256 # the path is relative to this config lr_scheduler: class_path: mindspore.nn.WarmUpLR @@ -64,8 +64,27 @@ train: save: ckpt_save_policy: latest_k - ckpt_save_interval: 500 + ckpt_save_interval: &save_interval 500 ckpt_max_keep: 10 log_interval: 1 save_ema_only: False record_lr: False + +valid: + sampling_steps: 10 + frequency: *save_interval + + dataset: + csv_path: CSV_PATH + video_folder: VIDEO_FOLDER + text_emb_folder: + ul2: UL2_FOLDER + byt5: BYT5_FOLDER + target_size: [ 256, 256 ] + apply_transforms_dataset: True + output_columns: [ "video", "ul2_caption", "byt5_caption" ] + + dataloader: + batch_size: 50 + shuffle: False + num_workers_dataset: 4 diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index 860f37e3ce..becfc16898 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -34,7 +34,7 @@ dataloader: train: steps: 20000 - output_path: output/stage2_t2iv_256x256 + output_path: ../../output/stage2_t2iv_256x256 # the path is relative to this config lr_scheduler: class_path: mindspore.nn.WarmUpLR diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index ce25de023d..7c80206e77 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -51,7 +51,7 @@ def prepare_captions( def main(args): # TODO: CFG error - save_dir = os.path.join(__dir__, args.output_path.relative) + save_dir = args.output_path.absolute if args.append_timestamp: time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") save_dir = os.path.join(save_dir, time_str) diff --git a/examples/moviegen/moviegen/dataset/dataset.py b/examples/moviegen/moviegen/dataset/dataset.py index 5915acfc79..a729bc6607 100644 --- a/examples/moviegen/moviegen/dataset/dataset.py +++ b/examples/moviegen/moviegen/dataset/dataset.py @@ -234,35 +234,18 @@ def train_transforms( tokenizer: Optional[Callable[[str], np.ndarray]] = None, ) -> List[dict]: transforms = [] - vae_downsample_rate = self._vae_downsample_rate - if not self._vae_latent_folder: - vae_downsample_rate = 1 transforms.append( { "operations": [ ResizeCrop(target_size, interpolation=interpolation), lambda x: x.astype(np.float32) / 127.5 - 1, + lambda x: x[None, ...] if x.ndim == 3 else x, # if image lambda x: np.transpose(x, (0, 3, 1, 2)), ], "input_columns": ["video"], } ) - # the followings are not transformation for video frames, can be excluded - transforms.append( - { - "operations": [ - lambda video: ( - video, # need to return the video itself to preserve the column - np.array(video.shape[-2] * vae_downsample_rate, dtype=np.float32), - np.array(video.shape[-1] * vae_downsample_rate, dtype=np.float32), - np.array(video.shape[-2] / video.shape[-1], dtype=np.float32), - ) - ], - "input_columns": ["video"], - "output_columns": ["video", "height", "width", "ar"], - } - ) if "caption" in self.output_columns and not self._text_emb_folder: if tokenizer is None: diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index 8fe6e3f0a0..7829448fa3 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -59,6 +59,10 @@ def get_latents(self, video_tokens: Tensor) -> Tensor: video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME return video_emb + def set_train(self, mode=True): + # Set the diffusion model only to train or eval mode + self.network.set_train(mode) + def construct(self, video_tokens: Tensor, ul2_tokens: Tensor, byt5_tokens: Tensor) -> Tensor: latent_embedding = self.get_latents(video_tokens) ul2_emb = self.get_condition_embeddings(ul2_tokens) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index fa4fc92ea1..0a38dc26b2 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -__all__ = ["RFLOW", "RFlowLossWrapper"] +__all__ = ["RFLOW", "RFlowLossWrapper", "RFlowEvalLoss"] class LogisticNormal(nn.Cell): @@ -174,4 +174,21 @@ def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: timesteps = timesteps[:, None, None, None, None] # 3.1.2 First Eqa. - return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise # TODO: check for zero SNR + return timesteps * x + (1 - (1 - self.eps) * timesteps) * noise # TODO: check for zero SNR + + +class RFlowEvalLoss(nn.Cell): + def __init__(self, network: RFlowLossWrapper, num_sampling_steps: int = 10): + super().__init__() + self.network = network + self.timesteps = Tensor( + np.linspace(0, network.num_timesteps, num_sampling_steps + 2)[1:-1].reshape(-1, 1), dtype=ms.float32 + ) + + def construct(self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, **kwargs) -> Tensor: + loss = Tensor(0, dtype=ms.float32) + timesteps = mint.tile(self.timesteps, (1, x.shape[0])) + for t in timesteps: + loss += self.network(x, ul2_emb, metaclip_emb, byt5_emb, t) + + return loss / len(self.timesteps) diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/moviegen/utils/__init__.py index 62934d58bc..73fb65477b 100644 --- a/examples/moviegen/moviegen/utils/__init__.py +++ b/examples/moviegen/moviegen/utils/__init__.py @@ -1,3 +1,4 @@ +from .callbacks import * from .ema import * from .model_utils import * from .utils import * diff --git a/examples/moviegen/moviegen/utils/callbacks.py b/examples/moviegen/moviegen/utils/callbacks.py new file mode 100644 index 0000000000..9e3c5bda76 --- /dev/null +++ b/examples/moviegen/moviegen/utils/callbacks.py @@ -0,0 +1,122 @@ +import logging +import os +import time +from typing import List, Optional + +import numpy as np + +from mindspore import Callback, RunContext, nn, ops +from mindspore.communication import GlobalComm +from mindspore.dataset import GeneratorDataset + +from mindone.trainers.ema import EMA + +__all__ = ["ValidationCallback", "PerfRecorderCallback"] + +_logger = logging.getLogger(__name__) + + +class ValidationCallback(Callback): + """ + A callback for performing validation during training on a per-step basis. + + Args: + network (nn.Cell): The neural network model to be validated. + dataset (GeneratorDataset): The dataset to use for validation. + rank_id (int): The rank ID of the current process. Defaults to 0. + valid_frequency (int, optional): The frequency of validation in terms of training steps. + Defaults to 100. + ema (Optional[EMA], optional): An Exponential Moving Average object for the model weights. + If provided, it will be used during validation. Defaults to None. + + Example: + >>> model = MyModel() + >>> val_dataset = MyValidationDataset() + >>> val_callback = ValidationCallback(model, val_dataset, valid_frequency=500) + >>> model.train(num_epochs, train_dataset, callbacks=[val_callback]) + """ + + def __init__( + self, + network: nn.Cell, + dataset: GeneratorDataset, + rank_id: int = 0, + valid_frequency: int = 100, + ema: Optional[EMA] = None, + ): + super().__init__() + self.network = network + self.dataset = dataset + self.rank_id = rank_id + self.valid_frequency = valid_frequency + self.ema = ema + self.reduce = ops.AllReduce() if GlobalComm.INITED else None + + def on_train_step_end(self, run_context: RunContext): + cb_params = run_context.original_args() + cb_params.eval_results = {} # Erase previous validation results + cur_step = cb_params.cur_step_num + + if cur_step % self.valid_frequency == 0: + if self.ema is not None: + self.ema.swap_before_eval() + self.network.set_train(False) + + loss = 0 + for data in self.dataset.create_tuple_iterator(num_epochs=1): + loss += self.network(*data) + loss = loss / self.dataset.get_dataset_size() + if self.reduce is not None: + loss = self.reduce(loss) + loss = loss.item() + + cb_params.eval_results = {"eval_loss": loss} + _logger.info(f"Step: {cur_step}, Validation Loss: {loss}.") + + self.network.set_train(True) + if self.ema is not None: + self.ema.swap_after_eval() + + +class PerfRecorderCallback(Callback): + """ + Improved version of `mindone.trainers.recorder.PerfRecorder` that tracks validation metrics as well. + Used here first for testing. + """ + + def __init__( + self, + save_dir: str, + file_name: str = "result.log", + metric_names: List[str] = None, + separator: str = "\t", + resume: bool = False, + ): + super().__init__() + self._sep = separator + self._metrics = metric_names or [] + + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + self._log_file = os.path.join(save_dir, file_name) + if not resume: + header = separator.join([f"{'step':<7}", f"{'loss':<10}", "train_time(s)"] + self._metrics) + with open(self._log_file, "w", encoding="utf-8") as fp: + fp.write(header + "\n") + + def on_train_step_begin(self, run_context: RunContext): + self._step_time = time.perf_counter() + + def on_train_step_end(self, run_context: RunContext): + step_time = time.perf_counter() - self._step_time + cb_params = run_context.original_args() + cur_step = cb_params.cur_step_num + loss = cb_params.net_outputs + loss = loss[0].asnumpy() if isinstance(loss, tuple) else np.mean(loss.asnumpy()) + eval_loss = cb_params.get("eval_results", []) + metrics = (self._sep + self._sep.join([f"{eval_loss[m]:.6f}" for m in self._metrics])) if eval_loss else "" + + with open(self._log_file, "a", encoding="utf-8") as fp: + fp.write( + self._sep.join([f"{cur_step:<7}", f"{loss.item():<10.6f}", f"{step_time:<13.3f}"]) + metrics + "\n" + ) diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index cba5ecdc64..00f639ecd7 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -16,9 +16,8 @@ from moviegen.dataset import ImageVideoDataset from moviegen.pipelines import DiffusionWithLoss -from moviegen.schedulers import RFlowLossWrapper -from moviegen.utils import EMA -from moviegen.utils.model_utils import MODEL_DTYPE, init_model +from moviegen.schedulers import RFlowEvalLoss, RFlowLossWrapper +from moviegen.utils import EMA, MODEL_DTYPE, PerfRecorderCallback, ValidationCallback, init_model from mindone.data import create_dataloader from mindone.trainers import create_optimizer @@ -35,7 +34,7 @@ def main(args): # 1. init env - args.train.output_path = os.path.join(__dir__, args.train.output_path.relative) + args.train.output_path = args.train.output_path.absolute os.makedirs(args.train.output_path, exist_ok=True) device_id, rank_id, device_num = init_train_env(**args.env) set_logger("", output_dir=args.train.output_path, rank=rank_id) @@ -71,6 +70,18 @@ def main(args): dataset, transforms=transforms, device_num=device_num, rank_id=rank_id, **args.dataloader ) + eval_diffusion_with_loss, val_dataloader = None, None + if args.valid.dataset is not None: + val_dataset = ImageVideoDataset(**args.valid.dataset.init_args) + transforms = None + if not args.valid.dataset.init_args.apply_transforms_dataset: + transforms = val_dataset.train_transforms(args.valid.dataset.init_args.target_size) + val_dataloader = create_dataloader( + val_dataset, transforms=transforms, device_num=device_num, rank_id=rank_id, **args.valid.dataloader + ) + eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps) + eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, vae) + # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR epochs = ceil(args.train.steps / dataloader.get_dataset_size()) @@ -89,7 +100,7 @@ def main(args): model = Model(net_with_grads) # 5.4 callbacks - callbacks = [OverflowMonitor(), StopAtStepCallback(train_steps=args.train.steps)] + callbacks = [OverflowMonitor()] if rank_id == 0: callbacks.extend( [ @@ -107,6 +118,26 @@ def main(args): ), ] ) + + if val_dataloader is not None: + callbacks.append( + ValidationCallback( + network=eval_diffusion_with_loss, + dataset=val_dataloader, + rank_id=rank_id, + valid_frequency=args.valid.frequency, + ema=ema, + ) + ) + callbacks.extend( + [ + PerfRecorderCallback(args.train.output_path, file_name="result_val.log", metric_names=["eval_loss"]), + StopAtStepCallback(train_steps=args.train.steps), + ] + ) + + # 5.5 print out key info and save config + if rank_id == 0: num_params_vae, num_params_trainable_vae = count_params(vae) num_params_network, num_params_trainable_network = count_params(network) num_params = num_params_vae + num_params_network @@ -115,14 +146,16 @@ def main(args): key_info += "\n".join( [ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}", + f"Debug mode: {args.env.debug}", f"JIT level: {args.env.jit_level}", f"Distributed mode: {args.env.distributed}", f"Data path: {args.dataset.csv_path}", f"Number of samples: {len(dataset)}", - f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,})", - f"Num trainable params: {num_params_trainable:,}", + f"Model name: {args.model.name}", f"Model dtype: {args.model.dtype}", f"VAE dtype: {args.vae.dtype}", + f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,})", + f"Num trainable params: {num_params_trainable:,}", f"Learning rate: {args.train.lr_scheduler.init_args.learning_rate:.0e}", f"Batch size: {args.dataloader.batch_size}", f"Image size: {args.dataset.target_size}", @@ -206,5 +239,24 @@ def main(args): }, instantiate=False, ) + + # validation + val_group = parser.add_argument_group("Validation") + val_group.add_argument( + "valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation." + ) + val_group.add_argument("valid.frequency", type=int, default=1, help="Frequency of validation in steps.") + val_group.add_subclass_arguments( + ImageVideoDataset, + "valid.dataset", + skip={"frames_mask_generator", "t_compress_func"}, + instantiate=False, + required=False, + ) + val_group.add_function_arguments( + create_dataloader, "valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} + ) + parser.link_arguments("env.debug", "valid.dataloader.debug", apply_on="parse") + cfg = parser.parse_args() main(cfg) diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index 385c26f7b9..18d0ccc0b8 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -101,6 +101,10 @@ def __init__( if gradient_accumulation_steps > 1: self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros") + def set_train(self, mode=True): + # Delegate the setting of training mode behavior to the network. + self.network.set_train(mode) + def construct(self, *inputs): # compute loss weights = self.weights From 1410f37c203f490fe79cf7d9e801df96286b47bc Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:14:49 +0800 Subject: [PATCH 052/122] add ReduceLROnPlateau --- .../configs/train/stage1_t2i_256x256.yaml | 14 +++- .../configs/train/stage2_t2iv_256x256.yaml | 15 +++- examples/moviegen/moviegen/utils/callbacks.py | 83 ++++++++++++++++++- examples/moviegen/moviegen/utils/utils.py | 9 +- examples/moviegen/train.py | 37 +++++---- 5 files changed, 127 insertions(+), 31 deletions(-) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index eac569e618..ae513aa907 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -36,10 +36,16 @@ train: output_path: ../../output/stage1_t2i_256x256 # the path is relative to this config lr_scheduler: - class_path: mindspore.nn.WarmUpLR - init_args: - learning_rate: 1.0e-4 - warmup_steps: 1000 + name: constant + lr: 1.0e-4 + warmup_steps: 1000 + + lr_reduce_on_plateau: + factor: 0.5 + patience: 10 # in the number of validation steps, i.e., valid.frequency * patience steps + mode: min + min_delta: 0.01 + min_lr: 1.0e-6 optimizer: name: adamw_re diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index becfc16898..d5aa1631a0 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -37,10 +37,17 @@ train: output_path: ../../output/stage2_t2iv_256x256 # the path is relative to this config lr_scheduler: - class_path: mindspore.nn.WarmUpLR - init_args: - learning_rate: 6.0e-5 - warmup_steps: 1000 + name: constant + lr: 6.0e-5 + warmup_steps: 1000 + + lr_reduce_on_plateau: + alpha_smooth: 0.01 + factor: 0.5 + patience: 5000 + mode: min + min_delta: 0.01 + min_lr: 1.0e-6 optimizer: name: adamw_re diff --git a/examples/moviegen/moviegen/utils/callbacks.py b/examples/moviegen/moviegen/utils/callbacks.py index 9e3c5bda76..b088f14394 100644 --- a/examples/moviegen/moviegen/utils/callbacks.py +++ b/examples/moviegen/moviegen/utils/callbacks.py @@ -1,17 +1,21 @@ import logging import os import time -from typing import List, Optional +from typing import List, Literal, Optional import numpy as np +import pandas as pd -from mindspore import Callback, RunContext, nn, ops +from mindspore import Callback, Parameter, ReduceLROnPlateau, RunContext, Tensor +from mindspore import dtype as mstype +from mindspore import mint, nn, ops from mindspore.communication import GlobalComm from mindspore.dataset import GeneratorDataset +from mindspore.ops import functional as F from mindone.trainers.ema import EMA -__all__ = ["ValidationCallback", "PerfRecorderCallback"] +__all__ = ["ValidationCallback", "PerfRecorderCallback", "ReduceLROnPlateauByStep"] _logger = logging.getLogger(__name__) @@ -41,6 +45,7 @@ def __init__( network: nn.Cell, dataset: GeneratorDataset, rank_id: int = 0, + alpha_smooth: float = 0.01, valid_frequency: int = 100, ema: Optional[EMA] = None, ): @@ -48,9 +53,11 @@ def __init__( self.network = network self.dataset = dataset self.rank_id = rank_id + self.alpha_smooth = alpha_smooth self.valid_frequency = valid_frequency self.ema = ema self.reduce = ops.AllReduce() if GlobalComm.INITED else None + self.data = pd.Series(dtype=np.float32) def on_train_step_end(self, run_context: RunContext): cb_params = run_context.original_args() @@ -70,7 +77,10 @@ def on_train_step_end(self, run_context: RunContext): loss = self.reduce(loss) loss = loss.item() - cb_params.eval_results = {"eval_loss": loss} + self.data = pd.concat([self.data, pd.Series(loss)], ignore_index=True) + loss_smoothed = self.data.ewm(alpha=self.alpha_smooth).mean().iloc[-1] + + cb_params.eval_results = {"eval_loss": loss, "eval_loss_smoothed": loss_smoothed} _logger.info(f"Step: {cur_step}, Validation Loss: {loss}.") self.network.set_train(True) @@ -120,3 +130,68 @@ def on_train_step_end(self, run_context: RunContext): fp.write( self._sep.join([f"{cur_step:<7}", f"{loss.item():<10.6f}", f"{step_time:<13.3f}"]) + metrics + "\n" ) + + +class ReduceLROnPlateauByStep(ReduceLROnPlateau): + """ + Extends ReduceLROnPlateau to reduce the learning rate at the end of a step and incorporates loss smoothing. + """ + + def __init__( + self, + optimizer, + monitor: str = "eval_loss_smoothed", + factor: float = 0.1, + patience: int = 10, + mode: Literal["auto", "min", "max"] = "auto", + min_delta: float = 1e-4, + cooldown: int = 0, + min_lr: float = 0.0, + ): + super().__init__(monitor, factor, patience, mode=mode, min_delta=min_delta, cooldown=cooldown, min_lr=min_lr) + self.optimizer = optimizer + self.min_lr = Tensor(self.min_lr, dtype=mstype.float32) + + def on_train_step_end(self, run_context): + """ + monitors the training process and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Copy of the original `on_train_step_end()` with changes to add loss alpha smoothing. + + Args: + run_context (RunContext): Context information of the model. For more details, + please refer to :class:`mindspore.train.RunContext`. + """ + cb_params = run_context.original_args() + cur_step = cb_params.cur_step_num + lrs = self.optimizer.learning_rate.learning_rate + if not isinstance(lrs, Parameter): + raise ValueError("ReduceLROnPlateau does not support dynamic learning rate and group learning rate now.") + + current_monitor_value = cb_params.get("eval_results") + if current_monitor_value: + current_monitor_value = current_monitor_value[self.monitor] + + if self.cooldown_counter > 0: + self.cooldown_counter -= 1 + self.wait = 0 + + if self.is_improvement(current_monitor_value, self.best): + self.best = current_monitor_value + self.wait = 0 + elif self.cooldown_counter <= 0: + self.wait += 1 + if self.wait >= self.patience: + if lrs[cur_step] > self.min_lr: # FIXME: doesn't hold for future LRs + new_lr = lrs * self.factor + min_lr = mint.tile(self.min_lr, lrs.shape) + new_lr = mint.where(new_lr < min_lr, min_lr, new_lr) + F.assign(self.optimizer.learning_rate.learning_rate, new_lr) + _logger.info(f"Step {cur_step}: reducing learning rate to {new_lr[cur_step]}.") + self.cooldown_counter = self.cooldown + self.wait = 0 + + def on_train_epoch_end(self, run_context): + # Use `on_train_step_end` instead + pass diff --git a/examples/moviegen/moviegen/utils/utils.py b/examples/moviegen/moviegen/utils/utils.py index 70bf86c777..93682df59e 100644 --- a/examples/moviegen/moviegen/utils/utils.py +++ b/examples/moviegen/moviegen/utils/utils.py @@ -1,11 +1,12 @@ import numpy as np -import mindspore as ms +from mindspore import Tensor +from mindspore import dtype as mstype __all__ = ["to_numpy"] -def to_numpy(x: ms.Tensor) -> np.ndarray: - if x.dtype == ms.bfloat16: - x = x.astype(ms.float32) +def to_numpy(x: Tensor) -> np.ndarray: + if x.dtype == mstype.bfloat16: + x = x.astype(mstype.float32) return x.asnumpy() diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 00f639ecd7..487033d694 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -6,7 +6,7 @@ from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import path_type -from mindspore import Model, amp, nn +from mindspore import Model, amp, nn, set_seed from mindspore.train.callback import TimeMonitor # TODO: remove in future when mindone is ready for install @@ -17,10 +17,11 @@ from moviegen.dataset import ImageVideoDataset from moviegen.pipelines import DiffusionWithLoss from moviegen.schedulers import RFlowEvalLoss, RFlowLossWrapper -from moviegen.utils import EMA, MODEL_DTYPE, PerfRecorderCallback, ValidationCallback, init_model +from moviegen.utils import EMA, MODEL_DTYPE, init_model +from moviegen.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback from mindone.data import create_dataloader -from mindone.trainers import create_optimizer +from mindone.trainers import create_optimizer, create_scheduler from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, StopAtStepCallback from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger @@ -37,6 +38,7 @@ def main(args): args.train.output_path = args.train.output_path.absolute os.makedirs(args.train.output_path, exist_ok=True) device_id, rank_id, device_num = init_train_env(**args.env) + set_seed(args.env.seed + rank_id) # TODO: do it better set_logger("", output_dir=args.train.output_path, rank=rank_id) # instantiate classes only after initializing training environment @@ -85,7 +87,7 @@ def main(args): # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR epochs = ceil(args.train.steps / dataloader.get_dataset_size()) - lr = initializer.train.lr_scheduler + lr = create_scheduler(steps_per_epoch=0, **args.train.lr_scheduler) # 5.2 optimizer optimizer = create_optimizer(latent_diffusion_with_loss.trainable_params(), lr=lr, **args.train.optimizer) @@ -120,14 +122,17 @@ def main(args): ) if val_dataloader is not None: - callbacks.append( - ValidationCallback( - network=eval_diffusion_with_loss, - dataset=val_dataloader, - rank_id=rank_id, - valid_frequency=args.valid.frequency, - ema=ema, - ) + callbacks.extend( + [ + ValidationCallback( + network=eval_diffusion_with_loss, + dataset=val_dataloader, + rank_id=rank_id, + valid_frequency=args.valid.frequency, + ema=ema, + ), + ReduceLROnPlateauByStep(optimizer, **args.train.lr_reduce_on_plateau), + ] ) callbacks.extend( [ @@ -156,7 +161,7 @@ def main(args): f"VAE dtype: {args.vae.dtype}", f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,})", f"Num trainable params: {num_params_trainable:,}", - f"Learning rate: {args.train.lr_scheduler.init_args.learning_rate:.0e}", + f"Learning rate: {args.train.lr_scheduler.lr:.0e}", f"Batch size: {args.dataloader.batch_size}", f"Image size: {args.dataset.target_size}", f"Frames: {args.dataset.sample_n_frames}", @@ -201,8 +206,9 @@ def main(args): create_dataloader, "dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") - parser.add_subclass_arguments( - nn.learning_rate_schedule.LearningRateSchedule, "train.lr_scheduler", fail_untyped=False + parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"}) + parser.add_class_arguments( + ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False ) parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"}) parser.add_subclass_arguments( @@ -222,6 +228,7 @@ def main(args): help="Output directory to save training results.", ) parser.add_argument("--train.steps", default=100, type=int, help="Number of steps to train. Default: 100.") + parser.link_arguments("train.steps", "train.lr_scheduler.total_steps", apply_on="parse") parser.add_class_arguments( EvalSaveCallback, "train.save", From 625ee0d97747e88781328cc418ad33925802d89f Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:22:26 +0800 Subject: [PATCH 053/122] save top K checkpoints --- .../configs/train/stage1_t2i_256x256.yaml | 10 +++---- examples/moviegen/train.py | 28 +++++++++--------- mindone/trainers/callback.py | 29 ++++++++++--------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index ae513aa907..b076be1153 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -32,7 +32,7 @@ dataloader: num_workers_dataset: 4 train: - steps: 20000 + steps: 30000 output_path: ../../output/stage1_t2i_256x256 # the path is relative to this config lr_scheduler: @@ -42,7 +42,7 @@ train: lr_reduce_on_plateau: factor: 0.5 - patience: 10 # in the number of validation steps, i.e., valid.frequency * patience steps + patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps mode: min min_delta: 0.01 min_lr: 1.0e-6 @@ -69,8 +69,8 @@ train: clip_norm: 1.0 save: - ckpt_save_policy: latest_k - ckpt_save_interval: &save_interval 500 + ckpt_save_policy: top_k + ckpt_save_interval: &save_interval 100 ckpt_max_keep: 10 log_interval: 1 save_ema_only: False @@ -78,7 +78,7 @@ train: valid: sampling_steps: 10 - frequency: *save_interval + frequency: *save_interval # train.save.ckpt_save_interval should be divisible by the frequency dataset: csv_path: CSV_PATH diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 487033d694..1342a1e7c6 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -103,6 +103,21 @@ def main(args): # 5.4 callbacks callbacks = [OverflowMonitor()] + if val_dataloader is not None: + callbacks.extend( + [ + ValidationCallback( + network=eval_diffusion_with_loss, + dataset=val_dataloader, + rank_id=rank_id, + alpha_smooth=0.01, # FIXME + valid_frequency=args.valid.frequency, + ema=ema, + ), + ReduceLROnPlateauByStep(optimizer, **args.train.lr_reduce_on_plateau), + ] + ) + if rank_id == 0: callbacks.extend( [ @@ -121,19 +136,6 @@ def main(args): ] ) - if val_dataloader is not None: - callbacks.extend( - [ - ValidationCallback( - network=eval_diffusion_with_loss, - dataset=val_dataloader, - rank_id=rank_id, - valid_frequency=args.valid.frequency, - ema=ema, - ), - ReduceLROnPlateauByStep(optimizer, **args.train.lr_reduce_on_plateau), - ] - ) callbacks.extend( [ PerfRecorderCallback(args.train.output_path, file_name="result_val.log", metric_names=["eval_loss"]), diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index 08b39c0751..3172f2bdd7 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -158,19 +158,22 @@ def on_train_step_end(self, run_context): ) append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None - if self.ema is not None: - if not self.save_ema_only: - self.ckpt_manager.save( - self.net_to_save, - None, - ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), - append_dict=append_dict, - ) - # swap ema weight and network weight - self.ema.swap_before_eval() - - # save history checkpoints - self.ckpt_manager.save(self.net_to_save, None, ckpt_name=ckpt_name, append_dict=append_dict) + perf = cb_params.get("eval_results") + if perf: + perf = perf["eval_loss_smoothed"] + if self.ema is not None: + if not self.save_ema_only: + self.ckpt_manager.save( + self.net_to_save, + perf, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=append_dict, + ) + # swap ema weight and network weight + self.ema.swap_before_eval() + + # save history checkpoints + self.ckpt_manager.save(self.net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict) if self.save_training_resume: # TODO: resume training for step. From bf6988eb3c827197018f158cf4c28eb456de63db Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:55:07 +0800 Subject: [PATCH 054/122] add drop text conditioning for training --- .../configs/train/stage1_t2i_256x256.yaml | 4 ++++ .../configs/train/stage2_t2iv_256x256.yaml | 4 ++++ examples/moviegen/moviegen/dataset/dataset.py | 15 +++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index b076be1153..a41f6b141e 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -22,6 +22,10 @@ dataset: text_emb_folder: ul2: UL2_FOLDER byt5: BYT5_FOLDER + empty_text_emb: + ul2: EMPTY_TEXT_EMB + byt5: EMPTY_TEXT_EMB + text_drop_prob: 0.2 target_size: [ 256, 256 ] apply_transforms_dataset: True output_columns: ["video", "ul2_caption", "byt5_caption"] diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index d5aa1631a0..043b92db2c 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -22,6 +22,10 @@ dataset: text_emb_folder: ul2: UL2_FOLDER byt5: BYT5_FOLDER + empty_text_emb: + ul2: EMPTY_TEXT_EMB + byt5: EMPTY_TEXT_EMB + text_drop_prob: 0.2 target_size: [ 256, 256 ] sample_n_frames: 272 # FIXME: add variable frames support. FIXME: 17 * 16 = 272 frames of OSv1.2 VAE apply_transforms_dataset: True diff --git a/examples/moviegen/moviegen/dataset/dataset.py b/examples/moviegen/moviegen/dataset/dataset.py index a729bc6607..5f763e71d7 100644 --- a/examples/moviegen/moviegen/dataset/dataset.py +++ b/examples/moviegen/moviegen/dataset/dataset.py @@ -27,6 +27,8 @@ def __init__( csv_path: str, video_folder: str, text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, + empty_text_emb: Optional[Union[str, Dict[str, str]]] = None, + text_drop_prob: float = 0.2, vae_latent_folder: Optional[str] = None, vae_downsample_rate: float = 8.0, vae_scale_factor: float = 0.18215, @@ -49,7 +51,17 @@ def __init__( self._frames = sample_n_frames self._stride = sample_stride self._min_length = (self._frames - 1) * self._stride + 1 + self._text_emb_folder = text_emb_folder + self._empty_text_emb = empty_text_emb if text_drop_prob > 0 else None + if self._empty_text_emb: + if isinstance(self._empty_text_emb, str): + assert os.path.exists(self._empty_text_emb), f"Empty text embedding not found: {self._empty_text_emb}" + else: + for path in self._empty_text_emb.values(): + assert os.path.exists(path), f"Empty text embedding not found: {path}" + self._text_drop_prob = text_drop_prob + self._vae_latent_folder = vae_latent_folder self._vae_downsample_rate = vae_downsample_rate self._vae_scale_factor = vae_scale_factor @@ -139,6 +151,9 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup num_frames = self._frames if self._text_emb_folder: + if self._empty_text_emb and random.random() <= self._text_drop_prob: + data["text_emb"] = self._empty_text_emb + if isinstance(data["text_emb"], str): with np.load(data["text_emb"]) as td: data.update({"caption": td["text_emb"], "mask": td["mask"]}) From 1a2bfb790ed868fc68206d788ac1e7b7b28e59f0 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:01:59 +0800 Subject: [PATCH 055/122] fix eval loss calculation --- examples/moviegen/moviegen/utils/callbacks.py | 13 +++++++------ examples/moviegen/scripts/stage1_train.sh | 2 +- examples/moviegen/train.py | 11 ++++------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/moviegen/moviegen/utils/callbacks.py b/examples/moviegen/moviegen/utils/callbacks.py index b088f14394..5c5587c6af 100644 --- a/examples/moviegen/moviegen/utils/callbacks.py +++ b/examples/moviegen/moviegen/utils/callbacks.py @@ -9,7 +9,7 @@ from mindspore import Callback, Parameter, ReduceLROnPlateau, RunContext, Tensor from mindspore import dtype as mstype from mindspore import mint, nn, ops -from mindspore.communication import GlobalComm +from mindspore.communication import GlobalComm, get_group_size from mindspore.dataset import GeneratorDataset from mindspore.ops import functional as F @@ -27,7 +27,7 @@ class ValidationCallback(Callback): Args: network (nn.Cell): The neural network model to be validated. dataset (GeneratorDataset): The dataset to use for validation. - rank_id (int): The rank ID of the current process. Defaults to 0. + alpha_smooth (float, optional): The smoothing factor for the loss. Defaults to 0.01. valid_frequency (int, optional): The frequency of validation in terms of training steps. Defaults to 100. ema (Optional[EMA], optional): An Exponential Moving Average object for the model weights. @@ -44,7 +44,6 @@ def __init__( self, network: nn.Cell, dataset: GeneratorDataset, - rank_id: int = 0, alpha_smooth: float = 0.01, valid_frequency: int = 100, ema: Optional[EMA] = None, @@ -52,11 +51,13 @@ def __init__( super().__init__() self.network = network self.dataset = dataset - self.rank_id = rank_id self.alpha_smooth = alpha_smooth self.valid_frequency = valid_frequency self.ema = ema - self.reduce = ops.AllReduce() if GlobalComm.INITED else None + self.reduce, self.rank_size = None, 1 + if GlobalComm.INITED: + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM) + self.rank_size = get_group_size() self.data = pd.Series(dtype=np.float32) def on_train_step_end(self, run_context: RunContext): @@ -75,7 +76,7 @@ def on_train_step_end(self, run_context: RunContext): loss = loss / self.dataset.get_dataset_size() if self.reduce is not None: loss = self.reduce(loss) - loss = loss.item() + loss = loss.item() / self.rank_size self.data = pd.concat([self.data, pd.Series(loss)], ignore_index=True) loss_smoothed = self.data.ewm(alpha=self.alpha_smooth).mean().iloc[-1] diff --git a/examples/moviegen/scripts/stage1_train.sh b/examples/moviegen/scripts/stage1_train.sh index 7831f4b7fc..ae1402acde 100644 --- a/examples/moviegen/scripts/stage1_train.sh +++ b/examples/moviegen/scripts/stage1_train.sh @@ -7,7 +7,7 @@ export GLOG_v=2 output_dir=output/stage1_t2i_256x256/$(date +"%Y.%m.%d-%H.%M.%S") -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ python train.py \ --config configs/train/stage1_t2i_256x256.yaml \ --env.mode 0 \ diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 1342a1e7c6..8194b08335 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -109,7 +109,6 @@ def main(args): ValidationCallback( network=eval_diffusion_with_loss, dataset=val_dataloader, - rank_id=rank_id, alpha_smooth=0.01, # FIXME valid_frequency=args.valid.frequency, ema=ema, @@ -133,15 +132,13 @@ def main(args): train_steps=args.train.steps, **args.train.save, ), + PerfRecorderCallback( + args.train.output_path, file_name="result_val.log", metric_names=["eval_loss", "eval_loss_smoothed"] + ), ] ) - callbacks.extend( - [ - PerfRecorderCallback(args.train.output_path, file_name="result_val.log", metric_names=["eval_loss"]), - StopAtStepCallback(train_steps=args.train.steps), - ] - ) + callbacks.append(StopAtStepCallback(train_steps=args.train.steps)) # 5.5 print out key info and save config if rank_id == 0: From 04e11f34056291b192578ddfc9ab8d411206146d Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:34:48 +0800 Subject: [PATCH 056/122] add model parallel --- examples/moviegen/scripts/stage2_train.sh | 3 ++- .../tests/parallel/test_llama3_parallel.py | 8 +++++--- .../tests/parallel/test_rflow_parallel.py | 16 +++++++++------- examples/moviegen/train.py | 16 +++++++++++++--- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/examples/moviegen/scripts/stage2_train.sh b/examples/moviegen/scripts/stage2_train.sh index ac6b6855fb..921f7c6023 100644 --- a/examples/moviegen/scripts/stage2_train.sh +++ b/examples/moviegen/scripts/stage2_train.sh @@ -14,7 +14,8 @@ python train.py \ --env.jit_level O1 \ --env.max_device_memory 59GB \ --env.distributed True \ - --train.settings.zero_stage 2 \ + --model.model_parallelism True \ + --train.model_parallel.model_parallel_shards 8 \ --dataset.csv_path CSV_PATH \ --dataset.video_folder VIDEO_FOLDER \ --dataset.text_emb_folder.ul2 UL2_FOLDER \ diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py index 59b47dd951..500cbcd3cd 100644 --- a/examples/moviegen/tests/parallel/test_llama3_parallel.py +++ b/examples/moviegen/tests/parallel/test_llama3_parallel.py @@ -25,11 +25,13 @@ def construct(self, *inputs): return output.mean() * 1024.0 -def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, Tensor, Tensor]: +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) timestep = ms.Tensor([35], dtype=ms.int64) - text_embedding = ops.rand([1, 64, 4096], dtype=dtype) - return latent_embedding, timestep, text_embedding + ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) + metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) + byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) + return latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb def get_network_config(model_parallelism=False, fused_tensor_parallel=False): diff --git a/examples/moviegen/tests/parallel/test_rflow_parallel.py b/examples/moviegen/tests/parallel/test_rflow_parallel.py index 6ffd4b254b..eec776fba7 100644 --- a/examples/moviegen/tests/parallel/test_rflow_parallel.py +++ b/examples/moviegen/tests/parallel/test_rflow_parallel.py @@ -5,16 +5,16 @@ from moviegen.schedulers import RFlowLossWrapper import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor +from mindspore import Tensor, nn, ops from mindspore.communication import get_group_size, init from mindone.utils.seed import set_random_seed class SimpleNet(nn.Cell): - def construct(self, x: Tensor, timestamp: Tensor, text_embedding: Tensor): + def construct( + self, x: Tensor, timestamp: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor + ) -> Tensor: return x.to(ms.float32) @property @@ -22,10 +22,12 @@ def dtype(self): return ms.float32 -def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, Tensor]: +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) - text_embedding = ops.rand([1, 64, 4096], dtype=dtype) - return latent_embedding, text_embedding + ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) + metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) + byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) + return latent_embedding, ul2_emb, metaclip_emb, byt5_emb def run_network(mode: int = 0): diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 8194b08335..64808d6ca5 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -15,6 +15,7 @@ sys.path.append(mindone_lib_path) from moviegen.dataset import ImageVideoDataset +from moviegen.parallel import create_parallel_group from moviegen.pipelines import DiffusionWithLoss from moviegen.schedulers import RFlowEvalLoss, RFlowLossWrapper from moviegen.utils import EMA, MODEL_DTYPE, init_model @@ -38,7 +39,15 @@ def main(args): args.train.output_path = args.train.output_path.absolute os.makedirs(args.train.output_path, exist_ok=True) device_id, rank_id, device_num = init_train_env(**args.env) - set_seed(args.env.seed + rank_id) # TODO: do it better + + # 1.1 init model parallel + shard_rank_id = rank_id + if (shards := args.train.model_parallel.model_parallel_shards) > 1: + create_parallel_group(**args.train.model_parallel) + device_num = device_num // shards + shard_rank_id = rank_id // shards + + set_seed(args.env.seed + shard_rank_id) # TODO: do it better set_logger("", output_dir=args.train.output_path, rank=rank_id) # instantiate classes only after initializing training environment @@ -69,7 +78,7 @@ def main(args): dataset.train_transforms(args.dataset.target_size) if not args.dataset.apply_transforms_dataset else None ) dataloader = create_dataloader( - dataset, transforms=transforms, device_num=device_num, rank_id=rank_id, **args.dataloader + dataset, transforms=transforms, device_num=device_num, rank_id=shard_rank_id, **args.dataloader ) eval_diffusion_with_loss, val_dataloader = None, None @@ -79,7 +88,7 @@ def main(args): if not args.valid.dataset.init_args.apply_transforms_dataset: transforms = val_dataset.train_transforms(args.valid.dataset.init_args.target_size) val_dataloader = create_dataloader( - val_dataset, transforms=transforms, device_num=device_num, rank_id=rank_id, **args.valid.dataloader + val_dataset, transforms=transforms, device_num=device_num, rank_id=shard_rank_id, **args.valid.dataloader ) eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps) eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, vae) @@ -205,6 +214,7 @@ def main(args): create_dataloader, "dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") + parser.add_function_arguments(create_parallel_group, "train.model_parallel") parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"}) parser.add_class_arguments( ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False From 8af74372444239639e933b514de77750e846f372 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 29 Oct 2024 11:41:57 +0800 Subject: [PATCH 057/122] hack for model parallel --- mindone/trainers/train_step.py | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index 18d0ccc0b8..b33532c848 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -1,6 +1,8 @@ """Train step wrapper supporting setting drop overflow update, ema etc""" from typing import Optional +from typing import Callable, Tuple + from packaging import version import mindspore as ms @@ -10,6 +12,7 @@ from mindspore.boost.grad_accumulation import gradient_clear_op as _grad_clear_op from mindspore.common import RowTensor from mindspore.common import dtype as mstype +from mindspore.communication import get_group_size from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -35,6 +38,38 @@ def tensor_grad_scale_row_tensor(scale, grad): ) +communicate_opt = C.MultitypeFuncGraph("communicate_opt") + + +@communicate_opt.register("Function", "Number", "Tensor", "Bool") +def _communicate_opt(func: Callable[[Tensor], Tensor], num: int, grad: Tensor, need_reduce: bool): + if not need_reduce: + return grad + grad = func(grad) + grad = grad / num + return grad + + +class GradReducer(nn.Cell): + def __init__(self): + super().__init__() + self.hypermap = C.HyperMap() + self.is_single = False + try: + self.num = get_group_size() + except RuntimeError: + self.is_single = True + + if not self.is_single: + self.reduce = ops.AllReduce() + + def construct(self, grads: Tuple[Tensor], need_reduce: Tuple[bool]): + if self.is_single: + return grads + grads = self.hypermap(ops.partial(communicate_opt, self.reduce, self.num), grads, need_reduce) + return grads + + class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell): """TrainStep with ema and clip grad. @@ -91,6 +126,9 @@ def __init__( self.map = ops.Map() self.partial = ops.Partial() + self.grad_reducer = GradReducer() + self.need_reduce = tuple([2048 in x.shape for x in self.weights]) + # zero init self.zero_helper = zero_helper self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0 @@ -130,7 +168,7 @@ def construct(self, *inputs): grads = self.zero_helper.cal_gradients(grads) if self.accum_steps == 1: - grads = self.grad_reducer(grads) + grads = self.grad_reducer(grads, self.need_reduce) scaling_sens = ops.depend(scaling_sens, grads) # 2. down-scale gradients by loss_scale. grads = grads / scaling_sense / grad_accum_steps From bf505d4f308a4f02b2c5fbfa97eb6e251c90b830 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:25:33 +0800 Subject: [PATCH 058/122] fix hack --- examples/moviegen/train.py | 10 ++++++++-- mindone/trainers/train_step.py | 7 +++---- mindone/trainers/zero.py | 5 ++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 64808d6ca5..2f251f27e3 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -1,5 +1,6 @@ import logging import os +import re import sys from math import ceil @@ -105,7 +106,12 @@ def main(args): ema = EMA(latent_diffusion_with_loss.network, **args.train.ema.init_args) if args.train.ema else None loss_scaler = initializer.train.loss_scaler net_with_grads = prepare_train_network( - latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings + latent_diffusion_with_loss, + optimizer=optimizer, + scale_sense=loss_scaler, + ema=ema, + need_reduce=tuple(bool(re.search(r"layers\.(\d+)\.mlp", param.name)) for param in optimizer.parameters), + **args.train.settings, ) model = Model(net_with_grads) @@ -227,7 +233,7 @@ def main(args): help="mindspore.nn.FixedLossScaleUpdateCell or mindspore.nn.DynamicLossScaleUpdateCell", ) parser.add_function_arguments( - prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema"} + prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema", "need_reduce"} ) parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False) parser.add_argument( diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index b33532c848..b33eca33f3 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -1,7 +1,5 @@ """Train step wrapper supporting setting drop overflow update, ema etc""" -from typing import Optional - -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple from packaging import version @@ -101,6 +99,7 @@ def __init__( clip_norm=1.0, verbose=False, zero_helper=None, + need_reduce: Optional[Tuple[bool]] = None, ): super().__init__(network, optimizer, scale_sense) self.ema = ema @@ -127,7 +126,7 @@ def __init__( self.partial = ops.Partial() self.grad_reducer = GradReducer() - self.need_reduce = tuple([2048 in x.shape for x in self.weights]) + self.need_reduce = need_reduce # zero init self.zero_helper = zero_helper diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 7f5c0adaff..9ccda9c629 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Literal +from typing import Literal, Optional, Tuple import mindspore as ms from mindspore import nn, ops @@ -561,6 +561,7 @@ def prepare_train_network( dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, + need_reduce: Optional[Tuple[bool, ...]] = None, ): """ Prepare network and optimizer for distributed training. @@ -599,6 +600,7 @@ def prepare_train_network( clip_grad=clip_grad, clip_norm=clip_norm, verbose=verbose, + need_reduce=need_reduce, ) return train_network @@ -628,6 +630,7 @@ def prepare_train_network( clip_norm=clip_norm, verbose=verbose, zero_helper=zero_helper, + need_reduce=need_reduce, ) return train_network From a048efb8bd00d6e8fa0b7d1d312ce3fe29bbb1f6 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 21 Nov 2024 17:52:46 +0800 Subject: [PATCH 059/122] small fixes --- .../configs/train/stage2_t2iv_256x256.yaml | 9 ++++----- examples/moviegen/moviegen/utils/model_utils.py | 2 ++ .../tests/parallel/test_llama3_parallel.py | 16 ++++++++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index 043b92db2c..e262e898e2 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -32,7 +32,7 @@ dataset: output_columns: ["video", "ul2_caption", "byt5_caption"] dataloader: - batch_size: 64 + batch_size: 1 shuffle: True num_workers_dataset: 4 @@ -46,9 +46,8 @@ train: warmup_steps: 1000 lr_reduce_on_plateau: - alpha_smooth: 0.01 factor: 0.5 - patience: 5000 + patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps mode: min min_delta: 0.01 min_lr: 1.0e-6 @@ -75,8 +74,8 @@ train: clip_norm: 1.0 save: - ckpt_save_policy: latest_k - ckpt_save_interval: 500 + ckpt_save_policy: top_k + ckpt_save_interval: &save_interval 100 ckpt_max_keep: 10 log_interval: 1 save_ema_only: False diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/moviegen/utils/model_utils.py index 6f687790c3..7a5723da80 100644 --- a/examples/moviegen/moviegen/utils/model_utils.py +++ b/examples/moviegen/moviegen/utils/model_utils.py @@ -61,6 +61,7 @@ def init_model( in_channels: int = 4, pretrained_model_path: Optional[Path_fr] = None, enable_flash_attention: bool = True, + model_parallelism: bool = False, recompute: bool = False, dtype: Literal["fp32", "fp16", "bf16"] = "fp32", ) -> LlamaModel: @@ -68,6 +69,7 @@ def init_model( model = MODEL_SPEC[name]( in_channels=in_channels, attn_implementation=attn_implementation, + model_parallelism=model_parallelism, gradient_checkpointing=recompute, dtype=MODEL_DTYPE[dtype], ) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py index 500cbcd3cd..08ec46c59d 100644 --- a/examples/moviegen/tests/parallel/test_llama3_parallel.py +++ b/examples/moviegen/tests/parallel/test_llama3_parallel.py @@ -45,7 +45,7 @@ def get_network_config(model_parallelism=False, fused_tensor_parallel=False): return config -def run_network(mode: int = 0, fused_tensor_parallel: bool = False, dtype: ms.Type = ms.float32): +def run_network(mode: int = 0, dtype: ms.Type = ms.float32): ms.set_context(mode=mode) init() @@ -56,6 +56,14 @@ def run_network(mode: int = 0, fused_tensor_parallel: bool = False, dtype: ms.Ty # prepare group create_parallel_group(model_parallel_shards=get_group_size()) + print("Non-fused tensor parallel:", flush=True) + run_parallel_network(data, fused_tensor_parallel=False) + + print("Fused tensor parallel:", flush=True) + run_parallel_network(data, fused_tensor_parallel=True) + + +def run_parallel_network(data: Tuple[Tensor, ...], fused_tensor_parallel: bool = False, dtype: ms.Type = ms.float32): # non parallel network set_random_seed(1024) non_parallel_network_cfg = get_network_config(model_parallelism=False, fused_tensor_parallel=fused_tensor_parallel) @@ -102,8 +110,4 @@ def run_network(mode: int = 0, fused_tensor_parallel: bool = False, dtype: ms.Ty "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" ) args = parser.parse_args() - print("Non-fused tensor parallel:", flush=True) - run_network(mode=args.mode, fused_tensor_parallel=False) - - print("Fused tensor parallel:", flush=True) - run_network(mode=args.mode, fused_tensor_parallel=True) + run_network(mode=args.mode) From 03e4271430e819a05d25f3c247822d0d2dc84718 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 22 Nov 2024 16:02:16 +0800 Subject: [PATCH 060/122] add temporal tile --- examples/movie_gen/mg/models/tae/losses.py | 4 +- examples/movie_gen/mg/models/tae/tae.py | 140 +++++++++++++++++- examples/movie_gen/scripts/inference_vae.py | 2 + .../movie_gen/scripts/run/run_train_tae.sh | 2 +- examples/movie_gen/scripts/train_tae.py | 6 +- examples/movie_gen/tests/test_tae.py | 24 ++- 6 files changed, 166 insertions(+), 12 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/losses.py b/examples/movie_gen/mg/models/tae/losses.py index 0a97e3e9b3..181afbd7c1 100644 --- a/examples/movie_gen/mg/models/tae/losses.py +++ b/examples/movie_gen/mg/models/tae/losses.py @@ -52,7 +52,7 @@ def kl(self, mean, logvar): return kl_loss def vae_loss_fn( - self, x, recons, mean, logvar, nll_weights=None, no_perceptual=False, no_kl=False, pixelwise_mean=False + self, x, recons, mean, logvar, nll_weights=None, no_perceptual=False, no_kl=False, pixelwise_mean=False, ): """ return: @@ -60,10 +60,10 @@ def vae_loss_fn( weighted_nll_loss: weighted mean of nll_loss weighted_kl_loss: KL divergence on posterior """ - bs = x.shape[0] # (b c t h w) -> (b*t c h w) x = _rearrange_in(x) recons = _rearrange_in(recons) + bs = x.shape[0] # reconstruction loss in pixels # FIXME: debugging: use pixelwise mean to reduce loss scale diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 986d13a3e4..81ad7434de 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -1,7 +1,9 @@ import mindspore as ms +import math from mindspore import nn, ops from .modules import Conv2_5d, Encoder, Decoder + # TODO: set z_channels to 16 SDXL_CONFIG = { @@ -57,6 +59,11 @@ def __init__( pretrained: str = None, use_recompute: bool=False, sample_deterministic: bool=False, + use_tile: bool=False, + encode_tile: int=32, + encode_overlap: int=0, + decode_tile: int=32, + decode_overlap: int=16, ): super().__init__() @@ -82,9 +89,25 @@ def __init__( self.sample_deterministic = sample_deterministic self.discard_spurious_frames = True + + # tile + self.encode_tile = encode_tile + self.time_compress = 2**len(config['temporal_downsample_level']) # 8 + self.encode_overlap = encode_overlap + self.use_tile = use_tile + if use_tile: + assert (self.encode_tile % self.time_compress == 0) and (self.encode_tile > 0), f'num tile frames should be divisable by {self.time_compress} and non-zero' + assert self.encode_overlap % self.time_compress == 0, f'overlap frames should be divisable by {self.time_compress}' + # TODO: support encode overlap + assert self.encode_overlap == 0, 'not supported' + + self.decode_tile = decode_tile + self.decode_overlap = decode_overlap + if use_recompute: - # self.recompute(self.encoder) + # TODO: uncomment if OOM + self.recompute(self.encoder) # self.recompute(self.quant_conv) # self.recompute(self.post_quant_conv) self.recompute(self.decoder) @@ -127,27 +150,132 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: if self.sample_deterministic: return posterior_mean z = self.sample(posterior_mean, posterior_logvar) - - return z + + # TODO: align interface + return z, posterior_mean, posterior_logvar def decode(self, z: ms.Tensor) -> ms.Tensor: if self.use_post_quant_conv: z = self.post_quant_conv(z) dec = self.decoder(z) + + # TODO: consider decoding latent without knowing encode input frame length return dec + def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: + tf = self.encode_tile + # of = self.encode_overlap + + z_out, mean, logvar = self.encode(x[:, :, :tf]) + + print('D--: use encode tile ', tf) + # import pdb; pdb.set_trace() + # for i in range(tf - of, x.shape[2], tf): + for i in range(tf, x.shape[2], tf): + z_cur, mean, logvar = self.encode(x[:, :, i : i + tf]) + z_out = ops.cat((z_out, z_cur), axis=2) + + # TODO: merge mean, logvar for different slices? + return z_out, mean, logvar + + def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: + # + tl = self.decode_tile // self.time_compress # tile len + ol = self.decode_overlap // self.time_compress # overlap len + stride = tl - ol + in_len = z.shape[2] + num_slices = (in_len - tl) // stride + 1 + if (in_len - tl) % stride != 0 and (in_len - tl) + stride < in_len: + num_slices += 1 + + # ms graph mode requires an init x_out + x_out = self.decode(z[:, :, :tl]) + + print('D--: use decode tile ', tl, ol) + # import pdb; pdb.set_trace() + # FIXME: the end idx is not right + # for i in range(stride, z.shape[2], stride): + visited = tl + i = stride # start position + while visited < in_len: + x_cur = self.decode(z[:, :, i : i + tl]) + x_out = ops.cat((x_out, x_cur), axis=2) + + visited = i + tl + i += stride + + # linear blend the overlapp part + if self.decode_overlap > 0: + x_out = self.blend_slices(x_out, self.decode_tile, self.decode_overlap) + + return x_out + + def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16, use_numpy=False): + """ + Use with decode_with_tile + + Args: + x: (b c t h w) is the concatenation of the decoded slices, + slice_len: slice length; for decoding, it's the latent tile size mulitplied by temporal upsampling ratio. default is 4*8 for moviegen tae. + overlap_len: overlap between slices. for decoding, default is 2*8 for movie gen tae + + Note that the length of the last slice can be shorter than slice_len. + + Returns: + numpy if use_numpy is True, otherwise return ms.Tensor + """ + B, C, in_len, H, W = x.shape + num_slices = math.ceil(in_len / slice_len) + stride = slice_len - overlap_len + + out_len = ((num_slices-1) * slice_len) - (num_slices - 2) * overlap_len + last_slice_len = in_len - (num_slices -1 ) * slice_len + out_len += last_slice_len - overlap_len + ''' + if use_numpy: + x = x.asnumpy() + out_tensor = np.zeros((B, C, out_len, H, W), np.float32) + out_cnt = np.zeros((B, C, out_len, H, W), np.float32) + ''' + # TODO: can it work in graph mode? + out_tensor = ops.zeros((B, C, out_len, H, W), ms.float32) + out_cnt = ops.zeros((B, C, out_len, H, W), ms.float32) + + # import pdb; pdb.set_trace() + for i in range(num_slices): + # get the slice form the concatnated latent + cur_slice = x[:, :, i*slice_len:(i+1)*slice_len] + cur_len = cur_slice.shape[2] + + # put the slice into the right position of output tensor + start = i * stride + out_tensor[:, :, start:start+cur_len] += cur_slice + out_cnt[:, :, start:start+cur_len] += 1 + + out_tensor = out_tensor / out_cnt + + return out_tensor + + def construct(self, x: ms.Tensor) -> ms.Tensor: """ video reconstruction x: (b c t h w) """ + if self.use_tile: + z, posterior_mean, posterior_logvar = self.encode_with_tile(x) + else: + posterior_mean, posterior_logvar = self._encode(x) + z = self.sample(posterior_mean, posterior_logvar) - posterior_mean, posterior_logvar = self._encode(x) - z = self.sample(posterior_mean, posterior_logvar) - recons = self.decode(z) + if self.use_tile: + recons = self.decode_with_tile(z) + else: + recons = self.decode(z) if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): + print("WARNING: discard suprious frames, ", recons.shape[-3], x.shape[-3]) recons = recons[:, :, :x.shape[-3], :, :] return recons, z, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/scripts/inference_vae.py b/examples/movie_gen/scripts/inference_vae.py index 2885790f09..81f02496c5 100644 --- a/examples/movie_gen/scripts/inference_vae.py +++ b/examples/movie_gen/scripts/inference_vae.py @@ -101,6 +101,7 @@ def main(args): # build model model = TemporalAutoencoder( pretrained=args.ckpt_path, + use_tile=args.enable_tile, ) model.set_train(False) @@ -285,6 +286,7 @@ def parse_args(): parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images") parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae") parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution") + parser.add_argument("--enable_tile", default=False, type=str2bool, help="enable temporal tiling with linear blending for decoder") parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") parser.add_argument( "--mixed_strategy", diff --git a/examples/movie_gen/scripts/run/run_train_tae.sh b/examples/movie_gen/scripts/run/run_train_tae.sh index 8396e3e4da..f056ef4566 100755 --- a/examples/movie_gen/scripts/run/run_train_tae.sh +++ b/examples/movie_gen/scripts/run/run_train_tae.sh @@ -11,7 +11,7 @@ export MS_DEV_ENABLE_KERNEL_PACKET=on # log level export GLOG_v=2 -output_dir=outputs/train_tae_1p_sd3.5vaeInit_noOpl +output_dir=outputs/debug_train_tae_1p_sd3.5vaeInit_noOpl python scripts/train_tae.py \ --mode=0 \ diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index f3e6d47c21..e847228be5 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -22,6 +22,7 @@ from mg.datasets.tae_dataset import create_dataloader from mg.models.tae.losses import GeneratorWithLoss from mg.models.tae.tae import TemporalAutoencoder +from mg.models.tae.modules import SpatialUpsample, SpatialDownsample, TemporalUpsample, TemporalDownsample from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network @@ -203,11 +204,14 @@ def main(args): # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. if args.dtype in ["fp16", "bf16"]: dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] + # TODO: check ResizeNearest bf16 support for ms>2.3.1 ae = auto_mixed_precision( ae, args.amp_level, dtype, - custom_fp32_cells=[nn.GroupNorm] if args.vae_keep_gn_fp32 else [], + custom_fp32_cells= [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] + \ + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []), + # custom_fp32_cells=[nn.GroupNorm, SpatialUpsample] if args.vae_keep_gn_fp32 else [SpatialUpsample], ) # 4. build net with loss diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 711b02476c..3157b927cd 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -249,8 +249,27 @@ def test_sd3d5_vae(): print(recons.sum()) +def test_tae_tile(): + tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, + encode_tile=32, decode_tile=32, decode_overlap=16) + + # in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) + # in_shape = (B, C, T, H, W) = (1, 3, 96, 32, 32) + in_shape = (B, C, T, H, W) = (1, 3, 64+16, 64, 64) + + x = np.random.normal(size=in_shape).astype(np.float32) + x = ms.Tensor(x) + + y = tae(x) + + print(y[0].shape) + + # check correctness of blend + + + if __name__ == "__main__": - ms.set_context(mode=1) + ms.set_context(mode=0) # test_conv25d() # test_resnetblock() @@ -265,6 +284,7 @@ def test_sd3d5_vae(): # test_decoder() # test_tae_encode() # test_tae_decode() - test_tae_rec() + # test_tae_rec() + test_tae_tile() # test_sd3d5_vae() From 640871ad5122d1677e220cf48bec63f237d34c33 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 22 Nov 2024 20:30:39 +0800 Subject: [PATCH 061/122] rm comments --- examples/movie_gen/mg/models/tae/tae.py | 7 +------ examples/movie_gen/scripts/args_train_tae.py | 6 ++++++ examples/movie_gen/scripts/train_tae.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 81ad7434de..dbcb7f80de 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -168,7 +168,6 @@ def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: z_out, mean, logvar = self.encode(x[:, :, :tf]) - print('D--: use encode tile ', tf) # import pdb; pdb.set_trace() # for i in range(tf - of, x.shape[2], tf): for i in range(tf, x.shape[2], tf): @@ -191,10 +190,6 @@ def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: # ms graph mode requires an init x_out x_out = self.decode(z[:, :, :tl]) - print('D--: use decode tile ', tl, ol) - # import pdb; pdb.set_trace() - # FIXME: the end idx is not right - # for i in range(stride, z.shape[2], stride): visited = tl i = stride # start position while visited < in_len: @@ -275,7 +270,7 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: recons = self.decode(z) if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): - print("WARNING: discard suprious frames, ", recons.shape[-3], x.shape[-3]) + # print("WARNING: discard suprious frames, ", recons.shape[-3], x.shape[-3]) recons = recons[:, :, :x.shape[-3], :, :] return recons, z, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/scripts/args_train_tae.py b/examples/movie_gen/scripts/args_train_tae.py index d2962498c6..e06e73a068 100644 --- a/examples/movie_gen/scripts/args_train_tae.py +++ b/examples/movie_gen/scripts/args_train_tae.py @@ -179,6 +179,12 @@ def parse_train_args(parser): type=str2bool, help="whether keep GroupNorm in fp32.", ) + parser.add_argument( + "--vae_keep_updown_fp32", + default=True, + type=str2bool, + help="whether keep spatial/temporal upsample and downsample in fp32.", + ) parser.add_argument( "--global_bf16", default=False, diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/movie_gen/scripts/train_tae.py index e847228be5..44152782b6 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/movie_gen/scripts/train_tae.py @@ -209,7 +209,7 @@ def main(args): ae, args.amp_level, dtype, - custom_fp32_cells= [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] + \ + custom_fp32_cells= [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] if args.vae_keep_updown_fp32 else [] + \ ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []), # custom_fp32_cells=[nn.GroupNorm, SpatialUpsample] if args.vae_keep_gn_fp32 else [SpatialUpsample], ) From 1cd636e4c3cefbb63f6ca9f74ae3a670c4a9c011 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 22 Nov 2024 22:36:39 +0800 Subject: [PATCH 062/122] clean code --- examples/movie_gen/mg/models/tae/tae.py | 99 +++++++++++++------ examples/movie_gen/scripts/inference_vae.py | 6 +- examples/movie_gen/tests/test_tae.py | 21 +++- .../opensora_hpcai/scripts/inference_vae.py | 4 +- 4 files changed, 90 insertions(+), 40 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index dbcb7f80de..5812f6a32f 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -4,8 +4,6 @@ from .modules import Conv2_5d, Encoder, Decoder -# TODO: set z_channels to 16 - SDXL_CONFIG = { "double_z": True, "z_channels": 16, @@ -98,18 +96,14 @@ def __init__( if use_tile: assert (self.encode_tile % self.time_compress == 0) and (self.encode_tile > 0), f'num tile frames should be divisable by {self.time_compress} and non-zero' assert self.encode_overlap % self.time_compress == 0, f'overlap frames should be divisable by {self.time_compress}' - # TODO: support encode overlap assert self.encode_overlap == 0, 'not supported' self.decode_tile = decode_tile self.decode_overlap = decode_overlap - - + + # recompute if use_recompute: - # TODO: uncomment if OOM self.recompute(self.encoder) - # self.recompute(self.quant_conv) - # self.recompute(self.post_quant_conv) self.recompute(self.decoder) if pretrained is not None: @@ -145,40 +139,77 @@ def sample(self, mean, logvar): return z def encode(self, x: ms.Tensor) -> ms.Tensor: - # embedding, get latent representation z + r""" + Encode a batch of videos into latents + + Args: + x (Tensor): input video tensor of shape (b c t h w) + + Returns: + z (Tensor): the sampled latent tensor, shape (b z t' h' w') + posterior_mean (Tensor): mean of latent distribution + posterior_logvar (Tensor): logvar of latent distribution + """ + posterior_mean, posterior_logvar = self._encode(x) if self.sample_deterministic: return posterior_mean z = self.sample(posterior_mean, posterior_logvar) - # TODO: align interface return z, posterior_mean, posterior_logvar def decode(self, z: ms.Tensor) -> ms.Tensor: + r""" + Decode a batch of latents to videos + + Args: + x (Tensor): input latent tensor of shape (b z t' h' w') + + Returns: + z (Tensor): the decoded videos of shape (b c t h w) + """ + if self.use_post_quant_conv: z = self.post_quant_conv(z) dec = self.decoder(z) - # TODO: consider decoding latent without knowing encode input frame length return dec def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: + r""" + Encode a batch of videos into latents with tiling + + Args: + x (Tensor): input video tensor of shape (b c t h w) + + Returns: + z (Tensor): the sampled latent tensor, shape (b z t/8 h/8 w/8) + posterior_mean (Tensor): mean of latent distribution + posterior_logvar (Tensor): logvar of latent distribution + """ + tf = self.encode_tile - # of = self.encode_overlap z_out, mean, logvar = self.encode(x[:, :, :tf]) - # import pdb; pdb.set_trace() - # for i in range(tf - of, x.shape[2], tf): for i in range(tf, x.shape[2], tf): z_cur, mean, logvar = self.encode(x[:, :, i : i + tf]) z_out = ops.cat((z_out, z_cur), axis=2) - # TODO: merge mean, logvar for different slices? + # TODO: merge mean, logvar for different slices for training tae with tile return z_out, mean, logvar def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: - # + r""" + Decode a batch of latents to videos with tiling + + Args: + x (Tensor): input latent tensor of shape (b z t' h' w') + + Returns: + z (Tensor): the decoded videos of shape (b c t h w) + """ + tl = self.decode_tile // self.time_compress # tile len ol = self.decode_overlap // self.time_compress # overlap len stride = tl - ol @@ -205,9 +236,9 @@ def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: return x_out - def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16, use_numpy=False): + def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): """ - Use with decode_with_tile + Blend decoded latent slices, used with decode_with_tile Args: x: (b c t h w) is the concatenation of the decoded slices, @@ -217,8 +248,9 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16, use_numpy=Fal Note that the length of the last slice can be shorter than slice_len. Returns: - numpy if use_numpy is True, otherwise return ms.Tensor + ms.Tensor """ + B, C, in_len, H, W = x.shape num_slices = math.ceil(in_len / slice_len) stride = slice_len - overlap_len @@ -226,17 +258,11 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16, use_numpy=Fal out_len = ((num_slices-1) * slice_len) - (num_slices - 2) * overlap_len last_slice_len = in_len - (num_slices -1 ) * slice_len out_len += last_slice_len - overlap_len - ''' - if use_numpy: - x = x.asnumpy() - out_tensor = np.zeros((B, C, out_len, H, W), np.float32) - out_cnt = np.zeros((B, C, out_len, H, W), np.float32) - ''' - # TODO: can it work in graph mode? + out_tensor = ops.zeros((B, C, out_len, H, W), ms.float32) out_cnt = ops.zeros((B, C, out_len, H, W), ms.float32) - - # import pdb; pdb.set_trace() + + import pdb; pdb.set_trace() for i in range(num_slices): # get the slice form the concatnated latent cur_slice = x[:, :, i*slice_len:(i+1)*slice_len] @@ -253,11 +279,19 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16, use_numpy=Fal def construct(self, x: ms.Tensor) -> ms.Tensor: - """ - video reconstruction + r""" + Video reconstruction - x: (b c t h w) + Args: + x: a batch of videos of shape (b c t h w) + + Returns: + recons (Tensor): the reconstructed videos of shape (b c t h w) + z (Tensor): the latent tensor, shape (b z t' h' w') + posterior_mean (Tensor): mean of latent distribution + posterior_logvar (Tensor): logvar of latent distribution """ + if self.use_tile: z, posterior_mean, posterior_logvar = self.encode_with_tile(x) else: @@ -277,6 +311,7 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: def load_pretrained(self, ckpt_path:str): + if ckpt_path.endswith('safetensors'): # load vae parameters from safetensors into my mindspore model import safetensors @@ -299,6 +334,6 @@ def load_pretrained(self, ckpt_path:str): if param_not_load or ckpt_not_load: print(f"{param_not_load} in network is not loaded") print(f"{ckpt_not_load} in checkpoint is not loaded!") - print('tae checkpoint loaded') + print('TAE checkpoint loaded') diff --git a/examples/movie_gen/scripts/inference_vae.py b/examples/movie_gen/scripts/inference_vae.py index 81f02496c5..ba9fe34d6d 100644 --- a/examples/movie_gen/scripts/inference_vae.py +++ b/examples/movie_gen/scripts/inference_vae.py @@ -13,8 +13,9 @@ from mindspore import nn, ops -# mindone_dir = '/home/mindocr/yx/mindone' -mindone_dir = "/home_host/yx/mindone" + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_dir) @@ -203,7 +204,6 @@ def main(args): logger.info(f"cur ssim: {ssim_cur[-1]:.4f}, mean ssim:{mean_ssim/num_samples:.4f}") if args.eval_loss: - print("D--: ", x.shape, recons.shape) recon_loss = np.abs((x - recons).asnumpy()) t = x.shape[2] diff --git a/examples/movie_gen/tests/test_tae.py b/examples/movie_gen/tests/test_tae.py index 3157b927cd..a426ed3654 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/movie_gen/tests/test_tae.py @@ -248,14 +248,27 @@ def test_sd3d5_vae(): print(recons.sum()) +def test_blend(): + ms.set_context(mode=1) + tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, + encode_tile=32, decode_tile=32, decode_overlap=16) + + in_shape = (B, C, T, H, W) = (1, 1, 12, 1, 1) + x = np.random.normal(size=in_shape).astype(np.float32) + x = ms.Tensor(x) + + out = tae.blend_slices(x, slice_len=4, overlap_len=2) + + print(out.shape) + def test_tae_tile(): tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, encode_tile=32, decode_tile=32, decode_overlap=16) # in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) - # in_shape = (B, C, T, H, W) = (1, 3, 96, 32, 32) - in_shape = (B, C, T, H, W) = (1, 3, 64+16, 64, 64) + in_shape = (B, C, T, H, W) = (1, 3, 96, 32, 32) + # in_shape = (B, C, T, H, W) = (1, 3, 64+16, 64, 64) x = np.random.normal(size=in_shape).astype(np.float32) x = ms.Tensor(x) @@ -265,6 +278,7 @@ def test_tae_tile(): print(y[0].shape) # check correctness of blend + @@ -285,6 +299,7 @@ def test_tae_tile(): # test_tae_encode() # test_tae_decode() # test_tae_rec() - test_tae_tile() + # test_tae_tile() + test_blend() # test_sd3d5_vae() diff --git a/examples/opensora_hpcai/scripts/inference_vae.py b/examples/opensora_hpcai/scripts/inference_vae.py index 8fc76cad14..e1c0fe96d3 100644 --- a/examples/opensora_hpcai/scripts/inference_vae.py +++ b/examples/opensora_hpcai/scripts/inference_vae.py @@ -13,8 +13,8 @@ from mindspore import nn, ops -# mindone_dir = '/home/mindocr/yx/mindone' -mindone_dir = "/home_host/yx/mindone" +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_dir) # from ae.models.lpips import LPIPS From cae87f579bd9da273abac8ba502d6c685b01dff6 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 22 Nov 2024 23:12:40 +0800 Subject: [PATCH 063/122] draft readme and update decode --- examples/movie_gen/mg/models/tae/tae.py | 13 ++++++++++--- .../{inflate_sd3.5_vae.py => inflate_vae_to_tae.py} | 0 2 files changed, 10 insertions(+), 3 deletions(-) rename examples/movie_gen/tools/{inflate_sd3.5_vae.py => inflate_vae_to_tae.py} (100%) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 5812f6a32f..318a7bc7cd 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -158,12 +158,13 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: return z, posterior_mean, posterior_logvar - def decode(self, z: ms.Tensor) -> ms.Tensor: + def decode(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: r""" Decode a batch of latents to videos Args: x (Tensor): input latent tensor of shape (b z t' h' w') + target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. Returns: z (Tensor): the decoded videos of shape (b c t h w) @@ -173,6 +174,9 @@ def decode(self, z: ms.Tensor) -> ms.Tensor: z = self.post_quant_conv(z) dec = self.decoder(z) + if target_num_frames is not None: + dec = dec[:, :, :target_num_frames] + return dec def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: @@ -199,12 +203,13 @@ def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: # TODO: merge mean, logvar for different slices for training tae with tile return z_out, mean, logvar - def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: + def decode_with_tile(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: r""" Decode a batch of latents to videos with tiling Args: x (Tensor): input latent tensor of shape (b z t' h' w') + target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. Returns: z (Tensor): the decoded videos of shape (b c t h w) @@ -234,6 +239,9 @@ def decode_with_tile(self, z: ms.Tensor) -> ms.Tensor: if self.decode_overlap > 0: x_out = self.blend_slices(x_out, self.decode_tile, self.decode_overlap) + if target_num_frames is not None: + x_out = x_out[:, :, :target_num_frames] + return x_out def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): @@ -304,7 +312,6 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: recons = self.decode(z) if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): - # print("WARNING: discard suprious frames, ", recons.shape[-3], x.shape[-3]) recons = recons[:, :, :x.shape[-3], :, :] return recons, z, posterior_mean, posterior_logvar diff --git a/examples/movie_gen/tools/inflate_sd3.5_vae.py b/examples/movie_gen/tools/inflate_vae_to_tae.py similarity index 100% rename from examples/movie_gen/tools/inflate_sd3.5_vae.py rename to examples/movie_gen/tools/inflate_vae_to_tae.py From f64fa00981b0c166c8408b00f262c8b6eb27b33d Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 22 Nov 2024 23:13:28 +0800 Subject: [PATCH 064/122] add config --- .../configs/tae/train/mixed_256x256x16.yaml | 46 +++++++++++++++++++ .../{video_ft.yaml => mixed_256x256x32.yaml} | 18 ++++---- 2 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml rename examples/movie_gen/configs/tae/train/{video_ft.yaml => mixed_256x256x32.yaml} (77%) diff --git a/examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml b/examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml new file mode 100644 index 0000000000..af58e62ee1 --- /dev/null +++ b/examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml @@ -0,0 +1,46 @@ +# model +pretrained_model_path: "models/tae_vae2d.ckpt" + +# loss +perceptual_loss_weight: 1.0 +kl_loss_weight: 1.e-6 +use_outlier_penalty_loss: False # OPL bring no benefit in our experiments +mixed_strategy: "mixed_video_image" +mixed_image_ratio: 0.2 + +# data +dataset_name: "video" +csv_path: "../videocomposer/datasets/webvid5_copy.csv" +video_folder: "../videocomposer/datasets/webvid5" +frame_stride: 1 +num_frames: 16 +image_size: 256 +crop_size: 256 +# flip: True + +# training recipe +seed: 42 +use_discriminator: False +batch_size: 1 +clip_grad: True +max_grad_norm: 1.0 +start_learning_rate: 1.e-5 +scale_lr: False +weight_decay: 0. + +dtype: "fp32" +amp_level: "O0" +use_recompute: False + +epochs: 2000 +ckpt_save_interval: 50 +init_loss_scale: 1024. +loss_scaler_type: dynamic + +scheduler: "constant" +use_ema: False + +output_path: "outputs/tae_train" + +# ms settting +jit_level: O0 diff --git a/examples/movie_gen/configs/tae/train/video_ft.yaml b/examples/movie_gen/configs/tae/train/mixed_256x256x32.yaml similarity index 77% rename from examples/movie_gen/configs/tae/train/video_ft.yaml rename to examples/movie_gen/configs/tae/train/mixed_256x256x32.yaml index ac78edb15d..990ec83d72 100644 --- a/examples/movie_gen/configs/tae/train/video_ft.yaml +++ b/examples/movie_gen/configs/tae/train/mixed_256x256x32.yaml @@ -4,7 +4,7 @@ pretrained_model_path: "models/tae_vae2d.ckpt" # loss perceptual_loss_weight: 1.0 kl_loss_weight: 1.e-6 -use_outlier_penalty_loss: True +use_outlier_penalty_loss: False # OPL bring no benefit in our experiments mixed_strategy: "mixed_video_image" mixed_image_ratio: 0.2 @@ -13,27 +13,27 @@ dataset_name: "video" csv_path: "../videocomposer/datasets/webvid5_copy.csv" video_folder: "../videocomposer/datasets/webvid5" frame_stride: 1 -num_frames: 16 +num_frames: 32 image_size: 256 - -# micro_frame_size: 17 -# micro_batch_size: 4 +crop_size: 256 # flip: True # training recipe seed: 42 use_discriminator: False -dtype: "bf16" batch_size: 1 clip_grad: True max_grad_norm: 1.0 start_learning_rate: 1.e-5 scale_lr: False weight_decay: 0. + +dtype: "bf16" +amp_level: "O2" # reduce memory cost use_recompute: True -epochs: 400 -ckpt_save_interval: 100 +epochs: 2000 +ckpt_save_interval: 50 init_loss_scale: 1024. loss_scaler_type: dynamic @@ -43,4 +43,4 @@ use_ema: False output_path: "outputs/tae_train" # ms settting -jit_level: O1 +jit_level: O0 From 961514b065088d8f83588b680a566dc62c0d066b Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 22 Nov 2024 23:39:47 +0800 Subject: [PATCH 065/122] add readme draft --- examples/movie_gen/README.md | 102 ++++++++++++++++++++++++ examples/movie_gen/mg/models/tae/tae.py | 33 ++++---- 2 files changed, 118 insertions(+), 17 deletions(-) create mode 100644 examples/movie_gen/README.md diff --git a/examples/movie_gen/README.md b/examples/movie_gen/README.md new file mode 100644 index 0000000000..32510c2f30 --- /dev/null +++ b/examples/movie_gen/README.md @@ -0,0 +1,102 @@ +# Movie Gen Video + + +## TAE + +### Requirements + +ms2.3.1 + +### Prepare weights + +We use SD3.5 VAE to initialize the spatial layers of TAE, since both have a latent channel of 16. + +1. Download SD3.5 VAE from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae + +2. Convert VAE checkpoint for TAE loading + +```shell +python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt +``` + + +### Training + +```shell +output_dir=outputs/train_tae_256x256x16 + +python scripts/train_tae.py \ +--config configs/tae/train/mixed_256x256x16.yaml \ +--output_path=$output_dir \ +--csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_train.csv \ +--video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ + +``` + +OPL - outlier penality loss is found to be not beneficial in our experiment (PSNR decreased). Thus we set it to False by default. + +Change mixed_256x256x16.yaml to mixed_256x256x32.yaml for training on 32 frames. + + +#### Performance + +Train on 80 samples of mixkit-100 (train set), test on the other 20 samples (test set) + +256x256x16, 1p, FP32, 1.99 s/step, test set psnr 28.5 + +256x256x32, 1p, BF16, 2.49 s/step, test set psnr 28.3 + + +### Inference + + +#### Video Reconstruction + +```shell +python scripts/inference_vae.py \ +--ckpt_path /path/to/tae.ckpt \ +--batch_size 2 \ +--num_frames=16 \ +--image_size 256 \ +--csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_test.csv \ +--video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ +--enable_tile=False \ +``` + +#### Encoding video + +```python +from mg.models.tae.tae import TemporalAutoencoder, TAE_CONFIG + +# may set use_tile=True to save memory +tae = TemporalAutoencoder( + pretrained='/path/to/tae.ckpt', + use_tile=False, + ) + +# x - a batch of videos, shape (b c t h w) +z, _, _ = tae.encode(x) + + +# you may scale z by: +# z = TAE_CONFIG['scaling_factor'] * z + TAE_CONFIG['shift_factor'] + + +``` + +For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) + +#### Decoding video latent + +```python + +# if z is scaled, you should unscale at first: +# z = (z - TAE_CONFIG['shift_factor']) / TAE_CONFIG['scaling_factor'] + +# z - a batch of video latent, shape (b c t h w) +x = tae.decode(z) + +# for image decoding, set num_target_frames to discard the spurious frames +x = tae.decode(z, num_target_frames=1) +``` + diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 318a7bc7cd..2de0767006 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -87,7 +87,7 @@ def __init__( self.sample_deterministic = sample_deterministic self.discard_spurious_frames = True - + # tile self.encode_tile = encode_tile self.time_compress = 2**len(config['temporal_downsample_level']) # 8 @@ -100,7 +100,7 @@ def __init__( self.decode_tile = decode_tile self.decode_overlap = decode_overlap - + # recompute if use_recompute: self.recompute(self.encoder) @@ -155,13 +155,13 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: if self.sample_deterministic: return posterior_mean z = self.sample(posterior_mean, posterior_logvar) - + return z, posterior_mean, posterior_logvar def decode(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: r""" Decode a batch of latents to videos - + Args: x (Tensor): input latent tensor of shape (b z t' h' w') target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. @@ -199,14 +199,14 @@ def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: for i in range(tf, x.shape[2], tf): z_cur, mean, logvar = self.encode(x[:, :, i : i + tf]) z_out = ops.cat((z_out, z_cur), axis=2) - + # TODO: merge mean, logvar for different slices for training tae with tile return z_out, mean, logvar def decode_with_tile(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: r""" Decode a batch of latents to videos with tiling - + Args: x (Tensor): input latent tensor of shape (b z t' h' w') target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. @@ -222,11 +222,11 @@ def decode_with_tile(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tens num_slices = (in_len - tl) // stride + 1 if (in_len - tl) % stride != 0 and (in_len - tl) + stride < in_len: num_slices += 1 - + # ms graph mode requires an init x_out x_out = self.decode(z[:, :, :tl]) - - visited = tl + + visited = tl i = stride # start position while visited < in_len: x_cur = self.decode(z[:, :, i : i + tl]) @@ -249,7 +249,7 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): Blend decoded latent slices, used with decode_with_tile Args: - x: (b c t h w) is the concatenation of the decoded slices, + x: (b c t h w) is the concatenation of the decoded slices, slice_len: slice length; for decoding, it's the latent tile size mulitplied by temporal upsampling ratio. default is 4*8 for moviegen tae. overlap_len: overlap between slices. for decoding, default is 2*8 for movie gen tae @@ -263,28 +263,27 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): num_slices = math.ceil(in_len / slice_len) stride = slice_len - overlap_len - out_len = ((num_slices-1) * slice_len) - (num_slices - 2) * overlap_len + out_len = ((num_slices-1) * slice_len) - (num_slices - 2) * overlap_len last_slice_len = in_len - (num_slices -1 ) * slice_len out_len += last_slice_len - overlap_len out_tensor = ops.zeros((B, C, out_len, H, W), ms.float32) out_cnt = ops.zeros((B, C, out_len, H, W), ms.float32) - - import pdb; pdb.set_trace() + for i in range(num_slices): # get the slice form the concatnated latent cur_slice = x[:, :, i*slice_len:(i+1)*slice_len] cur_len = cur_slice.shape[2] - + # put the slice into the right position of output tensor - start = i * stride + start = i * stride out_tensor[:, :, start:start+cur_len] += cur_slice out_cnt[:, :, start:start+cur_len] += 1 out_tensor = out_tensor / out_cnt return out_tensor - + def construct(self, x: ms.Tensor) -> ms.Tensor: r""" @@ -294,7 +293,7 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: x: a batch of videos of shape (b c t h w) Returns: - recons (Tensor): the reconstructed videos of shape (b c t h w) + recons (Tensor): the reconstructed videos of shape (b c t h w) z (Tensor): the latent tensor, shape (b z t' h' w') posterior_mean (Tensor): mean of latent distribution posterior_logvar (Tensor): logvar of latent distribution From f0836209db62dfbb7e77d38b0bd7b370c57c3cd8 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:34:44 +0800 Subject: [PATCH 066/122] add TAE to Movie Gen --- examples/movie_gen/mg/models/tae/tae.py | 123 ++++++++++-------- .../inference/moviegen_t2i_256x256.yaml | 7 +- .../configs/train/stage1_t2i_256x256.yaml | 7 +- .../configs/train/stage2_t2iv_256x256.yaml | 7 +- examples/moviegen/inference.py | 44 ++++--- .../moviegen/pipelines/infer_pipeline.py | 18 +-- .../moviegen/pipelines/train_pipeline.py | 2 +- examples/moviegen/train.py | 46 ++++--- 8 files changed, 143 insertions(+), 111 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 2de0767006..9556fa3a7a 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -1,8 +1,10 @@ -import mindspore as ms import math +from typing import Optional, Tuple + +import mindspore as ms from mindspore import nn, ops -from .modules import Conv2_5d, Encoder, Decoder +from .modules import Conv2_5d, Decoder, Encoder SDXL_CONFIG = { "double_z": True, @@ -16,7 +18,7 @@ "attn_resolutions": [], "dropout": 0.0, "use_post_quant_conv": True, - "use_quant_conv": True + "use_quant_conv": True, } # modify based on SD3d5_CONFIG @@ -38,7 +40,6 @@ "attn_type": "vanilla", "temporal_downsample_level": [0, 1, 2], "temporal_upsample_level": [3, 2, 1], - } @@ -54,29 +55,30 @@ class TemporalAutoencoder(nn.Cell): def __init__( self, config: dict = TAE_CONFIG, - pretrained: str = None, - use_recompute: bool=False, - sample_deterministic: bool=False, - use_tile: bool=False, - encode_tile: int=32, - encode_overlap: int=0, - decode_tile: int=32, - decode_overlap: int=16, + pretrained: Optional[str] = None, + use_recompute: bool = False, + sample_deterministic: bool = False, + use_tile: bool = False, + encode_tile: int = 32, + encode_overlap: int = 0, + decode_tile: int = 32, + decode_overlap: int = 16, ): super().__init__() + self.out_channels = config["z_channels"] # encoder self.encoder = Encoder(**config) # quant and post quant - embed_dim = config['z_channels'] - if config['use_quant_conv']: + embed_dim = config["z_channels"] + if config["use_quant_conv"]: self.quant_conv = Conv2_5d(2 * embed_dim, 2 * embed_dim, 1, pad_mode="valid", has_bias=True) - if config['use_post_quant_conv']: + if config["use_post_quant_conv"]: self.post_quant_conv = Conv2_5d(embed_dim, embed_dim, 1, pad_mode="valid", has_bias=True) - self.use_quant_conv = config['use_quant_conv'] - self.use_post_quant_conv = config['use_post_quant_conv'] + self.use_quant_conv = config["use_quant_conv"] + self.use_post_quant_conv = config["use_post_quant_conv"] # decoder self.decoder = Decoder(**config) @@ -90,13 +92,17 @@ def __init__( # tile self.encode_tile = encode_tile - self.time_compress = 2**len(config['temporal_downsample_level']) # 8 + self.time_compress = 2 ** len(config["temporal_downsample_level"]) # 8 self.encode_overlap = encode_overlap self.use_tile = use_tile if use_tile: - assert (self.encode_tile % self.time_compress == 0) and (self.encode_tile > 0), f'num tile frames should be divisable by {self.time_compress} and non-zero' - assert self.encode_overlap % self.time_compress == 0, f'overlap frames should be divisable by {self.time_compress}' - assert self.encode_overlap == 0, 'not supported' + assert (self.encode_tile % self.time_compress == 0) and ( + self.encode_tile > 0 + ), f"num tile frames should be divisable by {self.time_compress} and non-zero" + assert ( + self.encode_overlap % self.time_compress == 0 + ), f"overlap frames should be divisable by {self.time_compress}" + assert self.encode_overlap == 0, "not supported" self.decode_tile = decode_tile self.decode_overlap = decode_overlap @@ -106,10 +112,9 @@ def __init__( self.recompute(self.encoder) self.recompute(self.decoder) - if pretrained is not None: + if pretrained: self.load_pretrained(pretrained) - def recompute(self, b): if not b._has_config_recompute: b.recompute() @@ -118,7 +123,6 @@ def recompute(self, b): else: b.add_flags(output_no_recompute=True) - def _encode(self, x): # return latent distribution, N(mean, logvar) h = self.encoder(x) @@ -138,7 +142,7 @@ def sample(self, mean, logvar): return z - def encode(self, x: ms.Tensor) -> ms.Tensor: + def encode(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: r""" Encode a batch of videos into latents @@ -150,7 +154,12 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: posterior_mean (Tensor): mean of latent distribution posterior_logvar (Tensor): logvar of latent distribution """ + if self.use_tile: + return self.encode_with_tile(x) + else: + return self._encode_no_tile(x) + def _encode_no_tile(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: posterior_mean, posterior_logvar = self._encode(x) if self.sample_deterministic: return posterior_mean @@ -158,18 +167,25 @@ def encode(self, x: ms.Tensor) -> ms.Tensor: return z, posterior_mean, posterior_logvar - def decode(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: + def decode(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor: r""" Decode a batch of latents to videos Args: - x (Tensor): input latent tensor of shape (b z t' h' w') - target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. + z (Tensor): input latent tensor of shape (b z t' h' w') + target_num_frames (int): target number of frames for output. + If None, all the decoded frames will be reserved. + Otherwise, the previous this number of frames will be reserved. Returns: z (Tensor): the decoded videos of shape (b c t h w) """ + if self.use_tile: + return self.decode_with_tile(z, target_num_frames) + else: + return self._decode_no_tile(z, target_num_frames) + def _decode_no_tile(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor: if self.use_post_quant_conv: z = self.post_quant_conv(z) dec = self.decoder(z) @@ -179,7 +195,7 @@ def decode(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: return dec - def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: + def encode_with_tile(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: r""" Encode a batch of videos into latents with tiling @@ -194,22 +210,24 @@ def encode_with_tile(self, x: ms.Tensor) -> ms.Tensor: tf = self.encode_tile - z_out, mean, logvar = self.encode(x[:, :, :tf]) + z_out, mean, logvar = self._encode_no_tile(x[:, :, :tf]) for i in range(tf, x.shape[2], tf): - z_cur, mean, logvar = self.encode(x[:, :, i : i + tf]) + z_cur, mean, logvar = self._encode_no_tile(x[:, :, i : i + tf]) z_out = ops.cat((z_out, z_cur), axis=2) # TODO: merge mean, logvar for different slices for training tae with tile return z_out, mean, logvar - def decode_with_tile(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tensor: + def decode_with_tile(self, z: ms.Tensor, target_num_frames: int = None) -> ms.Tensor: r""" Decode a batch of latents to videos with tiling Args: x (Tensor): input latent tensor of shape (b z t' h' w') - target_num_frames (int): target number of frames for output, if None, all the decoded frames will be reserved. Otherwise, the previous this number of frames will be reserved. + target_num_frames (int): target number of frames for output. + If None, all the decoded frames will be reserved. + Otherwise, the previous this number of frames will be reserved. Returns: z (Tensor): the decoded videos of shape (b c t h w) @@ -221,15 +239,15 @@ def decode_with_tile(self, z: ms.Tensor, target_num_frames: int=None) -> ms.Tens in_len = z.shape[2] num_slices = (in_len - tl) // stride + 1 if (in_len - tl) % stride != 0 and (in_len - tl) + stride < in_len: - num_slices += 1 + num_slices += 1 # ms graph mode requires an init x_out - x_out = self.decode(z[:, :, :tl]) + x_out = self._decode_no_tile(z[:, :, :tl]) visited = tl i = stride # start position while visited < in_len: - x_cur = self.decode(z[:, :, i : i + tl]) + x_cur = self._decode_no_tile(z[:, :, i : i + tl]) x_out = ops.cat((x_out, x_cur), axis=2) visited = i + tl @@ -263,8 +281,8 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): num_slices = math.ceil(in_len / slice_len) stride = slice_len - overlap_len - out_len = ((num_slices-1) * slice_len) - (num_slices - 2) * overlap_len - last_slice_len = in_len - (num_slices -1 ) * slice_len + out_len = ((num_slices - 1) * slice_len) - (num_slices - 2) * overlap_len + last_slice_len = in_len - (num_slices - 1) * slice_len out_len += last_slice_len - overlap_len out_tensor = ops.zeros((B, C, out_len, H, W), ms.float32) @@ -272,20 +290,19 @@ def blend_slices(self, x: ms.Tensor, slice_len=32, overlap_len=16): for i in range(num_slices): # get the slice form the concatnated latent - cur_slice = x[:, :, i*slice_len:(i+1)*slice_len] + cur_slice = x[:, :, i * slice_len : (i + 1) * slice_len] cur_len = cur_slice.shape[2] # put the slice into the right position of output tensor start = i * stride - out_tensor[:, :, start:start+cur_len] += cur_slice - out_cnt[:, :, start:start+cur_len] += 1 + out_tensor[:, :, start : start + cur_len] += cur_slice + out_cnt[:, :, start : start + cur_len] += 1 out_tensor = out_tensor / out_cnt return out_tensor - - def construct(self, x: ms.Tensor) -> ms.Tensor: + def construct(self, x: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: r""" Video reconstruction @@ -305,22 +322,18 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: posterior_mean, posterior_logvar = self._encode(x) z = self.sample(posterior_mean, posterior_logvar) - if self.use_tile: - recons = self.decode_with_tile(z) - else: - recons = self.decode(z) + recons = self.decode(z) if self.discard_spurious_frames and (recons.shape[-3] != x.shape[-3]): - recons = recons[:, :, :x.shape[-3], :, :] + recons = recons[:, :, : x.shape[-3], :, :] return recons, z, posterior_mean, posterior_logvar - - def load_pretrained(self, ckpt_path:str): - - if ckpt_path.endswith('safetensors'): + def load_pretrained(self, ckpt_path: str): + if ckpt_path.endswith("safetensors"): # load vae parameters from safetensors into my mindspore model import safetensors + ckpt = safetensors.safe_open(ckpt_path, framework="pt") state_dict = {} for key in ckpt.keys(): @@ -341,5 +354,9 @@ def load_pretrained(self, ckpt_path:str): print(f"{param_not_load} in network is not loaded") print(f"{ckpt_not_load} in checkpoint is not loaded!") - print('TAE checkpoint loaded') + print("TAE checkpoint loaded") + @staticmethod + def get_latent_size(input_size: Tuple[int, int, int]) -> Tuple[int, int, int]: + # FIXME: validate + return max(input_size[0] // 8, 1), input_size[1] // 8, input_size[2] // 8 diff --git a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml index 803e28da75..2aaf76293f 100644 --- a/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml +++ b/examples/moviegen/configs/inference/moviegen_t2i_256x256.yaml @@ -11,9 +11,10 @@ model: enable_flash_attention: True dtype: bf16 -vae: - ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt - dtype: fp16 +tae: + pretrained: "" + use_tile: True + dtype: bf16 # Inference parameters num_sampling_steps: 50 diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index a41f6b141e..0027c5391e 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -12,9 +12,10 @@ model: recompute: True dtype: bf16 -vae: - ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt - dtype: fp16 +tae: + pretrained: "" + use_tile: True + dtype: bf16 dataset: csv_path: CSV_PATH diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index e262e898e2..e26f2a2f60 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -12,9 +12,10 @@ model: recompute: True dtype: bf16 -vae: - ckpt_path: models/OpenSora-VAE-v1.2/model.ckpt - dtype: fp16 +tae: + pretrained: "" + use_tile: True + dtype: bf16 dataset: csv_path: CSV_PATH diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 7c80206e77..a903447668 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -24,9 +24,10 @@ from mindone.utils import init_train_env, set_logger from mindone.visualize.videos import save_videos -# TODO: remove when VAE is added to the project -sys.path.append(os.path.join(__dir__, "../opensora_hpcai/")) -from opensora.models.vae.vae import OpenSoraVAE_V1_2 +# TODO: remove when TAE is added to the project +sys.path.append(os.path.join(__dir__, "../movie_gen/")) +from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample +from mg.models.tae.tae import TemporalAutoencoder logger = logging.getLogger(__name__) @@ -69,21 +70,26 @@ def main(args): ul2_emb, metaclip_emb, byt5_emb = prepare_captions(**args.text_emb, rank_id=rank_id, device_num=device_num) # 2. model initiate and weight loading - # 2.1 vae - logger.info("vae init") - vae_args = args.vae.as_dict() - vae_dtype = vae_args.pop("dtype") - vae = OpenSoraVAE_V1_2(**vae_args).set_train(False) - if vae_dtype != "fp32": + # 2.1 tae + logger.info("TAE init") + tae_args = args.tae.as_dict() + tae_dtype = tae_args.pop("dtype") + tae = TemporalAutoencoder(**tae_args).set_train(False) + if tae_dtype != "fp32": # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative - amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=MODEL_DTYPE[vae_dtype]) + amp.custom_mixed_precision( + tae, + black_list=amp.get_black_list() + + [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample, nn.GroupNorm], + dtype=MODEL_DTYPE[tae_dtype], + ) img_h, img_w = args.image_size if isinstance(args.image_size, list) else (args.image_size, args.image_size) num_frames = args.num_frames - latent_size = vae.get_latent_size((num_frames, img_h, img_w)) + latent_size = tae.get_latent_size((num_frames, img_h, img_w)) # 2.2 Llama 3 - model = init_model(in_channels=vae.out_channels, **args.model).set_train(False) + model = init_model(in_channels=tae.out_channels, **args.model).set_train(False) # 2.3 text embeddings prompt_prefix = [os.path.basename(emb)[:-4] for emb in ul2_emb] @@ -95,7 +101,7 @@ def main(args): # 3. build inference pipeline pipeline = InferPipeline( model, - vae, + tae, latent_size, scale_factor=args.scale_factor, # FIXME: refactor guidance_scale=args.guidance_scale, @@ -111,7 +117,7 @@ def main(args): f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}", f"Num of captions: {num_prompts}", f"Model dtype: {args.model.dtype}", - f"VAE dtype: {vae_dtype}", + f"TAE dtype: {tae_dtype}", f"Image size: {(img_h, img_w)}", f"Num frames: {num_frames}", f"Sampling steps {args.num_sampling_steps}", @@ -167,13 +173,13 @@ def main(args): ) parser.add_function_arguments(init_train_env, "env") parser.add_function_arguments(init_model, "model", skip={"in_channels"}) - vae_group = parser.add_argument_group("VAE parameters") - vae_group.add_function_arguments(OpenSoraVAE_V1_2, "vae", fail_untyped=False) - vae_group.add_argument( - "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." + tae_group = parser.add_argument_group("TAE parameters") + parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) + parser.add_argument( + "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." ) infer_group = parser.add_argument_group("Inference parameters") - infer_group.add_class_arguments(InferPipeline, skip={"model", "vae", "latent_size"}, instantiate=False) + infer_group.add_class_arguments(InferPipeline, skip={"model", "tae", "latent_size"}, instantiate=False) infer_group.add_argument("--image_size", type=int, nargs="+", help="Output video size") infer_group.add_argument("--num_frames", type=int, default=17, help="number of frames") infer_group.add_argument("--fps", type=int, default=16, help="FPS in the saved video") diff --git a/examples/moviegen/moviegen/pipelines/infer_pipeline.py b/examples/moviegen/moviegen/pipelines/infer_pipeline.py index e18e09c0e2..cf419fd014 100644 --- a/examples/moviegen/moviegen/pipelines/infer_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/infer_pipeline.py @@ -15,8 +15,8 @@ class InferPipeline: Args: model (nn.Cell): A noise prediction model to denoise the encoded image latents. - vae (nn.Cell): Variational Auto-Encoder (VAE) Model to encode and decode images or videos to and from latent representations. - scale_factor (float): scale_factor for vae. + tae (nn.Cell): Temporal Auto-Encoder (TAE) Model to encode and decode images or videos to and from latent representations. + scale_factor (float): scale_factor for TAE. guidance_scale (float): A higher guidance scale value for noise rescale. num_sampling_steps: (int): The number of denoising steps. """ @@ -24,7 +24,7 @@ class InferPipeline: def __init__( self, model: nn.Cell, - vae: nn.Cell, + tae: nn.Cell, latent_size: Tuple[int, int, int] = (1, 64, 64), scale_factor: float = 1.0, guidance_scale: float = 1.0, @@ -34,7 +34,7 @@ def __init__( ): super().__init__() self.model = model - self.vae = vae + self.tae = tae self.latent_size = latent_size self.micro_batch_size = micro_batch_size self.scale_factor = scale_factor @@ -42,7 +42,7 @@ def __init__( self.use_cfg = guidance_scale > 1.0 self.rflow = RFLOW(num_sampling_steps, sample_method=sample_method) - def vae_decode_video(self, x, num_frames=None): + def tae_decode_video(self, x, num_frames=None): """ Args: x: (b t c h w), denoised latent @@ -50,7 +50,7 @@ def vae_decode_video(self, x, num_frames=None): y: (b f H W 3), batch of images, normalized to [0, 1] """ x = mint.permute(x, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy - y = self.vae.decode(x, num_frames=num_frames) # FIXME: extract scale_factor from VAE and use it here + y = self.tae.decode(x, target_num_frames=num_frames) # FIXME: extract scale_factor from TAE and use it here y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) # (b 3 t h w) -> (b t h w 3) y = mint.permute(y, (0, 2, 3, 4, 1)) @@ -68,7 +68,7 @@ def __call__( """ z = ms.Tensor( np.random.randn( - ul2_emb.shape[0], self.latent_size[0], self.vae.out_channels, self.latent_size[1], self.latent_size[2] + ul2_emb.shape[0], self.latent_size[0], self.tae.out_channels, self.latent_size[1], self.latent_size[2] ).astype(np.float32), dtype=self.model.dtype, ) @@ -83,10 +83,10 @@ def __call__( byt5_emb.to(self.model.dtype), ).to(ms.float32) - if self.vae is not None: + if self.tae is not None: # latents: (b t c h w) # out: (b T H W C) - images = self.vae_decode_video(latents, num_frames=num_frames) + images = self.tae_decode_video(latents, num_frames=num_frames) return images, latents else: return None, latents diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index 7829448fa3..359d79c9ce 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -55,7 +55,7 @@ def get_latents(self, video_tokens: Tensor) -> Tensor: # (b c f h w) shape is expected. FIXME: remove this redundancy video_tokens = mint.permute(video_tokens, (0, 2, 1, 3, 4)) # FIXME: extract scale_factor from VAE and use it here - video_emb = ops.stop_gradient(self.vae.encode(video_tokens)).to(ms.float32) + video_emb = ops.stop_gradient(self.vae.encode(video_tokens)[0]).to(ms.float32) video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME return video_emb diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 2f251f27e3..43f8c498bc 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -28,9 +28,10 @@ from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger -# TODO: remove when VAE is added to the project -sys.path.append(os.path.join(__dir__, "../opensora_hpcai/")) -from opensora.models.vae.vae import OpenSoraVAE_V1_2 +# TODO: remove when TAE is added to the project +sys.path.append(os.path.join(__dir__, "../movie_gen/")) +from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample +from mg.models.tae.tae import TemporalAutoencoder logger = logging.getLogger(__name__) @@ -55,23 +56,28 @@ def main(args): initializer = parser.instantiate_classes(cfg) # 2. model initialize and weight loading - # 2.1 VAE - logger.info("vae init") + # 2.1 TAE + logger.info("TAE init") # TODO: add support of training with latents - vae_args = args.vae.as_dict() - vae_dtype = vae_args.pop("dtype") - vae = OpenSoraVAE_V1_2(**vae_args).set_train(False) - if vae_dtype != "fp32": + tae_args = args.tae.as_dict() + tae_dtype = tae_args.pop("dtype") + tae = TemporalAutoencoder(**tae_args).set_train(False) + if tae_dtype != "fp32": # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative - amp.custom_mixed_precision(vae, black_list=amp.get_black_list() + [nn.GroupNorm], dtype=MODEL_DTYPE[vae_dtype]) + amp.custom_mixed_precision( + tae, + black_list=amp.get_black_list() + + [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample, nn.GroupNorm], + dtype=MODEL_DTYPE[tae_dtype], + ) # 2.2 Llama 3 - network = init_model(in_channels=vae.out_channels, **args.model) + network = init_model(in_channels=tae.out_channels, **args.model) # 2.3 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) # 3. build training network - latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, vae) + latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, tae) # 4. build dataset dataset = ImageVideoDataset(**args.dataset) @@ -92,7 +98,7 @@ def main(args): val_dataset, transforms=transforms, device_num=device_num, rank_id=shard_rank_id, **args.valid.dataloader ) eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps) - eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, vae) + eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, tae) # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR @@ -157,10 +163,10 @@ def main(args): # 5.5 print out key info and save config if rank_id == 0: - num_params_vae, num_params_trainable_vae = count_params(vae) + num_params_tae, num_params_trainable_tae = count_params(tae) num_params_network, num_params_trainable_network = count_params(network) - num_params = num_params_vae + num_params_network - num_params_trainable = num_params_trainable_vae + num_params_trainable_network + num_params = num_params_tae + num_params_network + num_params_trainable = num_params_trainable_tae + num_params_trainable_network key_info = "Key Settings:\n" + "=" * 50 + "\n" key_info += "\n".join( [ @@ -172,8 +178,8 @@ def main(args): f"Number of samples: {len(dataset)}", f"Model name: {args.model.name}", f"Model dtype: {args.model.dtype}", - f"VAE dtype: {args.vae.dtype}", - f"Num params: {num_params:,} (network: {num_params_network:,}, vae: {num_params_vae:,})", + f"TAE dtype: {args.tae.dtype}", + f"Num params: {num_params:,} (network: {num_params_network:,}, tae: {num_params_tae:,})", f"Num trainable params: {num_params_trainable:,}", f"Learning rate: {args.train.lr_scheduler.lr:.0e}", f"Batch size: {args.dataloader.batch_size}", @@ -209,9 +215,9 @@ def main(args): ) parser.add_function_arguments(init_train_env, "env") parser.add_function_arguments(init_model, "model", skip={"in_channels"}) - parser.add_function_arguments(OpenSoraVAE_V1_2, "vae", fail_untyped=False) + parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) parser.add_argument( - "--vae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="VAE model precision." + "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." ) parser.add_class_arguments( ImageVideoDataset, "dataset", skip={"frames_mask_generator", "t_compress_func"}, instantiate=False From 7bdac420d167005d99533c0f3d710ad8609d9feb Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:44:55 +0800 Subject: [PATCH 067/122] add buckets and dynamic graph support --- examples/movie_gen/mg/models/tae/modules.py | 6 +- .../configs/train/stage2_t2iv_256x256.yaml | 6 +- .../moviegen/moviegen/dataset/__init__.py | 1 + examples/moviegen/moviegen/dataset/buckets.py | 13 ++++ .../moviegen/moviegen/models/llama/network.py | 1 + examples/moviegen/moviegen/utils/callbacks.py | 8 +- examples/moviegen/train.py | 76 ++++++++++++++----- 7 files changed, 81 insertions(+), 30 deletions(-) create mode 100644 examples/moviegen/moviegen/dataset/buckets.py diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index 426f111f27..572b576017 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -5,7 +5,7 @@ from packaging import version import mindspore as ms -from mindspore import nn, ops +from mindspore import nn, ops, lazy_inline _logger = logging.getLogger(__name__) @@ -650,7 +650,7 @@ def make_attn(in_channels, attn_type="vanilla"): # used in vae class Encoder(nn.Cell): - # @ms.lazy_inline() + @lazy_inline() def __init__( self, ch=128, @@ -781,7 +781,7 @@ def construct(self, x): class Decoder(nn.Cell): - # @ms.lazy_inline() + @ms.lazy_inline() def __init__( self, *, diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index e26f2a2f60..4087d33ac5 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -28,12 +28,14 @@ dataset: byt5: EMPTY_TEXT_EMB text_drop_prob: 0.2 target_size: [ 256, 256 ] - sample_n_frames: 272 # FIXME: add variable frames support. FIXME: 17 * 16 = 272 frames of OSv1.2 VAE + sample_n_frames: 256 # FIXME: add variable frames support. apply_transforms_dataset: True output_columns: ["video", "ul2_caption", "byt5_caption"] dataloader: - batch_size: 1 + batch_size: + image_batch_size: 70 + video_batch_size: 1 shuffle: True num_workers_dataset: 4 diff --git a/examples/moviegen/moviegen/dataset/__init__.py b/examples/moviegen/moviegen/dataset/__init__.py index 54fc7d4725..d968ae874a 100644 --- a/examples/moviegen/moviegen/dataset/__init__.py +++ b/examples/moviegen/moviegen/dataset/__init__.py @@ -1 +1,2 @@ +from .buckets import bucket_split_function from .dataset import ImageVideoDataset diff --git a/examples/moviegen/moviegen/dataset/buckets.py b/examples/moviegen/moviegen/dataset/buckets.py new file mode 100644 index 0000000000..e8d4970f36 --- /dev/null +++ b/examples/moviegen/moviegen/dataset/buckets.py @@ -0,0 +1,13 @@ +from typing import Callable, List, Tuple + +import numpy as np + + +def bucket_split_function( + image_batch_size: int, video_batch_size: int +) -> Tuple[Callable[[np.ndarray], int], List[int], List[int]]: + return ( + lambda x: int(x.shape[0] > 1), # image or video + [1], # 2 buckets for now: image and videos of fixed length + [image_batch_size, video_batch_size], + ) diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index 5b40dc1c19..b59e487d85 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -49,6 +49,7 @@ def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: class LlamaDecoderLayer(nn.Cell): + @ms.lazy_inline(policy="front") def __init__( self, hidden_size: int = 4096, diff --git a/examples/moviegen/moviegen/utils/callbacks.py b/examples/moviegen/moviegen/utils/callbacks.py index 5c5587c6af..07cb1934e7 100644 --- a/examples/moviegen/moviegen/utils/callbacks.py +++ b/examples/moviegen/moviegen/utils/callbacks.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union import numpy as np import pandas as pd @@ -10,7 +10,7 @@ from mindspore import dtype as mstype from mindspore import mint, nn, ops from mindspore.communication import GlobalComm, get_group_size -from mindspore.dataset import GeneratorDataset +from mindspore.dataset import BatchDataset, BucketBatchByLengthDataset, GeneratorDataset from mindspore.ops import functional as F from mindone.trainers.ema import EMA @@ -26,7 +26,7 @@ class ValidationCallback(Callback): Args: network (nn.Cell): The neural network model to be validated. - dataset (GeneratorDataset): The dataset to use for validation. + dataset (BatchDataset, BucketBatchByLengthDataset, GeneratorDataset): The dataset to use for validation. alpha_smooth (float, optional): The smoothing factor for the loss. Defaults to 0.01. valid_frequency (int, optional): The frequency of validation in terms of training steps. Defaults to 100. @@ -43,7 +43,7 @@ class ValidationCallback(Callback): def __init__( self, network: nn.Cell, - dataset: GeneratorDataset, + dataset: Union[BatchDataset, BucketBatchByLengthDataset, GeneratorDataset], alpha_smooth: float = 0.01, valid_frequency: int = 100, ema: Optional[EMA] = None, diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 43f8c498bc..8c81b1197d 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -2,12 +2,15 @@ import os import re import sys -from math import ceil +from typing import Tuple, Union from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import path_type -from mindspore import Model, amp, nn, set_seed +from mindspore import GRAPH_MODE, Model, Symbol, Tensor, amp +from mindspore import dtype as mstype +from mindspore import get_context, nn, set_seed +from mindspore.dataset import BatchDataset, BucketBatchByLengthDataset from mindspore.train.callback import TimeMonitor # TODO: remove in future when mindone is ready for install @@ -15,7 +18,7 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.append(mindone_lib_path) -from moviegen.dataset import ImageVideoDataset +from moviegen.dataset import ImageVideoDataset, bucket_split_function from moviegen.parallel import create_parallel_group from moviegen.pipelines import DiffusionWithLoss from moviegen.schedulers import RFlowEvalLoss, RFlowLossWrapper @@ -36,11 +39,41 @@ logger = logging.getLogger(__name__) +def initialize_dataset( + dataset_args, dataloader_args, device_num: int, shard_rank_id: int +) -> Tuple[Union[BatchDataset, BucketBatchByLengthDataset], int]: + dataset = ImageVideoDataset(**dataset_args) + transforms = ( + dataset.train_transforms(dataset_args.target_size) if not dataset_args.apply_transforms_dataset else None + ) + dataloader_args = dataloader_args.as_dict() + batch_size = dataloader_args.pop("batch_size") + dataloader = create_dataloader( + dataset, + batch_size=batch_size if isinstance(batch_size, int) else 0, # Turn off batching if using buckets + transforms=transforms, + device_num=device_num, + rank_id=shard_rank_id, + **dataloader_args, + ) + if isinstance(batch_size, dict): # if buckets are used + hash_func, bucket_boundaries, bucket_batch_sizes = bucket_split_function(**batch_size) + dataloader = dataloader.bucket_batch_by_length( + ["video"], + bucket_boundaries, + bucket_batch_sizes, + element_length_function=hash_func, + drop_remainder=dataloader_args.drop_remainder, + ) + return dataloader, len(dataset) + + def main(args): # 1. init env args.train.output_path = args.train.output_path.absolute os.makedirs(args.train.output_path, exist_ok=True) device_id, rank_id, device_num = init_train_env(**args.env) + mode = get_context("mode") # `init_train_env()` may change the mode during debugging # 1.1 init model parallel shard_rank_id = rank_id @@ -79,30 +112,19 @@ def main(args): # 3. build training network latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, tae) - # 4. build dataset - dataset = ImageVideoDataset(**args.dataset) - transforms = ( - dataset.train_transforms(args.dataset.target_size) if not args.dataset.apply_transforms_dataset else None - ) - dataloader = create_dataloader( - dataset, transforms=transforms, device_num=device_num, rank_id=shard_rank_id, **args.dataloader - ) + # 4. build train & val datasets + dataloader, dataset_len = initialize_dataset(args.dataset, args.dataloader, device_num, shard_rank_id) eval_diffusion_with_loss, val_dataloader = None, None if args.valid.dataset is not None: - val_dataset = ImageVideoDataset(**args.valid.dataset.init_args) - transforms = None - if not args.valid.dataset.init_args.apply_transforms_dataset: - transforms = val_dataset.train_transforms(args.valid.dataset.init_args.target_size) - val_dataloader = create_dataloader( - val_dataset, transforms=transforms, device_num=device_num, rank_id=shard_rank_id, **args.valid.dataloader + val_dataloader, _ = initialize_dataset( + args.valid.dataset.init_args, args.valid.dataloader, device_num, shard_rank_id ) eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps) eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, tae) # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR - epochs = ceil(args.train.steps / dataloader.get_dataset_size()) lr = create_scheduler(steps_per_epoch=0, **args.train.lr_scheduler) # 5.2 optimizer @@ -120,6 +142,17 @@ def main(args): **args.train.settings, ) + # TODO: validation graph? + # if bucketing is used in Graph mode, activate dynamic inputs + if mode == GRAPH_MODE and isinstance(args.dataset.batch_size, dict): + bs = Symbol(unique=True) + video = Tensor(shape=[bs, None, 3, None, None], dtype=mstype.float32) + # FIXME: fix sequence length + ul2_emb = Tensor(shape=[bs, 300, 4096], dtype=mstype.float32) + byt5_emb = Tensor(shape=[bs, 100, 1472], dtype=mstype.float32) + net_with_grads.set_inputs(video, ul2_emb, byt5_emb) + logger.info("Dynamic inputs are initialized for bucket config training in Graph mode.") + model = Model(net_with_grads) # 5.4 callbacks @@ -170,12 +203,12 @@ def main(args): key_info = "Key Settings:\n" + "=" * 50 + "\n" key_info += "\n".join( [ - f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.env.mode}", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {mode}", f"Debug mode: {args.env.debug}", f"JIT level: {args.env.jit_level}", f"Distributed mode: {args.env.distributed}", f"Data path: {args.dataset.csv_path}", - f"Number of samples: {len(dataset)}", + f"Number of samples: {dataset_len}", f"Model name: {args.model.name}", f"Model dtype: {args.model.dtype}", f"TAE dtype: {args.tae.dtype}", @@ -202,7 +235,8 @@ def main(args): # 6. train logger.info("Start training...") - model.train(epochs, dataloader, callbacks=callbacks) + # train() uses epochs, so the training will be terminated by the StopAtStepCallback + model.train(args.train.steps, dataloader, callbacks=callbacks) if __name__ == "__main__": From 37346d002f5e209a2c7871923cf0b7d427c9f818 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Tue, 26 Nov 2024 16:37:39 +0800 Subject: [PATCH 068/122] fix dynamic shape: defualt manual pad for conv1d same pad --- examples/movie_gen/mg/models/tae/modules.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index 426f111f27..f133bfb851 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -358,9 +358,16 @@ def construct(self, x): class TemporalUpsample(nn.Cell): - def __init__(self, in_channels): + def __init__(self, in_channels, manual_pad=True): super().__init__() - self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init='zeros') + # to support danamic shape in graph mode + self.manual_pad = manual_pad + if not self.manual_pad: + self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init='zeros') + else: + self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="valid", has_bias=True, bias_init='zeros') + + # TODO: init conv weight so that it pass in image mode self.ch = in_channels self.init_weight('median') @@ -394,6 +401,12 @@ def construct(self, x): T = T0 * 2 x = ops.transpose(x, (0, 3, 1, 2)) x = ops.reshape(x, (B*H*W, C, T)) + + if self.manual_pad: + # work with pad_mode = valid, kernel_size=1 + pad_t_l = ops.zeros((B*H*W, C, 1), x.dtype) + pad_t_r = ops.zeros((B*H*W, C, 1), x.dtype) + x = ops.cat([pad_t_l, x, pad_t_r], 2) x = self.conv(x) From ef3fa0844db48932e86196a55f37d468271af82d Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:45:01 +0800 Subject: [PATCH 069/122] fix save callback and TAE scaling --- examples/movie_gen/mg/models/tae/tae.py | 2 ++ .../configs/train/stage1_t2i_256x256.yaml | 1 + .../configs/train/stage2_t2iv_256x256.yaml | 2 +- .../moviegen/pipelines/infer_pipeline.py | 7 ++++-- .../moviegen/pipelines/train_pipeline.py | 22 ++++++++++--------- examples/moviegen/train.py | 4 ++-- mindone/trainers/callback.py | 11 ++++++---- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/movie_gen/mg/models/tae/tae.py index 9556fa3a7a..15f9ca0316 100644 --- a/examples/movie_gen/mg/models/tae/tae.py +++ b/examples/movie_gen/mg/models/tae/tae.py @@ -66,6 +66,8 @@ def __init__( ): super().__init__() self.out_channels = config["z_channels"] + self.scale_factor = config["scaling_factor"] + self.shift_factor = config["shift_factor"] # encoder self.encoder = Encoder(**config) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index 0027c5391e..1c1e5b1bf9 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -75,6 +75,7 @@ train: save: ckpt_save_policy: top_k + monitor_metric: eval_loss_smoothed ckpt_save_interval: &save_interval 100 ckpt_max_keep: 10 log_interval: 1 diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index 4087d33ac5..add1eb20a4 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -77,7 +77,7 @@ train: clip_norm: 1.0 save: - ckpt_save_policy: top_k + ckpt_save_policy: latest_k ckpt_save_interval: &save_interval 100 ckpt_max_keep: 10 log_interval: 1 diff --git a/examples/moviegen/moviegen/pipelines/infer_pipeline.py b/examples/moviegen/moviegen/pipelines/infer_pipeline.py index cf419fd014..9147bd0cc2 100644 --- a/examples/moviegen/moviegen/pipelines/infer_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/infer_pipeline.py @@ -26,7 +26,8 @@ def __init__( model: nn.Cell, tae: nn.Cell, latent_size: Tuple[int, int, int] = (1, 64, 64), - scale_factor: float = 1.0, + scale_factor: float = 1.5305, + shift_factor: float = 0.0609, guidance_scale: float = 1.0, num_sampling_steps: int = 50, sample_method: Literal["linear", "linear-quadratic"] = "linear", @@ -37,7 +38,8 @@ def __init__( self.tae = tae self.latent_size = latent_size self.micro_batch_size = micro_batch_size - self.scale_factor = scale_factor + self.scale_factor = scale_factor if tae is None else tae.scale_factor + self.shift_factor = shift_factor if tae is None else tae.shift_factor self.guidance_rescale = guidance_scale self.use_cfg = guidance_scale > 1.0 self.rflow = RFLOW(num_sampling_steps, sample_method=sample_method) @@ -50,6 +52,7 @@ def tae_decode_video(self, x, num_frames=None): y: (b f H W 3), batch of images, normalized to [0, 1] """ x = mint.permute(x, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy + x = x / self.scale_factor + self.shift_factor y = self.tae.decode(x, target_num_frames=num_frames) # FIXME: extract scale_factor from TAE and use it here y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) # (b 3 t h w) -> (b t h w 3) diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/moviegen/pipelines/train_pipeline.py index 359d79c9ce..3ebacf3429 100644 --- a/examples/moviegen/moviegen/pipelines/train_pipeline.py +++ b/examples/moviegen/moviegen/pipelines/train_pipeline.py @@ -13,9 +13,10 @@ class DiffusionWithLoss(nn.Cell): def __init__( self, network: RFlowLossWrapper, - vae: Optional[nn.Cell] = None, + tae: Optional[nn.Cell] = None, text_encoder: Optional[nn.Cell] = None, - scale_factor: float = 0.13025, + scale_factor: float = 1.5305, + shift_factor: float = 0.0609, text_emb_cached: bool = True, video_emb_cached: bool = False, ): @@ -23,18 +24,19 @@ def __init__( if not text_emb_cached and text_encoder is None: raise ValueError("`text_encoder` must be provided when `text_emb_cached=False`.") - if not video_emb_cached and vae is None: - raise ValueError("`vae` must be provided when `video_emb_cached=False`.") + if not video_emb_cached and tae is None: + raise ValueError("`TAE` must be provided when `video_emb_cached=False`.") self.network = network - self.vae = vae + self.tae = tae self.text_encoder = text_encoder - self.scale_factor = scale_factor + self.scale_factor = scale_factor if tae is None else tae.scale_factor + self.shift_factor = shift_factor if tae is None else tae.shift_factor self.text_emb_cached = text_emb_cached self.video_emb_cached = video_emb_cached - if self.vae is not None: - for param in self.vae.trainable_params(): + if self.tae is not None: + for param in self.tae.trainable_params(): param.requires_grad = False if self.text_encoder is not None: @@ -54,8 +56,8 @@ def get_latents(self, video_tokens: Tensor) -> Tensor: with no_grad(): # (b c f h w) shape is expected. FIXME: remove this redundancy video_tokens = mint.permute(video_tokens, (0, 2, 1, 3, 4)) - # FIXME: extract scale_factor from VAE and use it here - video_emb = ops.stop_gradient(self.vae.encode(video_tokens)[0]).to(ms.float32) + video_emb = ops.stop_gradient(self.tae.encode(video_tokens)[0]).to(ms.float32) + video_emb = (video_emb - self.shift_factor) * self.scale_factor video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME return video_emb diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 8c81b1197d..e1a25efecb 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -63,7 +63,7 @@ def initialize_dataset( bucket_boundaries, bucket_batch_sizes, element_length_function=hash_func, - drop_remainder=dataloader_args.drop_remainder, + drop_remainder=dataloader_args["drop_remainder"], ) return dataloader, len(dataset) @@ -144,7 +144,7 @@ def main(args): # TODO: validation graph? # if bucketing is used in Graph mode, activate dynamic inputs - if mode == GRAPH_MODE and isinstance(args.dataset.batch_size, dict): + if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict): bs = Symbol(unique=True) video = Tensor(shape=[bs, None, 3, None, None], dtype=mstype.float32) # FIXME: fix sequence length diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index 3172f2bdd7..fe8e6d8c13 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import List +from typing import List, Optional import mindspore as ms from mindspore.train.callback._callback import Callback, _handle_loss @@ -35,7 +35,8 @@ def __init__( output_dir=None, ema=None, save_ema_only=True, - ckpt_save_policy="lastest_k", + ckpt_save_policy="latest_k", + monitor_metric: Optional[str] = None, ckpt_max_keep=10, step_mode=False, ckpt_save_interval=1, @@ -85,6 +86,7 @@ def __init__( if self.is_main_device: self.ckpt_save_policy = ckpt_save_policy + self.monitor_metric = monitor_metric self.ckpt_manager = CheckpointManager( ckpt_save_dir, ckpt_save_policy, @@ -159,8 +161,9 @@ def on_train_step_end(self, run_context): append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None perf = cb_params.get("eval_results") - if perf: - perf = perf["eval_loss_smoothed"] + if perf or self.ckpt_save_policy != "top_k": + if perf: + perf = perf[self.monitor_metric] if self.ema is not None: if not self.save_ema_only: self.ckpt_manager.save( From 0054700f3bd05c483942b5fd4561949625593f55 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 26 Nov 2024 16:03:51 +0800 Subject: [PATCH 070/122] Revert "fix hack" This reverts commit bf505d4f308a4f02b2c5fbfa97eb6e251c90b830. --- examples/moviegen/train.py | 10 ++-------- mindone/trainers/train_step.py | 7 ++++--- mindone/trainers/zero.py | 5 +---- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index e1a25efecb..aa486332c2 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -1,6 +1,5 @@ import logging import os -import re import sys from typing import Tuple, Union @@ -134,12 +133,7 @@ def main(args): ema = EMA(latent_diffusion_with_loss.network, **args.train.ema.init_args) if args.train.ema else None loss_scaler = initializer.train.loss_scaler net_with_grads = prepare_train_network( - latent_diffusion_with_loss, - optimizer=optimizer, - scale_sense=loss_scaler, - ema=ema, - need_reduce=tuple(bool(re.search(r"layers\.(\d+)\.mlp", param.name)) for param in optimizer.parameters), - **args.train.settings, + latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings ) # TODO: validation graph? @@ -273,7 +267,7 @@ def main(args): help="mindspore.nn.FixedLossScaleUpdateCell or mindspore.nn.DynamicLossScaleUpdateCell", ) parser.add_function_arguments( - prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema", "need_reduce"} + prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema"} ) parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False) parser.add_argument( diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index b33eca33f3..b33532c848 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -1,5 +1,7 @@ """Train step wrapper supporting setting drop overflow update, ema etc""" -from typing import Callable, Optional, Tuple +from typing import Optional + +from typing import Callable, Tuple from packaging import version @@ -99,7 +101,6 @@ def __init__( clip_norm=1.0, verbose=False, zero_helper=None, - need_reduce: Optional[Tuple[bool]] = None, ): super().__init__(network, optimizer, scale_sense) self.ema = ema @@ -126,7 +127,7 @@ def __init__( self.partial = ops.Partial() self.grad_reducer = GradReducer() - self.need_reduce = need_reduce + self.need_reduce = tuple([2048 in x.shape for x in self.weights]) # zero init self.zero_helper = zero_helper diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 9ccda9c629..7f5c0adaff 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Literal, Optional, Tuple +from typing import Literal import mindspore as ms from mindspore import nn, ops @@ -561,7 +561,6 @@ def prepare_train_network( dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, - need_reduce: Optional[Tuple[bool, ...]] = None, ): """ Prepare network and optimizer for distributed training. @@ -600,7 +599,6 @@ def prepare_train_network( clip_grad=clip_grad, clip_norm=clip_norm, verbose=verbose, - need_reduce=need_reduce, ) return train_network @@ -630,7 +628,6 @@ def prepare_train_network( clip_norm=clip_norm, verbose=verbose, zero_helper=zero_helper, - need_reduce=need_reduce, ) return train_network From 28349a5e05f9a24d9f2b0468be8f90716c0eaa9f Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 26 Nov 2024 16:03:51 +0800 Subject: [PATCH 071/122] Revert "hack for model parallel" This reverts commit 8af74372444239639e933b514de77750e846f372. --- mindone/trainers/train_step.py | 40 +--------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index b33532c848..18d0ccc0b8 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -1,8 +1,6 @@ """Train step wrapper supporting setting drop overflow update, ema etc""" from typing import Optional -from typing import Callable, Tuple - from packaging import version import mindspore as ms @@ -12,7 +10,6 @@ from mindspore.boost.grad_accumulation import gradient_clear_op as _grad_clear_op from mindspore.common import RowTensor from mindspore.common import dtype as mstype -from mindspore.communication import get_group_size from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -38,38 +35,6 @@ def tensor_grad_scale_row_tensor(scale, grad): ) -communicate_opt = C.MultitypeFuncGraph("communicate_opt") - - -@communicate_opt.register("Function", "Number", "Tensor", "Bool") -def _communicate_opt(func: Callable[[Tensor], Tensor], num: int, grad: Tensor, need_reduce: bool): - if not need_reduce: - return grad - grad = func(grad) - grad = grad / num - return grad - - -class GradReducer(nn.Cell): - def __init__(self): - super().__init__() - self.hypermap = C.HyperMap() - self.is_single = False - try: - self.num = get_group_size() - except RuntimeError: - self.is_single = True - - if not self.is_single: - self.reduce = ops.AllReduce() - - def construct(self, grads: Tuple[Tensor], need_reduce: Tuple[bool]): - if self.is_single: - return grads - grads = self.hypermap(ops.partial(communicate_opt, self.reduce, self.num), grads, need_reduce) - return grads - - class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell): """TrainStep with ema and clip grad. @@ -126,9 +91,6 @@ def __init__( self.map = ops.Map() self.partial = ops.Partial() - self.grad_reducer = GradReducer() - self.need_reduce = tuple([2048 in x.shape for x in self.weights]) - # zero init self.zero_helper = zero_helper self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0 @@ -168,7 +130,7 @@ def construct(self, *inputs): grads = self.zero_helper.cal_gradients(grads) if self.accum_steps == 1: - grads = self.grad_reducer(grads, self.need_reduce) + grads = self.grad_reducer(grads) scaling_sens = ops.depend(scaling_sens, grads) # 2. down-scale gradients by loss_scale. grads = grads / scaling_sense / grad_accum_steps From 84a25dd6e2ee3b703858d9a2464b1e5cf2c85c8c Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:43:20 +0800 Subject: [PATCH 072/122] revert it later --- examples/movie_gen/mg/models/tae/modules.py | 4 ++-- examples/moviegen/moviegen/models/llama/block.py | 12 ++++++------ examples/moviegen/moviegen/models/llama/network.py | 10 +++++----- .../moviegen/moviegen/schedulers/rectified_flow.py | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/movie_gen/mg/models/tae/modules.py index e7d95cbdd4..e09e037d27 100644 --- a/examples/movie_gen/mg/models/tae/modules.py +++ b/examples/movie_gen/mg/models/tae/modules.py @@ -663,7 +663,7 @@ def make_attn(in_channels, attn_type="vanilla"): # used in vae class Encoder(nn.Cell): - @lazy_inline() + # @lazy_inline() def __init__( self, ch=128, @@ -794,7 +794,7 @@ def construct(self, x): class Decoder(nn.Cell): - @ms.lazy_inline() + # @ms.lazy_inline() def __init__( self, *, diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/moviegen/models/llama/block.py index de71142ec0..a53db68e03 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/moviegen/models/llama/block.py @@ -422,9 +422,9 @@ def __init__( def construct(self, x: Tensor) -> Tensor: _, t, _, h, w = x.shape - assert t % self.patch_size[0] == 0 - assert h % self.patch_size[1] == 0 - assert w % self.patch_size[2] == 0 + # assert t % self.patch_size[0] == 0 + # assert h % self.patch_size[1] == 0 + # assert w % self.patch_size[2] == 0 x = mint.permute(x, (0, 2, 1, 3, 4)) x = self.proj(x) # (B C T H W) @@ -449,9 +449,9 @@ def __init__( def construct(self, x: Tensor) -> Tensor: b, t, c, h, w = x.shape - assert t % self.patch_size[0] == 0 - assert h % self.patch_size[1] == 0 - assert w % self.patch_size[2] == 0 + # assert t % self.patch_size[0] == 0 + # assert h % self.patch_size[1] == 0 + # assert w % self.patch_size[2] == 0 p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] nt, nh, nw = t // p0, h // p1, w // p2 diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/moviegen/models/llama/network.py index b59e487d85..8834b9e63e 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/moviegen/models/llama/network.py @@ -49,7 +49,7 @@ def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: class LlamaDecoderLayer(nn.Cell): - @ms.lazy_inline(policy="front") + # @ms.lazy_inline(policy="front") def __init__( self, hidden_size: int = 4096, @@ -426,9 +426,9 @@ def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] nt, nh, nw = t // p0, h // p1, w // p2 - assert nt < self.max_length[0] - assert nh < self.max_length[1] - assert nw < self.max_length[2] + # assert nt < self.max_length[0] + # assert nh < self.max_length[1] + # assert nw < self.max_length[2] t_inds = mint.arange(nt, dtype=ms.int64) h_inds = mint.arange(nh, dtype=ms.int64) @@ -492,7 +492,7 @@ def construct( # 3.1.6 Sequence Parallelism Start if self.model_parallelism: - assert hidden_states.shape[1] % self.group_size == 0 + # assert hidden_states.shape[1] % self.group_size == 0 hidden_states = self.split_forward_gather_backward(hidden_states) position_embedding = self.split_forward_gather_backward(position_embedding) diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/moviegen/schedulers/rectified_flow.py index 0a38dc26b2..a80bd7dd44 100644 --- a/examples/moviegen/moviegen/schedulers/rectified_flow.py +++ b/examples/moviegen/moviegen/schedulers/rectified_flow.py @@ -27,7 +27,7 @@ def __init__(self, loc: float = 0.0, scale: float = 1.0) -> None: self._max = Tensor(1.0 - np.finfo(np.float32).eps, dtype=ms.float32) def construct(self, shape: Tuple[int, ...]) -> Tensor: - assert shape[-1] == 1 + # assert shape[-1] == 1 x = mint.normal(mean=self.mean, std=self.std, size=shape) offset = x.shape[-1] + 1 - mint.cumsum(mint.ones(x.shape[-1]), dim=-1) z = self._clipped_sigmoid(x - mint.log(offset)) From a260a1e223f0b2d98bf5328297b126d7eaee2b36 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:48:31 +0800 Subject: [PATCH 073/122] small fixes --- examples/moviegen/inference.py | 1 + examples/moviegen/moviegen/utils/model_utils.py | 1 + examples/moviegen/train.py | 1 + 3 files changed, 3 insertions(+) diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index a903447668..24d6448257 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -89,6 +89,7 @@ def main(args): latent_size = tae.get_latent_size((num_frames, img_h, img_w)) # 2.2 Llama 3 + logger.info("Transformer init") model = init_model(in_channels=tae.out_channels, **args.model).set_train(False) # 2.3 text embeddings diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/moviegen/utils/model_utils.py index 7a5723da80..b139158036 100644 --- a/examples/moviegen/moviegen/utils/model_utils.py +++ b/examples/moviegen/moviegen/utils/model_utils.py @@ -24,6 +24,7 @@ def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: if isinstance(ckpt, str): logger.info(f"Loading {ckpt} params into network...") param_dict = ms.load_checkpoint(ckpt) + param_dict = {k.replace("network.model.", ""): v for k, v in param_dict.items()} else: param_dict = ckpt diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index aa486332c2..6204c10dae 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -104,6 +104,7 @@ def main(args): ) # 2.2 Llama 3 + logger.info("Transformer init") network = init_model(in_channels=tae.out_channels, **args.model) # 2.3 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) From ef72175dd137c841da344b223b82ccd066fc5e92 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:39:50 +0800 Subject: [PATCH 074/122] refactoring --- .../scripts => moviegen}/args_train_tae.py | 0 .../configs/tae/train/mixed_256x256x16.yaml | 0 .../configs/tae/train/mixed_256x256x32.yaml | 0 examples/moviegen/inference.py | 11 ++++------ examples/moviegen/inference_text_enc.py | 2 +- .../scripts => moviegen}/inference_vae.py | 6 ++---- .../moviegen/{moviegen => mg}/__init__.py | 0 .../{moviegen => mg}/dataset/__init__.py | 0 .../{moviegen => mg}/dataset/buckets.py | 0 .../{moviegen => mg}/dataset/dataset.py | 0 .../mg/dataset}/tae_dataset.py | 0 .../{moviegen => mg}/dataset/transforms.py | 0 .../{moviegen => mg}/models/__init__.py | 0 .../{moviegen => mg}/models/llama/__init__.py | 0 .../models/llama/activation.py | 0 .../{moviegen => mg}/models/llama/block.py | 2 +- .../{moviegen => mg}/models/llama/network.py | 4 ++-- examples/moviegen/mg/models/tae/__init__.py | 1 + .../mg/models/tae/losses.py | 0 .../mg/models/tae/lpips.py | 0 .../mg/models/tae/modules.py | 0 .../mg/models/tae/modules_2d.py | 0 .../mg/models/tae/sd3_vae.py | 0 .../mg/models/tae/tae.py | 0 .../models/text_encoders/__init__.py | 0 .../models/text_encoders/text_projector.py | 0 .../{moviegen => mg}/parallel/__init__.py | 0 .../{moviegen => mg}/parallel/layers.py | 0 .../parallel/parallel_states.py | 0 .../{moviegen => mg}/pipelines/__init__.py | 0 .../pipelines/infer_pipeline.py | 0 .../pipelines/train_pipeline.py | 0 .../{moviegen => mg}/schedulers/__init__.py | 0 .../schedulers/rectified_flow.py | 0 .../{moviegen => mg}/utils/__init__.py | 0 .../{moviegen => mg}/utils/callbacks.py | 0 .../moviegen/{moviegen => mg}/utils/ema.py | 0 .../mg/utils/load_models.py | 0 .../{moviegen => mg}/utils/model_utils.py | 2 +- .../mg/utils/parser.py | 0 .../moviegen/{moviegen => mg}/utils/utils.py | 0 .../run => moviegen/scripts}/run_train_tae.sh | 2 +- .../tests/parallel/test_llama3_parallel.py | 4 ++-- .../parallel/test_llama3_parallel_block.py | 4 ++-- .../parallel/test_llama3_parallel_layer.py | 2 +- .../tests/parallel/test_rflow_parallel.py | 4 ++-- .../tests => moviegen/tests/ut}/test_gn.py | 0 .../moviegen/tests/ut/test_llama3_forward.py | 2 +- examples/moviegen/tests/ut/test_rflow.py | 2 +- .../tests => moviegen/tests/ut}/test_tae.py | 5 ++--- .../tools/inflate_vae_to_tae.py | 1 - .../tools/ms_pnames_sd3.5_vae.txt | 0 .../tools/ms_pnames_tae_vae.txt | 0 .../tools/pt_pnames_sd3.5_vae.txt | 0 examples/moviegen/train.py | 21 +++++++------------ .../scripts => moviegen}/train_tae.py | 4 ++-- 56 files changed, 34 insertions(+), 45 deletions(-) rename examples/{movie_gen/scripts => moviegen}/args_train_tae.py (100%) rename examples/{movie_gen => moviegen}/configs/tae/train/mixed_256x256x16.yaml (100%) rename examples/{movie_gen => moviegen}/configs/tae/train/mixed_256x256x32.yaml (100%) rename examples/{movie_gen/scripts => moviegen}/inference_vae.py (98%) rename examples/moviegen/{moviegen => mg}/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/dataset/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/dataset/buckets.py (100%) rename examples/moviegen/{moviegen => mg}/dataset/dataset.py (100%) rename examples/{movie_gen/mg/datasets => moviegen/mg/dataset}/tae_dataset.py (100%) rename examples/moviegen/{moviegen => mg}/dataset/transforms.py (100%) rename examples/moviegen/{moviegen => mg}/models/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/models/llama/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/models/llama/activation.py (100%) rename examples/moviegen/{moviegen => mg}/models/llama/block.py (99%) rename examples/moviegen/{moviegen => mg}/models/llama/network.py (99%) create mode 100644 examples/moviegen/mg/models/tae/__init__.py rename examples/{movie_gen => moviegen}/mg/models/tae/losses.py (100%) rename examples/{movie_gen => moviegen}/mg/models/tae/lpips.py (100%) rename examples/{movie_gen => moviegen}/mg/models/tae/modules.py (100%) rename examples/{movie_gen => moviegen}/mg/models/tae/modules_2d.py (100%) rename examples/{movie_gen => moviegen}/mg/models/tae/sd3_vae.py (100%) rename examples/{movie_gen => moviegen}/mg/models/tae/tae.py (100%) rename examples/moviegen/{moviegen => mg}/models/text_encoders/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/models/text_encoders/text_projector.py (100%) rename examples/moviegen/{moviegen => mg}/parallel/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/parallel/layers.py (100%) rename examples/moviegen/{moviegen => mg}/parallel/parallel_states.py (100%) rename examples/moviegen/{moviegen => mg}/pipelines/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/pipelines/infer_pipeline.py (100%) rename examples/moviegen/{moviegen => mg}/pipelines/train_pipeline.py (100%) rename examples/moviegen/{moviegen => mg}/schedulers/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/schedulers/rectified_flow.py (100%) rename examples/moviegen/{moviegen => mg}/utils/__init__.py (100%) rename examples/moviegen/{moviegen => mg}/utils/callbacks.py (100%) rename examples/moviegen/{moviegen => mg}/utils/ema.py (100%) rename examples/{movie_gen => moviegen}/mg/utils/load_models.py (100%) rename examples/moviegen/{moviegen => mg}/utils/model_utils.py (97%) rename examples/{movie_gen => moviegen}/mg/utils/parser.py (100%) rename examples/moviegen/{moviegen => mg}/utils/utils.py (100%) rename examples/{movie_gen/scripts/run => moviegen/scripts}/run_train_tae.sh (95%) mode change 100755 => 100644 rename examples/{movie_gen/tests => moviegen/tests/ut}/test_gn.py (100%) rename examples/{movie_gen/tests => moviegen/tests/ut}/test_tae.py (98%) rename examples/{movie_gen => moviegen}/tools/inflate_vae_to_tae.py (99%) rename examples/{movie_gen => moviegen}/tools/ms_pnames_sd3.5_vae.txt (100%) rename examples/{movie_gen => moviegen}/tools/ms_pnames_tae_vae.txt (100%) rename examples/{movie_gen => moviegen}/tools/pt_pnames_sd3.5_vae.txt (100%) rename examples/{movie_gen/scripts => moviegen}/train_tae.py (99%) diff --git a/examples/movie_gen/scripts/args_train_tae.py b/examples/moviegen/args_train_tae.py similarity index 100% rename from examples/movie_gen/scripts/args_train_tae.py rename to examples/moviegen/args_train_tae.py diff --git a/examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml similarity index 100% rename from examples/movie_gen/configs/tae/train/mixed_256x256x16.yaml rename to examples/moviegen/configs/tae/train/mixed_256x256x16.yaml diff --git a/examples/movie_gen/configs/tae/train/mixed_256x256x32.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml similarity index 100% rename from examples/movie_gen/configs/tae/train/mixed_256x256x32.yaml rename to examples/moviegen/configs/tae/train/mixed_256x256x32.yaml diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 24d6448257..cc705d0b27 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -18,17 +18,14 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.append(mindone_lib_path) -from moviegen.pipelines import InferPipeline -from moviegen.utils import MODEL_DTYPE, init_model, to_numpy +from mg.models.tae import TemporalAutoencoder +from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample +from mg.pipelines import InferPipeline +from mg.utils import MODEL_DTYPE, init_model, to_numpy from mindone.utils import init_train_env, set_logger from mindone.visualize.videos import save_videos -# TODO: remove when TAE is added to the project -sys.path.append(os.path.join(__dir__, "../movie_gen/")) -from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample -from mg.models.tae.tae import TemporalAutoencoder - logger = logging.getLogger(__name__) Path_dr = path_type("dr", docstring="path to a directory that exists and is readable") diff --git a/examples/moviegen/inference_text_enc.py b/examples/moviegen/inference_text_enc.py index 5a500bffe4..05ced5b108 100644 --- a/examples/moviegen/inference_text_enc.py +++ b/examples/moviegen/inference_text_enc.py @@ -19,7 +19,7 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.append(mindone_lib_path) -from moviegen.utils import MODEL_DTYPE, to_numpy +from mg.utils import MODEL_DTYPE, to_numpy from mindone.transformers.models.t5.modeling_t5 import T5EncoderModel from mindone.utils import init_train_env, set_logger diff --git a/examples/movie_gen/scripts/inference_vae.py b/examples/moviegen/inference_vae.py similarity index 98% rename from examples/movie_gen/scripts/inference_vae.py rename to examples/moviegen/inference_vae.py index ba9fe34d6d..3c30ca67bd 100644 --- a/examples/movie_gen/scripts/inference_vae.py +++ b/examples/moviegen/inference_vae.py @@ -18,8 +18,6 @@ mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_dir) - -from omegaconf import OmegaConf from PIL import Image from skimage.metrics import peak_signal_noise_ratio as calc_psnr from skimage.metrics import structural_similarity as calc_ssim @@ -32,12 +30,12 @@ sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) -from mg.datasets.tae_dataset import create_dataloader +from mg.dataset.tae_dataset import create_dataloader from mg.models.tae.tae import TemporalAutoencoder from mg.models.tae.lpips import LPIPS from mindone.utils.amp import auto_mixed_precision -from mindone.utils.config import instantiate_from_config, str2bool +from mindone.utils.config import str2bool from mindone.utils.logger import set_logger logger = logging.getLogger(__name__) diff --git a/examples/moviegen/moviegen/__init__.py b/examples/moviegen/mg/__init__.py similarity index 100% rename from examples/moviegen/moviegen/__init__.py rename to examples/moviegen/mg/__init__.py diff --git a/examples/moviegen/moviegen/dataset/__init__.py b/examples/moviegen/mg/dataset/__init__.py similarity index 100% rename from examples/moviegen/moviegen/dataset/__init__.py rename to examples/moviegen/mg/dataset/__init__.py diff --git a/examples/moviegen/moviegen/dataset/buckets.py b/examples/moviegen/mg/dataset/buckets.py similarity index 100% rename from examples/moviegen/moviegen/dataset/buckets.py rename to examples/moviegen/mg/dataset/buckets.py diff --git a/examples/moviegen/moviegen/dataset/dataset.py b/examples/moviegen/mg/dataset/dataset.py similarity index 100% rename from examples/moviegen/moviegen/dataset/dataset.py rename to examples/moviegen/mg/dataset/dataset.py diff --git a/examples/movie_gen/mg/datasets/tae_dataset.py b/examples/moviegen/mg/dataset/tae_dataset.py similarity index 100% rename from examples/movie_gen/mg/datasets/tae_dataset.py rename to examples/moviegen/mg/dataset/tae_dataset.py diff --git a/examples/moviegen/moviegen/dataset/transforms.py b/examples/moviegen/mg/dataset/transforms.py similarity index 100% rename from examples/moviegen/moviegen/dataset/transforms.py rename to examples/moviegen/mg/dataset/transforms.py diff --git a/examples/moviegen/moviegen/models/__init__.py b/examples/moviegen/mg/models/__init__.py similarity index 100% rename from examples/moviegen/moviegen/models/__init__.py rename to examples/moviegen/mg/models/__init__.py diff --git a/examples/moviegen/moviegen/models/llama/__init__.py b/examples/moviegen/mg/models/llama/__init__.py similarity index 100% rename from examples/moviegen/moviegen/models/llama/__init__.py rename to examples/moviegen/mg/models/llama/__init__.py diff --git a/examples/moviegen/moviegen/models/llama/activation.py b/examples/moviegen/mg/models/llama/activation.py similarity index 100% rename from examples/moviegen/moviegen/models/llama/activation.py rename to examples/moviegen/mg/models/llama/activation.py diff --git a/examples/moviegen/moviegen/models/llama/block.py b/examples/moviegen/mg/models/llama/block.py similarity index 99% rename from examples/moviegen/moviegen/models/llama/block.py rename to examples/moviegen/mg/models/llama/block.py index a53db68e03..474f3d821e 100644 --- a/examples/moviegen/moviegen/models/llama/block.py +++ b/examples/moviegen/mg/models/llama/block.py @@ -2,7 +2,7 @@ from typing import Optional, Sequence, Tuple, Union import numpy as np -from moviegen.parallel import ( +from mg.parallel import ( ColumnParallelLinear, FusedColumnParallelLinear, FusedRowParallelLinear, diff --git a/examples/moviegen/moviegen/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py similarity index 99% rename from examples/moviegen/moviegen/models/llama/network.py rename to examples/moviegen/mg/models/llama/network.py index 8834b9e63e..5269f70491 100644 --- a/examples/moviegen/moviegen/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -3,8 +3,8 @@ from typing import Literal, Optional, Tuple, Union import numpy as np -from moviegen.parallel import GatherForwardSplitBackward, SplitForwardGatherBackward -from moviegen.parallel.parallel_states import get_model_parallel_group +from mg.parallel import GatherForwardSplitBackward, SplitForwardGatherBackward +from mg.parallel.parallel_states import get_model_parallel_group import mindspore as ms import mindspore.mint as mint diff --git a/examples/moviegen/mg/models/tae/__init__.py b/examples/moviegen/mg/models/tae/__init__.py new file mode 100644 index 0000000000..75d32e29fc --- /dev/null +++ b/examples/moviegen/mg/models/tae/__init__.py @@ -0,0 +1 @@ +from .tae import TemporalAutoencoder diff --git a/examples/movie_gen/mg/models/tae/losses.py b/examples/moviegen/mg/models/tae/losses.py similarity index 100% rename from examples/movie_gen/mg/models/tae/losses.py rename to examples/moviegen/mg/models/tae/losses.py diff --git a/examples/movie_gen/mg/models/tae/lpips.py b/examples/moviegen/mg/models/tae/lpips.py similarity index 100% rename from examples/movie_gen/mg/models/tae/lpips.py rename to examples/moviegen/mg/models/tae/lpips.py diff --git a/examples/movie_gen/mg/models/tae/modules.py b/examples/moviegen/mg/models/tae/modules.py similarity index 100% rename from examples/movie_gen/mg/models/tae/modules.py rename to examples/moviegen/mg/models/tae/modules.py diff --git a/examples/movie_gen/mg/models/tae/modules_2d.py b/examples/moviegen/mg/models/tae/modules_2d.py similarity index 100% rename from examples/movie_gen/mg/models/tae/modules_2d.py rename to examples/moviegen/mg/models/tae/modules_2d.py diff --git a/examples/movie_gen/mg/models/tae/sd3_vae.py b/examples/moviegen/mg/models/tae/sd3_vae.py similarity index 100% rename from examples/movie_gen/mg/models/tae/sd3_vae.py rename to examples/moviegen/mg/models/tae/sd3_vae.py diff --git a/examples/movie_gen/mg/models/tae/tae.py b/examples/moviegen/mg/models/tae/tae.py similarity index 100% rename from examples/movie_gen/mg/models/tae/tae.py rename to examples/moviegen/mg/models/tae/tae.py diff --git a/examples/moviegen/moviegen/models/text_encoders/__init__.py b/examples/moviegen/mg/models/text_encoders/__init__.py similarity index 100% rename from examples/moviegen/moviegen/models/text_encoders/__init__.py rename to examples/moviegen/mg/models/text_encoders/__init__.py diff --git a/examples/moviegen/moviegen/models/text_encoders/text_projector.py b/examples/moviegen/mg/models/text_encoders/text_projector.py similarity index 100% rename from examples/moviegen/moviegen/models/text_encoders/text_projector.py rename to examples/moviegen/mg/models/text_encoders/text_projector.py diff --git a/examples/moviegen/moviegen/parallel/__init__.py b/examples/moviegen/mg/parallel/__init__.py similarity index 100% rename from examples/moviegen/moviegen/parallel/__init__.py rename to examples/moviegen/mg/parallel/__init__.py diff --git a/examples/moviegen/moviegen/parallel/layers.py b/examples/moviegen/mg/parallel/layers.py similarity index 100% rename from examples/moviegen/moviegen/parallel/layers.py rename to examples/moviegen/mg/parallel/layers.py diff --git a/examples/moviegen/moviegen/parallel/parallel_states.py b/examples/moviegen/mg/parallel/parallel_states.py similarity index 100% rename from examples/moviegen/moviegen/parallel/parallel_states.py rename to examples/moviegen/mg/parallel/parallel_states.py diff --git a/examples/moviegen/moviegen/pipelines/__init__.py b/examples/moviegen/mg/pipelines/__init__.py similarity index 100% rename from examples/moviegen/moviegen/pipelines/__init__.py rename to examples/moviegen/mg/pipelines/__init__.py diff --git a/examples/moviegen/moviegen/pipelines/infer_pipeline.py b/examples/moviegen/mg/pipelines/infer_pipeline.py similarity index 100% rename from examples/moviegen/moviegen/pipelines/infer_pipeline.py rename to examples/moviegen/mg/pipelines/infer_pipeline.py diff --git a/examples/moviegen/moviegen/pipelines/train_pipeline.py b/examples/moviegen/mg/pipelines/train_pipeline.py similarity index 100% rename from examples/moviegen/moviegen/pipelines/train_pipeline.py rename to examples/moviegen/mg/pipelines/train_pipeline.py diff --git a/examples/moviegen/moviegen/schedulers/__init__.py b/examples/moviegen/mg/schedulers/__init__.py similarity index 100% rename from examples/moviegen/moviegen/schedulers/__init__.py rename to examples/moviegen/mg/schedulers/__init__.py diff --git a/examples/moviegen/moviegen/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py similarity index 100% rename from examples/moviegen/moviegen/schedulers/rectified_flow.py rename to examples/moviegen/mg/schedulers/rectified_flow.py diff --git a/examples/moviegen/moviegen/utils/__init__.py b/examples/moviegen/mg/utils/__init__.py similarity index 100% rename from examples/moviegen/moviegen/utils/__init__.py rename to examples/moviegen/mg/utils/__init__.py diff --git a/examples/moviegen/moviegen/utils/callbacks.py b/examples/moviegen/mg/utils/callbacks.py similarity index 100% rename from examples/moviegen/moviegen/utils/callbacks.py rename to examples/moviegen/mg/utils/callbacks.py diff --git a/examples/moviegen/moviegen/utils/ema.py b/examples/moviegen/mg/utils/ema.py similarity index 100% rename from examples/moviegen/moviegen/utils/ema.py rename to examples/moviegen/mg/utils/ema.py diff --git a/examples/movie_gen/mg/utils/load_models.py b/examples/moviegen/mg/utils/load_models.py similarity index 100% rename from examples/movie_gen/mg/utils/load_models.py rename to examples/moviegen/mg/utils/load_models.py diff --git a/examples/moviegen/moviegen/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py similarity index 97% rename from examples/moviegen/moviegen/utils/model_utils.py rename to examples/moviegen/mg/utils/model_utils.py index b139158036..91ef90d0fc 100644 --- a/examples/moviegen/moviegen/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -2,7 +2,7 @@ from typing import Dict, Literal, Optional, Union from jsonargparse.typing import Path_fr -from moviegen import LlamaModel, llama3_1B, llama3_5B, llama3_30B +from mg import LlamaModel, llama3_1B, llama3_5B, llama3_30B import mindspore as ms from mindspore import _no_grad, jit_class, nn diff --git a/examples/movie_gen/mg/utils/parser.py b/examples/moviegen/mg/utils/parser.py similarity index 100% rename from examples/movie_gen/mg/utils/parser.py rename to examples/moviegen/mg/utils/parser.py diff --git a/examples/moviegen/moviegen/utils/utils.py b/examples/moviegen/mg/utils/utils.py similarity index 100% rename from examples/moviegen/moviegen/utils/utils.py rename to examples/moviegen/mg/utils/utils.py diff --git a/examples/movie_gen/scripts/run/run_train_tae.sh b/examples/moviegen/scripts/run_train_tae.sh old mode 100755 new mode 100644 similarity index 95% rename from examples/movie_gen/scripts/run/run_train_tae.sh rename to examples/moviegen/scripts/run_train_tae.sh index f056ef4566..babd24a99e --- a/examples/movie_gen/scripts/run/run_train_tae.sh +++ b/examples/moviegen/scripts/run_train_tae.sh @@ -13,7 +13,7 @@ export GLOG_v=2 output_dir=outputs/debug_train_tae_1p_sd3.5vaeInit_noOpl -python scripts/train_tae.py \ +python train_tae.py \ --mode=0 \ --jit_level O0 \ --amp_level O0 \ diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py index 08ec46c59d..542e44d007 100644 --- a/examples/moviegen/tests/parallel/test_llama3_parallel.py +++ b/examples/moviegen/tests/parallel/test_llama3_parallel.py @@ -2,8 +2,8 @@ from typing import Tuple import numpy as np -from moviegen.models.llama.network import LlamaModel -from moviegen.parallel import create_parallel_group +from mg.models.llama.network import LlamaModel +from mg.parallel import create_parallel_group from utils import gather_or_reduce_parallel_gradient import mindspore as ms diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py index f9b3a765a8..82141a1d31 100644 --- a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py @@ -1,8 +1,8 @@ import argparse import numpy as np -from moviegen.models.llama.block import LlamaMLP, TensorParallelLlamaMLP -from moviegen.parallel import create_parallel_group +from mg.models.llama.block import LlamaMLP, TensorParallelLlamaMLP +from mg.parallel import create_parallel_group from utils import gather_or_reduce_parallel_gradient import mindspore as ms diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py index a2e35a0576..a4c5afb140 100644 --- a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py @@ -2,7 +2,7 @@ from typing import Literal import numpy as np -from moviegen.parallel import ColumnParallelLinear, RowParallelLinear, create_parallel_group, get_model_parallel_group +from mg.parallel import ColumnParallelLinear, RowParallelLinear, create_parallel_group, get_model_parallel_group from utils import gather_or_reduce_parallel_gradient import mindspore as ms diff --git a/examples/moviegen/tests/parallel/test_rflow_parallel.py b/examples/moviegen/tests/parallel/test_rflow_parallel.py index eec776fba7..a3d6302e3a 100644 --- a/examples/moviegen/tests/parallel/test_rflow_parallel.py +++ b/examples/moviegen/tests/parallel/test_rflow_parallel.py @@ -1,8 +1,8 @@ import argparse from typing import Tuple -from moviegen.parallel import create_parallel_group -from moviegen.schedulers import RFlowLossWrapper +from mg.parallel import create_parallel_group +from mg.schedulers import RFlowLossWrapper import mindspore as ms from mindspore import Tensor, nn, ops diff --git a/examples/movie_gen/tests/test_gn.py b/examples/moviegen/tests/ut/test_gn.py similarity index 100% rename from examples/movie_gen/tests/test_gn.py rename to examples/moviegen/tests/ut/test_gn.py diff --git a/examples/moviegen/tests/ut/test_llama3_forward.py b/examples/moviegen/tests/ut/test_llama3_forward.py index f582962559..260e65cb04 100644 --- a/examples/moviegen/tests/ut/test_llama3_forward.py +++ b/examples/moviegen/tests/ut/test_llama3_forward.py @@ -1,5 +1,5 @@ import numpy as np -from moviegen import llama3_1B +from mg import llama3_1B import mindspore as ms diff --git a/examples/moviegen/tests/ut/test_rflow.py b/examples/moviegen/tests/ut/test_rflow.py index a5bc12da89..7bd3c1c8c0 100644 --- a/examples/moviegen/tests/ut/test_rflow.py +++ b/examples/moviegen/tests/ut/test_rflow.py @@ -1,5 +1,5 @@ import numpy as np -from moviegen.schedulers import RFlowLossWrapper +from mg.schedulers import RFlowLossWrapper import mindspore as ms import mindspore.nn as nn diff --git a/examples/movie_gen/tests/test_tae.py b/examples/moviegen/tests/ut/test_tae.py similarity index 98% rename from examples/movie_gen/tests/test_tae.py rename to examples/moviegen/tests/ut/test_tae.py index a426ed3654..45e24c4ec7 100644 --- a/examples/movie_gen/tests/test_tae.py +++ b/examples/moviegen/tests/ut/test_tae.py @@ -1,14 +1,13 @@ import numpy as np import sys from PIL import Image -sys.path.insert(0, '.') +sys.path.insert(0, '..') from mg.models.tae.modules import ( Conv2_5d, Decoder, Encoder, ResnetBlock, - SpatialAttnBlock, SpatialAttnBlockV2, SpatialDownsample, SpatialUpsample, @@ -17,7 +16,7 @@ TemporalUpsample, ) from mg.models.tae.tae import SDXL_CONFIG, TAE_CONFIG, TemporalAutoencoder -from mg.models.tae.sd3_vae import SD3d5_CONFIG, SD3d5_VAE +from mg.models.tae.sd3_vae import SD3d5_VAE import mindspore as ms diff --git a/examples/movie_gen/tools/inflate_vae_to_tae.py b/examples/moviegen/tools/inflate_vae_to_tae.py similarity index 99% rename from examples/movie_gen/tools/inflate_vae_to_tae.py rename to examples/moviegen/tools/inflate_vae_to_tae.py index f09bf4a2a4..6893fdc633 100644 --- a/examples/movie_gen/tools/inflate_vae_to_tae.py +++ b/examples/moviegen/tools/inflate_vae_to_tae.py @@ -1,6 +1,5 @@ from safetensors import safe_open import argparse -import os import numpy as np import mindspore as ms diff --git a/examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt b/examples/moviegen/tools/ms_pnames_sd3.5_vae.txt similarity index 100% rename from examples/movie_gen/tools/ms_pnames_sd3.5_vae.txt rename to examples/moviegen/tools/ms_pnames_sd3.5_vae.txt diff --git a/examples/movie_gen/tools/ms_pnames_tae_vae.txt b/examples/moviegen/tools/ms_pnames_tae_vae.txt similarity index 100% rename from examples/movie_gen/tools/ms_pnames_tae_vae.txt rename to examples/moviegen/tools/ms_pnames_tae_vae.txt diff --git a/examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt b/examples/moviegen/tools/pt_pnames_sd3.5_vae.txt similarity index 100% rename from examples/movie_gen/tools/pt_pnames_sd3.5_vae.txt rename to examples/moviegen/tools/pt_pnames_sd3.5_vae.txt diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 6204c10dae..1596708567 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -10,19 +10,20 @@ from mindspore import dtype as mstype from mindspore import get_context, nn, set_seed from mindspore.dataset import BatchDataset, BucketBatchByLengthDataset -from mindspore.train.callback import TimeMonitor # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.append(mindone_lib_path) -from moviegen.dataset import ImageVideoDataset, bucket_split_function -from moviegen.parallel import create_parallel_group -from moviegen.pipelines import DiffusionWithLoss -from moviegen.schedulers import RFlowEvalLoss, RFlowLossWrapper -from moviegen.utils import EMA, MODEL_DTYPE, init_model -from moviegen.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback +from mg.dataset import ImageVideoDataset, bucket_split_function +from mg.models.tae import TemporalAutoencoder +from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample +from mg.parallel import create_parallel_group +from mg.pipelines import DiffusionWithLoss +from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper +from mg.utils import EMA, MODEL_DTYPE, init_model +from mg.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback from mindone.data import create_dataloader from mindone.trainers import create_optimizer, create_scheduler @@ -30,11 +31,6 @@ from mindone.trainers.zero import prepare_train_network from mindone.utils import count_params, init_train_env, set_logger -# TODO: remove when TAE is added to the project -sys.path.append(os.path.join(__dir__, "../movie_gen/")) -from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample -from mg.models.tae.tae import TemporalAutoencoder - logger = logging.getLogger(__name__) @@ -169,7 +165,6 @@ def main(args): if rank_id == 0: callbacks.extend( [ - TimeMonitor(args.train.save.log_interval), EvalSaveCallback( network=latent_diffusion_with_loss.network, model_name=args.model.name, diff --git a/examples/movie_gen/scripts/train_tae.py b/examples/moviegen/train_tae.py similarity index 99% rename from examples/movie_gen/scripts/train_tae.py rename to examples/moviegen/train_tae.py index 44152782b6..0496157d0f 100644 --- a/examples/movie_gen/scripts/train_tae.py +++ b/examples/moviegen/train_tae.py @@ -19,10 +19,10 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) from args_train_tae import parse_args -from mg.datasets.tae_dataset import create_dataloader +from mg.dataset.tae_dataset import create_dataloader from mg.models.tae.losses import GeneratorWithLoss from mg.models.tae.tae import TemporalAutoencoder -from mg.models.tae.modules import SpatialUpsample, SpatialDownsample, TemporalUpsample, TemporalDownsample +from mg.models.tae.modules import SpatialUpsample, SpatialDownsample, TemporalUpsample, TemporalDownsample from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network From 5ef55ef8800a78cbee46c4792e311ce0aa90dc4a Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:37:06 +0800 Subject: [PATCH 075/122] linting --- examples/movie_gen/README.md | 13 +- examples/moviegen/args_train_tae.py | 4 +- examples/moviegen/inference_vae.py | 16 +- examples/moviegen/mg/models/tae/losses.py | 12 +- examples/moviegen/mg/models/tae/modules.py | 240 +++++++++++------- examples/moviegen/mg/models/tae/sd3_vae.py | 30 +-- examples/moviegen/mg/utils/parser.py | 1 - examples/moviegen/scripts/run_train_tae.sh | 1 - examples/moviegen/tests/ut/test_tae.py | 38 ++- examples/moviegen/tools/inflate_vae_to_tae.py | 23 +- examples/moviegen/train_tae.py | 15 +- .../opensora_hpcai/tools/mem_monitor/plot.py | 2 - 12 files changed, 227 insertions(+), 168 deletions(-) diff --git a/examples/movie_gen/README.md b/examples/movie_gen/README.md index 32510c2f30..bb5ae59175 100644 --- a/examples/movie_gen/README.md +++ b/examples/movie_gen/README.md @@ -16,7 +16,7 @@ We use SD3.5 VAE to initialize the spatial layers of TAE, since both have a late 2. Convert VAE checkpoint for TAE loading ```shell -python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt +python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt ``` @@ -33,7 +33,7 @@ python scripts/train_tae.py \ ``` -OPL - outlier penality loss is found to be not beneficial in our experiment (PSNR decreased). Thus we set it to False by default. +OPL - outlier penality loss is found to be not beneficial in our experiment (PSNR decreased). Thus we set it to False by default. Change mixed_256x256x16.yaml to mixed_256x256x32.yaml for training on 32 frames. @@ -63,7 +63,7 @@ python scripts/inference_vae.py \ --enable_tile=False \ ``` -#### Encoding video +#### Encoding video ```python from mg.models.tae.tae import TemporalAutoencoder, TAE_CONFIG @@ -79,7 +79,7 @@ z, _, _ = tae.encode(x) # you may scale z by: -# z = TAE_CONFIG['scaling_factor'] * z + TAE_CONFIG['shift_factor'] +# z = TAE_CONFIG['scaling_factor'] * z + TAE_CONFIG['shift_factor'] ``` @@ -91,12 +91,11 @@ For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/ ```python # if z is scaled, you should unscale at first: -# z = (z - TAE_CONFIG['shift_factor']) / TAE_CONFIG['scaling_factor'] +# z = (z - TAE_CONFIG['shift_factor']) / TAE_CONFIG['scaling_factor'] # z - a batch of video latent, shape (b c t h w) x = tae.decode(z) -# for image decoding, set num_target_frames to discard the spurious frames +# for image decoding, set num_target_frames to discard the spurious frames x = tae.decode(z, num_target_frames=1) ``` - diff --git a/examples/moviegen/args_train_tae.py b/examples/moviegen/args_train_tae.py index e06e73a068..8b1634613d 100644 --- a/examples/moviegen/args_train_tae.py +++ b/examples/moviegen/args_train_tae.py @@ -215,7 +215,9 @@ def parse_train_args(parser): parser.add_argument( "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model." ) - parser.add_argument("--image_size", default=256, type=int, nargs="+", help="image size for resizing the input image") + parser.add_argument( + "--image_size", default=256, type=int, nargs="+", help="image size for resizing the input image" + ) parser.add_argument("--crop_size", default=256, type=int, help="crop size after resize") parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model") parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride") diff --git a/examples/moviegen/inference_vae.py b/examples/moviegen/inference_vae.py index 3c30ca67bd..59b0e54e22 100644 --- a/examples/moviegen/inference_vae.py +++ b/examples/moviegen/inference_vae.py @@ -13,7 +13,6 @@ from mindspore import nn, ops - __dir__ = os.path.dirname(os.path.abspath(__file__)) mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_dir) @@ -31,8 +30,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) from mg.dataset.tae_dataset import create_dataloader -from mg.models.tae.tae import TemporalAutoencoder from mg.models.tae.lpips import LPIPS +from mg.models.tae.tae import TemporalAutoencoder from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool @@ -84,6 +83,7 @@ def rearrange_in(x): x = ops.reshape(x, (b * t, h, w, c)) return x + def rearrange_out(x, t): bt, c, h, w = x.shape b = bt // t @@ -101,13 +101,13 @@ def main(args): model = TemporalAutoencoder( pretrained=args.ckpt_path, use_tile=args.enable_tile, - ) + ) model.set_train(False) logger.info(f"Loaded checkpoint from {args.ckpt_path}") if args.eval_loss: - lpips_loss_fn = LPIPS() + lpips_loss_fn = LPIPS() if args.dtype != "fp32": amp_level = "O2" @@ -164,7 +164,7 @@ def main(args): for step, data in tqdm(enumerate(ds_iter)): x = data["video"] start_time = time.time() - + if args.encode_only: z = model.encode(x) else: @@ -203,7 +203,7 @@ def main(args): if args.eval_loss: recon_loss = np.abs((x - recons).asnumpy()) - + t = x.shape[2] x = rearrange_in(x) # lpips_loss = lpips_loss_fn(x, recons).asnumpy() @@ -284,7 +284,9 @@ def parse_args(): parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images") parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae") parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution") - parser.add_argument("--enable_tile", default=False, type=str2bool, help="enable temporal tiling with linear blending for decoder") + parser.add_argument( + "--enable_tile", default=False, type=str2bool, help="enable temporal tiling with linear blending for decoder" + ) parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") parser.add_argument( "--mixed_strategy", diff --git a/examples/moviegen/mg/models/tae/losses.py b/examples/moviegen/mg/models/tae/losses.py index 181afbd7c1..cd41cb8be4 100644 --- a/examples/moviegen/mg/models/tae/losses.py +++ b/examples/moviegen/mg/models/tae/losses.py @@ -52,7 +52,15 @@ def kl(self, mean, logvar): return kl_loss def vae_loss_fn( - self, x, recons, mean, logvar, nll_weights=None, no_perceptual=False, no_kl=False, pixelwise_mean=False, + self, + x, + recons, + mean, + logvar, + nll_weights=None, + no_perceptual=False, + no_kl=False, + pixelwise_mean=False, ): """ return: @@ -111,8 +119,6 @@ def construct(self, x: ms.Tensor, global_step: ms.Tensor = -1, weights: ms.Tenso posterior_logvar.to(ms.float32), ) - frames = x.shape[2] - # Loss compute # video frames x reconstruction loss # TODO: loss dtype setting diff --git a/examples/moviegen/mg/models/tae/modules.py b/examples/moviegen/mg/models/tae/modules.py index e09e037d27..cd2aaa0a0b 100644 --- a/examples/moviegen/mg/models/tae/modules.py +++ b/examples/moviegen/mg/models/tae/modules.py @@ -1,11 +1,10 @@ import logging -import functools import numpy as np from packaging import version import mindspore as ms -from mindspore import nn, ops, lazy_inline +from mindspore import nn, ops _logger = logging.getLogger(__name__) @@ -25,6 +24,7 @@ def cast_tuple(t, length=1): def nonlinearity(x): return x * (ops.sigmoid(x)) + class GroupNorm5d(nn.GroupNorm): def construct(self, x): # x (b c t h w) @@ -42,6 +42,7 @@ def construct(self, x): return out + def Normalize(in_channels, num_groups=32): if version.parse(ms.__version__) >= version.parse("2.3.1"): return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) @@ -52,29 +53,32 @@ def Normalize(in_channels, num_groups=32): def rearrange_in_spatial(x): # (b t c h w) -> (b*t c h w) B, T, C, H, W = x.shape - x = ops.reshape(x, (B*T, C, H, W)) + x = ops.reshape(x, (B * T, C, H, W)) return x + def rearrange_out_spatial(x, T): # (b*t c h w) -> (b t c h w) BT, C, H, W = x.shape - x = ops.reshape(x, (BT//T, T, C, H, W)) + x = ops.reshape(x, (BT // T, T, C, H, W)) return x + def rearrange_in_temporal(x): # (b t c h w) -> (b*h*w c t) B, C, T, H, W = x.shape # (b t c h w) -> (b h w c t) x = ops.transpose(x, (0, 3, 4, 2, 1)) # (b h w c t) -> (b*h*w c t) - x = ops.reshape(x, (B*H*W, C, T)) + x = ops.reshape(x, (B * H * W, C, T)) return x + def rearrange_out_temporal(x, H, W): # (b*h*w c t) -> (b t c h w) BHW, C, T = x.shape # (b*h*w c t) -> (b h w c t) - x = ops.reshape(x, (BHW // (H*W), H, W, C, T)) + x = ops.reshape(x, (BHW // (H * W), H, W, C, T)) # (b h w c t) -> (b t c h w) x = ops.transpose(x, (0, 4, 3, 1, 2)) return x @@ -84,7 +88,9 @@ class TemporalConv1d(nn.Cell): r""" Temporal conv1d with symmetrical replicate padding """ - def __init__(self, + + def __init__( + self, in_channels, out_channels, kernel_size, @@ -94,12 +100,14 @@ def __init__(self, # dilation=1, has_bias=True, **kwargs, - ): + ): # assert dilation ==1 - assert stride == 1, 'not supported for stride > 1' + assert stride == 1, "not supported for stride > 1" # TODO; consider stride - self.pad = nn.Pad(paddings=((2, (kernel_size-1)//2)), mode="SYMMETRIC") - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=True) + self.pad = nn.Pad(paddings=((2, (kernel_size - 1) // 2)), mode="SYMMETRIC") + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=True + ) def construct(self, x): r""" @@ -123,6 +131,7 @@ class Conv2_5d(nn.Cell): r""" Conv2.5d, a 2D spatial convolution followed by 1D temporal convolution """ + def __init__( self, in_channels, @@ -136,27 +145,51 @@ def __init__( **kwargs, ): super().__init__() - assert stride==1 - assert dilation==1 + assert stride == 1 + assert dilation == 1 # spatial conv - self.conv_spat = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias) + self.conv_spat = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + has_bias=has_bias, + ) # temp_pad_mode = 'zero' # temp_pad = 'mint_rep' - temp_pad = 'manual' + # temp_pad = "manual" # temporal conv if kernel_size > 1: # symmetric padding + conv1d - assert kernel_size == 3, 'symmetric padding currently only support kernel size 3' - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias, bias_init='zeros') + assert kernel_size == 3, "symmetric padding currently only support kernel size 3" + self.conv_temp = nn.Conv1d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode="valid", + has_bias=has_bias, + bias_init="zeros", + ) self.pad = self.symmetric_pad1d self.use_pad = True else: self.use_pad = False - self.conv_temp = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode="valid", has_bias=has_bias, bias_init='zeros') + self.conv_temp = nn.Conv1d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode="valid", + has_bias=has_bias, + bias_init="zeros", + ) - self.init_temporal_weight('median') + self.init_temporal_weight("median") @staticmethod def symmetric_pad1d(x): @@ -168,13 +201,12 @@ def symmetric_pad1d(x): return x def construct(self, x): - ''' + """ Parameters: x: (b c t h w) Returns: (b c t h w) - ''' - + """ B, Ci, T, Hi, Wi = x.shape # (b c t h w) -> (b t c h w) @@ -182,7 +214,7 @@ def construct(self, x): # spatial conv2d # (b t c h w) -> (b*t c h w) - x = ops.reshape(x, (B*T, Ci, Hi, Wi)) + x = ops.reshape(x, (B * T, Ci, Hi, Wi)) x = self.conv_spat(x) @@ -193,7 +225,7 @@ def construct(self, x): # temporal conv1d # (b t c h w) -> (b*h*w c t) x = ops.transpose(x, (0, 3, 4, 2, 1)) # (b t c h w) -> (b h w c t) - x = ops.reshape(x, (B*Ho*Wo, Co, T)) + x = ops.reshape(x, (B * Ho * Wo, Co, T)) if self.use_pad: # import pdb; pdb.set_trace() @@ -211,11 +243,11 @@ def construct(self, x): return x - def init_temporal_weight(self, method='median'): - if method == 'normal': + def init_temporal_weight(self, method="median"): + if method == "normal": return - elif method == 'median': + elif method == "median": # temporal conv kernel: (cout, cin, 1, ks) # ks=1 or 3, cin == cout w = self.conv_temp.weight @@ -225,7 +257,7 @@ def init_temporal_weight(self, method='median'): # only the middle element of the kernel is 1 so that the output is the same input in initialization for i in range(ch): - value[i, i, 0, ks//2] = 1 + value[i, i, 0, ks // 2] = 1 w.set_data(ms.Tensor(value, dtype=ms.float32)) # bias is initialized to zero in layer def @@ -243,15 +275,15 @@ def __init__(self, in_channels, with_conv): ) def construct(self, x): - ''' + """ x: (b c t h w) return: (b c t h w) - ''' + """ B, Ci, T, Hi, Wi = x.shape # (b c t h w) -> (b t c h w) x = ops.transpose(x, (0, 2, 1, 3, 4)) # (b t c h w) -> (b*t c h w) - x = ops.reshape(x, (B*T, Ci, Hi, Wi)) + x = ops.reshape(x, (B * T, Ci, Hi, Wi)) in_shape = x.shape[-2:] out_shape = tuple(2 * x for x in in_shape) @@ -284,7 +316,7 @@ def construct(self, x): # TODO: reduce transpose and reshape op B, C, T, H, W = x.shape x = ops.transpose(x, (0, 2, 1, 3, 4)) - x = ops.reshape(x, (B*T, C, H, W)) + x = ops.reshape(x, (B * T, C, H, W)) if self.with_conv: x = self.pad(x) @@ -306,41 +338,47 @@ def __init__(self, in_channels): self.ks = 3 self.ch = in_channels self.conv = nn.Conv1d( - in_channels, in_channels, kernel_size=self.ks, stride=2, pad_mode="valid", padding=0, has_bias=True, bias_init='zeros', + in_channels, + in_channels, + kernel_size=self.ks, + stride=2, + pad_mode="valid", + padding=0, + has_bias=True, + bias_init="zeros", ) # tail padding, pad with last frame self.time_pad = self.ks - 1 self.init_weight("median") - def init_weight(self, method='mean'): - if method == 'normal': + def init_weight(self, method="mean"): + if method == "normal": # default conv init return - + # no way to reserve complete input since stride 2 w = self.conv.weight value = np.zeros(tuple(w.shape)) - if method == 'mean': + if method == "mean": # initially, it's a mean filter for temporal downsampling for i in range(self.ch): - value[i, i, 0, :] = 1/self.ks # (cout, cin, 1, ks) - elif method == 'median': + value[i, i, 0, :] = 1 / self.ks # (cout, cin, 1, ks) + elif method == "median": # a median filter for temporal downsampling for i in range(self.ch): - value[i, i, 0, self.ks//2] = 1 # (cout, cin, 1, ks) + value[i, i, 0, self.ks // 2] = 1 # (cout, cin, 1, ks) else: raise NotImplementedError w.set_data(ms.Tensor(value, dtype=ms.float32)) - def construct(self, x): # x (b c t h w) # -> (bhw c t) B, C, T, H, W = x.shape x = ops.transpose(x, (0, 3, 4, 1, 2)) - x = ops.reshape(x, (B*H*W, C, T)) + x = ops.reshape(x, (B * H * W, C, T)) # tail padding last_frame = x[:, :, -1:] @@ -363,26 +401,29 @@ def __init__(self, in_channels, manual_pad=True): # to support danamic shape in graph mode self.manual_pad = manual_pad if not self.manual_pad: - self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init='zeros') + self.conv = nn.Conv1d( + in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init="zeros" + ) else: - self.conv = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=1, pad_mode="valid", has_bias=True, bias_init='zeros') + self.conv = nn.Conv1d( + in_channels, in_channels, kernel_size=3, stride=1, pad_mode="valid", has_bias=True, bias_init="zeros" + ) - # TODO: init conv weight so that it pass in image mode self.ch = in_channels - self.init_weight('median') + self.init_weight("median") - def init_weight(self, method='median'): - if method == 'normal': + def init_weight(self, method="median"): + if method == "normal": return # init so that the output is the same as vae2d for image input w = self.conv.weight value = np.zeros(tuple(w.shape)) - if method == 'median': + if method == "median": # consider image input, make sure it's the same for i in range(self.ch): - value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) + value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) w.set_data(ms.Tensor(value, dtype=ms.float32)) else: raise NotImplementedError @@ -390,8 +431,8 @@ def init_weight(self, method='median'): def construct(self, x): # x (b c t h w) B, C, T0, H, W = x.shape - x = ops.reshape(x, (B, C, T0, H*W)) - + x = ops.reshape(x, (B, C, T0, H * W)) + # NOTE: bf16 only support 4D interpolate # x = ops.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest") out_shape = (T0 * 2, H * W) @@ -400,12 +441,12 @@ def construct(self, x): # x (b c t hw) -> (bhw c t) T = T0 * 2 x = ops.transpose(x, (0, 3, 1, 2)) - x = ops.reshape(x, (B*H*W, C, T)) - + x = ops.reshape(x, (B * H * W, C, T)) + if self.manual_pad: # work with pad_mode = valid, kernel_size=1 - pad_t_l = ops.zeros((B*H*W, C, 1), x.dtype) - pad_t_r = ops.zeros((B*H*W, C, 1), x.dtype) + pad_t_l = ops.zeros((B * H * W, C, 1), x.dtype) + pad_t_r = ops.zeros((B * H * W, C, 1), x.dtype) x = ops.cat([pad_t_l, x, pad_t_r], 2) x = self.conv(x) @@ -416,7 +457,7 @@ def construct(self, x): return x - ''' + """ def construct(self, x): # x (b c t h w) x = ops.interpolate(x, scale_factor=(2.0, 1.0, 1.0), mode="nearest") @@ -433,7 +474,7 @@ def construct(self, x): x = ops.transpose(x, (0, 3, 4, 1, 2)) return x - ''' + """ # used in vae @@ -451,11 +492,17 @@ def __init__( self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels - assert conv_shortcut==False + assert not conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = Conv2_5d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True, + in_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, ) if temb_channels > 0: @@ -463,7 +510,13 @@ def __init__( self.norm2 = Normalize(out_channels) self.dropout = nn.Dropout(p=dropout) self.conv2 = Conv2_5d( - out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True, + out_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, ) if self.in_channels != self.out_channels: # TODO: @@ -502,7 +555,7 @@ def __init__(self, in_channels): self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True) self.hidden_dim = in_channels - self.scale = ms.Tensor(self.hidden_dim**(-0.5), dtype=ms.float32) + self.scale = ms.Tensor(self.hidden_dim ** (-0.5), dtype=ms.float32) def construct(self, x): # x (b c t h w) @@ -512,7 +565,7 @@ def construct(self, x): # rearrange to spatial sequence (b c t h w) -> (bt c h w) T = x.shape[2] h_ = ops.transpose(h_, (0, 2, 1, 3, 4)) - h_ = ops.reshape(h_, (h_.shape[0]*h_.shape[1], h_.shape[2], h_.shape[3], h_.shape[4] )) + h_ = ops.reshape(h_, (h_.shape[0] * h_.shape[1], h_.shape[2], h_.shape[3], h_.shape[4])) q = self.q(h_) k = self.k(h_) @@ -539,7 +592,7 @@ def construct(self, x): # rearrange back # -> (b t c h w) - h_ = ops.reshape(h_, (b//T, T, c, h, w)) + h_ = ops.reshape(h_, (b // T, T, c, h, w)) h_ = ops.transpose(h_, (0, 2, 1, 3, 4)) return x + h_ @@ -558,7 +611,7 @@ def __init__(self, in_channels): self.v = nn.Dense(in_channels, in_channels, has_bias=True) self.proj_out = nn.Dense(in_channels, in_channels, has_bias=True) - self.scale = ms.Tensor(in_channels**(-0.5), dtype=ms.float32) # hidden_dim = in_channels + self.scale = ms.Tensor(in_channels ** (-0.5), dtype=ms.float32) # hidden_dim = in_channels def construct(self, x): # x (b c t h w) @@ -568,7 +621,7 @@ def construct(self, x): # rearrange h_ to (b*t h*w c) B, C, T, H, W = h_.shape h_ = ops.transpose(h_, (0, 2, 3, 4, 1)) - h_ = ops.reshape(h_, (B*T, H*W, C)) + h_ = ops.reshape(h_, (B * T, H * W, C)) q = self.q(h_) k = self.k(h_) @@ -583,7 +636,7 @@ def construct(self, x): attn = ops.softmax(m, axis=-1).astype(v.dtype) # (bt nq nk) # attend to values (nk = nv) - h_ = self.bmm(attn, v) # (bt nq c) = (bt hw c) + h_ = self.bmm(attn, v) # (bt nq c) = (bt hw c) h_ = self.proj_out(h_) # rearrange back to input shape @@ -601,12 +654,12 @@ def __init__(self, in_channels, has_bias=True): # TODO: instead of GroupNorm, LayerNorm is better for tiling self.norm = Normalize(in_channels) # TODO: compare conv1d with Dense on performance - self.to_q = nn.Dense(in_channels, in_channels, has_bias=has_bias) - self.to_k = nn.Dense(in_channels, in_channels, has_bias=has_bias) - self.to_v = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.to_q = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.to_k = nn.Dense(in_channels, in_channels, has_bias=has_bias) + self.to_v = nn.Dense(in_channels, in_channels, has_bias=has_bias) self.proj_out = nn.Dense(in_channels, in_channels, has_bias=has_bias) - self.scale = ms.Tensor(in_channels**(-0.5), dtype=ms.float32) # hidden_dim = in_channels + self.scale = ms.Tensor(in_channels ** (-0.5), dtype=ms.float32) # hidden_dim = in_channels def construct(self, x): # x (b c t h w) @@ -617,7 +670,7 @@ def construct(self, x): # (b c t h w) -> (b*h*w t c) = (B S H) B, C, T, H, W = h_.shape h_ = ops.transpose(h_, (0, 3, 4, 2, 1)) - h_ = ops.reshape(h_, (B*H*W, T, C)) + h_ = ops.reshape(h_, (B * H * W, T, C)) # projection q = self.to_q(h_) # (bhw t c) @@ -650,17 +703,18 @@ def make_attn(in_channels, attn_type="vanilla"): print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return nn.SequentialCell( - SpatialAttnBlock(in_channels), - TemporalAttnBlock(in_channels), - ) - elif attn_type == 'spat_only': + SpatialAttnBlock(in_channels), + TemporalAttnBlock(in_channels), + ) + elif attn_type == "spat_only": # to ensure naming consistency return nn.SequentialCell( - SpatialAttnBlock(in_channels), - ) + SpatialAttnBlock(in_channels), + ) else: raise NotImplementedError + # used in vae class Encoder(nn.Cell): # @lazy_inline() @@ -693,7 +747,13 @@ def __init__( # downsampling self.conv_in = Conv2_5d( - in_channels, ch, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=True, + in_channels, + ch, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, ) curr_res = resolution @@ -760,12 +820,12 @@ def __init__( ) def construct(self, x): - ''' + """ Args: x: (b c t h w) Returns: (b c t h w) - ''' + """ # spatial and temporal conv hs = self.conv_in(x) @@ -812,7 +872,7 @@ def __init__( tanh_out=False, use_linear_attn=False, attn_type="vanilla", - temporal_upsample_level=(1,2,3), # same as spatial + temporal_upsample_level=(1, 2, 3), # same as spatial **ignorekwargs, ): super().__init__() @@ -834,19 +894,13 @@ def __init__( _logger.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = Conv2_5d( - z_channels, block_in, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True - ) + self.conv_in = Conv2_5d(z_channels, block_in, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) # middle self.mid = nn.Cell() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, dropout=dropout - ) + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, dropout=dropout - ) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) # upsampling self.up = nn.CellList(auto_prefix=False) @@ -890,12 +944,12 @@ def __init__( self.conv_out = Conv2_5d(block_in, out_ch, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) def construct(self, z): - ''' + """ Args: x: (b c t h w) Returns: (b c t h w) - ''' + """ # z to block_in h = self.conv_in(z) diff --git a/examples/moviegen/mg/models/tae/sd3_vae.py b/examples/moviegen/mg/models/tae/sd3_vae.py index 37e337e5c5..0b1b66e36f 100644 --- a/examples/moviegen/mg/models/tae/sd3_vae.py +++ b/examples/moviegen/mg/models/tae/sd3_vae.py @@ -1,6 +1,7 @@ import mindspore as ms from mindspore import nn, ops -from .modules_2d import Encoder, Decoder + +from .modules_2d import Decoder, Encoder # TODO: set z_channels to 16 SD3d5_CONFIG = { @@ -17,7 +18,7 @@ "scaling_factor": 1.5305, "shift_factor": 0.0609, "use_post_quant_conv": False, - "use_quant_conv": False + "use_quant_conv": False, } @@ -34,8 +35,8 @@ def __init__( self, config: dict = SD3d5_CONFIG, pretrained: str = None, - use_recompute: bool=False, - sample_deterministic: bool=False, + use_recompute: bool = False, + sample_deterministic: bool = False, ): super().__init__() @@ -43,14 +44,14 @@ def __init__( self.encoder = Encoder(**config) # quant and post quant - embed_dim = config['z_channels'] - if config['use_quant_conv']: + embed_dim = config["z_channels"] + if config["use_quant_conv"]: self.quant_conv = nn.Conv2d(2 * embed_dim, 2 * embed_dim, 1, pad_mode="valid", has_bias=True) - if config['use_post_quant_conv']: + if config["use_post_quant_conv"]: self.post_quant_conv = nn.Conv2d(embed_dim, embed_dim, 1, pad_mode="valid", has_bias=True) - self.use_quant_conv = config['use_quant_conv'] - self.use_post_quant_conv = config['use_post_quant_conv'] + self.use_quant_conv = config["use_quant_conv"] + self.use_post_quant_conv = config["use_post_quant_conv"] # decoder self.decoder = Decoder(**config) @@ -67,7 +68,6 @@ def __init__( # self.recompute(self.post_quant_conv) self.recompute(self.decoder) - def recompute(self, b): if not b._has_config_recompute: b.recompute() @@ -76,7 +76,6 @@ def recompute(self, b): else: b.add_flags(output_no_recompute=True) - def _encode(self, x): # return latent distribution, N(mean, logvar) h = self.encoder(x) @@ -124,10 +123,11 @@ def construct(self, x: ms.Tensor) -> ms.Tensor: return recons, z, posterior_mean, posterior_logvar - def load_pretrained(self, ckpt_path:str): - if ckpt_path.endswith('safetensors'): + def load_pretrained(self, ckpt_path: str): + if ckpt_path.endswith("safetensors"): # load vae parameters from safetensors into my mindspore model import safetensors + ckpt = safetensors.safe_open(ckpt_path, framework="pt") state_dict = {} for key in ckpt.keys(): @@ -138,6 +138,4 @@ def load_pretrained(self, ckpt_path:str): param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) if param_not_load or ckpt_not_load: print(f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!") - print('vae checkpoint loaded') - - + print("vae checkpoint loaded") diff --git a/examples/moviegen/mg/utils/parser.py b/examples/moviegen/mg/utils/parser.py index 96d431f8eb..25874cd2bb 100644 --- a/examples/moviegen/mg/utils/parser.py +++ b/examples/moviegen/mg/utils/parser.py @@ -1,5 +1,4 @@ import argparse -import logging def remove_pname_prefix(param_dict, prefix="network."): diff --git a/examples/moviegen/scripts/run_train_tae.sh b/examples/moviegen/scripts/run_train_tae.sh index babd24a99e..d28300d363 100644 --- a/examples/moviegen/scripts/run_train_tae.sh +++ b/examples/moviegen/scripts/run_train_tae.sh @@ -24,4 +24,3 @@ python train_tae.py \ --epochs=2000 --ckpt_save_interval=50 \ # --use_parallel=True \ - diff --git a/examples/moviegen/tests/ut/test_tae.py b/examples/moviegen/tests/ut/test_tae.py index 45e24c4ec7..ae5c3798e7 100644 --- a/examples/moviegen/tests/ut/test_tae.py +++ b/examples/moviegen/tests/ut/test_tae.py @@ -1,7 +1,9 @@ -import numpy as np import sys + +import numpy as np from PIL import Image -sys.path.insert(0, '..') + +sys.path.insert(0, "..") from mg.models.tae.modules import ( Conv2_5d, @@ -15,19 +17,17 @@ TemporalDownsample, TemporalUpsample, ) -from mg.models.tae.tae import SDXL_CONFIG, TAE_CONFIG, TemporalAutoencoder from mg.models.tae.sd3_vae import SD3d5_VAE +from mg.models.tae.tae import SDXL_CONFIG, TAE_CONFIG, TemporalAutoencoder import mindspore as ms -def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", - W=128, - H=128): +def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", W=128, H=128): target_size = (H, W) # read image using PIL and preprocess - image = Image.open(img_path).convert('RGB') + image = Image.open(img_path).convert("RGB") image = image.resize(target_size) pixel_values = np.array(image, dtype=np.float32) pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) @@ -36,7 +36,8 @@ def get_input_image(img_path="../videocomposer/demo_video/moon_on_water.jpg", return pixel_values -def save_output_image(image_array, output_path='tests/tmp_output.png'): + +def save_output_image(image_array, output_path="tests/tmp_output.png"): image_array = image_array.transpose((1, 2, 0)) image_array = ((image_array + 1) * 127.5).astype(np.uint8) image_array = np.clip(image_array, 0, 255) @@ -44,7 +45,7 @@ def save_output_image(image_array, output_path='tests/tmp_output.png'): image = Image.fromarray(image_array) image.save(output_path) - print(f'image saved in {output_path}') + print(f"image saved in {output_path}") def test_conv25d(): @@ -80,14 +81,12 @@ def test_resnetblock(): def test_spatial_attn(): in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) - cout = C x = np.random.normal(size=in_shape).astype(np.float32) # TODO: compare time cost for v1 and v2 # sa = SpatialAttnBlock(C) sa = SpatialAttnBlockV2(C) - x = ms.Tensor(x) y = sa(x) @@ -97,13 +96,11 @@ def test_spatial_attn(): def test_temporal_attn(): in_shape = (B, C, T, H, W) = (1, 64, 4, 32, 32) - cout = C x = np.random.normal(size=in_shape).astype(np.float32) # TODO: compare time cost for v1 and v2 ta = TemporalAttnBlock(C) - x = ms.Tensor(x) y = ta(x) @@ -197,6 +194,7 @@ def test_tae_encode(): print(y.shape) + def test_tae_decode(): # in_shape = (B, C, T, H, W) = (1, 3, 1, 64, 64) in_shape = (B, C, T, H, W) = (1, 4, 1, 8, 8) @@ -210,7 +208,7 @@ def test_tae_decode(): def test_tae_rec(): - TAE_CONFIG['attn_type'] = 'spat_only' + TAE_CONFIG["attn_type"] = "spat_only" tae = TemporalAutoencoder(config=TAE_CONFIG) tae.load_pretrained("models/tae_vae2d.ckpt") @@ -224,7 +222,8 @@ def test_tae_rec(): y = tae(x) print(y[0].shape) - save_output_image(y[0].numpy()[0, :, 0, :, :], 'tests/tmp_tae_output.png') + save_output_image(y[0].numpy()[0, :, 0, :, :], "tests/tmp_tae_output.png") + def test_sd3d5_vae(): vae = SD3d5_VAE(sample_deterministic=True) @@ -247,10 +246,10 @@ def test_sd3d5_vae(): print(recons.sum()) + def test_blend(): ms.set_context(mode=1) - tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, - encode_tile=32, decode_tile=32, decode_overlap=16) + tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, encode_tile=32, decode_tile=32, decode_overlap=16) in_shape = (B, C, T, H, W) = (1, 1, 12, 1, 1) x = np.random.normal(size=in_shape).astype(np.float32) @@ -262,8 +261,7 @@ def test_blend(): def test_tae_tile(): - tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, - encode_tile=32, decode_tile=32, decode_overlap=16) + tae = TemporalAutoencoder(config=TAE_CONFIG, use_tile=True, encode_tile=32, decode_tile=32, decode_overlap=16) # in_shape = (B, C, T, H, W) = (1, 3, 16, 64, 64) in_shape = (B, C, T, H, W) = (1, 3, 96, 32, 32) @@ -277,8 +275,6 @@ def test_tae_tile(): print(y[0].shape) # check correctness of blend - - if __name__ == "__main__": diff --git a/examples/moviegen/tools/inflate_vae_to_tae.py b/examples/moviegen/tools/inflate_vae_to_tae.py index 6893fdc633..8542a17110 100644 --- a/examples/moviegen/tools/inflate_vae_to_tae.py +++ b/examples/moviegen/tools/inflate_vae_to_tae.py @@ -1,6 +1,8 @@ -from safetensors import safe_open import argparse + import numpy as np +from safetensors import safe_open + import mindspore as ms @@ -10,24 +12,28 @@ def get_shape_from_str(shape): return shape + def get_pname_shape(ckpt_path): - with safe_open(ckpt_path, framework="pt", device='cpu') as fp: + with safe_open(ckpt_path, framework="pt", device="cpu") as fp: for key in fp.keys(): val = fp.get_tensor(key) shape = tuple(val.shape) dtype = val.dtype print(f"{key}#{shape}#{dtype}") + def load_torch_ckpt(ckpt_path): pt_state_dict = {} - with safe_open(ckpt_path, framework="pt", device='cpu') as fp: + with safe_open(ckpt_path, framework="pt", device="cpu") as fp: for key in fp.keys(): pt_state_dict[key] = fp.get_tensor(key) # print(key) return pt_state_dict + def plot_ms_vae2d5(): from mg.models.tae.tae import SD3d5_CONFIG, TemporalAutoencoder + tae = TemporalAutoencoder(config=SD3d5_CONFIG) sd = tae.parameters_dict() @@ -37,10 +43,10 @@ def plot_ms_vae2d5(): print(f"{pname}#{shape}") -def convert_vae2d(source_fp, target_fp, target_model='vae2d'): +def convert_vae2d(source_fp, target_fp, target_model="vae2d"): # read param mapping files - ms_pnames_file = "tools/ms_pnames_sd3.5_vae.txt" if target_model == 'vae2d' else "tools/ms_pnames_tae_vae.txt" - print('target ms pnames is annotated in ', ms_pnames_file) + ms_pnames_file = "tools/ms_pnames_sd3.5_vae.txt" if target_model == "vae2d" else "tools/ms_pnames_tae_vae.txt" + print("target ms pnames is annotated in ", ms_pnames_file) with open(ms_pnames_file) as file_ms: lines_ms = list(file_ms.readlines()) with open("tools/pt_pnames_sd3.5_vae.txt") as file_pt: @@ -90,7 +96,7 @@ def convert_vae2d(source_fp, target_fp, target_model='vae2d'): "--target", "-t", type=str, - default='models/tae_vae2d.ckpt', + default="models/tae_vae2d.ckpt", help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/vae.ckpt", ) args = parser.parse_args() @@ -100,5 +106,4 @@ def convert_vae2d(source_fp, target_fp, target_model='vae2d'): # plot_ms_vae2d5() # convert_vae2d(ckpt_path, "models/sd3.5_vae.ckpt") - convert_vae2d(args.src, args.target, target_model='tae') - + convert_vae2d(args.src, args.target, target_model="tae") diff --git a/examples/moviegen/train_tae.py b/examples/moviegen/train_tae.py index 0496157d0f..fb97c6a1ba 100644 --- a/examples/moviegen/train_tae.py +++ b/examples/moviegen/train_tae.py @@ -21,8 +21,8 @@ from args_train_tae import parse_args from mg.dataset.tae_dataset import create_dataloader from mg.models.tae.losses import GeneratorWithLoss +from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample from mg.models.tae.tae import TemporalAutoencoder -from mg.models.tae.modules import SpatialUpsample, SpatialDownsample, TemporalUpsample, TemporalDownsample from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network @@ -134,9 +134,9 @@ def init_env( # only effective in GE mode, i.e. jit_level: O2 ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) - if dynamic_shape: - print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") - set_dynamic_mode(True) + # if dynamic_shape: + # print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") + # set_dynamic_mode(True) return rank_id, device_num @@ -195,7 +195,7 @@ def main(args): ae = TemporalAutoencoder( pretrained=args.pretrained_model_path, use_recompute=args.use_recompute, - ) + ) if args.use_discriminator: logging.error("Discriminator is not used or supported in OpenSora v1.2") @@ -209,8 +209,9 @@ def main(args): ae, args.amp_level, dtype, - custom_fp32_cells= [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] if args.vae_keep_updown_fp32 else [] + \ - ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []), + custom_fp32_cells=[SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] + if args.vae_keep_updown_fp32 + else [] + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []), # custom_fp32_cells=[nn.GroupNorm, SpatialUpsample] if args.vae_keep_gn_fp32 else [SpatialUpsample], ) diff --git a/examples/opensora_hpcai/tools/mem_monitor/plot.py b/examples/opensora_hpcai/tools/mem_monitor/plot.py index bb5d4588a6..6b2e33ae07 100644 --- a/examples/opensora_hpcai/tools/mem_monitor/plot.py +++ b/examples/opensora_hpcai/tools/mem_monitor/plot.py @@ -1,5 +1,3 @@ -import sys - import matplotlib.pyplot as plt import pandas as pd From ccf943a209fd6897cbbed0991008719de8aeee46 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:05:25 +0800 Subject: [PATCH 076/122] add docs --- examples/movie_gen/README.md | 101 ------------- examples/moviegen/README.md | 280 +++++++++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 101 deletions(-) delete mode 100644 examples/movie_gen/README.md create mode 100644 examples/moviegen/README.md diff --git a/examples/movie_gen/README.md b/examples/movie_gen/README.md deleted file mode 100644 index bb5ae59175..0000000000 --- a/examples/movie_gen/README.md +++ /dev/null @@ -1,101 +0,0 @@ -# Movie Gen Video - - -## TAE - -### Requirements - -ms2.3.1 - -### Prepare weights - -We use SD3.5 VAE to initialize the spatial layers of TAE, since both have a latent channel of 16. - -1. Download SD3.5 VAE from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae - -2. Convert VAE checkpoint for TAE loading - -```shell -python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt -``` - - -### Training - -```shell -output_dir=outputs/train_tae_256x256x16 - -python scripts/train_tae.py \ ---config configs/tae/train/mixed_256x256x16.yaml \ ---output_path=$output_dir \ ---csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_train.csv \ ---video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ - -``` - -OPL - outlier penality loss is found to be not beneficial in our experiment (PSNR decreased). Thus we set it to False by default. - -Change mixed_256x256x16.yaml to mixed_256x256x32.yaml for training on 32 frames. - - -#### Performance - -Train on 80 samples of mixkit-100 (train set), test on the other 20 samples (test set) - -256x256x16, 1p, FP32, 1.99 s/step, test set psnr 28.5 - -256x256x32, 1p, BF16, 2.49 s/step, test set psnr 28.3 - - -### Inference - - -#### Video Reconstruction - -```shell -python scripts/inference_vae.py \ ---ckpt_path /path/to/tae.ckpt \ ---batch_size 2 \ ---num_frames=16 \ ---image_size 256 \ ---csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_test.csv \ ---video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ ---enable_tile=False \ -``` - -#### Encoding video - -```python -from mg.models.tae.tae import TemporalAutoencoder, TAE_CONFIG - -# may set use_tile=True to save memory -tae = TemporalAutoencoder( - pretrained='/path/to/tae.ckpt', - use_tile=False, - ) - -# x - a batch of videos, shape (b c t h w) -z, _, _ = tae.encode(x) - - -# you may scale z by: -# z = TAE_CONFIG['scaling_factor'] * z + TAE_CONFIG['shift_factor'] - - -``` - -For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) - -#### Decoding video latent - -```python - -# if z is scaled, you should unscale at first: -# z = (z - TAE_CONFIG['shift_factor']) / TAE_CONFIG['scaling_factor'] - -# z - a batch of video latent, shape (b c t h w) -x = tae.decode(z) - -# for image decoding, set num_target_frames to discard the spurious frames -x = tae.decode(z, num_target_frames=1) -``` diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md new file mode 100644 index 0000000000..fb42ffbeec --- /dev/null +++ b/examples/moviegen/README.md @@ -0,0 +1,280 @@ +# Movie Gen + +This repository implements the [Movie Gen](https://arxiv.org/abs/2410.13720) model presented by Meta. + +Movie Gen is a family of foundation models that can natively generate high-fidelity images and videos +while also possessing the abilities to edit and personalize the videos. + +Meta researchers found that scaling the training data, compute, and model parameters of a simple +Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained with +[Flow Matching](https://arxiv.org/abs/2210.02747) yields high quality generative models for video or audio. + +## Features: + +1. :white_check_mark: Text-to-Video synthesis +2. \[Coming soon] Video personalization +3. \[Coming soon] Video editing + +### TODO + +- [ ] Fix EMA. +- [ ] Use ByT5 for encoding visual text only (i.e., text within quotes). +- [ ] CFG inference. +- [ ] Multi-aspect and variable length video training (including PE interpolation). +- [ ] Fix Model Parallel training. +- [ ] Add FPS conditioning. + +# Demo + +Coming soon. + +# Architecture + +
+Architecture details + +## Transformer Backbone + +The Movie Gen family of models contains the following variations: 1B, 5B, and 30B parameters. +It uses the [LLaMa3](https://arxiv.org/abs/2407.21783) backbone architecture for the joint image-video generation model, +enabling confident scaling of the model size while maintaining efficient training. + +There are three changes to the LLaMa3 Transformer block for the use case of video generation using Flow Matching: + +1. Add a cross-attention module between the self-attention module and the feed forward network (FFN) + to each Transformer block to incorporate text conditioning based on the text prompt embedding **P**. + Multiple different text encoders are leveraged due to their complementary strengths + (see [Text Encoders](#text-encoders)). +2. Add adaptive layer norm blocks to incorporate the time-step t to the Transformer, as used in prior work + ([DiT](https://arxiv.org/abs/2212.09748)). +3. Use full bidirectional attention instead of causal attention used in language modeling. + +## TAE + +[//]: # (TODO) + +## Text Encoders + +Movie Gen uses a combination of [UL2](https://arxiv.org/abs/2205.05131), [ByT5](https://arxiv.org/abs/2105.13626), and +Long-prompt [MetaCLIP](https://arxiv.org/abs/2309.16671) as text encoders to provide both semantic-level and +character-level text understanding for the backbone: + +- **UL2** is trained using massive text-only data and potentially provides strong text reasoning abilities in its + features. +- **Long-prompt MetaCLIP** provides text representations that are aligned with visual representations that are + beneficial + for cross-modal generation. +- **ByT5** encoder is only used to encode visual text, i.e., the part of the text prompt that explicitly asks for a + character string to be generated in the output image / video. + +
+ +# Installation + +1. Install MindSpore according to the [official instructions](https://www.mindspore.cn/install). + For Ascend devices, please install + [CANN8.0.RC2.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) + and [MindSpore 2.3.1](https://www.mindspore.cn/install). +2. Install requirements + ```shell + pip install -r requirements.txt + ``` + +# Model Weights + +
+TAE + +We use SD3.5 VAE to initialize the spatial layers of TAE since both have a latent channel of 16. + +1. Download SD3.5 VAE from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae + +2. Convert VAE checkpoint for TAE loading + ```shell + python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt + ``` + +
+ +
+Text Encoders + +Downloading and conversion of the text encoders' weights to the `.safetensors` format can be done automatically by using +the following commands: + +```shell +python tools/download_convert_st.py "google/byt5-small" +python tools/download_convert_st.py "google/ul2" +``` + +If you face an SSL certificate verification error, you can add `--disable_ssl_verify` option. + +
+ +# Generating Text Embeddings + +Due to the large memory footprint of the text encoders, the inference and training pipelines do not support generating +text embeddings online. Therefore, you need to prepare them in advance by running the following command: + +```shell +python inference_text_enc.py \ +--model_name google/ul2 \ +--prompts_file /path/to/prompts.csv \ +--output_path /path/to/output/directory \ +--model_max_length 512 +``` + +> [!TIP] +> We use the sequence length of 512 tokens for UL2, 256 for MetaCLIP, and 100 for ByT5. + +# Inference + +## Text-to-Video + +```shell +python inference.py \ +--config configs/inference/moviegen_t2i_256x256.yaml \ +--model.name llama-5B +--model.pretrained_model_path /path/to/llama-5B.ckpt \ +--text_emb.ul2_dir /path/to/ul2_embeddings \ +--text_emb.metaclip_dir /path/to/metaclip_embeddings \ +--text_emb.byt5_dir /path/to/byt5_embeddings \ +--image_size 256 455 +``` + +## Text-to-Image + +```shell +python inference.py \ +--config configs/inference/moviegen_t2i_256x256.yaml \ +--model.name llama-5B \ +--model.pretrained_model_path /path/to/llama-5B.ckpt \ +--text_emb.ul2_dir /path/to/ul2_embeddings \ +--text_emb.metaclip_dir /path/to/metaclip_embeddings \ +--text_emb.byt5_dir /path/to/byt5_embeddings \ +--image_size 256 455 \ +--num_frames 32 \ +--batch_size 2 \ +--save_format mp4 +``` + +## TAE + +#### Video Reconstruction + +```shell +python inference_vae.py \ +--ckpt_path /path/to/tae.ckpt \ +--batch_size 2 \ +--num_frames 16 \ +--image_size 256 \ +--csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_test.csv \ +--video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ +--enable_tile False \ +``` + +#### Encoding video + +```python +from mg.models.tae import TemporalAutoencoder + +# may set use_tile=True to save memory +tae = TemporalAutoencoder( + pretrained='/path/to/tae.ckpt', + use_tile=False, +) + +# x - a batch of videos, shape (b c t h w) +z, _, _ = tae.encode(x) + +# you may scale z by: +z = (z - tae.shift_factor) * tae.scale_factor +``` + +For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) + +#### Decoding video latent + +```python +# if z is scaled, you should unscale at first: +z = z / tae.scale_factor + tae.shift_factor + +# z - a batch of video latent, shape (b c t h w) +x = tae.decode(z) + +# for image decoding, set num_target_frames to discard the spurious frames +x = tae.decode(z, num_target_frames=1) +``` + +# Training + +Movie Gen is trained jointly on images and videos in 4 stages: + +1. Training on images at 256 px resolution. +2. Joint training on images and videos at 256 px resolution. +3. Joint training at 768 px resolution. +4. Fine-tune the model on high quality videos. + +Images are treated as single frame videos, enabling the use of the same model to generate both images and videos. +Compared to video data, paired image-text datasets are easier to scale with diverse concepts and styles, +and thus joint modeling of image and video leads to better generalization. + +## Movie Gen + +To train Movie Gen, run the following command: + +```shell +scripts/stage1_train.sh # for stage 1 training +scripts/stage2_train.sh # for stage 2 training +``` + +### Performance + +| Model | Context | Jit level | Stage | Precision | Resolution | Batch size | NPUs | Time (s/step) | Config | +|-------|-------------------|-----------|---------|-----------|----------------|------------|------|---------------|--------------------------------------------------------------------| +| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](./configs/train/stage1_t2i_256x256.yaml) | + +### Validation During Training + +Validation can be enabled by either setting parameters in the `valid` field of the configuration file +([example](configs/train/stage1_t2i_256x256.yaml)) or by supplying the following arguments to `train.py`: + +```shell +--valid.sampling_steps 10 \ +--valid.frequency 100 \ +--valid.dataset.csv_path /path/to/valid_dataset.csv \ +--valid.dataset.video_folder /path/to/videos \ +--valid.dataset.text_emb_folder.ul2 /path/to/ul2_embeddings \ +--valid.dataset.text_emb_folder.metaclip /path/to/metaclip_embeddings \ +--valid.dataset.text_emb_folder.byt5 /path/to/byt5_embeddings +``` + +## TAE + +```shell +output_dir=outputs/train_tae_256x256x16 + +python train_tae.py \ +--config configs/tae/train/mixed_256x256x16.yaml \ +--output_path $output_dir \ +--csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_train.csv \ +--video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ +``` + +OPL - outlier penalty loss is found to be not beneficial in our experiment (PSNR decreased). +Thus, we set it to False by default. + +Change mixed_256x256x16.yaml to mixed_256x256x32.yaml for training on 32 frames. + +### Performance + +Train on 80 samples of mixkit-100 (train set), test on the other 20 samples (test set) + +| Resolution | NPUs | Precision | Time (s/step) | PSNR (test set) | +|------------|------|-----------|---------------|-----------------| +| 256x256x16 | 1 | FP32 | 1.99 | 28.5 | +| 256x256x32 | 1 | BF16 | 2.49 | 28.3 | + +# Evaluation + +Coming soon. From 3f7d207ea12d0b42f77c6d34a1119c7b5cb001ad Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:16:30 +0800 Subject: [PATCH 077/122] refactor TAE add latents generation other small changes --- examples/moviegen/README.md | 12 +- examples/moviegen/args_train_tae.py | 194 +++----------- .../configs/tae/train/mixed_256x256x16.yaml | 11 +- .../configs/tae/train/mixed_256x256x32.yaml | 11 +- .../{inference_vae.py => eval_tae.py} | 186 ++++---------- examples/moviegen/inference.py | 12 +- examples/moviegen/inference_tae_enc.py | 123 +++++++++ examples/moviegen/mg/dataset/dataset.py | 4 +- examples/moviegen/mg/dataset/tae_dataset.py | 239 ++++++++---------- .../moviegen/mg/pipelines/infer_pipeline.py | 2 +- examples/moviegen/scripts/stage2_train.sh | 10 +- examples/moviegen/train.py | 23 +- examples/moviegen/train_tae.py | 152 +++-------- examples/opensora_hpcai/scripts/train.py | 2 +- examples/svd/train.py | 2 +- examples/t2i_adapter/train_t2i_adapter_sd.py | 2 +- mindone/data/loader.py | 21 +- mindone/trainers/callback.py | 23 +- 18 files changed, 425 insertions(+), 604 deletions(-) rename examples/moviegen/{inference_vae.py => eval_tae.py} (55%) create mode 100644 examples/moviegen/inference_tae_enc.py diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index fb42ffbeec..284a564f82 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -163,14 +163,14 @@ python inference.py \ #### Video Reconstruction ```shell -python inference_vae.py \ ---ckpt_path /path/to/tae.ckpt \ +python eval_tae.py \ +--pretrained /path/to/tae.ckpt \ --batch_size 2 \ ---num_frames 16 \ ---image_size 256 \ +--sample_n_frames 16 \ +--size 256 \ --csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_test.csv \ ---video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ ---enable_tile False \ +--folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ +--use_tile False ``` #### Encoding video diff --git a/examples/moviegen/args_train_tae.py b/examples/moviegen/args_train_tae.py index 8b1634613d..2a82b9e28b 100644 --- a/examples/moviegen/args_train_tae.py +++ b/examples/moviegen/args_train_tae.py @@ -1,89 +1,57 @@ -import argparse import logging import os import sys -import yaml +from jsonargparse import ActionConfigFile, ArgumentParser +# TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) -mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) -sys.path.insert(0, mindone_lib_path) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.append(mindone_lib_path) -from mg.utils.parser import _check_cfgs_in_parser, str2bool +from mg.dataset.tae_dataset import BatchTransform, VideoDataset +from mg.models.tae import TemporalAutoencoder +from mindone.data import create_dataloader +from mindone.utils import init_train_env from mindone.utils.misc import to_abspath logger = logging.getLogger() -def parse_train_args(parser): +def parse_train_args(): + parser = ArgumentParser(description="Temporal Autoencoder training script.") parser.add_argument( - "--config", "-c", - default="", - type=str, - help="path to load a config yaml file that describes the training recipes which will override the default arguments", + action=ActionConfigFile, + help="Path to load a config yaml file that describes the setting which will override the default arguments.", ) - # the following args's defualt value will be overrided if specified in config yaml - - # data - parser.add_argument("--dataset_name", default="", type=str, help="dataset name") - parser.add_argument( - "--csv_path", - default="", - type=str, - help="path to csv annotation file. columns: video, caption. \ - video indicates the relative path of video file in video_folder. caption - the text caption for video", + parser.add_function_arguments( + init_train_env, skip={"ascend_config", "num_workers", "json_data_path", "enable_modelarts"} ) - parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") - parser.add_argument("--random_crop", default=False, type=str2bool, help="randonly crop the image") - parser.add_argument("--flip", default=False, type=str2bool, help="flip the image") - + parser.add_class_arguments(TemporalAutoencoder, instantiate=False) parser.add_argument( - "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file" + "--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." + ) + parser.add_class_arguments(VideoDataset, skip={"output_columns"}, instantiate=False) + parser.add_class_arguments(BatchTransform, instantiate=False) + parser.add_function_arguments( + create_dataloader, + skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, ) - parser.add_argument("--video_folder", default="", type=str, help="root dir for the video data") parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") parser.add_argument( "--add_datetime", default=True, type=str, help="If True, add datetime subfolder under output_path" ) # model - parser.add_argument("--model_type", default="OpenSora-VAE-v1.2", type=str, help="VAE model type") - parser.add_argument("--freeze_vae_2d", default=True, type=str2bool, help="Freeze 2d vae") - parser.add_argument( - "--use_discriminator", default=False, type=str2bool, help="Use discriminator for adversarial training." - ) - parser.add_argument( - "--pretrained_model_path", - default="", - type=str, - help="Specify the pretrained model path", - ) parser.add_argument("--perceptual_loss_weight", default=0.1, type=float, help="perceptual (lpips) loss weight") parser.add_argument("--kl_loss_weight", default=1.0e-6, type=float, help="KL loss weight") parser.add_argument( "--use_outlier_penalty_loss", default=False, - type=str2bool, + type=bool, help="use outlier penalty loss", ) - # data - parser.add_argument("--mixed_strategy", type=str, default=None, help="video and image mixed strategy") - parser.add_argument( - "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" - ) - - # ms - parser.add_argument("--debug", type=str2bool, default=False, help="Execute inference in debug mode.") - parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") - parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") - parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") - parser.add_argument( - "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" - ) - parser.add_argument("--jit_level", default="O0", type=str, help="O0 kbk, O1 dvm, O2 ge") - # training hyper-params parser.add_argument( "--resume", @@ -110,32 +78,18 @@ def parse_train_args(parser): If None, filter list is [layernorm, bias], Default: None", ) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay.") - parser.add_argument("--seed", default=3407, type=int, help="data path") parser.add_argument("--warmup_steps", default=1000, type=int, help="warmup steps") - parser.add_argument("--batch_size", default=10, type=int, help="batch size") - parser.add_argument( - "--micro_batch_size", - type=int, - default=4, - help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation", - ) - parser.add_argument( - "--micro_frame_size", - type=int, - default=17, - help="If not None, split batch_size*num_frames into smaller ones for VAE encoding to reduce memory limitation. Used by temporal vae", - ) parser.add_argument("--start_learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--end_learning_rate", default=1e-7, type=float, help="The end learning rate for Adam.") parser.add_argument( - "--scale_lr", default=False, type=str2bool, help="scale base-lr by ngpu * batch_size * n_accumulate" + "--scale_lr", default=False, type=bool, help="scale base-lr by ngpu * batch_size * n_accumulate" ) parser.add_argument("--decay_steps", default=0, type=int, help="lr decay steps.") parser.add_argument("--scheduler", default="cosine_decay", type=str, help="scheduler.") - parser.add_argument("--pre_patchify", default=False, type=str2bool, help="Training with patchified latent.") + parser.add_argument("--pre_patchify", default=False, type=bool, help="Training with patchified latent.") # dataloader params - parser.add_argument("--dataset_sink_mode", default=False, type=str2bool, help="sink mode") + parser.add_argument("--dataset_sink_mode", default=False, type=bool, help="sink mode") parser.add_argument("--sink_size", default=-1, type=int, help="dataset sink size. If -1, sink size = dataset size.") parser.add_argument( "--epochs", @@ -150,94 +104,29 @@ def parse_train_args(parser): parser.add_argument("--loss_scale_factor", default=2, type=float, help="loss scale factor") parser.add_argument("--scale_window", default=2000, type=float, help="scale window") parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="gradient accumulation steps") - # parser.add_argument("--cond_stage_trainable", default=False, type=str2bool, help="whether text encoder is trainable") - parser.add_argument("--use_ema", default=False, type=str2bool, help="whether use EMA") + # parser.add_argument("--cond_stage_trainable", default=False, type=bool, help="whether text encoder is trainable") + parser.add_argument("--use_ema", default=False, type=bool, help="whether use EMA") parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay ratio") - parser.add_argument("--clip_grad", default=False, type=str2bool, help="whether apply gradient clipping") - parser.add_argument( - "--use_recompute", - default=False, - type=str2bool, - help="whether use recompute.", - ) - parser.add_argument( - "--num_recompute_blocks", - default=None, - type=int, - help="If None, all stdit blocks will be applied with recompute (gradient checkpointing). If int, the first N blocks will be applied with recompute", - ) - parser.add_argument( - "--dtype", - default="fp16", - type=str, - choices=["bf16", "fp16", "fp32"], - help="what computation data type to use for latte. Default is `fp16`, which corresponds to ms.float16", - ) + parser.add_argument("--clip_grad", default=False, type=bool, help="whether apply gradient clipping") parser.add_argument( "--vae_keep_gn_fp32", default=True, - type=str2bool, + type=bool, help="whether keep GroupNorm in fp32.", ) parser.add_argument( "--vae_keep_updown_fp32", default=True, - type=str2bool, + type=bool, help="whether keep spatial/temporal upsample and downsample in fp32.", ) - parser.add_argument( - "--global_bf16", - default=False, - type=str2bool, - help="Experimental. If True, dtype will be overrided, operators will be computered in bf16 if they are supported by CANN", - ) - parser.add_argument( - "--vae_param_dtype", - default="fp32", - type=str, - choices=["bf16", "fp16", "fp32"], - help="what param data type to use for vae. Default is `fp32`, which corresponds to ms.float32", - ) - parser.add_argument( - "--amp_level", - default="O2", - type=str, - help="mindspore amp level, O1: most fp32, only layers in whitelist compute in fp16 (dense, conv, etc); \ - O2: most fp16, only layers in blacklist compute in fp32 (batch norm etc)", - ) - parser.add_argument("--vae_amp_level", default="O2", type=str, help="O2 or O3") - parser.add_argument( - "--vae_checkpoint", - type=str, - default="models/sd-vae-ft-ema.ckpt", - help="VAE checkpoint file path which is used to load vae weight.", - ) - parser.add_argument( - "--sd_scale_factor", type=float, default=0.18215, help="VAE scale factor of Stable Diffusion model." - ) - parser.add_argument( - "--image_size", default=256, type=int, nargs="+", help="image size for resizing the input image" - ) - parser.add_argument("--crop_size", default=256, type=int, help="crop size after resize") - parser.add_argument("--num_frames", default=16, type=int, help="the num of frames used to initiate model") - parser.add_argument("--frame_stride", default=3, type=int, help="frame sampling stride") - parser.add_argument("--mask_ratios", type=dict, help="Masking ratios") - parser.add_argument("--bucket_config", type=dict, help="Multi-resolution bucketing configuration") - parser.add_argument("--num_parallel_workers", default=12, type=int, help="num workers for data loading") - parser.add_argument( - "--data_multiprocessing", - default=False, - type=str2bool, - help="If True, use multiprocessing for data processing. Default: multithreading.", - ) - parser.add_argument("--max_rowsize", default=64, type=int, help="max rowsize for data loading") parser.add_argument( "--enable_flash_attention", default=None, - type=str2bool, + type=bool, help="whether to enable flash attention.", ) - parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") + parser.add_argument("--drop_overflow_update", default=True, type=bool, help="drop overflow update") parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static") parser.add_argument( "--max_grad_norm", @@ -256,10 +145,10 @@ def parse_train_args(parser): parser.add_argument( "--step_mode", default=False, - type=str2bool, + type=bool, help="whether save ckpt by steps. If False, save ckpt by epochs.", ) - parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + parser.add_argument("--profile", default=False, type=bool, help="Profile or not") parser.add_argument( "--log_level", type=str, @@ -276,19 +165,12 @@ def parse_train_args(parser): def parse_args(): - parser = argparse.ArgumentParser() - parser = parse_train_args(parser) + parser = parse_train_args() + args = parser.parse_args() __dir__ = os.path.dirname(os.path.abspath(__file__)) abs_path = os.path.abspath(os.path.join(__dir__, "..")) - default_args = parser.parse_args() - if default_args.config: - default_args.config = to_abspath(abs_path, default_args.config) - with open(default_args.config, "r") as f: - cfg = yaml.safe_load(f) - _check_cfgs_in_parser(cfg, parser) - parser.set_defaults(**cfg) - args = parser.parse_args() + # convert to absolute path, necessary for modelarts args.csv_path = to_abspath(abs_path, args.csv_path) args.video_folder = to_abspath(abs_path, args.video_folder) diff --git a/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml index af58e62ee1..6c9b44d6b9 100644 --- a/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml +++ b/examples/moviegen/configs/tae/train/mixed_256x256x16.yaml @@ -1,5 +1,5 @@ # model -pretrained_model_path: "models/tae_vae2d.ckpt" +pretrained: "models/tae_vae2d.ckpt" # loss perceptual_loss_weight: 1.0 @@ -9,18 +9,16 @@ mixed_strategy: "mixed_video_image" mixed_image_ratio: 0.2 # data -dataset_name: "video" csv_path: "../videocomposer/datasets/webvid5_copy.csv" -video_folder: "../videocomposer/datasets/webvid5" -frame_stride: 1 -num_frames: 16 +folder: "../videocomposer/datasets/webvid5" +sample_stride: 1 +sample_n_frames: 16 image_size: 256 crop_size: 256 # flip: True # training recipe seed: 42 -use_discriminator: False batch_size: 1 clip_grad: True max_grad_norm: 1.0 @@ -29,7 +27,6 @@ scale_lr: False weight_decay: 0. dtype: "fp32" -amp_level: "O0" use_recompute: False epochs: 2000 diff --git a/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml b/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml index 990ec83d72..822e09c676 100644 --- a/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml +++ b/examples/moviegen/configs/tae/train/mixed_256x256x32.yaml @@ -1,5 +1,5 @@ # model -pretrained_model_path: "models/tae_vae2d.ckpt" +pretrained: "models/tae_vae2d.ckpt" # loss perceptual_loss_weight: 1.0 @@ -9,18 +9,16 @@ mixed_strategy: "mixed_video_image" mixed_image_ratio: 0.2 # data -dataset_name: "video" csv_path: "../videocomposer/datasets/webvid5_copy.csv" -video_folder: "../videocomposer/datasets/webvid5" -frame_stride: 1 -num_frames: 32 +folder: "../videocomposer/datasets/webvid5" +sample_stride: 1 +sample_n_frames: 32 image_size: 256 crop_size: 256 # flip: True # training recipe seed: 42 -use_discriminator: False batch_size: 1 clip_grad: True max_grad_norm: 1.0 @@ -29,7 +27,6 @@ scale_lr: False weight_decay: 0. dtype: "bf16" -amp_level: "O2" # reduce memory cost use_recompute: True epochs: 2000 diff --git a/examples/moviegen/inference_vae.py b/examples/moviegen/eval_tae.py similarity index 55% rename from examples/moviegen/inference_vae.py rename to examples/moviegen/eval_tae.py index 59b0e54e22..91fbb15389 100644 --- a/examples/moviegen/inference_vae.py +++ b/examples/moviegen/eval_tae.py @@ -1,8 +1,6 @@ -# flake8: noqa """ Infer and evaluate autoencoders """ -import argparse import logging import os import sys @@ -10,32 +8,28 @@ import imageio import numpy as np - -from mindspore import nn, ops - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -mindone_dir = os.path.abspath(os.path.join(__dir__, "../../../")) -sys.path.insert(0, mindone_dir) - +from jsonargparse import ArgumentParser from PIL import Image from skimage.metrics import peak_signal_noise_ratio as calc_psnr from skimage.metrics import structural_similarity as calc_ssim from tqdm import tqdm import mindspore as ms +from mindspore import amp, nn, ops +# TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) -mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) -sys.path.insert(0, mindone_lib_path) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.append(mindone_lib_path) -from mg.dataset.tae_dataset import create_dataloader -from mg.models.tae.lpips import LPIPS -from mg.models.tae.tae import TemporalAutoencoder +from mg.dataset.tae_dataset import VideoDataset +from mg.models.tae import TemporalAutoencoder + +from mindone.data import create_dataloader +from mindone.utils import init_train_env, set_logger + +# from mg.models.tae.lpips import LPIPS -from mindone.utils.amp import auto_mixed_precision -from mindone.utils.config import str2bool -from mindone.utils.logger import set_logger logger = logging.getLogger(__name__) @@ -93,88 +87,61 @@ def rearrange_out(x, t): def main(args): - ascend_config = {"precision_mode": "must_keep_origin_dtype"} - ms.set_context(mode=args.mode, ascend_config=ascend_config) - set_logger(name="", output_dir=args.output_path, rank=0) + # set env + # TODO: rename as train and infer are identical? + _, rank_id, device_num = init_train_env(mode=args.mode, ascend_config={"precision_mode": "must_keep_origin_dtype"}) + set_logger(name="", output_dir=args.output_path, rank=rank_id) # build model - model = TemporalAutoencoder( - pretrained=args.ckpt_path, - use_tile=args.enable_tile, - ) - - model.set_train(False) - logger.info(f"Loaded checkpoint from {args.ckpt_path}") - - if args.eval_loss: - lpips_loss_fn = LPIPS() - + model = TemporalAutoencoder(pretrained=args.pretrained, use_tile=args.use_tile).set_train(False) if args.dtype != "fp32": - amp_level = "O2" dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] # FIXME: due to AvgPool and ops.interpolate doesn't support bf16, we add them to fp32 cells custom_fp32_cells = [nn.GroupNorm, nn.AvgPool2d, nn.Upsample] - model = auto_mixed_precision(model, amp_level, dtype, custom_fp32_cells) + model = amp.custom_mixed_precision(model, black_list=amp.get_black_list() + custom_fp32_cells, dtype=dtype) logger.info(f"Set mixed precision to O2 with dtype={args.dtype}") - else: - amp_level = "O0" - # build dataset - if isinstance(args.image_size, int): - image_size = args.image_size - else: - if len(args.image_size) == 2: - assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" - image_size = args.image_size[0] + # if args.eval_loss: + # lpips_loss_fn = LPIPS() - ds_config = dict( + # build dataset + dataset = VideoDataset( csv_path=args.csv_path, - data_folder=args.video_folder, - size=image_size, - crop_size=image_size, - sample_n_frames=args.num_frames, - sample_stride=args.frame_stride, + folder=args.folder, + size=args.image_size, + crop_size=args.image_size, + sample_n_frames=args.sample_n_frames, + sample_stride=args.sample_stride, video_column=args.video_column, random_crop=False, flip=False, + output_columns=["video"], ) dataset = create_dataloader( - ds_config, + dataset, args.batch_size, - mixed_strategy=None, - mixed_image_ratio=0.0, - num_parallel_workers=8, + num_workers=8, max_rowsize=256, shuffle=False, - device_num=1, - rank_id=0, + device_num=device_num, + rank_id=rank_id, drop_remainder=False, ) num_batches = dataset.get_dataset_size() - ds_iter = dataset.create_dict_iterator(1) + ds_iter = dataset.create_dict_iterator(num_epochs=1) - logger.info("Inferene begins") - mean_infer_time = 0 - mean_psnr = 0 - mean_ssim = 0 - mean_lpips = 0 - mean_recon = 0 - num_samples = 0 + mean_infer_time, mean_psnr, mean_ssim, mean_lpips, mean_recon, num_samples = (0,) * 6 for step, data in tqdm(enumerate(ds_iter)): x = data["video"] - start_time = time.time() + start_time = time.perf_counter() if args.encode_only: - z = model.encode(x) + z, posterior_mean, posterior_logvar = model.encode(x) else: - # recons = model.decode(z) recons, z, posterior_mean, posterior_logvar = model(x) - # adapt to bf16 - recons = recons.to(ms.float32) - - infer_time = time.time() - start_time + infer_time = time.perf_counter() - start_time mean_infer_time += infer_time logger.info(f"Infer time: {infer_time}") @@ -213,9 +180,7 @@ def main(args): logger.info(f"mean recon loss: {mean_recon/num_batches:.4f}") if args.save_vis: - save_fn = os.path.join( - args.output_path, "{}-{}".format(os.path.basename(args.video_folder), f"step{step:03d}") - ) + save_fn = os.path.join(args.output_path, f"{os.path.basename(args.video_folder)}-{f'step{step:03d}'}") if not is_video: visualize_image(recons_rgb, x_rgb, save_fn=save_fn) else: @@ -241,86 +206,43 @@ def main(args): # logger.info(f"mean lpips loss: {mean_lpips:.4f}") -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_config", - default="configs/autoencoder_kl_f8.yaml", - type=str, - help="model architecture config", +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_function_arguments( + init_train_env, skip={"ascend_config", "num_workers", "json_data_path", "enable_modelarts"} ) + parser.add_class_arguments(TemporalAutoencoder, instantiate=False) parser.add_argument( - "--ckpt_path", default="outputs/vae_train/ckpt/vae_kl_f8-e10.ckpt", type=str, help="checkpoint path" + "--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." ) - parser.add_argument( - "--csv_path", - default=None, - type=str, - help="path to csv annotation file. If None, will get videos from the folder of `data_path`", + parser.add_class_arguments(VideoDataset, skip={"output_columns"}, instantiate=False) + parser.add_function_arguments( + create_dataloader, + skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, ) - parser.add_argument("--video_folder", default=None, type=str, help="folder of videos") parser.add_argument( "--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results" ) - parser.add_argument("--num_frames", default=17, type=int, help="num frames") - parser.add_argument("--frame_stride", default=1, type=int, help="frame sampling stride") parser.add_argument( "--expand_dim_t", default=False, - type=str2bool, + type=bool, help="expand temporal axis for image data, used for vae 3d inference with image data", ) - parser.add_argument("--image_size", default=256, type=int, help="image rescale size") - # parser.add_argument("--crop_size", default=256, type=int, help="image crop size") - - parser.add_argument("--batch_size", default=1, type=int, help="batch size") - parser.add_argument("--num_parallel_workers", default=8, type=int, help="num workers for data loading") parser.add_argument( "--eval_loss", default=False, - type=str2bool, + type=bool, help="whether measure loss including reconstruction, kl, perceptual loss", ) - parser.add_argument("--save_vis", default=True, type=str2bool, help="whether save reconstructed images") - parser.add_argument("--use_temporal_vae", default=True, type=str2bool, help="if False, just use spatial vae") - parser.add_argument("--encode_only", default=False, type=str2bool, help="only encode to save z or distribution") - parser.add_argument( - "--enable_tile", default=False, type=str2bool, help="enable temporal tiling with linear blending for decoder" - ) - parser.add_argument("--video_column", default="video", type=str, help="name of column for videos saved in csv file") - parser.add_argument( - "--mixed_strategy", - type=str, - default=None, - choices=[None, "mixed_video_image", "image_only"], - help="video and image mixed strategy.", - ) - parser.add_argument( - "--mixed_image_ratio", default=0.0, type=float, help="image ratio in mixed video and image data training" - ) + parser.add_argument("--save_vis", default=True, type=bool, help="whether save reconstructed images") + parser.add_argument("--use_temporal_vae", default=True, type=bool, help="if False, just use spatial vae") + parser.add_argument("--encode_only", default=False, type=bool, help="only encode to save z or distribution") parser.add_argument( "--save_z_dist", default=False, - type=str2bool, + type=bool, help="If True, save z distribution, mean and logvar. Otherwise, save z after sampling.", ) - # ms related - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") - parser.add_argument( - "--dtype", - default="fp32", - type=str, - choices=["fp32", "fp16", "bf16"], - help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ - if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", - ) - parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = parse_args() main(args) diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index cc705d0b27..79d8f55897 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -19,12 +19,11 @@ sys.path.append(mindone_lib_path) from mg.models.tae import TemporalAutoencoder -from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample from mg.pipelines import InferPipeline from mg.utils import MODEL_DTYPE, init_model, to_numpy from mindone.utils import init_train_env, set_logger -from mindone.visualize.videos import save_videos +from mindone.visualize import save_videos logger = logging.getLogger(__name__) @@ -76,8 +75,7 @@ def main(args): # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative amp.custom_mixed_precision( tae, - black_list=amp.get_black_list() - + [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample, nn.GroupNorm], + black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=MODEL_DTYPE[tae_dtype], ) @@ -172,14 +170,14 @@ def main(args): parser.add_function_arguments(init_train_env, "env") parser.add_function_arguments(init_model, "model", skip={"in_channels"}) tae_group = parser.add_argument_group("TAE parameters") - parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) - parser.add_argument( + tae_group.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) + tae_group.add_argument( "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." ) infer_group = parser.add_argument_group("Inference parameters") infer_group.add_class_arguments(InferPipeline, skip={"model", "tae", "latent_size"}, instantiate=False) infer_group.add_argument("--image_size", type=int, nargs="+", help="Output video size") - infer_group.add_argument("--num_frames", type=int, default=17, help="number of frames") + infer_group.add_argument("--num_frames", type=int, default=16, help="number of frames") infer_group.add_argument("--fps", type=int, default=16, help="FPS in the saved video") infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num"}) infer_group.add_argument("--batch_size", type=int, default=1) diff --git a/examples/moviegen/inference_tae_enc.py b/examples/moviegen/inference_tae_enc.py new file mode 100644 index 0000000000..94ebc99e70 --- /dev/null +++ b/examples/moviegen/inference_tae_enc.py @@ -0,0 +1,123 @@ +import logging +import os +import sys +from pathlib import Path + +import numpy as np +from jsonargparse import ArgumentParser +from jsonargparse.typing import path_type +from tqdm import tqdm + +from mindspore import amp, get_context, nn + +# TODO: remove in future when mindone is ready for install +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.append(mindone_lib_path) + +from mg.dataset.tae_dataset import VideoDataset +from mg.models.tae import TemporalAutoencoder +from mg.utils import MODEL_DTYPE, to_numpy + +from mindone.data import create_dataloader +from mindone.utils import init_train_env, set_logger + +logger = logging.getLogger(__name__) + + +def main(args): + # 1. init env + _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical? + mode = get_context("mode") # `init_train_env()` may change the mode during debugging + + save_dir = Path(args.output_path.absolute) + save_dir.mkdir(parents=True, exist_ok=True) + set_logger(name="", output_dir=str(save_dir), rank=rank_id) + + # 2 build dataset + dataset = VideoDataset( + **args.data, + sample_n_frames=10**5, # read the full video, limitation of `albumentations` (i.e., `additional_targets`) + output_columns=["video", "rel_path"], + ) + dataloader = create_dataloader( + dataset, drop_remainder=False, device_num=device_num, rank_id=rank_id, **args.dataloader + ) + + # 3. TAE initiate and weight loading + logger.info("TAE init") + tae_args = args.tae.as_dict() + tae_dtype = tae_args.pop("dtype") + tae = TemporalAutoencoder(**tae_args).set_train(False) + if tae_dtype != "fp32": + # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative + amp.custom_mixed_precision( + tae, + black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], + dtype=MODEL_DTYPE[tae_dtype], + ) + + # 4. print key info + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {mode}", + f"Debug mode: {args.env.debug}", + f"Num of batches: {dataloader.get_dataset_size()}", + f"TAE dtype: {tae_dtype}", + f"Image size: {args.data.size}", + f"Crop size: {args.data.crop_size}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + for samples in tqdm(dataloader.create_tuple_iterator(num_epochs=1), total=dataloader.get_dataset_size()): + z, _, _ = tae.encode(samples[0]) + z = to_numpy(z) + for latent, path in zip(z, samples[1].tolist()): + out_path = save_dir / path + out_path.parent.mkdir(parents=True, exist_ok=True) + np.save(out_path.with_suffix(".npy"), latent) + logger.info(f"Completed, Denoised latents saved in {save_dir}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="TAE inference script.") + parser.add_function_arguments(init_train_env, "env") + tae_group = parser.add_argument_group("TAE parameters") + tae_group.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) + tae_group.add_argument( + "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." + ) + parser.add_class_arguments( + VideoDataset, + "data", + skip={"random_crop", "flip", "sample_n_frames", "return_image", "output_columns"}, + instantiate=False, + ) + parser.add_function_arguments( + create_dataloader, + "dataloader", + skip={ + "dataset", + "transforms", + "batch_transforms", + "project_columns", + "shuffle", + "num_workers", # no transformations inside `.map()` + "drop_remainder", + "device_num", + "rank_id", + "enable_modelarts", + }, + ) + parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") + parser.add_argument( + "--output_path", + default="output/", + type=path_type("dcc"), # path to a directory that can be created if it does not exist + help="Output directory to save training results.", + ) + cfg = parser.parse_args() + main(cfg) diff --git a/examples/moviegen/mg/dataset/dataset.py b/examples/moviegen/mg/dataset/dataset.py index 5f763e71d7..30001b9452 100644 --- a/examples/moviegen/mg/dataset/dataset.py +++ b/examples/moviegen/mg/dataset/dataset.py @@ -104,7 +104,7 @@ def _filter_data(sample_): return None return sample_ - with open(csv_path, "r") as csv_file: + with open(csv_path, "r", encoding="utf-8") as csv_file: try: data = [] for item in csv.DictReader(csv_file): @@ -118,7 +118,7 @@ def _filter_data(sample_): for name, path in text_emb_folder.items() } if vae_latent_folder: - sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npz")) + sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npy")) data.append(sample) except KeyError as e: _logger.error(f"CSV file requires `video` (file paths) column, but got {list(item.keys())}") diff --git a/examples/moviegen/mg/dataset/tae_dataset.py b/examples/moviegen/mg/dataset/tae_dataset.py index e09310460a..7be42dbd48 100644 --- a/examples/moviegen/mg/dataset/tae_dataset.py +++ b/examples/moviegen/mg/dataset/tae_dataset.py @@ -1,17 +1,19 @@ import copy import csv -import glob import logging import os import random +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple, Union -import albumentations import cv2 import imageio import numpy as np from decord import VideoReader -import mindspore as ms +from mindone.data import BaseDataset + +__all__ = ["VideoDataset", "BatchTransform"] logger = logging.getLogger() @@ -20,21 +22,25 @@ def create_video_transforms( size=384, crop_size=256, interpolation="bicubic", backend="al", random_crop=False, flip=False, num_frames=None ): if backend == "al": + os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" # prevent albumentations from being annoying # expect rgb image in range 0-255, shape (h w c) - from albumentations import CenterCrop, HorizontalFlip, RandomCrop, SmallestMaxSize + from albumentations import CenterCrop, Compose, HorizontalFlip, RandomCrop, SmallestMaxSize + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) # NOTE: to ensure augment all frames in a video in the same way. assert num_frames is not None, "num_frames must be parsed" - targets = {"image{}".format(i): "image" for i in range(num_frames)} + targets = {f"image{i}": "image" for i in range(num_frames)} mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} transforms = [ SmallestMaxSize(max_size=size, interpolation=mapping[interpolation]), - CenterCrop(crop_size, crop_size) if not random_crop else RandomCrop(crop_size, crop_size), + CenterCrop(*crop_size) if not random_crop else RandomCrop(*crop_size), ] if flip: transforms += [HorizontalFlip(p=0.5)] - pixel_transforms = albumentations.Compose( + pixel_transforms = Compose( transforms, additional_targets=targets, ) @@ -44,29 +50,40 @@ def create_video_transforms( return pixel_transforms -def get_video_path_list(folder): - # TODO: find recursively - fmts = ["avi", "mp4", "gif"] - out = [] - for fmt in fmts: - out += glob.glob(os.path.join(folder, f"*.{fmt}")) - return sorted(out) +def get_video_path_list(folder: str, video_column: str) -> List[Dict[str, str]]: + """ + Constructs a list of images and videos in the given directory (recursively). + + Args: + folder: path to a directory containing images and videos. + video_column: name of the column to store video paths. + Returns: + A list of paths to images and videos in the given directory (absolute and relative). + """ + exts = (".jpg", ".jpeg", ".png", ".gif", ".mp4", ".avi") + data = [ + {video_column: str(item), "rel_path": str(item.relative_to(folder))} + for item in Path(folder).rglob("*") + if (item.is_file() and item.suffix.lower() in exts) + ] + return sorted(data, key=lambda x: x[video_column]) -class VideoDataset: +class VideoDataset(BaseDataset): def __init__( self, - csv_path=None, - data_folder=None, - size=384, - crop_size=256, - random_crop=False, - flip=False, - sample_stride=4, - sample_n_frames=16, - return_image=False, - transform_backend="al", - video_column="video", + csv_path: Optional[str], + folder: str, + size: int = 384, + crop_size: Union[int, Tuple[int, int]] = 256, + random_crop: bool = False, + flip: bool = False, + sample_stride: int = 1, + sample_n_frames: int = 16, + return_image: bool = False, + video_column: str = "video", + *, + output_columns: List[str], ): """ size: image resize size @@ -76,17 +93,18 @@ def __init__( if csv_path is not None: with open(csv_path, "r") as csvfile: - self.dataset = list(csv.DictReader(csvfile)) - self.read_from_csv = True + self.dataset = [ + {**item, video_column: os.path.join(folder, item[video_column]), "rel_path": item[video_column]} + for item in csv.DictReader(csvfile) + ] else: - self.dataset = get_video_path_list(data_folder) - self.read_from_csv = False + self.dataset = get_video_path_list(folder, video_column) self.length = len(self.dataset) logger.info(f"Num data samples: {self.length}") logger.info(f"sample_n_frames: {sample_n_frames}") - self.data_folder = data_folder + self.folder = folder self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.return_image = return_image @@ -98,8 +116,8 @@ def __init__( flip=flip, num_frames=sample_n_frames, ) - self.transform_backend = transform_backend self.video_column = video_column + self.output_columns = output_columns # prepare replacement data max_attempts = 100 @@ -123,12 +141,8 @@ def get_replace_data(self, max_attempts=100): def get_batch(self, idx): # get video raw pixels (batch of frame) and its caption - if self.read_from_csv: - video_dict = self.dataset[idx] - video_fn = video_dict[list(video_dict.keys())[0]] - video_path = os.path.join(self.data_folder, video_fn) - else: - video_path = self.dataset[idx] + video_dict = self.dataset[idx].copy() + video_path = video_dict[self.video_column] video_reader = VideoReader(video_path) @@ -142,13 +156,13 @@ def get_batch(self, idx): batch_index = [random.randint(0, video_length - 1)] if video_path.endswith(".gif"): - pixel_values = video_reader[batch_index] # shape: (f, h, w, c) + video_dict[self.video_column] = video_reader[batch_index] # shape: (f, h, w, c) else: - pixel_values = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c) + video_dict[self.video_column] = video_reader.get_batch(batch_index).asnumpy() # shape: (f, h, w, c) del video_reader - return pixel_values + return tuple(video_dict[c] for c in self.output_columns) def __len__(self): return self.length @@ -159,44 +173,47 @@ def __getitem__(self, idx): video: preprocessed video frames in shape (f, c, h, w), normalized to [-1, 1] """ try: - pixel_values = self.get_batch(idx) + data = self.get_batch(idx) if (self.prev_ok_sample is None) or (self.require_update_prev): - self.prev_ok_sample = copy.deepcopy(pixel_values) + self.prev_ok_sample = copy.deepcopy(data) self.require_update_prev = False except Exception as e: logger.warning(f"Fail to get sample of idx {idx}. The corrupted video will be replaced.") print("\tError msg: {}".format(e), flush=True) assert self.prev_ok_sample is not None - pixel_values = self.prev_ok_sample # unless the first sample is already not ok + data = self.prev_ok_sample # unless the first sample is already not ok self.require_update_prev = True if idx >= self.length: raise IndexError # needed for checking the end of dataset iteration + pixel_values = data[0] num_frames = len(pixel_values) # pixel value: (f, h, w, 3) -> transforms -> (f 3 h' w') - if self.transform_backend == "al": - # NOTE:it's to ensure augment all frames in a video in the same way. - # ref: https://albumentations.ai/docs/examples/example_multi_target/ + # NOTE:it's to ensure augment all frames in a video in the same way. + # ref: https://albumentations.ai/docs/examples/example_multi_target/ - inputs = {"image": pixel_values[0]} - for i in range(num_frames - 1): - inputs[f"image{i}"] = pixel_values[i + 1] + inputs = {"image": pixel_values[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = pixel_values[i + 1] - output = self.pixel_transforms(**inputs) + output = self.pixel_transforms(**inputs) - pixel_values = np.stack(list(output.values()), axis=0) - # (t h w c) -> (c t h w) - pixel_values = np.transpose(pixel_values, (3, 0, 1, 2)) - else: - raise NotImplementedError + pixel_values = np.stack(list(output.values()), axis=0) + # (t h w c) -> (c t h w) + pixel_values = np.transpose(pixel_values, (3, 0, 1, 2)) if self.return_image: pixel_values = pixel_values[1] pixel_values = (pixel_values / 127.5 - 1.0).astype(np.float32) - return pixel_values + return pixel_values, *data[1:] + + @staticmethod + def train_transforms(**kwargs) -> List[dict]: + # train transforms are performed during data reading + pass # TODO: parse in config dict @@ -214,87 +231,48 @@ def check_sanity(x, save_fp="./tmp.gif"): class BatchTransform: - def __init__(self, mixed_strategy, mixed_image_ratio=0.2): - self.mixed_strategy = mixed_strategy + def __init__( + self, + mixed_strategy: Literal["mixed_video_image", "mixed_video_random", "image_only"], + mixed_image_ratio: float = 0.2, + ): + if mixed_strategy == "mixed_video_image": + self._trans_fn = self._mixed_video_image + elif mixed_strategy == "mixed_video_random": + self._trans_fn = self._mixed_video_random + elif mixed_strategy == "image_only": + self._trans_fn = self._image_only + else: + raise NotImplementedError(f"Unknown mixed_strategy: {mixed_strategy}") self.mixed_image_ratio = mixed_image_ratio - def __call__(self, x): - # x: (bs, c, t, h, w) - if self.mixed_strategy == "mixed_video_image": - if random.random() < self.mixed_image_ratio: - x = x[:, :, :1, :, :] - elif self.mixed_strategy == "mixed_video_random": - # TODO: somehow it's slow. consider do it with tensor in NetWithLoss - length = random.randint(1, x.shape[2]) - x = x[:, :, :length, :, :] - elif self.mixed_strategy == "image_only": + def _mixed_video_image(self, x: np.ndarray) -> np.ndarray: + if random.random() < self.mixed_image_ratio: x = x[:, :, :1, :, :] - else: - raise ValueError return x + @staticmethod + def _mixed_video_random(x: np.ndarray) -> np.ndarray: + # TODO: somehow it's slow. consider do it with tensor in NetWithLoss + length = random.randint(1, x.shape[2]) + return x[:, :, :length, :, :] -def create_dataloader( - ds_config, - batch_size, - mixed_strategy=None, - mixed_image_ratio=0.0, - num_parallel_workers=12, - max_rowsize=32, - shuffle=True, - device_num=1, - rank_id=0, - drop_remainder=True, -): - """ - Args: - mixed_strategy: - None - all output batches are videoes [bs, c, T, h, w] - mixed_video_image - with prob of mixed_image_ratio, output batch are images [b, c, 1, h, w] - mixed_video_random - output batch has a random number of frames [bs, c, t, h, w], t is the same of samples in a batch - mixed_image_ratio: - ds_config, dataset config, args for ImageDataset or VideoDataset - ds_name: dataset name, image or video - """ - dataset = VideoDataset(**ds_config) - print("Total number of samples: ", len(dataset)) - - # Larger value leads to more memory consumption. Default: 16 - # prefetch_size = config.get("prefetch_size", 16) - # ms.dataset.config.set_prefetch_size(prefetch_size) - - dataloader = ms.dataset.GeneratorDataset( - source=dataset, - column_names=["video"], - num_shards=device_num, - shard_id=rank_id, - python_multiprocessing=True, - shuffle=shuffle, - num_parallel_workers=num_parallel_workers, - max_rowsize=max_rowsize, - ) - - dl = dataloader.batch( - batch_size, - drop_remainder=drop_remainder, - ) - - if mixed_strategy is not None: - batch_map_fn = BatchTransform(mixed_strategy, mixed_image_ratio) - dl = dl.map( - operations=batch_map_fn, - input_columns=["video"], - num_parallel_workers=1, - ) + @staticmethod + def _image_only(x: np.ndarray) -> np.ndarray: + return x[:, :, :1, :, :] - return dl + def __call__(self, x): + # x: (bs, c, t, h, w) + return self._trans_fn(x) if __name__ == "__main__": + from mindone.data import create_dataloader + test = "dl" if test == "dataset": ds_config = dict( - data_folder="../videocomposer/datasets/webvid5", + folder="../videocomposer/datasets/webvid5", random_crop=True, flip=True, ) @@ -312,19 +290,16 @@ def create_dataloader( ds_config = dict( csv_path="../videocomposer/datasets/webvid5_copy.csv", - data_folder="../videocomposer/datasets/webvid5", + folder="../videocomposer/datasets/webvid5", sample_n_frames=17, size=128, crop_size=128, ) + ds = VideoDataset(**ds_config) + bt = BatchTransform(mixed_strategy="mixed_video_random", mixed_image_ratio=0.2) # test loader - dl = create_dataloader( - ds_config, - 4, - mixed_strategy="mixed_video_random", - mixed_image_ratio=0.2, - ) + dl = create_dataloader(ds, batch_size=4, batch_transforms={"operations": bt, "input_columns": ["video"]}) num_batches = dl.get_dataset_size() # ms.set_context(mode=0) diff --git a/examples/moviegen/mg/pipelines/infer_pipeline.py b/examples/moviegen/mg/pipelines/infer_pipeline.py index 9147bd0cc2..21541dbe25 100644 --- a/examples/moviegen/mg/pipelines/infer_pipeline.py +++ b/examples/moviegen/mg/pipelines/infer_pipeline.py @@ -53,7 +53,7 @@ def tae_decode_video(self, x, num_frames=None): """ x = mint.permute(x, (0, 2, 1, 3, 4)) # FIXME: remove this redundancy x = x / self.scale_factor + self.shift_factor - y = self.tae.decode(x, target_num_frames=num_frames) # FIXME: extract scale_factor from TAE and use it here + y = self.tae.decode(x, target_num_frames=num_frames) y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) # (b 3 t h w) -> (b t h w 3) y = mint.permute(y, (0, 2, 3, 4, 1)) diff --git a/examples/moviegen/scripts/stage2_train.sh b/examples/moviegen/scripts/stage2_train.sh index 921f7c6023..5182f302d2 100644 --- a/examples/moviegen/scripts/stage2_train.sh +++ b/examples/moviegen/scripts/stage2_train.sh @@ -2,20 +2,22 @@ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # plot memory usage, feature/model: 1 export MS_MEMORY_STATISTIC=0 +# operation/graph fusion for dynamic shape +export MS_DEV_ENABLE_KERNEL_PACKET=on + # log level export GLOG_v=2 output_dir=output/stage2_t2iv_256x256/$(date +"%Y.%m.%d-%H.%M.%S") -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ python train.py \ --config configs/train/stage2_t2iv_256x256.yaml \ --env.mode 0 \ - --env.jit_level O1 \ + --env.jit_level O0 \ --env.max_device_memory 59GB \ --env.distributed True \ - --model.model_parallelism True \ - --train.model_parallel.model_parallel_shards 8 \ + --train.settings.zero_stage 2 \ --dataset.csv_path CSV_PATH \ --dataset.video_folder VIDEO_FOLDER \ --dataset.text_emb_folder.ul2 UL2_FOLDER \ diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 1596708567..00444cc237 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -6,10 +6,10 @@ from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import path_type +import mindspore.dataset as ds from mindspore import GRAPH_MODE, Model, Symbol, Tensor, amp from mindspore import dtype as mstype -from mindspore import get_context, nn, set_seed -from mindspore.dataset import BatchDataset, BucketBatchByLengthDataset +from mindspore import get_context, nn, set_context, set_seed # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -18,7 +18,6 @@ from mg.dataset import ImageVideoDataset, bucket_split_function from mg.models.tae import TemporalAutoencoder -from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample from mg.parallel import create_parallel_group from mg.pipelines import DiffusionWithLoss from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper @@ -36,7 +35,7 @@ def initialize_dataset( dataset_args, dataloader_args, device_num: int, shard_rank_id: int -) -> Tuple[Union[BatchDataset, BucketBatchByLengthDataset], int]: +) -> Tuple[Union[ds.BatchDataset, ds.BucketBatchByLengthDataset], int]: dataset = ImageVideoDataset(**dataset_args) transforms = ( dataset.train_transforms(dataset_args.target_size) if not dataset_args.apply_transforms_dataset else None @@ -70,6 +69,10 @@ def main(args): device_id, rank_id, device_num = init_train_env(**args.env) mode = get_context("mode") # `init_train_env()` may change the mode during debugging + # if bucketing is used in Graph mode, activate dynamic mode + if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict): + set_context(graph_kernel_flags="--disable_packet_ops=Reshape") + # 1.1 init model parallel shard_rank_id = rank_id if (shards := args.train.model_parallel.model_parallel_shards) > 1: @@ -77,7 +80,10 @@ def main(args): device_num = device_num // shards shard_rank_id = rank_id // shards - set_seed(args.env.seed + shard_rank_id) # TODO: do it better + # FIXME: Improve seed setting + set_seed(args.env.seed + shard_rank_id) # set different seeds per NPU for sampling different timesteps + ds.set_seed(args.env.seed) # keep MS.dataset's seed consistent as datasets first shuffled and then distributed + set_logger("", output_dir=args.train.output_path, rank=rank_id) # instantiate classes only after initializing training environment @@ -94,8 +100,7 @@ def main(args): # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative amp.custom_mixed_precision( tae, - black_list=amp.get_black_list() - + [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample, nn.GroupNorm], + black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=MODEL_DTYPE[tae_dtype], ) @@ -174,6 +179,7 @@ def main(args): step_mode=True, use_step_unit=True, train_steps=args.train.steps, + resume_prefix_blacklist=("tae.", "swap."), **args.train.save, ), PerfRecorderCallback( @@ -247,7 +253,7 @@ def main(args): ImageVideoDataset, "dataset", skip={"frames_mask_generator", "t_compress_func"}, instantiate=False ) parser.add_function_arguments( - create_dataloader, "dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} + create_dataloader, "dataloader", skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") parser.add_function_arguments(create_parallel_group, "train.model_parallel") @@ -288,6 +294,7 @@ def main(args): "step_mode", "use_step_unit", "train_steps", + "resume_prefix_blacklist", }, instantiate=False, ) diff --git a/examples/moviegen/train_tae.py b/examples/moviegen/train_tae.py index fb97c6a1ba..00a03f113b 100644 --- a/examples/moviegen/train_tae.py +++ b/examples/moviegen/train_tae.py @@ -3,13 +3,11 @@ import shutil import sys import time -from typing import Tuple import yaml import mindspore as ms -from mindspore import Model, nn -from mindspore.communication.management import get_group_size, get_rank, init +from mindspore import Model, amp, nn from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import TimeMonitor @@ -19,21 +17,21 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) from args_train_tae import parse_args -from mg.dataset.tae_dataset import create_dataloader +from mg.dataset.tae_dataset import BatchTransform, VideoDataset +from mg.models.tae import TemporalAutoencoder from mg.models.tae.losses import GeneratorWithLoss from mg.models.tae.modules import SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample -from mg.models.tae.tae import TemporalAutoencoder +from mindone.data import create_dataloader from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network from mindone.trainers.ema import EMA from mindone.trainers.lr_schedule import create_scheduler from mindone.trainers.optim import create_optimizer from mindone.trainers.train_step import TrainOneStepWrapper -from mindone.utils.amp import auto_mixed_precision +from mindone.utils import init_train_env from mindone.utils.logger import set_logger from mindone.utils.params import count_params -from mindone.utils.seed import set_random_seed os.environ["HCCL_CONNECT_TIMEOUT"] = "6000" os.environ["MS_ASCEND_CHECK_OVERFLOW_MODE"] = "INFNAN_MODE" @@ -54,105 +52,15 @@ def create_loss_scaler(loss_scaler_type, init_loss_scale, loss_scale_factor=2, s return loss_scaler -def init_env( - mode: int = ms.GRAPH_MODE, - seed: int = 42, - distributed: bool = False, - max_device_memory: str = None, - device_target: str = "Ascend", - parallel_mode: str = "data", - jit_level: str = "O2", - global_bf16: bool = False, - dynamic_shape: bool = False, - debug: bool = False, -) -> Tuple[int, int]: - """ - Initialize MindSpore environment. - - Args: - mode: MindSpore execution mode. Default is 0 (ms.GRAPH_MODE). - seed: The seed value for reproducibility. Default is 42. - distributed: Whether to enable distributed training. Default is False. - Returns: - A tuple containing the device ID, rank ID and number of devices. - """ - set_random_seed(seed) - - if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging - logger.warning("Debug mode is on, switching execution mode to PyNative.") - mode = ms.PYNATIVE_MODE - - if max_device_memory is not None: - ms.set_context(max_device_memory=max_device_memory) - - # ms.set_context(mempool_block_size="55GB") - # ms.set_context(pynative_synchronize=True) - if distributed: - ms.set_context( - mode=mode, - device_target=device_target, - ) - if parallel_mode == "optim": - print("use optim parallel") - ms.set_auto_parallel_context( - parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, - enable_parallel_optimizer=True, - ) - init() - device_num = get_group_size() - rank_id = get_rank() - else: - init() - device_num = get_group_size() - rank_id = get_rank() - logger.debug(f"rank_id: {rank_id}, device_num: {device_num}") - ms.reset_auto_parallel_context() - - ms.set_auto_parallel_context( - parallel_mode=ms.ParallelMode.DATA_PARALLEL, - gradients_mean=True, - device_num=device_num, - ) - - var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] - var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] - logger.info(dict(zip(var_info, var_value))) - - else: - device_num = 1 - rank_id = 0 - ms.set_context( - mode=mode, - device_target=device_target, - pynative_synchronize=debug, - ) - - if mode == 0: - ms.set_context(jit_config={"jit_level": jit_level}) - - if global_bf16: - # only effective in GE mode, i.e. jit_level: O2 - ms.set_context(ascend_config={"precision_mode": "allow_mix_precision_bf16"}) - - # if dynamic_shape: - # print("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") - # set_dynamic_mode(True) - - return rank_id, device_num - - def main(args): # 1. init - rank_id, device_num = init_env( - args.mode, + _, rank_id, device_num = init_train_env( + mode=args.mode, seed=args.seed, - distributed=args.use_parallel, + distributed=args.distributed, device_target=args.device_target, max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, jit_level=args.jit_level, - global_bf16=args.global_bf16, - dynamic_shape=(args.mixed_strategy == "mixed_video_random"), debug=args.debug, ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) @@ -165,23 +73,25 @@ def main(args): assert args.image_size[0] == args.image_size[1], "Currently only h==w is supported" image_size = args.image_size[0] - ds_config = dict( + dataset = VideoDataset( csv_path=args.csv_path, - data_folder=args.video_folder, - size=image_size, + folder=args.folder, + size=args.image_size, crop_size=args.crop_size, - sample_n_frames=args.num_frames, - sample_stride=args.frame_stride, + sample_n_frames=args.sample_n_frames, + sample_stride=args.sample_stride, video_column=args.video_column, random_crop=args.random_crop, flip=args.flip, + output_columns=["video"], ) + transform = BatchTransform(mixed_strategy=args.mixed_strategy, mixed_image_ratio=args.mixed_image_ratio) + transform = {"operations": transform, "input_columns": ["video"]} dataloader = create_dataloader( - ds_config, - args.batch_size, - mixed_strategy=args.mixed_strategy, - mixed_image_ratio=args.mixed_image_ratio, - num_parallel_workers=args.num_parallel_workers, + dataset=dataset, + batch_size=args.batch_size, + batch_transforms=transform, + num_workers=args.num_workers, max_rowsize=256, shuffle=True, device_num=device_num, @@ -193,26 +103,24 @@ def main(args): # 3. build models ae = TemporalAutoencoder( - pretrained=args.pretrained_model_path, + pretrained=args.pretrained, use_recompute=args.use_recompute, ) - if args.use_discriminator: - logging.error("Discriminator is not used or supported in OpenSora v1.2") - # mixed precision # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. - if args.dtype in ["fp16", "bf16"]: + if args.dtype != "fp32": dtype = {"fp16": ms.float16, "bf16": ms.bfloat16}[args.dtype] # TODO: check ResizeNearest bf16 support for ms>2.3.1 - ae = auto_mixed_precision( + ae = amp.custom_mixed_precision( ae, - args.amp_level, - dtype, - custom_fp32_cells=[SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] - if args.vae_keep_updown_fp32 - else [] + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []), - # custom_fp32_cells=[nn.GroupNorm, SpatialUpsample] if args.vae_keep_gn_fp32 else [SpatialUpsample], + black_list=amp.get_black_list() + + ( + [SpatialDownsample, SpatialUpsample, TemporalDownsample, TemporalUpsample] + if args.vae_keep_updown_fp32 + else [] + ([nn.GroupNorm] if args.vae_keep_gn_fp32 else []) + ), + dtype=dtype, ) # 4. build net with loss diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index af14e1a155..3246f38366 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -805,7 +805,7 @@ def main(args): log_interval=args.log_interval, start_epoch=start_epoch, model_name=model_name, - resume_prefix_blacklist=["vae.", "swap."], + resume_prefix_blacklist=("vae.", "swap."), record_lr=False, train_steps=args.train_steps, ) diff --git a/examples/svd/train.py b/examples/svd/train.py index c46dd2e7e5..0da71a749a 100644 --- a/examples/svd/train.py +++ b/examples/svd/train.py @@ -166,7 +166,7 @@ def main(args, initializer): parser.add_function_arguments( create_dataloader, "train.dataloader", - skip={"dataset", "transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, + skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, ) parser.add_function_arguments(create_scheduler, "train.scheduler", skip={"steps_per_epoch", "num_epochs"}) parser.add_function_arguments(create_optimizer, "train.optimizer", skip={"params", "lr"}) diff --git a/examples/t2i_adapter/train_t2i_adapter_sd.py b/examples/t2i_adapter/train_t2i_adapter_sd.py index 1dffe8bb47..a831c7406c 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sd.py +++ b/examples/t2i_adapter/train_t2i_adapter_sd.py @@ -142,7 +142,7 @@ def main(args, initializer): parser.add_function_arguments( create_dataloader, "train.dataloader", - skip={"dataset", "transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, + skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id", "debug", "enable_modelarts"}, ) parser.add_function_arguments(build_optimizer, "train.optimizer", skip={"model"}) parser.add_class_arguments( diff --git a/mindone/data/loader.py b/mindone/data/loader.py index 6032141d30..8ff9fd197e 100644 --- a/mindone/data/loader.py +++ b/mindone/data/loader.py @@ -11,6 +11,7 @@ def create_dataloader( dataset: BaseDataset, batch_size: int = 1, transforms: Optional[Union[List[dict], dict]] = None, + batch_transforms: Optional[Union[List[dict], dict]] = None, project_columns: Optional[List[str]] = None, shuffle: bool = False, num_workers: int = 4, @@ -38,6 +39,8 @@ def create_dataloader( "input_columns": [List of columns to apply transforms to], # Optional "output_columns": [List of output columns] # Optional, only used if different from the `input columns` } + batch_transforms: Optional transformations to apply to the dataset. Identical to `transforms` but applied to + batches. project_columns: Optional list of output columns names from transformations. These names can be used for column selection or sorting in a specific order. shuffle: Whether to randomly sample data. Default is False. @@ -49,8 +52,9 @@ def create_dataloader( python_multiprocessing: Whether to use Python multiprocessing for data transformations. This option could be beneficial if the Python operation is computational heavy. Default is True. prefetch_size: The number of samples to prefetch (per device). Default is 16. - max_rowsize: Maximum size of row in MB that is used for shared memory allocation to copy data between processes. - This is only used if `python_multiprocessing` is set to `True`. Default is 64. + max_rowsize: (MindSpore 2.2 and lower only) Maximum size of row in MB that is used for shared memory allocation + to copy data between processes. This is only used if `python_multiprocessing` is set to `True`. + Default is 64. device_num: The number of devices to distribute the dataset across. Default is 1. rank_id: The rank ID of the current device. Default is 0. debug: Whether to enable debug mode. Default is False. @@ -82,7 +86,7 @@ def create_dataloader( ) if transforms is not None: - if not isinstance(transforms, list): + if isinstance(transforms, dict): transforms = [transforms] for transform in transforms: @@ -109,5 +113,16 @@ def create_dataloader( dataloader = dataloader.batch( batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers_batch ) + if batch_transforms is not None: + if isinstance(batch_transforms, dict): + batch_transforms = [batch_transforms] + + for batch_transform in batch_transforms: + dataloader = dataloader.map( + **batch_transform, + python_multiprocessing=python_multiprocessing, + num_parallel_workers=num_workers, + max_rowsize=max_rowsize if MS_VERSION < "2.3" else -1, + ) return dataloader diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index fe8e6d8c13..e92c279e94 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import List, Optional +from typing import List, Optional, Tuple, Union import mindspore as ms from mindspore.train.callback._callback import Callback, _handle_loss @@ -49,7 +49,7 @@ def __init__( model_name="sd", save_trainable_only: bool = False, param_save_filter: List[str] = None, - resume_prefix_blacklist: List[str] = None, + resume_prefix_blacklist: Optional[Union[str, Tuple[str, ...]]] = None, integrated_save=False, save_training_resume=True, train_steps=-1, @@ -59,7 +59,8 @@ def __init__( step_mode: if True, ckpt_save_interval is counted in steps. otherwise, in epochs. param_save_filter: indicates what parameters to save in checkpoint. If None, save all parameters in network. \ Otherwise, only params that contain one of the keyword in param_save_filter list will be saved. - resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint. e.g. ['swap.', 'vae.']. + resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint, + e.g. ('swap.', 'vae.'). """ self.rank_id = rank_id self.is_main_device = rank_id in [0, None] @@ -123,17 +124,11 @@ def __init__( self.use_step_unit = use_step_unit self.train_steps = train_steps self.save_training_resume = save_training_resume - if resume_prefix_blacklist is not None: - - def choice_func(x): - for prefix in resume_prefix_blacklist: - if x.startswith("vae."): - return False - return True - - self.choice_func = choice_func - else: - self.choice_func = None + self.choice_func = None + if resume_prefix_blacklist: + if isinstance(resume_prefix_blacklist, str): + resume_prefix_blacklist = (resume_prefix_blacklist,) + self.choice_func = lambda x: x.startswith(resume_prefix_blacklist) def on_train_step_end(self, run_context): cb_params = run_context.original_args() From 2b0673cbe78cde680a95a569577bd53dc47a42bd Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:37:32 +0800 Subject: [PATCH 078/122] fix training with TAE latents --- examples/moviegen/inference_tae_enc.py | 12 +++--- examples/moviegen/mg/dataset/dataset.py | 50 ++++++++++------------ examples/moviegen/mg/utils/model_utils.py | 2 +- examples/moviegen/train.py | 51 +++++++++++++++-------- mindone/trainers/callback.py | 2 +- 5 files changed, 63 insertions(+), 54 deletions(-) diff --git a/examples/moviegen/inference_tae_enc.py b/examples/moviegen/inference_tae_enc.py index 94ebc99e70..7301c2ce1b 100644 --- a/examples/moviegen/inference_tae_enc.py +++ b/examples/moviegen/inference_tae_enc.py @@ -73,13 +73,15 @@ def main(args): logger.info(key_info) for samples in tqdm(dataloader.create_tuple_iterator(num_epochs=1), total=dataloader.get_dataset_size()): - z, _, _ = tae.encode(samples[0]) - z = to_numpy(z) - for latent, path in zip(z, samples[1].tolist()): + _, mean, logvar = tae.encode(samples[0]) + mean, logvar = to_numpy(mean), to_numpy(logvar) + std = np.exp(0.5 * np.clip(logvar, -30.0, 20.0)) + + for m, s, path in zip(mean, std, samples[1].tolist()): out_path = save_dir / path out_path.parent.mkdir(parents=True, exist_ok=True) - np.save(out_path.with_suffix(".npy"), latent) - logger.info(f"Completed, Denoised latents saved in {save_dir}") + np.savez(out_path.with_suffix(".npz"), latent_mean=m, latent_std=s) + logger.info(f"Completed. Latents saved in {save_dir}") if __name__ == "__main__": diff --git a/examples/moviegen/mg/dataset/dataset.py b/examples/moviegen/mg/dataset/dataset.py index 30001b9452..a4f33bdb19 100644 --- a/examples/moviegen/mg/dataset/dataset.py +++ b/examples/moviegen/mg/dataset/dataset.py @@ -29,9 +29,9 @@ def __init__( text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, empty_text_emb: Optional[Union[str, Dict[str, str]]] = None, text_drop_prob: float = 0.2, - vae_latent_folder: Optional[str] = None, - vae_downsample_rate: float = 8.0, - vae_scale_factor: float = 0.18215, + tae_latent_folder: Optional[str] = None, + tae_scale_factor: float = 1.5305, + tae_shift_factor: float = 0.0609, target_size: Optional[Tuple[int, int]] = None, sample_n_frames: int = 17, sample_stride: int = 1, @@ -47,7 +47,7 @@ def __init__( "Text embedding during training is not supported, please provide `text_emb_folder`." ) - self._data = self._read_data(video_folder, csv_path, text_emb_folder, vae_latent_folder, filter_data) + self._data = self._read_data(video_folder, csv_path, text_emb_folder, tae_latent_folder, filter_data) self._frames = sample_n_frames self._stride = sample_stride self._min_length = (self._frames - 1) * self._stride + 1 @@ -62,9 +62,9 @@ def __init__( assert os.path.exists(path), f"Empty text embedding not found: {path}" self._text_drop_prob = text_drop_prob - self._vae_latent_folder = vae_latent_folder - self._vae_downsample_rate = vae_downsample_rate - self._vae_scale_factor = vae_scale_factor + self._tae_latent_folder = tae_latent_folder + self._tae_scale_factor = tae_scale_factor + self._tae_shift_factor = tae_shift_factor self._fmask_gen = frames_mask_generator self._t_compress_func = t_compress_func or (lambda x: x) @@ -83,7 +83,7 @@ def _read_data( data_dir: str, csv_path: str, text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, - vae_latent_folder: Optional[str] = None, + tae_latent_folder: Optional[str] = None, filter_data: bool = False, ) -> List[dict]: def _filter_data(sample_): @@ -99,8 +99,8 @@ def _filter_data(sample_): if not os.path.isfile(sample_["text_emb"][name]): _logger.warning(f"Text embedding not found: {sample_['text_emb'][name]}") return None - if "vae_latent" in sample_ and not os.path.isfile(sample_["vae_latent"]): - _logger.warning(f"Text embedding not found: {sample_['vae_latent']}") + if "tae_latent" in sample_ and not os.path.isfile(sample_["tae_latent"]): + _logger.warning(f"Text embedding not found: {sample_['tae_latent']}") return None return sample_ @@ -117,8 +117,8 @@ def _filter_data(sample_): name: os.path.join(path, Path(item["video"]).with_suffix(".npz")) for name, path in text_emb_folder.items() } - if vae_latent_folder: - sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npy")) + if tae_latent_folder: + sample["tae_latent"] = os.path.join(tae_latent_folder, Path(item["video"]).with_suffix(".npz")) data.append(sample) except KeyError as e: _logger.error(f"CSV file requires `video` (file paths) column, but got {list(item.keys())}") @@ -162,27 +162,19 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup with np.load(path) as td: data.update({enc_name + "_caption": td["text_emb"], enc_name + "_mask": td["mask"]}) - if self._vae_latent_folder: - # TODO: add support for images - vae_latent_data = np.load(data["vae_latent"]) - latent_mean, latent_std = vae_latent_data["latent_mean"], vae_latent_data["latent_std"] - if len(latent_mean) < self._min_length: + if self._tae_latent_folder: + tae_latent_data = np.load(data["tae_latent"]) + latent_mean, latent_std = tae_latent_data["latent_mean"], tae_latent_data["latent_std"] + if len(latent_mean) < self._min_length: # TODO: add support for images and buckets raise ValueError(f"Video is too short: {data['video']}") - if "fps" not in data: - if "fps" in vae_latent_data: - data["fps"] = vae_latent_data["fps"] - else: - with VideoReader(data["video"]) as reader: - data["fps"] = reader.fps - data["fps"] = np.array(data["fps"] / self._stride, dtype=np.float32) - start_pos = random.randint(0, len(latent_mean) - self._min_length) batch_index = np.linspace(start_pos, start_pos + self._min_length - 1, num_frames, dtype=int) latent_mean, latent_std = latent_mean[batch_index], latent_std[batch_index] - vae_latent = latent_mean + latent_std * np.random.standard_normal(latent_mean.shape) - data["video"] = vae_latent * self._vae_scale_factor + tae_latent = np.random.normal(latent_mean, latent_std).astype(np.float32) + tae_latent = (tae_latent - self._tae_shift_factor) * self._tae_scale_factor + data["video"] = np.transpose(tae_latent, (1, 0, 2, 3)) # FIXME: remove unnecessary transpose else: if data["video"].lower().endswith(IMAGE_EXT): @@ -204,7 +196,7 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup data["num_frames"] = np.array(num_frames, dtype=np.float32) if self._fmask_gen is not None: - # return frames mask with respect to the VAE's latent temporal compression + # return frames mask with respect to the TAE's latent temporal compression data["frames_mask"] = self._fmask_gen(self._t_compress_func(num_frames)) if self._transforms: @@ -249,7 +241,7 @@ def train_transforms( tokenizer: Optional[Callable[[str], np.ndarray]] = None, ) -> List[dict]: transforms = [] - if not self._vae_latent_folder: + if not self._tae_latent_folder: transforms.append( { "operations": [ diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py index 91ef90d0fc..a6ca1572b6 100644 --- a/examples/moviegen/mg/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -59,7 +59,7 @@ def __exit__(self, *args): def init_model( name: Literal["llama-1B", "llama-5B", "llama-30B"], - in_channels: int = 4, + in_channels: int = 16, pretrained_model_path: Optional[Path_fr] = None, enable_flash_attention: bool = True, model_parallelism: bool = False, diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 00444cc237..2beb1c5cc1 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -91,27 +91,40 @@ def main(args): # 2. model initialize and weight loading # 2.1 TAE - logger.info("TAE init") - # TODO: add support of training with latents - tae_args = args.tae.as_dict() - tae_dtype = tae_args.pop("dtype") - tae = TemporalAutoencoder(**tae_args).set_train(False) - if tae_dtype != "fp32": - # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative - amp.custom_mixed_precision( - tae, - black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], - dtype=MODEL_DTYPE[tae_dtype], - ) + if not args.dataset.tae_latent_folder or ( + args.valid.dataset and not args.valid.dataset.init_args.tae_latent_folder + ): + logger.info("TAE init") + tae_args = args.tae.as_dict() + tae_dtype = tae_args.pop("dtype") + tae = TemporalAutoencoder(**tae_args).set_train(False) + if tae_dtype != "fp32": + # FIXME: remove AMP and add custom dtype conversion support for better compatibility with PyNative + amp.custom_mixed_precision( + tae, + black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], + dtype=MODEL_DTYPE[tae_dtype], + ) + if args.model.in_channels != tae.out_channels: + logger.warning( + f"The number of model input channels ({args.model.in_channels}) doesn't match the number of TAE output" + f" channels ({tae.out_channels}). Setting it to {tae.out_channels}." + ) + args.model.in_channels = tae.out_channels + else: + logger.info("TAE latent folder found. Skipping TAE initialization.") + tae = None # 2.2 Llama 3 logger.info("Transformer init") - network = init_model(in_channels=tae.out_channels, **args.model) + network = init_model(**args.model) # 2.3 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) # 3. build training network - latent_diffusion_with_loss = DiffusionWithLoss(rflow_loss_wrapper, tae) + latent_diffusion_with_loss = DiffusionWithLoss( + rflow_loss_wrapper, tae, video_emb_cached=bool(args.dataset.tae_latent_folder) + ) # 4. build train & val datasets dataloader, dataset_len = initialize_dataset(args.dataset, args.dataloader, device_num, shard_rank_id) @@ -122,7 +135,9 @@ def main(args): args.valid.dataset.init_args, args.valid.dataloader, device_num, shard_rank_id ) eval_rflow_loss = RFlowEvalLoss(rflow_loss_wrapper, num_sampling_steps=args.valid.sampling_steps) - eval_diffusion_with_loss = DiffusionWithLoss(eval_rflow_loss, tae) + eval_diffusion_with_loss = DiffusionWithLoss( + eval_rflow_loss, tae, video_emb_cached=bool(args.valid.dataset.init_args.tae_latent_folder) + ) # 5. build training utils: lr, optim, callbacks, trainer # 5.1 LR @@ -142,7 +157,7 @@ def main(args): # if bucketing is used in Graph mode, activate dynamic inputs if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict): bs = Symbol(unique=True) - video = Tensor(shape=[bs, None, 3, None, None], dtype=mstype.float32) + video = Tensor(shape=[bs, None, args.model.in_channels, None, None], dtype=mstype.float32) # FIXME: fix sequence length ul2_emb = Tensor(shape=[bs, 300, 4096], dtype=mstype.float32) byt5_emb = Tensor(shape=[bs, 100, 1472], dtype=mstype.float32) @@ -192,7 +207,7 @@ def main(args): # 5.5 print out key info and save config if rank_id == 0: - num_params_tae, num_params_trainable_tae = count_params(tae) + num_params_tae, num_params_trainable_tae = count_params(tae) if tae is not None else (0, 0) num_params_network, num_params_trainable_network = count_params(network) num_params = num_params_tae + num_params_network num_params_trainable = num_params_trainable_tae + num_params_trainable_network @@ -244,7 +259,7 @@ def main(args): help="Path to load a config yaml file that describes the setting which will override the default arguments.", ) parser.add_function_arguments(init_train_env, "env") - parser.add_function_arguments(init_model, "model", skip={"in_channels"}) + parser.add_function_arguments(init_model, "model") parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) parser.add_argument( "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index e92c279e94..14b76811a7 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -128,7 +128,7 @@ def __init__( if resume_prefix_blacklist: if isinstance(resume_prefix_blacklist, str): resume_prefix_blacklist = (resume_prefix_blacklist,) - self.choice_func = lambda x: x.startswith(resume_prefix_blacklist) + self.choice_func = lambda x: not x.startswith(resume_prefix_blacklist) def on_train_step_end(self, run_context): cb_params = run_context.original_args() From d763d6bd42af3501c70dddc8c3026786e4af90c1 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 10 Dec 2024 18:15:00 +0800 Subject: [PATCH 079/122] revert changes to OpenSora --- examples/README.md | 3 - .../animatediff/ad/models/diffusion/ddpm.py | 5 +- .../opensora-v1-1/train/train_stage2.yaml | 18 +-- .../opensora-v1-1/train/train_stage3.yaml | 18 +-- .../opensora/acceleration/parallel_states.py | 42 ++----- .../opensora/datasets/transforms.py | 110 +++++++++++++--- .../datasets/video_dataset_refactored.py | 119 +++++++++--------- .../opensora/models/stdit/__init__.py | 1 - .../opensora/models/stdit/stdit_llama3.py | 74 ----------- .../opensora/models/text_encoder/t5.py | 15 +-- .../opensora_hpcai/opensora/models/vae/vae.py | 58 ++++----- .../opensora/pipelines/train_pipeline.py | 4 - .../opensora/schedulers/rectified_flow.py | 2 - examples/opensora_hpcai/scripts/args_train.py | 20 +-- examples/opensora_hpcai/scripts/infer_t5.py | 56 ++------- examples/opensora_hpcai/scripts/inference.py | 95 ++++---------- examples/opensora_hpcai/scripts/train.py | 61 ++------- 17 files changed, 262 insertions(+), 439 deletions(-) delete mode 100644 examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py diff --git a/examples/README.md b/examples/README.md index 838b30d2ae..a9a238a856 100644 --- a/examples/README.md +++ b/examples/README.md @@ -21,6 +21,3 @@ | [llava](https://github.com/mindspore-lab/mindone/blob/master/examples/llava) | Haotian-Liu official | https://github.com/haotian-liu/LLaVA | [vila](https://github.com/mindspore-lab/mindone/blob/master/examples/vila) | Nvidia Lab official | https://github.com/NVlabs/VILA | [pllava](https://github.com/mindspore-lab/mindone/blob/master/examples/pllava) | Magic Research official | https://github.com/magic-research/PLLaVA -| [dynamicrafter](https://github.com/mindspore-lab/mindone/blob/master/examples/dynamicrafter) | Tencent Research official | https://github.com/Doubiiu/DynamiCrafter -| [hunyuan_dit](https://github.com/mindspore-lab/mindone/blob/master/examples/hunyuan_dit) | Tencent Research official | https://github.com/Tencent/HunyuanDiT -| [pixart_sigma](https://github.com/mindspore-lab/mindone/blob/master/examples/pixart_sigma) | Noah Lab official | https://github.com/PixArt-alpha/PixArt-sigma \ No newline at end of file diff --git a/examples/animatediff/ad/models/diffusion/ddpm.py b/examples/animatediff/ad/models/diffusion/ddpm.py index ae23a9d6ca..8010e347d4 100644 --- a/examples/animatediff/ad/models/diffusion/ddpm.py +++ b/examples/animatediff/ad/models/diffusion/ddpm.py @@ -365,10 +365,7 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs """ # 1. get image/video latents z using vae - if self.emb_cache: - z = x - else: - z = self.get_latents(x) + z = self.get_latents(x) # 2. sample timestep and add noise to latents t = self.uniform_int( diff --git a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml index 093c520c38..661b76b627 100644 --- a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml +++ b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage2.yaml @@ -43,15 +43,15 @@ epochs: 2000 ckpt_save_interval: 100 mask_ratios: - identity: 0.75 - quarter_random: 0.025 - quarter_head: 0.025 - quarter_tail: 0.025 - quarter_head_tail: 0.05 - image_random: 0.025 - image_head: 0.025 - image_tail: 0.025 - image_head_tail: 0.05 + mask_no: 0.75 + mask_quarter_random: 0.025 + mask_quarter_head: 0.025 + mask_quarter_tail: 0.025 + mask_quarter_head_tail: 0.05 + mask_image_random: 0.025 + mask_image_head: 0.025 + mask_image_tail: 0.025 + mask_image_head_tail: 0.05 bucket_config: # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } diff --git a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml index dd085233ce..8463e37a51 100644 --- a/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml +++ b/examples/opensora_hpcai/configs/opensora-v1-1/train/train_stage3.yaml @@ -43,15 +43,15 @@ epochs: 2000 ckpt_save_interval: 100 mask_ratios: - identity: 0.75 - quarter_random: 0.025 - quarter_head: 0.025 - quarter_tail: 0.025 - quarter_head_tail: 0.05 - image_random: 0.025 - image_head: 0.025 - image_tail: 0.025 - image_head_tail: 0.05 + mask_no: 0.75 + mask_quarter_random: 0.025 + mask_quarter_head: 0.025 + mask_quarter_tail: 0.025 + mask_quarter_head_tail: 0.05 + mask_image_random: 0.025 + mask_image_head: 0.025 + mask_image_tail: 0.025 + mask_image_head_tail: 0.05 bucket_config: # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } diff --git a/examples/opensora_hpcai/opensora/acceleration/parallel_states.py b/examples/opensora_hpcai/opensora/acceleration/parallel_states.py index b7e5f1e7f3..c60b9c3932 100644 --- a/examples/opensora_hpcai/opensora/acceleration/parallel_states.py +++ b/examples/opensora_hpcai/opensora/acceleration/parallel_states.py @@ -13,41 +13,23 @@ def get_sequence_parallel_group() -> Optional[str]: return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) -def set_model_parallel_group(group: str) -> None: - _GLOBAL_PARALLEL_GROUPS["model"] = group - - -def get_model_parallel_group() -> Optional[str]: - return _GLOBAL_PARALLEL_GROUPS.get("model", None) - - -def create_parallel_group(sequence_parallel_shards: int = 1, model_parallel_shards: int = 1) -> None: - if sequence_parallel_shards <= 1 and model_parallel_shards <= 1: +def create_parallel_group(sequence_parallel_shards: int) -> None: + if sequence_parallel_shards <= 1: raise ValueError( - f"`sequence_parallel_shards`/`model_parallel_shards` must be larger than 1 " - f"to enable sequence/model parallel, but get `{sequence_parallel_shards}` and `{model_parallel_shards}`." + f"`sequence_parallel_shards` must be larger than 1 to enable sequence parallel, but get `{sequence_parallel_shards}`." ) device_num = get_group_size() - if device_num % sequence_parallel_shards != 0 or device_num % model_parallel_shards != 0: + if device_num % sequence_parallel_shards != 0: raise ValueError( - f"Total number of devices ({device_num}) must be divisible by the number of " - f"sequence parallel shards ({sequence_parallel_shards}) and model parallel shards ({model_parallel_shards})." + f"Total number of devices ({device_num}) must be devisible by the number of sequence parallel shards ({sequence_parallel_shards})." ) rank_id = get_rank() - - if sequence_parallel_shards > 1: - sp_group_id = rank_id // sequence_parallel_shards - sp_group_rank_ids = list( - range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) - ) - sp_group_name = f"sp_group_{sp_group_id}" - create_group(sp_group_name, sp_group_rank_ids) - set_sequence_parallel_group(sp_group_name) - elif model_parallel_shards > 1: # not compatible with SP currently - mp_group_id = rank_id // model_parallel_shards - mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) - mp_group_name = f"mp_group_{mp_group_id}" - create_group(mp_group_name, mp_group_rank_ids) - set_model_parallel_group(mp_group_name) + sp_group_id = rank_id // sequence_parallel_shards + sp_group_rank_ids = list( + range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) + ) + sp_group_name = f"sp_group_{sp_group_id}" + create_group(sp_group_name, sp_group_rank_ids) + set_sequence_parallel_group(sp_group_name) diff --git a/examples/opensora_hpcai/opensora/datasets/transforms.py b/examples/opensora_hpcai/opensora/datasets/transforms.py index d69a01f2f2..c37fc37435 100644 --- a/examples/opensora_hpcai/opensora/datasets/transforms.py +++ b/examples/opensora_hpcai/opensora/datasets/transforms.py @@ -1,24 +1,102 @@ -from typing import Optional, Tuple +from typing import Tuple import cv2 import numpy as np +from mindspore.dataset.transforms import Compose +from mindspore.dataset.vision import CenterCrop, Inter +from mindspore.dataset.vision import Resize as MSResize -class ResizeCrop: - def __init__(self, size: Optional[Tuple[int, int]] = None, interpolation=cv2.INTER_LINEAR): - self._size = size +from .bucket import Bucket + + +class Resize: + def __init__(self, size: Tuple[int, int], interpolation=Inter.BILINEAR): + self._h, self._w = size self._inter = interpolation - def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np.ndarray: - h, w = x.shape[-3:-1] # support images and videos - th, tw = size or self._size - scale = max(th / h, tw / w) - if scale != 1: # resize - if x.ndim == 3: # if image - x = cv2.resize(x, None, fx=scale, fy=scale, interpolation=self._inter) - else: # if video - x = np.array([cv2.resize(i, None, fx=scale, fy=scale, interpolation=self._inter) for i in x]) - if x.shape[-3:-1] != (th, tw): # crop - i, j = round((x.shape[-3] - th) / 2.0), round((x.shape[-2] - tw) / 2.0) - x = x[..., i : i + th, j : j + tw, :] + def __call__(self, x: np.ndarray) -> np.ndarray: + img_h, img_w = x.shape[-3:-1] # support images and videos + scale = max(self._h / img_h, self._w / img_w) + if scale != 1: + x = MSResize((round(img_h * scale), round(img_w * scale)), self._inter)(x) return x + + +class BucketResizeCrop: + def __init__(self, buckets: Bucket): + self._transforms = {} # is this reasonable? There are 350+ buckets + for name, lengths in buckets.ar_criteria.items(): + self._transforms[name] = {} + for length, ars in lengths.items(): + self._transforms[name][str(length)] = {} + for ar, hw in ars.items(): + self._transforms[name][str(length)][ar] = Compose( + [MSResize(min(hw), interpolation=Inter.BILINEAR), CenterCrop(hw)] + ) + + def __call__(self, x, bucket_id): + return self._transforms[bucket_id[0]][bucket_id[1]][bucket_id[2]](x) + + +class ResizeAndCrop: + """Resize an RGB image to a target size while preserving the aspect ratio and cropping it. + Align to resize_crop_to_fill in torch. Ensure no black surrounding produced. + """ + + def __init__(self, target_height, target_width): + super(ResizeAndCrop, self).__init__() + self.tar_h = target_height + self.tar_w = target_width + + def __call__(self, img): + # Ensure the image is in RGB format + if img.shape[2] != 3: + raise ValueError("Input image must be in RGB format with 3 channels.") + + h, w = img.shape[:2] + th, tw = self.tar_h, self.tar_w # target + rh, rw = th / h, tw / w # ratio + + if rh > rw: + # target image is thinner than the original image + new_h, new_w = th, round(w * rh) + start_y = 0 + start_x = int(round(new_w - tw) / 2.0) + else: + # target image is fatter than the original image + new_h, new_w = round(h * rw), tw + start_y = int(round(new_h - th) / 2.0) + start_x = 0 + + if rh > rw: + new_h, new_w = th, round(w * rh) + start_y = 0 + start_x = int(round(new_w - tw)) + + # Resize the image + # NOTE: for opensora v1.2, HD videos are mainly downsampled according to buckets. The best choice for down-sample interpolation is INTER_AREA. + resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + + # Crop the image to the target size + cropped_img = resized_img[start_y : start_y + self.tar_h, start_x : start_x + self.tar_w] + + return cropped_img + + +class BucketResizeAndCrop(object): + """According to bucket config, resize an RGB image to a target size while preserving the aspect ratio and cropping it.""" + + def __init__(self, buckets): + super().__init__() + self._transforms = {} # is this reasonable? There are 350+ buckets + for name, lengths in buckets.ar_criteria.items(): + self._transforms[name] = {} + for length, ars in lengths.items(): + self._transforms[name][str(length)] = {} + for ar, hw in ars.items(): + self._transforms[name][str(length)][ar] = ResizeAndCrop(hw[0], hw[1]) + + def __call__(self, image, bucket_id=None): + resized_img = self._transforms[bucket_id[0]][str(bucket_id[1])][bucket_id[2]](image) + return resized_img diff --git a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py index d35653dc0e..8c91f917c0 100644 --- a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py +++ b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py @@ -6,7 +6,7 @@ import sys from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple import cv2 import numpy as np @@ -14,11 +14,12 @@ import mindspore as ms from mindspore.dataset.transforms import Compose +from mindspore.dataset.vision import CenterCrop, Inter, Normalize from mindone.data.video_reader import VideoReader as VideoReader_CV2 from .bucket import Bucket -from .transforms import ResizeCrop +from .transforms import BucketResizeAndCrop, BucketResizeCrop, Resize, ResizeAndCrop # FIXME: remove in future when mindone is ready for install sys.path.append(os.path.join(os.path.dirname(__file__), "../../../..")) @@ -29,26 +30,39 @@ _logger = logging.getLogger(__name__) -IMAGE_EXT = (".jpg", ".jpeg", ".png", ".gif", ".webp") - -def create_infer_transforms(target_size: Tuple[int, int], interpolation=cv2.INTER_LINEAR): +def create_infer_transforms(target_size: Tuple[int, int], interpolation=Inter.BILINEAR): return Compose( [ - ResizeCrop(target_size, interpolation=interpolation), - lambda x: x.astype(np.float32) / 127.5 - 1, + Resize(target_size, interpolation=interpolation), + CenterCrop(target_size), + lambda x: (x / 255.0).astype(np.float32), # ms.ToTensor() doesn't support 4D data + Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), lambda x: x[None, ...] if x.ndim == 3 else x, # if image - lambda x: np.transpose(x, (0, 3, 1, 2)), + lambda x: np.transpose(x, (0, 3, 1, 2)), # ms.HWC2CHW() doesn't support 4D data ] ) +def create_train_transforms(target_size, buckets=None): + """ + expect rgb image in range 0-255, shape (h w c) + """ + + if buckets is None: + transforms = ResizeAndCrop(target_size[0], target_size[1]) + else: + transforms = BucketResizeAndCrop(buckets) + + return transforms + + class VideoDatasetRefactored(BaseDataset): def __init__( self, csv_path: str, video_folder: str, - text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, + text_emb_folder: Optional[str] = None, vae_latent_folder: Optional[str] = None, vae_downsample_rate: float = 8.0, vae_scale_factor: float = 0.18215, @@ -129,7 +143,7 @@ def __init__( self.apply_train_transforms = apply_train_transforms if self.apply_train_transforms: - self.pixel_transforms = ResizeCrop(target_size, interpolation=cv2.INTER_AREA) + self.pixel_transforms = create_train_transforms(target_size, buckets=buckets) if "bucket_id" in self.output_columns: self.output_columns.remove("bucket_id") assert not pre_patchify, "transforms for prepatchify not implemented yet" @@ -142,7 +156,7 @@ def __init__( def _read_data( data_dir: str, csv_path: str, - text_emb_folder: Optional[Union[str, Dict[str, str]]] = None, + text_emb_folder: Optional[str] = None, vae_latent_folder: Optional[str] = None, filter_data: bool = False, ) -> List[dict]: @@ -164,13 +178,7 @@ def _filter_data(sample_): for item in csv.DictReader(csv_file): sample = {**item, "video": os.path.join(data_dir, item["video"])} if text_emb_folder: - if isinstance(text_emb_folder, str): - sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz")) - else: - sample["text_emb"] = { - name: os.path.join(path, Path(item["video"]).with_suffix(".npz")) - for name, path in text_emb_folder.items() - } + sample["text_emb"] = os.path.join(text_emb_folder, Path(item["video"]).with_suffix(".npz")) if vae_latent_folder: sample["vae_latent"] = os.path.join(vae_latent_folder, Path(item["video"]).with_suffix(".npz")) data.append(sample) @@ -209,15 +217,9 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: num_frames = self._frames if self._text_emb_folder: - if isinstance(self._text_emb_folder, str): - with np.load(text_emb_path) as td: - data["caption"] = td["text_emb"] - data["mask"] = td["mask"].astype(np.uint8) - else: - for enc_name, path in text_emb_path.items(): - with np.load(path) as td: - data[enc_name + "_caption"] = td["text_emb"] - data[enc_name + "_mask"] = td["mask"].astype(np.uint8) + with np.load(text_emb_path) as td: + data["caption"] = td["text_emb"] + data["mask"] = td["mask"].astype(np.uint8) if self._vae_latent_folder: # pick a resolution randomly if there are multi-resolution latents in vae folder @@ -286,33 +288,28 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: ) # / self._stride # FIXME: OS v1.1 incorrect del reader elif self.video_backend == "cv2": - if video_path.lower().endswith(IMAGE_EXT): - num_frames = 1 - data["fps"] = np.array(120, dtype=np.float32) # FIXME: extract as IMG_FPS - video = cv2.cvtColor(cv2.imread(data["video"]), cv2.COLOR_BGR2RGB) - else: - with VideoReader_CV2(video_path) as reader: - min_length = self._min_length - if self._buckets: - data["bucket_id"] = self._buckets.get_bucket_id( - T=len(reader), - H=reader.shape[1], - W=reader.shape[0], - frame_interval=self._stride, + with VideoReader_CV2(video_path) as reader: + min_length = self._min_length + if self._buckets: + data["bucket_id"] = self._buckets.get_bucket_id( + T=len(reader), + H=reader.shape[1], + W=reader.shape[0], + frame_interval=self._stride, + ) + if data["bucket_id"] is None: + raise ValueError( + f"Couldn't assign a bucket to {data['video']}" + f" (T={len(reader)}, H={reader.shape[1]}, W={reader.shape[0]})." ) - if data["bucket_id"] is None: - raise ValueError( - f"Couldn't assign a bucket to {data['video']}" - f" (T={len(reader)}, H={reader.shape[1]}, W={reader.shape[0]})." - ) - num_frames, *_ = self._buckets.get_thw(data["bucket_id"]) - min_length = (num_frames - 1) * self._stride + 1 - - if len(reader) < min_length: - raise ValueError(f"Video is too short: {video_path}") - start_pos = random.randint(0, len(reader) - min_length) - video = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride) - data["fps"] = np.array(reader.fps, dtype=np.float32) + num_frames, *_ = self._buckets.get_thw(data["bucket_id"]) + min_length = (num_frames - 1) * self._stride + 1 + + if len(reader) < min_length: + raise ValueError(f"Video is too short: {video_path}") + start_pos = random.randint(0, len(reader) - min_length) + video = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride) + data["fps"] = np.array(reader.fps, dtype=np.float32) else: # TODO: add pyav backend and test raise NotImplementedError @@ -328,9 +325,14 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: # apply transforms on video frames here if self.apply_train_transforms: # variable resize and crop, frame-wise - clip = self.pixel_transforms(video) - if clip.ndim == 3: - clip = np.expand_dims(clip, 0) + clip = [] + for i in range(num_frames): + if self._buckets: + resized_img = self.pixel_transforms(video[i], bucket_id=data["bucket_id"]) + else: + resized_img = self.pixel_transforms(video[i]) + clip.append(resized_img) + clip = np.stack(clip, axis=0) # transpose and norm, clip-wise clip = np.transpose(clip, (0, 3, 1, 2)) @@ -410,7 +412,7 @@ def train_transforms( transforms.extend( [ { - "operations": ResizeCrop(interpolation=cv2.INTER_AREA), + "operations": BucketResizeCrop(self._buckets), "input_columns": ["video", "bucket_id"], "output_columns": ["video"], # drop `bucket_id` column }, @@ -430,7 +432,8 @@ def train_transforms( transforms.append( { "operations": [ - ResizeCrop(target_size, interpolation=cv2.INTER_AREA), + Resize(target_size, interpolation=Inter.BILINEAR), + CenterCrop(target_size), lambda x: np.divide(x, 127.5, dtype=np.float32), lambda x: np.subtract(x, 1.0, dtype=np.float32), lambda x: np.transpose(x, (0, 3, 1, 2)), diff --git a/examples/opensora_hpcai/opensora/models/stdit/__init__.py b/examples/opensora_hpcai/opensora/models/stdit/__init__.py index dc2c63cb06..7957e9000f 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/__init__.py +++ b/examples/opensora_hpcai/opensora/models/stdit/__init__.py @@ -1,4 +1,3 @@ from .stdit import STDiT_XL_2 from .stdit2 import STDiT2_XL_2 from .stdit3 import STDiT3_3B_2, STDiT3_XL_2 -from .stdit_llama3 import STDiTLlama3Wrapper diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py deleted file mode 100644 index a2e9f559ee..0000000000 --- a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -from typing import Literal, Optional - -from moviegen import llama3_1B, llama3_5B, llama3_30B - -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor, load_checkpoint, load_param_into_net - - -class STDiTLlama3Wrapper(nn.Cell): - def __init__(self, model_size: Literal["1B", "5B", "30B"] = "1B", **kwargs): - super().__init__(auto_prefix=False) - - attn_implementation = "flash_attention" if kwargs.get("enable_flashattn", False) else "eager" - gradient_checkpointing = kwargs.get("use_recompute", False) - model_parallelism = kwargs.get("enable_model_parallelism", False) - - model_kwargs = dict( - in_channels=4, - out_channels=8, - attn_implementation=attn_implementation, - gradient_checkpointing=gradient_checkpointing, - model_parallelism=model_parallelism, - ) - - if model_size == "1B": - self.llama = llama3_1B(**model_kwargs) - elif model_size == "5B": - self.llama = llama3_5B(**model_kwargs) - else: - self.llama = llama3_30B(**model_kwargs) - - self.patch_size = self.llama.patch_size - self.hidden_size = self.llama.hidden_size - self.num_heads = self.llama.num_attention_heads - self.input_sq_size = None - self.in_channels = self.llama.in_channels - - def construct( - self, - x: Tensor, - timestep: Tensor, - y: Tensor, - mask: Optional[Tensor] = None, - frames_mask: Optional[Tensor] = None, - fps: Optional[Tensor] = None, - height: Optional[Tensor] = None, - width: Optional[Tensor] = None, - extra_text_embed1: Optional[Tensor] = None, - extra_mask1: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - x = ops.transpose(x, (0, 2, 1, 3, 4)) - - ul2_emb = ops.squeeze(y, axis=1) - metaclip_emb = ops.ones((extra_text_embed1.shape[0], 100, 1280), dtype=extra_text_embed1.dtype) - byt5_emb = extra_text_embed1 - - latent_embedding = x - output = self.llama(latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb) - output = ops.transpose(output, (0, 2, 1, 3, 4)) - return output - - def load_from_checkpoint(self, ckpt_path): - if not os.path.exists(ckpt_path): - print(f"WARNING: {ckpt_path} not found. No checkpoint loaded!!") - else: - sd = load_checkpoint(ckpt_path) - sd = {k.replace("network.llama.", "").replace("_backbone.", ""): v for k, v in sd.items()} - - m, u = load_param_into_net(self, sd, strict_load=True) - print("net param not load: ", m, len(m)) - print("ckpt param not load: ", u, len(u)) diff --git a/examples/opensora_hpcai/opensora/models/text_encoder/t5.py b/examples/opensora_hpcai/opensora/models/text_encoder/t5.py index 29609a98a6..6c79b68352 100644 --- a/examples/opensora_hpcai/opensora/models/text_encoder/t5.py +++ b/examples/opensora_hpcai/opensora/models/text_encoder/t5.py @@ -3,7 +3,6 @@ import logging import os import re -import sys import urllib.parse as ul import ftfy @@ -15,10 +14,6 @@ from .flan_t5_large.t5 import get_t5_encoder -# FIXME: remove in future when mindone is ready for install -sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) -from mindone.transformers import T5EncoderModel - logger = logging.getLogger(__name__) @@ -233,15 +228,7 @@ def get_text_encoder_and_tokenizer(name, ckpt_path, **kwargs): logger.info("T5 init") text_encoder = T5Embedder(cache_dir=ckpt_path, pretrained_ckpt=os.path.join(ckpt_path, "model.ckpt"), **kwargs) tokenizer = text_encoder.tokenizer - elif name.lower() == "ul2": - logger.info("UL2 init") - tokenizer = AutoTokenizer.from_pretrained("google/ul2", local_files_only=True, cache_dir=ckpt_path) - text_encoder = T5EncoderModel.from_pretrained("google/ul2", local_files_only=True, cache_dir=ckpt_path) - elif name.lower() == "byt5": - logger.info("ByT5 init") - tokenizer = AutoTokenizer.from_pretrained("google/byt5-small", local_files_only=True, cache_dir=ckpt_path) - text_encoder = T5EncoderModel.from_pretrained("google/byt5-small", local_files_only=True, cache_dir=ckpt_path) else: - raise NotImplementedError(f"Unknown text encoder: {name}") + raise NotImplementedError return text_encoder, tokenizer diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index 982cf8dd1b..676b75aa7f 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -19,7 +19,7 @@ _logger = logging.getLogger(__name__) SD_CONFIG = { "double_z": True, - "z_channels": 4, # TODO: set 16 + "z_channels": 4, "resolution": 256, "in_channels": 3, "out_ch": 3, @@ -33,6 +33,34 @@ SDXL_CONFIG.update({"resolution": 512}) +class AutoencoderKL(AutoencoderKL_SD): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.split = get_split_op() + + def init_from_ckpt(self, path, ignore_keys=list()): + if not os.path.exists(path): + raise ValueError( + "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" + ) + param_dict = ms.load_checkpoint(path) + param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) + if param_not_load or ckpt_not_load: + _logger.warning( + f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" + ) + + def encode_with_moments_output(self, x): + """For latent caching usage""" + h = self.encoder(x) + moments = self.quant_conv(h) + mean, logvar = self.split(moments, moments.shape[1] // 2, 1) + logvar = ops.clip_by_value(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + + return mean, std + + class VideoAutoencoderKL(nn.Cell): """ Spatial VAE @@ -455,31 +483,3 @@ def OpenSoraVAE_V1_2( pu, cu = ms.load_param_into_net(model.spatial_vae, sd, strict_load=False) return model - - -class AutoencoderKL(AutoencoderKL_SD): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.split = get_split_op() - - def init_from_ckpt(self, path, ignore_keys=list()): - if not os.path.exists(path): - raise ValueError( - "Maybe download failed. Please download the VAE encoder from https://huggingface.co/stabilityai/sd-vae-ft-ema" - ) - param_dict = ms.load_checkpoint(path) - param_not_load, ckpt_not_load = ms.load_param_into_net(self, param_dict, strict_load=True) - if param_not_load or ckpt_not_load: - _logger.warning( - f"{param_not_load} in network is not loaded or {ckpt_not_load} in checkpoint is not loaded!" - ) - - def encode_with_moments_output(self, x): - """For latent caching usage""" - h = self.encoder(x) - moments = self.quant_conv(h) - mean, logvar = self.split(moments, moments.shape[1] // 2, 1) - logvar = ops.clip_by_value(logvar, -30.0, 20.0) - std = self.exp(0.5 * logvar) - - return mean, std diff --git a/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py index 6b0fe50680..b49f025afa 100644 --- a/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/train_pipeline.py @@ -125,8 +125,6 @@ def construct( width: Optional[Tensor] = None, fps: Optional[Tensor] = None, ar: Optional[Tensor] = None, - extra_text_tokens1: Optional[Tensor] = None, - extra_mask1: Optional[Tensor] = None, ): """ Video diffusion model forward and loss computation for training @@ -168,8 +166,6 @@ def construct( width=width, fps=fps, ar=ar, - extra_text_embed1=extra_text_tokens1, - extra_mask1=extra_mask1, ) return loss diff --git a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py index 25522c0b4f..8a481cde8c 100644 --- a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py +++ b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py @@ -87,8 +87,6 @@ def __call__( noise_added = mask_t_upper pred = model(z, t, **model_kwargs) - # FIXME: a tmp solution for inference with cfg==1.0 - pred = pred[:, :4] # update z dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index 1dd144848d..dc32ad427c 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -41,8 +41,6 @@ def parse_train_args(parser): ) parser.add_argument("--video_folder", required=True, type=str, help="root dir for the video data") parser.add_argument("--text_embed_folder", type=str, help="root dir for the text embedding data") - parser.add_argument("--ul2_text_embed_folder", type=str, help="root dir for the text embedding data") - parser.add_argument("--byt5_text_embed_folder", type=str, help="root dir for the text embedding data") parser.add_argument("--vae_latent_folder", type=str, help="root dir for the vae latent data") parser.add_argument("--filter_data", default=False, type=str2bool, help="Filter non-existing videos.") parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") @@ -51,11 +49,7 @@ def parse_train_args(parser): ) # model parser.add_argument( - "--model_version", - default="v1", - type=str, - choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], - help="OpenSora model version.", + "--model_version", default="v1", type=str, choices=["v1", "v1.1", "v1.2"], help="OpenSora model version." ) parser.add_argument( "--pretrained_model_path", @@ -336,18 +330,6 @@ def parse_train_args(parser): type=int, help="The number of shards in sequence parallel. Default is 1.", ) - parser.add_argument( - "--enable_model_parallelism", - default=False, - type=str2bool, - help="whether to enable model parallelism. Default is False. Only for LLama3 strcture,", - ) - parser.add_argument( - "--model_parallel_shards", - default=1, - type=int, - help="The number of shards in model parallel. Default is 1.", - ) parser.add_argument("--drop_overflow_update", default=True, type=str2bool, help="drop overflow update") parser.add_argument("--loss_scaler_type", default="dynamic", type=str, help="dynamic or static") parser.add_argument( diff --git a/examples/opensora_hpcai/scripts/infer_t5.py b/examples/opensora_hpcai/scripts/infer_t5.py index 974474e16d..7a04f5a4de 100644 --- a/examples/opensora_hpcai/scripts/infer_t5.py +++ b/examples/opensora_hpcai/scripts/infer_t5.py @@ -22,7 +22,6 @@ from opensora.utils.cond_data import read_captions_from_csv, read_captions_from_txt from opensora.utils.model_utils import str2bool # _check_cfgs_in_parser -from mindone.transformers.models.t5.modeling_t5 import T5LayerNorm from mindone.utils.amp import auto_mixed_precision from mindone.utils.logger import set_logger from mindone.utils.misc import to_abspath @@ -128,19 +127,15 @@ def main(args): logger.info(f"Num batches: {dataset_size}") # model initiate and weight loading - ckpt_path = args.model_dir - text_encoder, tokenizer = get_text_encoder_and_tokenizer( - args.model, ckpt_path, model_max_length=args.model_max_length - ) + ckpt_path = args.t5_model_dir + text_encoder, tokenizer = get_text_encoder_and_tokenizer("t5", ckpt_path, model_max_length=args.model_max_length) text_encoder.set_train(False) for param in text_encoder.get_parameters(): # freeze latte_model param.requires_grad = False dtype_map = {"fp16": ms.float16, "bf16": ms.bfloat16} if args.dtype in ["fp16", "bf16"]: - text_encoder = auto_mixed_precision( - text_encoder, amp_level=args.amp_level, custom_fp32_cells=[T5LayerNorm], dtype=dtype_map[args.dtype] - ) + text_encoder = auto_mixed_precision(text_encoder, amp_level=args.amp_level, dtype=dtype_map[args.dtype]) # infer if args.csv_path is not None: @@ -160,22 +155,8 @@ def main(args): captions = [str(captions[i]) for i in range(len(captions))] # print(captions) - if args.model.lower() == "t5": - text_tokens, mask = text_encoder.get_text_tokens_and_mask(captions, return_tensor=True) - text_emb = text_encoder(text_tokens, mask) - else: - text_tokens_and_mask = tokenizer( - captions, - max_length=args.model_max_length, - padding="max_length", - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="np", - ) - text_tokens = ms.Tensor(text_tokens_and_mask["input_ids"], dtype=ms.int32) - mask = ms.Tensor(text_tokens_and_mask["attention_mask"], dtype=ms.float32) - text_emb = text_encoder(input_ids=text_tokens, attention_mask=mask)[0] + text_tokens, mask = text_encoder.get_text_tokens_and_mask(captions, return_tensor=True) + text_emb = text_encoder(text_tokens, mask) end_time = time.time() time_cost = end_time - start_time @@ -218,22 +199,8 @@ def main(args): batch_prompts = captions[i : i + args.batch_size] ns = len(batch_prompts) - if args.model.lower() == "t5": - batch_text_tokens, batch_mask = text_encoder.get_text_tokens_and_mask(batch_prompts, return_tensor=True) - batch_text_emb = text_encoder(batch_text_tokens, batch_mask) - else: - text_tokens_and_mask = tokenizer( - batch_prompts, - max_length=args.model_max_length, - padding="max_length", - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="np", - ) - batch_text_tokens = ms.Tensor(text_tokens_and_mask["input_ids"], dtype=ms.int32) - batch_mask = ms.Tensor(text_tokens_and_mask["attention_mask"], dtype=ms.float32) - batch_text_emb = text_encoder(input_ids=batch_text_tokens, attention_mask=batch_mask)[0] + batch_text_tokens, batch_mask = text_encoder.get_text_tokens_and_mask(batch_prompts, return_tensor=True) + batch_text_emb = text_encoder(batch_text_tokens, batch_mask) # save result batch_mask = batch_mask.asnumpy().astype(np.uint8) @@ -278,9 +245,8 @@ def parse_args(): help="output dir to save the embeddings, if None, will treat the parent dir of csv_path as output dir.", ) parser.add_argument("--caption_column", type=str, default="caption", help="caption column num in csv") - parser.add_argument("--model", default="t5", type=str, choices=["t5", "ul2", "byt5"], help="Name of the model.") - parser.add_argument("--model_dir", type=str, help="the T5 cache folder path") - parser.add_argument("--model_max_length", type=int, default=120, help="Model's embedded sequence length.") + parser.add_argument("--t5_model_dir", default="models/t5-v1_1-xxl", type=str, help="the T5 cache folder path") + parser.add_argument("--model_max_length", type=int, default=120, help="T5's embedded sequence length.") # MS new args parser.add_argument("--device_target", type=str, default="Ascend", help="Ascend or GPU") parser.add_argument("--mode", type=int, default=0, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1) (default=0)") @@ -338,7 +304,7 @@ def parse_args(): parser.set_defaults( **dict( captions=cfg["captions"], - model_dir=cfg["model_dir"], + t5_model_dir=cfg["t5_model_dir"], ) ) args = parser.parse_args() @@ -346,7 +312,7 @@ def parse_args(): args.csv_path = to_abspath(abs_path, args.csv_path) args.prompt_path = to_abspath(abs_path, args.prompt_path) args.output_path = to_abspath(abs_path, args.output_path) - args.model_dir = to_abspath(abs_path, args.model_dir) + args.t5_model_dir = to_abspath(abs_path, args.t5_model_dir) return args diff --git a/examples/opensora_hpcai/scripts/inference.py b/examples/opensora_hpcai/scripts/inference.py index f607d48d50..3e0defe0a1 100644 --- a/examples/opensora_hpcai/scripts/inference.py +++ b/examples/opensora_hpcai/scripts/inference.py @@ -17,11 +17,10 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../moviegen"))) from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets.aspect import ASPECT_RATIO_MAP, ASPECT_RATIOS, get_image_size, get_num_frames -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 from opensora.models.text_encoder.t5 import get_text_encoder_and_tokenizer from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import InferPipeline, InferPipelineFiTLike @@ -164,24 +163,16 @@ def main(args): latent_condition_frame_length = round(latent_condition_frame_length / 17 * 5) captions = process_prompts(captions, args.loop) # in v1.1 and above, each loop can have a different caption - start_idx, end_idx = 0, len(captions) - if args.text_embed_folder: - end_idx = len(glob.glob(os.path.join(args.text_embed_folder, "*.npz"))) - elif args.ul2_text_embed_folder: - end_idx = len(glob.glob(os.path.join(args.ul2_text_embed_folder, "*.npz"))) if not args.enable_sequence_parallelism: # split samples to NPUs as even as possible - start_idx, end_idx = distribute_samples(end_idx, rank_id, device_num) - if args.reference_path is not None: - args.reference_path = args.reference_path[start_idx:end_idx] - if args.mask_strategy is not None: - args.mask_strategy = args.mask_strategy[start_idx:end_idx] + start_idx, end_idx = distribute_samples(len(captions), rank_id, device_num) + captions = captions[start_idx:end_idx] base_data_idx = start_idx else: base_data_idx = 0 if args.use_parallel and not args.enable_sequence_parallelism: - print(f"Num captions for rank {rank_id}: {end_idx - start_idx}") + print(f"Num captions for rank {rank_id}: {len(captions)}") # 2. model initiate and weight loading # 2.1 vae @@ -262,12 +253,6 @@ def main(args): model_extra_args["qk_norm"] = True logger.info(f"{model_name} init") latte_model = STDiT3_XL_2(**model_extra_args) - elif args.model_version == "llama3_1b": - model_name = "Llama3-1B" - latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) - elif args.model_version == "llama3_5b": - model_name = "Llama3-5B" - latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") @@ -295,13 +280,10 @@ def main(args): logger.warning(f"{model_name} uses random initialization!") # 2.3 text encoder - if not args.text_embed_folder and not (args.ul2_text_embed_folder and args.byt5_text_embed_folder): - if args.model_version in ["llama3_1b", "llama3_5b"]: - raise ValueError("UL2 and ByT5 text embedding folders are required for MovieGen.") + if args.text_embed_folder is None: text_encoder, tokenizer = get_text_encoder_and_tokenizer( "t5", args.t5_model_dir, model_max_length=args.model_max_length ) - captions = captions[start_idx:end_idx] num_prompts = len(captions) text_tokens, mask = zip( *[text_encoder.get_text_tokens_and_mask(caption, return_tensor=False) for caption in captions] @@ -319,44 +301,28 @@ def main(args): ) logger.info(f"Num tokens: {mask.asnumpy().sum(2)}") else: + assert not args.use_parallel, "parallel inference is not supported for t5 cached sampling currently." if args.model_version != "v1": logger.warning("For embedded captions, only one prompt per video is supported at this moment.") - extra_embed_paths1 = None - if args.text_embed_folder: - assert args.model_version not in [ - "llama3_1b", - "llama3_5b", - ], "UL2 and ByT5 text embedding folders are required for MovieGen." - main_embed_paths = sorted(glob.glob(os.path.join(args.text_embed_folder, "*.npz")))[start_idx:end_idx] - elif args.ul2_text_embed_folder and args.byt5_text_embed_folder: - main_embed_paths = sorted(glob.glob(os.path.join(args.ul2_text_embed_folder, "*.npz")))[start_idx:end_idx] - extra_embed_paths1 = sorted(glob.glob(os.path.join(args.byt5_text_embed_folder, "*.npz")))[ - start_idx:end_idx - ] - else: - raise NotImplementedError("T5 or UL2 and ByT5 text embedding should be provided.") - - def read_embeddings(embed_paths): - prefix = [] - _mask, _text_emb = [], [] - for fp in embed_paths: - prefix.append(os.path.basename(fp)[:-4]) - with np.load(fp) as dat: - _mask.append(dat["mask"]) - _text_emb.append(dat["text_emb"]) - return ( - ms.Tensor(np.concatenate(_mask), dtype=ms.uint8), - ms.Tensor(np.concatenate(_text_emb), dtype=ms.float32), - prefix, - ) - - mask, text_emb, prompt_prefix = read_embeddings(main_embed_paths) - extra_mask1, extra_text_emb1, _ = ( - read_embeddings(extra_embed_paths1) if extra_embed_paths1 else (None, None, None) - ) + embed_paths = sorted(glob.glob(os.path.join(args.text_embed_folder, "*.npz"))) + prompt_prefix = [] + text_tokens, mask, text_emb = [], [], [] + for fp in embed_paths: + prompt_prefix.append(os.path.basename(fp)[:-4]) + dat = np.load(fp) + text_tokens.append(dat["tokens"]) + mask.append(dat["mask"]) + text_emb.append(dat["text_emb"]) + text_tokens = np.concatenate(text_tokens) + mask = np.concatenate(mask) + text_emb = np.concatenate(text_emb) logger.info(f"Num tokens: {mask.sum(1)}") + num_prompts = text_emb.shape[0] + text_tokens = ms.Tensor(text_tokens) + mask = ms.Tensor(mask, dtype=ms.uint8) + text_emb = ms.Tensor(text_emb, dtype=ms.float32) text_encoder = None if (args.model_version == "v1" or args.reference_path is None) and num_prompts < 1: @@ -491,9 +457,6 @@ def read_embeddings(embed_paths): inputs["text_tokens"] = None inputs["text_emb"] = text_emb[i : i + ns] inputs["mask"] = mask[i : i + ns] - if extra_text_emb1 is not None: - model_args["extra_text_embed1"] = extra_text_emb1[i : i + ns] - model_args["extra_mask1"] = extra_mask1[i : i + ns] logger.info("Sampling captions:") for j in range(ns): @@ -526,13 +489,13 @@ def read_embeddings(embed_paths): # save result for j in range(ns): - if not args.text_embed_folder and not (args.ul2_text_embed_folder and args.byt5_text_embed_folder): - global_idx = base_data_idx + i + j + global_idx = base_data_idx + i + j + if args.text_embed_folder is None: prompt = "-".join((batch_prompts[j][0].replace("/", "").split(" ")[:10])) save_fp = f"{save_dir}/{global_idx:03d}-{prompt}.{args.save_format}" latent_save_fp = f"{latent_dir}/{global_idx:03d}-{prompt}.npy" else: - fn = prompt_prefix[i + j] + fn = prompt_prefix[global_idx] save_fp = f"{save_dir}/{fn}.{args.save_format}" latent_save_fp = f"{latent_dir}/{fn}.npy" @@ -557,11 +520,7 @@ def parse_args(): help="path to load a config yaml file that describes the setting which will override the default arguments", ) parser.add_argument( - "--model_version", - default="v1", - type=str, - choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], - help="OpenSora model version.", + "--model_version", default="v1", type=str, choices=["v1", "v1.1", "v1.2"], help="OpenSora model version." ) parser.add_argument("--image_size", type=int, nargs="+", help="image size in [256, 512]") parser.add_argument("--resolution", type=str, help=f"Supported video resolutions: {list(ASPECT_RATIOS.keys())}") @@ -737,8 +696,6 @@ def parse_args(): parser.add_argument("--fps", type=int, default=8, help="FPS in the saved video") parser.add_argument("--batch_size", default=4, type=int, help="infer batch size") parser.add_argument("--text_embed_folder", type=str, default=None, help="path to t5 embedding") - parser.add_argument("--ul2_text_embed_folder", type=str, help="path to ul2 embedding") - parser.add_argument("--byt5_text_embed_folder", type=str, help="path to byt5 embedding") parser.add_argument( "--save_latent", type=str2bool, diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index 3246f38366..19819c24f3 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -22,13 +22,11 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../../")) sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../moviegen"))) - from args_train import parse_args from opensora.acceleration.parallel_states import create_parallel_group from opensora.datasets.aspect import ASPECT_RATIOS, get_image_size from opensora.models.layers.operation_selector import set_dynamic_mode -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import ( DiffusionWithLoss, @@ -71,9 +69,7 @@ def init_env( global_bf16: bool = False, dynamic_shape: bool = False, enable_sequence_parallelism: bool = False, - enable_model_parallelism: bool = False, sequence_parallel_shards: int = 1, - model_parallel_shards: int = 1, debug: bool = False, ) -> Tuple[int, int]: """ @@ -88,16 +84,12 @@ def init_env( """ set_random_seed(seed) - if enable_sequence_parallelism or enable_model_parallelism: + if enable_sequence_parallelism: if parallel_mode != "data" or not distributed: raise ValueError( - "sequence parallel / tensor parallel can only be used in data parallel mode, " + "sequence parallel can only be used in data parallel mode, " f"but get parallel_mode=`{parallel_mode}` with distributed=`{distributed}`." ) - if enable_sequence_parallelism and enable_model_parallelism: - raise ValueError( - "Cannot turn on sequence parallel (Non-Llama structure) / model paralell (Llama structure) in the same time." - ) if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging logger.warning("Debug mode is on, switching execution mode to PyNative.") @@ -134,12 +126,9 @@ def init_env( ) if enable_sequence_parallelism: - create_parallel_group(sequence_parallel_shards=sequence_parallel_shards) + create_parallel_group(sequence_parallel_shards) ms.set_auto_parallel_context(enable_alltoall=True) - if enable_model_parallelism: - create_parallel_group(model_parallel_shards=model_parallel_shards) - var_info = ["device_num", "rank_id", "device_num / 8", "rank_id / 8"] var_value = [device_num, rank_id, int(device_num / 8), int(rank_id / 8)] logger.info(dict(zip(var_info, var_value))) @@ -197,6 +186,7 @@ def initialize_dataset( args, csv_path, video_folder, + text_embed_folder, vae_latent_folder, batch_size, img_h, @@ -214,7 +204,7 @@ def initialize_dataset( ds_config = dict( csv_path=csv_path, video_folder=video_folder, - text_emb_folder=args.text_embed_folder, + text_emb_folder=text_embed_folder, return_text_emb=True, vae_latent_folder=vae_latent_folder, return_vae_latent=args.train_with_vae_latent, @@ -265,24 +255,6 @@ def initialize_dataset( if args.pre_patchify: output_columns.extend(["spatial_pos", "spatial_mask", "temporal_pos", "temporal_mask"]) - text_embed_folder = {} - if args.text_embed_folder: - text_embed_folder["t5"] = args.text_embed_folder - if args.ul2_text_embed_folder: - text_embed_folder["ul2"] = args.ul2_text_embed_folder - if args.byt5_text_embed_folder: - text_embed_folder["byt5"] = args.byt5_text_embed_folder - - if not len(text_embed_folder): - text_embed_folder = None - elif len(text_embed_folder) == 1: - text_embed_folder = list(text_embed_folder.values())[0] - else: - # FIXME: hardcoding - output_columns[1] = "ul2_caption" - output_columns[2] = "ul2_mask" - output_columns.extend(["byt5_caption", "byt5_mask"]) - datasets = [ VideoDatasetRefactored( csv_path=csv_path, @@ -387,9 +359,7 @@ def main(args): global_bf16=args.global_bf16, dynamic_shape=(args.bucket_config is not None), enable_sequence_parallelism=args.enable_sequence_parallelism, - enable_model_parallelism=args.enable_model_parallelism, sequence_parallel_shards=args.sequence_parallel_shards, - model_parallel_shards=args.model_parallel_shards, debug=args.debug, ) set_logger(name="", output_dir=args.output_path, rank=rank_id, log_level=eval(args.log_level)) @@ -460,7 +430,6 @@ def main(args): manual_pad=args.manual_pad, enable_flashattn=args.enable_flash_attention, enable_sequence_parallelism=args.enable_sequence_parallelism, - enable_model_parallelism=args.enable_model_parallelism, use_recompute=args.use_recompute, num_recompute_blocks=args.num_recompute_blocks, ) @@ -486,12 +455,6 @@ def main(args): model_extra_args["qk_norm"] = True model_extra_args["freeze_y_embedder"] = args.freeze_y_embedder latte_model = STDiT3_XL_2(**model_extra_args) - elif args.model_version == "llama3_1b": - model_name = "Llama3-1B" - latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) - elif args.model_version == "llama3_5b": - model_name = "Llama3-5B" - latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") logger.info(f"{model_name} input size: {latent_size if args.bucket_config is None else 'Variable'}") @@ -582,10 +545,6 @@ def main(args): data_device_num = device_num // args.sequence_parallel_shards data_rank_id = rank_id // args.sequence_parallel_shards logger.info(f"Creating dataloader: ID={rank_id}, group={data_rank_id}, num_groups={data_device_num}") - elif args.enable_model_parallelism: - data_device_num = device_num // args.model_parallel_shards - data_rank_id = rank_id // args.model_parallel_shards - logger.info(f"Creating dataloader: ID={rank_id}, group={data_rank_id}, num_groups={data_device_num}") else: data_device_num = device_num data_rank_id = rank_id @@ -594,6 +553,7 @@ def main(args): args, args.csv_path, args.video_folder, + args.text_embed_folder, args.vae_latent_folder, args.batch_size, img_h, @@ -787,10 +747,7 @@ def main(args): logger.info( "As steps per epoch are inaccurate with bucket config, TimeMonitor is disabled. See result.log for the actual step time" ) - if rank_id == 0 or args.enable_model_parallelism: - if args.enable_model_parallelism: - ckpt_dir = os.path.join(ckpt_dir, f"rank_{rank_id}") - + if rank_id == 0: save_cb = EvalSaveCallback( network=latent_diffusion_with_loss.network, rank_id=rank_id, @@ -809,8 +766,6 @@ def main(args): record_lr=False, train_steps=args.train_steps, ) - - if rank_id == 0: rec_cb = PerfRecorderCallback( save_dir=args.output_path, file_name="result_val.log", From 09118322c60f2db23c14febfc469574699809270 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:35:06 +0800 Subject: [PATCH 080/122] merge with PR #778 --- .../animatediff/ad/models/diffusion/ddpm.py | 5 +- examples/moviegen/README.md | 96 ++++++++++++------- examples/moviegen/inference_tae_enc.py | 1 + examples/moviegen/mg/models/tae/modules.py | 74 +++++--------- examples/moviegen/tests/ut/test_gn.py | 29 ------ examples/moviegen/tests/ut/test_tae.py | 8 +- examples/moviegen/tools/inflate_vae_to_tae.py | 5 +- 7 files changed, 99 insertions(+), 119 deletions(-) delete mode 100644 examples/moviegen/tests/ut/test_gn.py diff --git a/examples/animatediff/ad/models/diffusion/ddpm.py b/examples/animatediff/ad/models/diffusion/ddpm.py index 8010e347d4..ae23a9d6ca 100644 --- a/examples/animatediff/ad/models/diffusion/ddpm.py +++ b/examples/animatediff/ad/models/diffusion/ddpm.py @@ -365,7 +365,10 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs """ # 1. get image/video latents z using vae - z = self.get_latents(x) + if self.emb_cache: + z = x + else: + z = self.get_latents(x) # 2. sample timestep and add noise to latents t = self.uniform_int( diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 284a564f82..0bfead2150 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -71,6 +71,10 @@ character-level text understanding for the backbone: # Installation +| MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel | +|:---------:|:-------------:|:-----------:|:-------------------:| +| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 | + 1. Install MindSpore according to the [official instructions](https://www.mindspore.cn/install). For Ascend devices, please install [CANN8.0.RC2.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) @@ -85,11 +89,11 @@ character-level text understanding for the backbone:
TAE -We use SD3.5 VAE to initialize the spatial layers of TAE since both have a latent channel of 16. +We use SD3.5 VAE to initialize the spatial layers of TAE since both have the same number of latent channels, i.e., 16. -1. Download SD3.5 VAE from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae +1. Download SD3.5 VAE from [huggingface](https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae) -2. Convert VAE checkpoint for TAE loading +2. Inflate VAE checkpoint for TAE initialization by ```shell python inflate_vae_to_tae.py --src /path/to/sd3.5_vae/diffusion_pytorch_model.safetensors --target models/tae_vae2d.ckpt ``` @@ -113,7 +117,7 @@ If you face an SSL certificate verification error, you can add `--disable_ssl_ve # Generating Text Embeddings -Due to the large memory footprint of the text encoders, the inference and training pipelines do not support generating +Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating text embeddings online. Therefore, you need to prepare them in advance by running the following command: ```shell @@ -160,20 +164,7 @@ python inference.py \ ## TAE -#### Video Reconstruction - -```shell -python eval_tae.py \ ---pretrained /path/to/tae.ckpt \ ---batch_size 2 \ ---sample_n_frames 16 \ ---size 256 \ ---csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_test.csv \ ---folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ ---use_tile False -``` - -#### Encoding video +### Encoding video ```python from mg.models.tae import TemporalAutoencoder @@ -193,7 +184,7 @@ z = (z - tae.shift_factor) * tae.scale_factor For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) -#### Decoding video latent +### Decoding video latent ```python # if z is scaled, you should unscale at first: @@ -251,29 +242,68 @@ Validation can be enabled by either setting parameters in the `valid` field of t ## TAE -```shell -output_dir=outputs/train_tae_256x256x16 +### Prepare datasets + +We need to prepare a csv annotation file listing the path to each input video related to the root folder, indicated by +the `video_folder` argument. An example is + +``` +video +dance/vid001.mp4 +dance/vid002.mp4 +dance/vid003.mp4 +... +``` + +Taking UCF-101, for example, please download the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract +it to `datasets/UCF-101` folder. +### Training + +TAE is trained to optimize the reconstruction loss, perceptual loss, and the outlier penalty loss (OPL) proposed in the +MovieGen paper. + +To launch training, please run + +```shell python train_tae.py \ ---config configs/tae/train/mixed_256x256x16.yaml \ ---output_path $output_dir \ ---csv_path ../opensora_hpcai/datasets/mixkit-100videos/video_caption_train.csv \ ---video_folder ../opensora_hpcai/datasets/mixkit-100videos/mixkit \ +--config configs/tae/train/mixed_256x256x32.yaml \ +--output_path /path/to/save_ckpt_and_log \ +--csv_path /path/to/video_train.csv \ +--folder /path/to/video_root_folder \ ``` -OPL - outlier penalty loss is found to be not beneficial in our experiment (PSNR decreased). -Thus, we set it to False by default. +Different from the paper, we found that OPL loss doesn't benefit the training outcome in our ablation study (reducing in +lower PSNR decreased). Thus, we disable OPL loss by default. You may enable it by appending +`--use_outlier_penalty_loss True` -Change mixed_256x256x16.yaml to mixed_256x256x32.yaml for training on 32 frames. +For more details on the arguments, please run `python scripts/train_tae.py --help` + +### Evaluation + +To run video reconstruction with the trained TAE model and evaluate the PSNR and SSIM on the test set, please run + +```shell +python eval_tae.py \ +--ckpt_path /path/to/tae.ckpt \ +--batch_size 2 \ +--num_frames 32 \ +--image_size 256 \ +--csv_path /path/to/video_test.csv \ +--folder /path/to/video_root_folder \ +``` + +The reconstructed videos will be saved in `samples/recons`. ### Performance -Train on 80 samples of mixkit-100 (train set), test on the other 20 samples (test set) +Here, we report the training performance and evaluation results on the UCF-101 dataset. + +Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode. -| Resolution | NPUs | Precision | Time (s/step) | PSNR (test set) | -|------------|------|-----------|---------------|-----------------| -| 256x256x16 | 1 | FP32 | 1.99 | 28.5 | -| 256x256x32 | 1 | BF16 | 2.49 | 28.3 | +| model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe | +|:----------:|:-----:|:----------:|:----------:|:---------:|:---------:|:-------------:|:------:|:-----:|:----:|:-------------------------------------------------:| +| TAE | 1 | 1 | 256x256x32 | bf16 | O0 | 2 min | 2.18 | 31.35 | 0.92 | [config](configs/tae/train/mixed_256x256x32.yaml) | # Evaluation diff --git a/examples/moviegen/inference_tae_enc.py b/examples/moviegen/inference_tae_enc.py index 7301c2ce1b..d9ab6de579 100644 --- a/examples/moviegen/inference_tae_enc.py +++ b/examples/moviegen/inference_tae_enc.py @@ -56,6 +56,7 @@ def main(args): black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], dtype=MODEL_DTYPE[tae_dtype], ) + # TODO: add dynamic shape support # 4. print key info key_info = "Key Settings:\n" + "=" * 50 + "\n" diff --git a/examples/moviegen/mg/models/tae/modules.py b/examples/moviegen/mg/models/tae/modules.py index cd2aaa0a0b..9dd6f6f86a 100644 --- a/examples/moviegen/mg/models/tae/modules.py +++ b/examples/moviegen/mg/models/tae/modules.py @@ -25,6 +25,16 @@ def nonlinearity(x): return x * (ops.sigmoid(x)) +def symmetric_pad1d(x): + # x: (B C T), work with kernel size = 1 + first_frame = x[:, :, :1] + last_frame = x[:, :, -1:] + # last_frame_pad = ops.cat([last_frame] * self.time_pad, axis=2) + x = ops.concat((first_frame, x, last_frame), axis=2) + + return x + + class GroupNorm5d(nn.GroupNorm): def construct(self, x): # x (b c t h w) @@ -158,10 +168,6 @@ def __init__( has_bias=has_bias, ) - # temp_pad_mode = 'zero' - # temp_pad = 'mint_rep' - # temp_pad = "manual" - # temporal conv if kernel_size > 1: # symmetric padding + conv1d @@ -175,7 +181,7 @@ def __init__( has_bias=has_bias, bias_init="zeros", ) - self.pad = self.symmetric_pad1d + self.pad = symmetric_pad1d self.use_pad = True else: self.use_pad = False @@ -189,16 +195,7 @@ def __init__( bias_init="zeros", ) - self.init_temporal_weight("median") - - @staticmethod - def symmetric_pad1d(x): - first_frame = x[:, :, :1] - last_frame = x[:, :, -1:] - # last_frame_pad = ops.cat([last_frame] * self.time_pad, axis=2) - x = ops.concat((first_frame, x, last_frame), axis=2) - - return x + self.init_temporal_weight("centric") def construct(self, x): """ @@ -243,11 +240,11 @@ def construct(self, x): return x - def init_temporal_weight(self, method="median"): + def init_temporal_weight(self, method="centric"): if method == "normal": return - elif method == "median": + elif method == "centric": # temporal conv kernel: (cout, cin, 1, ks) # ks=1 or 3, cin == cout w = self.conv_temp.weight @@ -347,9 +344,9 @@ def __init__(self, in_channels): has_bias=True, bias_init="zeros", ) - # tail padding, pad with last frame + # tail padding, pad with self.time_pad = self.ks - 1 - self.init_weight("median") + self.init_weight("centric") def init_weight(self, method="mean"): if method == "normal": @@ -363,8 +360,8 @@ def init_weight(self, method="mean"): # initially, it's a mean filter for temporal downsampling for i in range(self.ch): value[i, i, 0, :] = 1 / self.ks # (cout, cin, 1, ks) - elif method == "median": - # a median filter for temporal downsampling + elif method == "centric": + # a centric filter for temporal downsampling for i in range(self.ch): value[i, i, 0, self.ks // 2] = 1 # (cout, cin, 1, ks) else: @@ -380,10 +377,8 @@ def construct(self, x): x = ops.transpose(x, (0, 3, 4, 1, 2)) x = ops.reshape(x, (B * H * W, C, T)) - # tail padding - last_frame = x[:, :, -1:] - last_frame_pad = ops.cat([last_frame] * self.time_pad, axis=2) - x = ops.concat((x, last_frame_pad), axis=2) + # symmetric padding + x = symmetric_pad1d(x) x = self.conv(x) @@ -398,8 +393,8 @@ def construct(self, x): class TemporalUpsample(nn.Cell): def __init__(self, in_channels, manual_pad=True): super().__init__() - # to support danamic shape in graph mode self.manual_pad = manual_pad + # to support danamic shape in graph mode if not self.manual_pad: self.conv = nn.Conv1d( in_channels, in_channels, kernel_size=3, stride=1, pad_mode="same", has_bias=True, bias_init="zeros" @@ -411,16 +406,16 @@ def __init__(self, in_channels, manual_pad=True): # TODO: init conv weight so that it pass in image mode self.ch = in_channels - self.init_weight("median") + self.init_weight("centric") - def init_weight(self, method="median"): + def init_weight(self, method="centric"): if method == "normal": return # init so that the output is the same as vae2d for image input w = self.conv.weight value = np.zeros(tuple(w.shape)) - if method == "median": + if method == "centric": # consider image input, make sure it's the same for i in range(self.ch): value[i, i, 0, 1] = 1 # (cout, cin, 1, ks) @@ -457,25 +452,6 @@ def construct(self, x): return x - """ - def construct(self, x): - # x (b c t h w) - x = ops.interpolate(x, scale_factor=(2.0, 1.0, 1.0), mode="nearest") - - # x (b c t h w) -> (bhw c t) - B, C, T, H, W = x.shape - x = ops.transpose(x, (0, 3, 4, 1, 2)) - x = ops.reshape(x, (B*H*W, C, T)) - - x = self.conv(x) - - # x (bhw c t) -> (b c t h w) - x = ops.reshape(x, (B, H, W, C, T)) - x = ops.transpose(x, (0, 3, 4, 1, 2)) - - return x - """ - # used in vae class ResnetBlock(nn.Cell): @@ -717,7 +693,7 @@ def make_attn(in_channels, attn_type="vanilla"): # used in vae class Encoder(nn.Cell): - # @lazy_inline() + # @ms.lazy_inline() def __init__( self, ch=128, diff --git a/examples/moviegen/tests/ut/test_gn.py b/examples/moviegen/tests/ut/test_gn.py deleted file mode 100644 index 8d55ef9a13..0000000000 --- a/examples/moviegen/tests/ut/test_gn.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn - -# 定义输入形状 -B, C, T, H, W = 2, 3, 16, 256, 256 -x = np.random.normal(size=(B, C, T, H, W)).astype(np.float32) -x_tensor = torch.tensor(x) - -# 定义 GroupNorm 层 -group_norm = nn.GroupNorm(num_groups=3, num_channels=C) - -# 第一次 GroupNorm 操作 -y1 = group_norm(x_tensor) - -# 重新排列形状 -x_rearranged = x_tensor.permute(0, 3, 4, 1, 2).contiguous().view(B * H * W, C, T) - -# 第二次 GroupNorm 操作 -y2 = group_norm(x_rearranged) - -# 恢复形状 -# y1 = y1.view(B, C, T, H, W).permute(0, 2, 1, 3, 4).contiguous() -y2 = y2.view(B, H, W, C, T).permute(0, 3, 4, 1, 2).contiguous() - -# 比较 y1 和 y2 -print(y1.sum()) -print(y2.sum()) -print(torch.allclose(y1, y2)) diff --git a/examples/moviegen/tests/ut/test_tae.py b/examples/moviegen/tests/ut/test_tae.py index ae5c3798e7..f2d2715151 100644 --- a/examples/moviegen/tests/ut/test_tae.py +++ b/examples/moviegen/tests/ut/test_tae.py @@ -134,7 +134,7 @@ def test_spatial_upsample(): def test_temporal_downsample(): # in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) - in_shape = (B, C, T, H, W) = (1, 64, 1, 32, 32) + in_shape = (B, C, T, H, W) = (1, 16, 5, 32, 32) x = np.random.normal(size=in_shape).astype(np.float32) td = TemporalDownsample(C) @@ -278,7 +278,7 @@ def test_tae_tile(): if __name__ == "__main__": - ms.set_context(mode=0) + ms.set_context(mode=1) # test_conv25d() # test_resnetblock() @@ -295,6 +295,6 @@ def test_tae_tile(): # test_tae_decode() # test_tae_rec() # test_tae_tile() - test_blend() + # test_blend() - # test_sd3d5_vae() + test_sd3d5_vae() diff --git a/examples/moviegen/tools/inflate_vae_to_tae.py b/examples/moviegen/tools/inflate_vae_to_tae.py index 8542a17110..6775a58c05 100644 --- a/examples/moviegen/tools/inflate_vae_to_tae.py +++ b/examples/moviegen/tools/inflate_vae_to_tae.py @@ -4,6 +4,7 @@ from safetensors import safe_open import mindspore as ms +from mg.models.tae.sd3_vae import SD3d5_CONFIG,SD3d5_VAE def get_shape_from_str(shape): @@ -32,9 +33,7 @@ def load_torch_ckpt(ckpt_path): def plot_ms_vae2d5(): - from mg.models.tae.tae import SD3d5_CONFIG, TemporalAutoencoder - - tae = TemporalAutoencoder(config=SD3d5_CONFIG) + tae = SD3d5_VAE(config=SD3d5_CONFIG) sd = tae.parameters_dict() pnames = list(sd.keys()) From 657628e32769d43b030e6fe33b23daa09689d2ca Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:09:32 +0800 Subject: [PATCH 081/122] small fix --- examples/moviegen/tools/inflate_vae_to_tae.py | 2 +- mindone/utils/env.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/moviegen/tools/inflate_vae_to_tae.py b/examples/moviegen/tools/inflate_vae_to_tae.py index 6775a58c05..08e09bbc7d 100644 --- a/examples/moviegen/tools/inflate_vae_to_tae.py +++ b/examples/moviegen/tools/inflate_vae_to_tae.py @@ -1,10 +1,10 @@ import argparse import numpy as np +from mg.models.tae.sd3_vae import SD3d5_CONFIG, SD3d5_VAE from safetensors import safe_open import mindspore as ms -from mg.models.tae.sd3_vae import SD3d5_CONFIG,SD3d5_VAE def get_shape_from_str(shape): diff --git a/mindone/utils/env.py b/mindone/utils/env.py index b459898d06..6a4f986c87 100644 --- a/mindone/utils/env.py +++ b/mindone/utils/env.py @@ -70,11 +70,11 @@ def init_train_env( ms.set_context(jit_config={"jit_level": jit_level}) if distributed: - device_id, kwargs = None, {} # if no rank table - if os.getenv("DEVICE_ID"): - device_id = int(os.getenv("DEVICE_ID")) - kwargs = {"device_id": int(os.getenv("DEVICE_ID"))} - ms.set_context(mode=mode, device_target=device_target, ascend_config=ascend_config or {}, **kwargs) + ms.set_context(mode=mode, device_target=device_target, ascend_config=ascend_config or {}) + device_id = os.getenv("DEVICE_ID", None) + if device_id: + ms.set_context(device_id=int(device_id)) + init() device_num = get_group_size() rank_id = get_rank() From 82cff9e3cc7fe50bce73300961808eecde0313cd Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:25:20 +0800 Subject: [PATCH 082/122] PR fixes: - remove forced dynamic memory allocation for data transformations - purge Model Parallel functionality until it's fully tested --- examples/moviegen/mg/models/llama/block.py | 214 ---------- examples/moviegen/mg/models/llama/network.py | 161 +------ examples/moviegen/mg/parallel/__init__.py | 2 - examples/moviegen/mg/parallel/layers.py | 398 ------------------ .../moviegen/mg/parallel/parallel_states.py | 38 -- .../moviegen/mg/schedulers/rectified_flow.py | 9 +- .../parallel/run_test_llama3_parallel.sh | 13 - .../run_test_llama3_parallel_block.sh | 13 - .../run_test_llama3_parallel_layer.sh | 13 - .../tests/parallel/run_test_rflow_parallel.sh | 13 - .../tests/parallel/test_llama3_parallel.py | 113 ----- .../parallel/test_llama3_parallel_block.py | 107 ----- .../parallel/test_llama3_parallel_layer.py | 125 ------ .../tests/parallel/test_rflow_parallel.py | 61 --- examples/moviegen/tests/parallel/utils.py | 32 -- examples/moviegen/train.py | 12 +- mindone/data/loader.py | 23 +- mindone/trainers/train_step.py | 2 +- 18 files changed, 25 insertions(+), 1324 deletions(-) delete mode 100644 examples/moviegen/mg/parallel/__init__.py delete mode 100644 examples/moviegen/mg/parallel/layers.py delete mode 100644 examples/moviegen/mg/parallel/parallel_states.py delete mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel.sh delete mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh delete mode 100755 examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh delete mode 100755 examples/moviegen/tests/parallel/run_test_rflow_parallel.sh delete mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel.py delete mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_block.py delete mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_layer.py delete mode 100644 examples/moviegen/tests/parallel/test_rflow_parallel.py delete mode 100644 examples/moviegen/tests/parallel/utils.py diff --git a/examples/moviegen/mg/models/llama/block.py b/examples/moviegen/mg/models/llama/block.py index 474f3d821e..b9645da307 100644 --- a/examples/moviegen/mg/models/llama/block.py +++ b/examples/moviegen/mg/models/llama/block.py @@ -2,13 +2,6 @@ from typing import Optional, Sequence, Tuple, Union import numpy as np -from mg.parallel import ( - ColumnParallelLinear, - FusedColumnParallelLinear, - FusedRowParallelLinear, - GatherForwardReduceScatterBackward, - RowParallelLinear, -) import mindspore as ms import mindspore.mint as mint @@ -16,7 +9,6 @@ import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor -from mindspore.communication import GlobalComm from mindspore.ops.operations.nn_ops import FlashAttentionScore from .activation import ACT2FN @@ -59,77 +51,6 @@ def construct(self, hidden_state: Tensor) -> Tensor: return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -class TensorParallelLlamaMLP(nn.Cell): - def __init__( - self, - intermediate_size: int = 8192, - hidden_size: int = 3072, - hidden_act: str = "silu", - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = ColumnParallelLinear( - self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype - ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, self.hidden_size, bias=False, input_is_parallel=True, group=group, dtype=dtype - ) - self.act_fn = ACT2FN[hidden_act] - - def construct(self, hidden_state: Tensor) -> Tensor: - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - def load_weight_from_non_parallel_cell(self, target: LlamaMLP): - self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) - self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) - self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) - - -class FusedTensorParallelLlamaMLP(nn.Cell): - def __init__( - self, - intermediate_size: int = 8192, - hidden_size: int = 3072, - hidden_act: str = "silu", - dim: int = 1, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = FusedColumnParallelLinear( - self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype - ) - self.up_proj = FusedColumnParallelLinear( - self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype - ) - self.down_proj = FusedRowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=False, - input_is_parallel=True, - dim=dim, - group=group, - dtype=dtype, - ) - self.act_fn = ACT2FN[hidden_act] - - def construct(self, hidden_state: Tensor) -> Tensor: - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - def load_weight_from_non_parallel_cell(self, target: LlamaMLP): - self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) - self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) - self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) - - def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: if n_rep == 1: return hidden_states @@ -210,81 +131,6 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso return attn_output -class ContextParallelLlamaAttention(nn.Cell): - def __init__( - self, - hidden_size: int = 4096, - num_attention_heads: int = 32, - num_key_value_heads: int = 8, - attention_dropout: float = 0.0, - attention_bias: bool = False, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__() - self.attention_dropout = attention_dropout - self.hidden_size = hidden_size - self.num_heads = num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype) - self.k_proj = mint.nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype - ) - self.v_proj = mint.nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype - ) - self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) - - self.gather_forward_reduce_scatter_backward = GatherForwardReduceScatterBackward(dim=1, group=group) - - def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: - bsz, q_len, _ = hidden_states.shape - - kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(kv_hidden_states) - value_states = self.v_proj(kv_hidden_states) - - key_states = self.gather_forward_reduce_scatter_backward(key_states) - value_states = self.gather_forward_reduce_scatter_backward(value_states) - - query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) - query_states = mint.permute(query_states, (0, 2, 1, 3)) - - key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) - key_states = mint.permute(key_states, (0, 2, 1, 3)) - - value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) - value_states = mint.permute(value_states, (0, 2, 1, 3)) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = mint.permute(key_states, (0, 1, 3, 2)) - attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim)) - - # upcast attention to fp32 - attn_weights = attn_weights.to(ms.float32) - attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = mint.matmul(attn_weights, value_states) - - attn_output = mint.permute(attn_output, (0, 2, 1, 3)) - attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) - attn_output = self.o_proj(attn_output) - - return attn_output - - class LlamaFlashAttention(LlamaAttention): def __init__( self, @@ -340,66 +186,6 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso return attn_output -class ContextParallelLlamaFlashAttention(ContextParallelLlamaAttention): - def __init__( - self, - hidden_size: int = 4096, - num_attention_heads: int = 32, - num_key_value_heads: int = 8, - attention_dropout: float = 0.0, - attention_bias: bool = False, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - attention_dropout=attention_dropout, - attention_bias=attention_bias, - group=group, - dtype=dtype, - ) - self.flash_attention = FlashAttentionScore( - self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" - ) - - def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: - bsz, q_len, _ = hidden_states.shape - - kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(kv_hidden_states) - value_states = self.v_proj(kv_hidden_states) - - key_states = self.gather_forward_reduce_scatter_backward(key_states) - value_states = self.gather_forward_reduce_scatter_backward(value_states) - - query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) - query_states = mint.permute(query_states, (0, 2, 1, 3)) - - key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) - key_states = mint.permute(key_states, (0, 2, 1, 3)) - - value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) - value_states = mint.permute(value_states, (0, 2, 1, 3)) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # Reshape to the expected shape and dtype for Flash Attention - query_states = mint.permute(query_states, (0, 2, 1, 3)) - key_states = mint.permute(key_states, (0, 2, 1, 3)) - value_states = mint.permute(value_states, (0, 2, 1, 3)) - - _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) - attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) - attn_output = self.o_proj(attn_output) - - return attn_output - - class PatchEmbed3D(nn.Cell): def __init__( self, diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index 5269f70491..b6f3faa179 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -3,31 +3,24 @@ from typing import Literal, Optional, Tuple, Union import numpy as np -from mg.parallel import GatherForwardSplitBackward, SplitForwardGatherBackward -from mg.parallel.parallel_states import get_model_parallel_group import mindspore as ms import mindspore.mint as mint import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor, load_checkpoint, load_param_into_net -from mindspore.communication import GlobalComm, get_group_size from mindone.models.utils import normal_, zeros_ from ..text_encoders import TextProjector from .activation import ACT2FN from .block import ( - ContextParallelLlamaAttention, - ContextParallelLlamaFlashAttention, - FusedTensorParallelLlamaMLP, LinearPatchEmbed3D, LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D, - TensorParallelLlamaMLP, TimestepEmbedder, ) @@ -38,11 +31,6 @@ "flash_attention": LlamaFlashAttention, } -CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES = { - "eager": ContextParallelLlamaAttention, - "flash_attention": ContextParallelLlamaFlashAttention, -} - def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: return x * (1 + scale) + shift @@ -133,117 +121,6 @@ def construct( return hidden_states -class ModelParallelLlamaDecoderLayer(nn.Cell): - def __init__( - self, - hidden_size: int = 4096, - intermediate_size: int = 14336, - num_attention_heads: int = 32, - num_key_value_heads: int = 8, - rms_norm_eps: float = 1e-5, - attention_dropout: float = 0.0, - attention_bias: bool = False, - hidden_act: str = "silu", - attn_implementation: Literal["eager", "flash_attention"] = "eager", - fused_tensor_parallel: bool = True, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: ms.Type = ms.float32, - ) -> None: - super().__init__() - - # 3.1.6 Context Parallelism - self.self_attn = CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES[attn_implementation]( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - attention_dropout=attention_dropout, - attention_bias=attention_bias, - dtype=dtype, - ) - - self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation]( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - attention_dropout=attention_dropout, - attention_bias=attention_bias, - dtype=dtype, - ) - - # 3.1.6 Tensor Parallelism - if fused_tensor_parallel: - self.mlp = FusedTensorParallelLlamaMLP( - intermediate_size=intermediate_size, - hidden_size=hidden_size, - hidden_act=hidden_act, - dim=1, - group=group, - dtype=dtype, - ) - else: - self.mlp = TensorParallelLlamaMLP( - intermediate_size=intermediate_size, - hidden_size=hidden_size, - hidden_act=hidden_act, - group=group, - dtype=dtype, - ) - - self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size) / hidden_size**0.5, dtype=dtype)) - self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) - self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) - - if not fused_tensor_parallel: - self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=group) - self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=group) - else: - self.split_forward_gather_backward = nn.Identity() - self.gather_forward_split_backward = nn.Identity() - - def construct( - self, - hidden_states: Tensor, - encoder_hidden_states: Tensor, - modulation_parameters: Tensor, - position_embedding: Tensor, - ) -> Tensor: - B = hidden_states.shape[0] - - # 3.1.3 Positional Embedding - hidden_states = hidden_states + position_embedding - - # 3.1.3 Adaptive Layer Norm - modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + ops.reshape( - modulation_parameters, (B, 6, -1) - ) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) - - # Self Attention (Bi-Directional Attention) - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) - hidden_states = self.self_attn(hidden_states) - hidden_states = gate_msa * hidden_states - hidden_states = residual + hidden_states - - # 3.1.3 Cross Attention - residual = hidden_states - hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp) - hidden_states = self.gather_forward_split_backward(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = self.split_forward_gather_backward(hidden_states) - hidden_states = gate_mlp * hidden_states - hidden_states = residual + hidden_states - - return hidden_states - - class LlamaFinalLayer(nn.Cell): def __init__( self, @@ -286,12 +163,10 @@ def __init__( initializer_range: float = 0.02, patch_size: Tuple[int, int, int] = (1, 2, 2), max_length: Tuple[int, int, int] = (128, 64, 64), - caption_channels: int = 4096, attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, use_linear_patch_embedder: bool = True, model_parallelism: bool = False, - fused_tensor_parallel: bool = True, post_init_weight: bool = True, dtype: ms.Type = ms.float32, ) -> None: @@ -306,28 +181,9 @@ def __init__( self.max_length = max_length self.model_parallelism = model_parallelism self._dtype = dtype - mp_group = get_model_parallel_group() if self.model_parallelism: - self.layers = nn.CellList( - [ - ModelParallelLlamaDecoderLayer( - hidden_size=self.hidden_size, - intermediate_size=intermediate_size, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - rms_norm_eps=rms_norm_eps, - attention_dropout=attention_dropout, - attention_bias=attention_bias, - hidden_act=hidden_act, - attn_implementation=attn_implementation, - fused_tensor_parallel=fused_tensor_parallel, - group=mp_group, - dtype=dtype, - ) - for _ in range(num_hidden_layers) - ] - ) + raise NotImplementedError("Model parallelism is not supported yet.") else: self.layers = nn.CellList( [ @@ -373,11 +229,6 @@ def __init__( out_features=self.hidden_size, layer_norm=LlamaRMSNorm, norm_eps=self.rms_norm_eps, dtype=dtype ) - if self.model_parallelism: - self.group_size = get_group_size(mp_group) - self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=mp_group) - self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=mp_group) - # post-init if post_init_weight: self.initializer_range = initializer_range @@ -490,19 +341,9 @@ def construct( # main blocks hidden_states = latent_embedding - # 3.1.6 Sequence Parallelism Start - if self.model_parallelism: - # assert hidden_states.shape[1] % self.group_size == 0 - hidden_states = self.split_forward_gather_backward(hidden_states) - position_embedding = self.split_forward_gather_backward(position_embedding) - for decoder_layer in self.layers: hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding) - # 3.1.6 Sequence Parallelism End - if self.model_parallelism: - hidden_states = self.gather_forward_split_backward(hidden_states) - # final block hidden_states = self.final_layer(hidden_states, timestep_embedding) diff --git a/examples/moviegen/mg/parallel/__init__.py b/examples/moviegen/mg/parallel/__init__.py deleted file mode 100644 index de133abd08..0000000000 --- a/examples/moviegen/mg/parallel/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .layers import * -from .parallel_states import * diff --git a/examples/moviegen/mg/parallel/layers.py b/examples/moviegen/mg/parallel/layers.py deleted file mode 100644 index d238d47391..0000000000 --- a/examples/moviegen/mg/parallel/layers.py +++ /dev/null @@ -1,398 +0,0 @@ -import numbers -from typing import Callable, Literal, Optional, Tuple, Union - -import mindspore as ms -import mindspore.mint as mint -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.common.initializer import Initializer -from mindspore.communication import GlobalComm, get_group_size, get_rank - -__all__ = [ - "SplitForwardGatherBackward", - "GatherForwardSplitBackward", - "GatherForwardReduceScatterBackward", - "ColumnParallelLinear", - "RowParallelLinear", - "FusedColumnParallelLinear", - "FusedRowParallelLinear", -] - - -def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: - x = x.swapaxes(0, dim) - x = func(x) - x = x.swapaxes(dim, 0) - return x - - -def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: - dim_size = x.shape[dim] - tensor_list = x.split(dim_size // world_size, axis=dim) - x = tensor_list[rank] - return x - - -class _CopyToModelParallelRegion(nn.Cell): - def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) - - def construct(self, x: Tensor) -> Tensor: - return x - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = self.reduce(dout) - return (dout,) - - -class _ReduceFromModelParallelRegion(nn.Cell): - def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) - - def construct(self, x: Tensor) -> Tensor: - return self.reduce(x) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - return (dout,) - - -class _ScatterToModelParallelRegion(nn.Cell): - def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.gather = ops.AllGather(group=group) - self.rank = get_rank(group) - self.world_size = get_group_size(group) - - def construct(self, x: Tensor) -> Tensor: - return _split(x, -1, self.rank, self.world_size) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = _communicate_along_dim(dout, -1, self.gather) - return (dout,) - - -class _GatherFromModelParallelRegion(nn.Cell): - def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.gather = ops.AllGather(group=group) - self.rank = get_rank(group) - self.world_size = get_group_size(group) - - def construct(self, x: Tensor) -> Tensor: - return _communicate_along_dim(x, -1, self.gather) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = _split(dout, -1, self.rank, self.world_size) - return (dout,) - - -class _GatherToModelParallelRegion(nn.Cell): - def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.dim = dim - self.gather = ops.AllGather(group=group) - self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) - self.rank = get_rank(group) - self.world_size = get_group_size(group) - self.scale = self.world_size - - def construct(self, x: Tensor) -> Tensor: - return _communicate_along_dim(x, self.dim, self.gather) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = dout * self.scale - dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) - return (dout,) - - -class _ReduceScatterFromModelParallelRegion(nn.Cell): - def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__(auto_prefix=False) - self.dim = dim - self.gather = ops.AllGather(group=group) - self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) - self.rank = get_rank(group) - self.world_size = get_group_size(group) - self.scale = 1 / self.world_size - - def construct(self, x: Tensor) -> Tensor: - return _communicate_along_dim(x, self.dim, self.reduce_scatter) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = dout * self.scale - dout = _communicate_along_dim(dout, self.dim, self.gather) - return (dout,) - - -class SplitForwardGatherBackward(nn.Cell): - def __init__( - self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP - ) -> None: - super().__init__() - self.dim = dim - self.rank = get_rank(group) - self.world_size = get_group_size(group) - self.gather = ops.AllGather(group=group) - - if grad_scale == "up": - self.scale = self.world_size - else: - self.scale = 1 / self.world_size - - def construct(self, x: Tensor) -> Tensor: - return _split(x, self.dim, self.rank, self.world_size) - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = dout * self.scale - dout = _communicate_along_dim(dout, self.dim, self.gather) - return (dout,) - - -class GatherForwardSplitBackward(nn.Cell): - def __init__( - self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP - ) -> None: - super().__init__() - self.dim = dim - self.rank = get_rank(group) - self.world_size = get_group_size(group) - self.gather = ops.AllGather(group=group) - - if grad_scale == "up": - self.scale = self.world_size - else: - self.scale = 1 / self.world_size - - def construct(self, x: Tensor) -> Tensor: - x = _communicate_along_dim(x, self.dim, self.gather) - return x - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = dout * self.scale - dout = _split(dout, self.dim, self.rank, self.world_size) - return (dout,) - - -class GatherForwardReduceScatterBackward(nn.Cell): - def __init__(self, dim: int = 0, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: - super().__init__() - self.dim = dim - self.rank = get_rank(group) - self.world_size = get_group_size(group) - self.gather = ops.AllGather(group=group) - self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) - - def construct(self, x: Tensor) -> Tensor: - x = _communicate_along_dim(x, self.dim, self.gather) - return x - - def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: - dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) - return (dout,) - - -class ColumnParallelLinear(nn.Cell): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - gather_output: bool = True, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: Optional[ms.Type] = None, - ): - super().__init__(auto_prefix=False) - - self.rank = get_rank(group) - self.world_size = get_group_size(group) - assert out_features % self.world_size == 0 - self.out_features_per_partition = out_features // self.world_size - self.gather_output = gather_output - - self.copy_to_tensor_parallel_region = _CopyToModelParallelRegion(group=group) - self.linear = mint.nn.Linear( - in_features, - self.out_features_per_partition, - bias=bias, - weight_init=weight_init, - bias_init=bias_init, - dtype=dtype, - ) - if self.gather_output: - self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) - - def construct(self, x: Tensor) -> Tensor: - x = self.copy_to_tensor_parallel_region(x) - x = self.linear(x) - if self.gather_output: - x = self.gather_from_tensor_parallel_region(x) - return x - - def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): - weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] - self.linear.weight.set_data(weight) - - if target.bias is not None: - bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] - self.linear.bias.set_data(bias) - - -class FusedColumnParallelLinear(nn.Cell): - """For tensor parallel using sequence parallel input - It is a fused operation of gather_forward_split_backward & allreduce backward - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - gather_output: bool = True, - dim: int = 1, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: Optional[ms.Type] = None, - ): - super().__init__(auto_prefix=False) - - self.rank = get_rank(group) - self.world_size = get_group_size(group) - assert out_features % self.world_size == 0 - self.out_features_per_partition = out_features // self.world_size - self.gather_output = gather_output - - self.gather_to_tensor_parallel_region = _GatherToModelParallelRegion(dim=dim, group=group) - self.linear = mint.nn.Linear( - in_features, - self.out_features_per_partition, - bias=bias, - weight_init=weight_init, - bias_init=bias_init, - dtype=dtype, - ) - if self.gather_output: - self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) - - def construct(self, x: Tensor) -> Tensor: - x = self.gather_to_tensor_parallel_region(x) - x = self.linear(x) - if self.gather_output: - x = self.gather_from_tensor_parallel_region(x) - return x - - def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): - weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] - self.linear.weight.set_data(weight) - - if target.bias is not None: - bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] - self.linear.bias.set_data(bias) - - -class RowParallelLinear(nn.Cell): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - input_is_parallel: bool = False, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: Optional[ms.Type] = None, - ): - super().__init__(auto_prefix=False) - - self.rank = get_rank(group) - self.world_size = get_group_size(group) - assert in_features % self.world_size == 0 - self.in_features_per_partition = in_features // self.world_size - self.input_is_parallel = input_is_parallel - - self.reduce_from_tensor_parallel_region = _ReduceFromModelParallelRegion(group=group) - self.linear = mint.nn.Linear( - self.in_features_per_partition, - out_features, - bias=bias, - weight_init=weight_init, - bias_init=bias_init, - dtype=dtype, - ) - if not self.input_is_parallel: - self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) - - def construct(self, x: Tensor) -> Tensor: - if not self.input_is_parallel: - x = self.scatter_to_tensor_parallel_region(x) - x = self.linear.dense(x, self.linear.weight) - x = self.reduce_from_tensor_parallel_region(x) - if self.linear.bias is not None: - x = x + self.linear.bias - return x - - def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): - weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] - self.linear.weight.set_data(weight) - - if target.bias is not None: - self.linear.bias.set_data(target.bias) - - -class FusedRowParallelLinear(nn.Cell): - """For tensor parallel to sequence parallel output - It is a fused operation of split_forward_gather_backward & allreduce forward - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, - input_is_parallel: bool = False, - dim: int = 1, - group: str = GlobalComm.WORLD_COMM_GROUP, - dtype: Optional[ms.Type] = None, - ): - super().__init__(auto_prefix=False) - - self.rank = get_rank(group) - self.world_size = get_group_size(group) - assert in_features % self.world_size == 0 - self.in_features_per_partition = in_features // self.world_size - self.input_is_parallel = input_is_parallel - - self.reduce_from_tensor_parallel_region = _ReduceScatterFromModelParallelRegion(dim=dim, group=group) - self.linear = mint.nn.Linear( - self.in_features_per_partition, - out_features, - bias=bias, - weight_init=weight_init, - bias_init=bias_init, - dtype=dtype, - ) - if not self.input_is_parallel: - self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) - - def construct(self, x: Tensor) -> Tensor: - if not self.input_is_parallel: - x = self.scatter_to_tensor_parallel_region(x) - x = self.linear.dense(x, self.linear.weight) - x = self.reduce_from_tensor_parallel_region(x) - if self.linear.bias is not None: - x = x + self.linear.bias - return x - - def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): - weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] - self.linear.weight.set_data(weight) - - if target.bias is not None: - self.linear.bias.set_data(target.bias) diff --git a/examples/moviegen/mg/parallel/parallel_states.py b/examples/moviegen/mg/parallel/parallel_states.py deleted file mode 100644 index 2a8d9c0a0c..0000000000 --- a/examples/moviegen/mg/parallel/parallel_states.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -from mindspore.communication import create_group, get_group_size, get_rank - -__all__ = ["set_model_parallel_group", "get_model_parallel_group", "create_parallel_group"] - - -_GLOBAL_PARALLEL_GROUPS = dict() - - -def set_model_parallel_group(group: str) -> None: - _GLOBAL_PARALLEL_GROUPS["model"] = group - - -def get_model_parallel_group() -> Optional[str]: - return _GLOBAL_PARALLEL_GROUPS.get("model", None) - - -def create_parallel_group(model_parallel_shards: int = 1) -> None: - if model_parallel_shards <= 1: - raise ValueError( - f"`model_parallel_shards` must be larger than 1 to enable model parallel, but get `{model_parallel_shards}`." - ) - - device_num = get_group_size() - if device_num % model_parallel_shards != 0: - raise ValueError( - f"Total number of devices ({device_num}) must be divisible by the number of model parallel shards ({model_parallel_shards})." - ) - - rank_id = get_rank() - - if model_parallel_shards > 1: - mp_group_id = rank_id // model_parallel_shards - mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) - mp_group_name = f"mp_group_{mp_group_id}" - create_group(mp_group_name, mp_group_rank_ids) - set_model_parallel_group(mp_group_name) diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index a80bd7dd44..89b3834baf 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -8,10 +8,8 @@ import mindspore as ms import mindspore.mint.nn.functional as F from mindspore import Tensor, mint, nn, ops -from mindspore.communication import get_rank from ..models import LlamaModel -from ..parallel import get_model_parallel_group logger = logging.getLogger(__name__) @@ -111,12 +109,7 @@ def __init__( self.model = model self.criteria = nn.MSELoss() - self.mp_group = get_model_parallel_group() - if self.mp_group is not None: - logging.info( - f"Broadcasting all random variables from rank (0) to current rank ({get_rank(self.mp_group)}) in group `{self.mp_group}`." - ) - self.broadcast = ops.Broadcast(0, group=self.mp_group) + self.mp_group = None def _discrete_sample(self, size: int) -> Tensor: return ops.randint(0, self.num_timesteps, (size,), dtype=ms.int64) diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh deleted file mode 100755 index b532dad534..0000000000 --- a/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh - -SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" -EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" -PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" - -export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" - -LOGDIR="./log_test_llama3_parallel_graph" -echo "Graph Mode:" -msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel.py --mode 0 -echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh deleted file mode 100755 index 603aac9fce..0000000000 --- a/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh - -SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" -EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" -PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" - -export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" - -LOGDIR="./log_test_llama3_parallel_block_graph" -echo "Graph Mode:" -msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_block.py --mode 0 -echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh deleted file mode 100755 index ecf23ff9a8..0000000000 --- a/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh - -SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" -EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" -PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" - -export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" - -LOGDIR="./log_test_llama3_parallel_layer_graph" -echo "Graph Mode:" -msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_layer.py --mode 0 -echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh b/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh deleted file mode 100755 index 88ad571cac..0000000000 --- a/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/sh - -SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" -PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" -EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" -PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" - -export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" - -LOGDIR="./log_test_rflow_parallel_graph" -echo "Graph Mode:" -msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_rflow_parallel.py --mode 0 -echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py deleted file mode 100644 index 542e44d007..0000000000 --- a/examples/moviegen/tests/parallel/test_llama3_parallel.py +++ /dev/null @@ -1,113 +0,0 @@ -import argparse -from typing import Tuple - -import numpy as np -from mg.models.llama.network import LlamaModel -from mg.parallel import create_parallel_group -from utils import gather_or_reduce_parallel_gradient - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.communication import get_group_size, init - -from mindone.utils.seed import set_random_seed - - -class MeanNet(nn.Cell): - def __init__(self, net: nn.Cell) -> None: - super().__init__() - self.net = net - - def construct(self, *inputs): - output = self.net(*inputs) - return output.mean() * 1024.0 - - -def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: - latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) - timestep = ms.Tensor([35], dtype=ms.int64) - ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) - metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) - byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) - return latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb - - -def get_network_config(model_parallelism=False, fused_tensor_parallel=False): - config = dict( - num_hidden_layers=2, - attn_implementation="eager", - model_parallelism=model_parallelism, - fused_tensor_parallel=fused_tensor_parallel, - post_init_weight=False, - ) - return config - - -def run_network(mode: int = 0, dtype: ms.Type = ms.float32): - ms.set_context(mode=mode) - init() - - # prepare data - set_random_seed(1024) - data = get_sample_data(dtype=dtype) - - # prepare group - create_parallel_group(model_parallel_shards=get_group_size()) - - print("Non-fused tensor parallel:", flush=True) - run_parallel_network(data, fused_tensor_parallel=False) - - print("Fused tensor parallel:", flush=True) - run_parallel_network(data, fused_tensor_parallel=True) - - -def run_parallel_network(data: Tuple[Tensor, ...], fused_tensor_parallel: bool = False, dtype: ms.Type = ms.float32): - # non parallel network - set_random_seed(1024) - non_parallel_network_cfg = get_network_config(model_parallelism=False, fused_tensor_parallel=fused_tensor_parallel) - non_parallel_network = LlamaModel(**non_parallel_network_cfg, dtype=dtype) - - # parallel netowrk - parallel_network_cfg = get_network_config(model_parallelism=True, fused_tensor_parallel=fused_tensor_parallel) - parallel_network = LlamaModel(**parallel_network_cfg, dtype=dtype) - - # load weight - parallel_network.load_weight_from_non_parallel_cell(non_parallel_network) - - # test forward - non_parallel_out = non_parallel_network(*data).asnumpy() - parallel_out = parallel_network(*data).asnumpy() - - assert np.count_nonzero(non_parallel_out) > 0 - np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) - np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) - print("Test 1 (Forward): Passed.", flush=True) - - # test backward - non_parallel_mean_net = MeanNet(non_parallel_network) - parallel_mean_net = MeanNet(parallel_network) - - # check the parameter gradient - grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) - non_parallel_grads = grad_fn(*data) - - grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) - parallel_grads = grad_fn(*data) - - for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): - grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) - grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() - assert np.count_nonzero(grad_0) > 0 - np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=2e-5) - print("Test 2 (Backward: Parameter Gradient): Passed.", flush=True) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" - ) - args = parser.parse_args() - run_network(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py deleted file mode 100644 index 82141a1d31..0000000000 --- a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py +++ /dev/null @@ -1,107 +0,0 @@ -import argparse - -import numpy as np -from mg.models.llama.block import LlamaMLP, TensorParallelLlamaMLP -from mg.parallel import create_parallel_group -from utils import gather_or_reduce_parallel_gradient - -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.communication import get_group_size, init - -from mindone.utils.seed import set_random_seed - - -class MeanNet(nn.Cell): - def __init__(self, net: nn.Cell) -> None: - super().__init__() - self.net = net - - def construct(self, *inputs): - output = self.net(*inputs) - return output.mean() * 1024.0 - - -def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: - x = ops.rand([4, 64, 3072], dtype=dtype) # (N, T, H) - return x - - -def get_block_config(): - config = dict(intermediate_size=8192, hidden_size=3072, hidden_act="silu") - return config - - -def run_block(mode: int = 0, dtype: ms.Type = ms.float32): - ms.set_context(mode=mode) - init() - - # prepare data - set_random_seed(1024) - data = get_sample_data(dtype=dtype) - - # prepare group - create_parallel_group(model_parallel_shards=get_group_size()) - - # non parallel block - set_random_seed(1024) - non_parallel_block_cfg = get_block_config() - non_parallel_block = LlamaMLP(**non_parallel_block_cfg, dtype=dtype) - - # parallel block - parallel_block_cfg = get_block_config() - parallel_block = TensorParallelLlamaMLP(**parallel_block_cfg, dtype=dtype) - - # load weight - parallel_block.load_weight_from_non_parallel_cell(non_parallel_block) - - # test forward - non_parallel_out = non_parallel_block(data).asnumpy() - parallel_out = parallel_block(data).asnumpy() - - assert np.count_nonzero(non_parallel_out) > 0 - np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) - np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) - print("Test 1 (Forward): Passed.") - - # test backward - non_parallel_mean_net = MeanNet(non_parallel_block) - parallel_mean_net = MeanNet(parallel_block) - - # check the parameter gradient - grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) - non_parallel_grads = grad_fn(data) - - grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) - parallel_grads = grad_fn(data) - - for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): - grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) - grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() - assert np.count_nonzero(grad_0) > 0 - np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) - print("Test 2 (Backward: Parameter Gradient): Passed.") - - # check the input gradient - grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) - non_parallel_grads = grad_fn(data) - - grad_fn = ops.grad(parallel_mean_net, grad_position=0) - parallel_grads = grad_fn(data) - - for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): - grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() - assert np.count_nonzero(grad_0) > 0 - np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) - print("Test 3 (Backward: Input Gradient): Passed.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" - ) - args = parser.parse_args() - run_block(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py deleted file mode 100644 index a4c5afb140..0000000000 --- a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse -from typing import Literal - -import numpy as np -from mg.parallel import ColumnParallelLinear, RowParallelLinear, create_parallel_group, get_model_parallel_group -from utils import gather_or_reduce_parallel_gradient - -import mindspore as ms -import mindspore.mint as mint -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.communication import get_group_size, init - -from mindone.utils.seed import set_random_seed - - -class MeanNet(nn.Cell): - def __init__(self, net: nn.Cell) -> None: - super().__init__() - self.net = net - - def construct(self, *inputs): - output = self.net(*inputs) - return output.mean() * 1024.0 - - -def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: - x = ops.rand([4, 64, 256], dtype=dtype) # (N, T, H) - return x - - -def get_layer_config(bias: bool = False): - config = dict(in_features=256, out_features=32, bias=bias) - return config - - -def run_layer(mode: int = 0, dtype: ms.Type = ms.float32): - ms.set_context(mode=mode) - init() - - # prepare data - set_random_seed(1024) - data = get_sample_data(dtype=dtype) - - # prepare group - create_parallel_group(model_parallel_shards=get_group_size()) - - print("Column Parallel Linear (Bias=True):") - run_parallel_linear(data, type="column_parallel", bias=True, dtype=dtype) - print("Column Parallel Linear (Bias=False):") - run_parallel_linear(data, type="column_parallel", bias=False, dtype=dtype) - print("Row Parallel Linear (Bias=True):") - run_parallel_linear(data, type="row_parallel", bias=True, dtype=dtype) - print("Row Parallel Linear (Bias=False):") - run_parallel_linear(data, type="row_parallel", bias=False, dtype=dtype) - - -def run_parallel_linear( - data: Tensor, type: Literal["column_parallel", "row_parallel"], bias: bool = False, dtype: ms.Type = ms.float32 -): - # non parallel layer - set_random_seed(1024) - non_parallel_layer_cfg = get_layer_config(bias=bias) - non_parallel_layer = mint.nn.Linear(**non_parallel_layer_cfg, dtype=dtype) - - # parallel layer - group = get_model_parallel_group() - parallel_layer_cfg = get_layer_config(bias=bias) - if type == "column_parallel": - parallel_layer = ColumnParallelLinear(**parallel_layer_cfg, gather_output=True, group=group, dtype=dtype) - else: - parallel_layer = RowParallelLinear(**parallel_layer_cfg, input_is_parallel=False, group=group, dtype=dtype) - - # load weight - parallel_layer.load_weight_from_non_parallel_cell(non_parallel_layer) - - # test forward - non_parallel_out = non_parallel_layer(data).asnumpy() - parallel_out = parallel_layer(data).asnumpy() - - assert np.count_nonzero(non_parallel_out) > 0 - np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) - np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) - print("Test 1 (Forward): Passed.") - - # test backward - non_parallel_mean_net = MeanNet(non_parallel_layer) - parallel_mean_net = MeanNet(parallel_layer) - - # check the parameter gradient - grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) - non_parallel_grads = grad_fn(data) - - grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) - parallel_grads = grad_fn(data) - - for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): - grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) - grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() - assert np.count_nonzero(grad_0) > 0 - np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) - print("Test 2 (Backward: Parameter Gradient): Passed.") - - # check the input gradient - grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) - non_parallel_grads = grad_fn(data) - - grad_fn = ops.grad(parallel_mean_net, grad_position=0) - parallel_grads = grad_fn(data) - - for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): - grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() - assert np.count_nonzero(grad_0) > 0 - np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) - print("Test 3 (Backward: Input Gradient): Passed.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" - ) - args = parser.parse_args() - run_layer(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_rflow_parallel.py b/examples/moviegen/tests/parallel/test_rflow_parallel.py deleted file mode 100644 index a3d6302e3a..0000000000 --- a/examples/moviegen/tests/parallel/test_rflow_parallel.py +++ /dev/null @@ -1,61 +0,0 @@ -import argparse -from typing import Tuple - -from mg.parallel import create_parallel_group -from mg.schedulers import RFlowLossWrapper - -import mindspore as ms -from mindspore import Tensor, nn, ops -from mindspore.communication import get_group_size, init - -from mindone.utils.seed import set_random_seed - - -class SimpleNet(nn.Cell): - def construct( - self, x: Tensor, timestamp: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor - ) -> Tensor: - return x.to(ms.float32) - - @property - def dtype(self): - return ms.float32 - - -def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: - latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) - ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) - metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) - byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) - return latent_embedding, ul2_emb, metaclip_emb, byt5_emb - - -def run_network(mode: int = 0): - ms.set_context(mode=mode) - init() - - # prepare data - set_random_seed(1024) - data = get_sample_data() - - # prepare group - create_parallel_group(model_parallel_shards=get_group_size()) - - model = SimpleNet() - - # parallel netowrk - network = RFlowLossWrapper(model) - - loss = network(*data) - loss = ops.AllGather()(ops.unsqueeze(loss, 0)).asnumpy() - assert loss[0] == loss[1], f"expected two elements to be same, but get `{loss}`." - print("Test 1: Passed.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" - ) - args = parser.parse_args() - run_network(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/utils.py b/examples/moviegen/tests/parallel/utils.py deleted file mode 100644 index 2f8d19e2d5..0000000000 --- a/examples/moviegen/tests/parallel/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Callable, Tuple - -import numpy as np - -import mindspore.ops as ops -from mindspore import Tensor -from mindspore.communication import GlobalComm, get_group_size - - -def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: - x = x.swapaxes(0, dim) - x = func(x) - x = x.swapaxes(dim, 0) - return x - - -def gather_or_reduce_parallel_gradient( - parallel_gradient: Tensor, non_parallel_gradient_shape: Tuple[int, ...], group: str = GlobalComm.WORLD_COMM_GROUP -) -> Tensor: - if parallel_gradient.shape == non_parallel_gradient_shape: - # Sequence Parallel / Context Parallel - allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) - parallel_gradient = allreduce(parallel_gradient) / get_group_size(group) - return parallel_gradient - - scales = np.array(non_parallel_gradient_shape) / np.array(parallel_gradient.shape) - assert np.count_nonzero(scales - 1) == 1 - assert np.prod(scales) == get_group_size(group) - dim = np.argmax(scales).item() - allgather = ops.AllGather(group=group) - parallel_gradient = _communicate_along_dim(parallel_gradient, dim, allgather) - return parallel_gradient diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 2beb1c5cc1..f65d83d2ec 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -18,7 +18,6 @@ from mg.dataset import ImageVideoDataset, bucket_split_function from mg.models.tae import TemporalAutoencoder -from mg.parallel import create_parallel_group from mg.pipelines import DiffusionWithLoss from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper from mg.utils import EMA, MODEL_DTYPE, init_model @@ -75,10 +74,10 @@ def main(args): # 1.1 init model parallel shard_rank_id = rank_id - if (shards := args.train.model_parallel.model_parallel_shards) > 1: - create_parallel_group(**args.train.model_parallel) - device_num = device_num // shards - shard_rank_id = rank_id // shards + # if (shards := args.train.model_parallel.model_parallel_shards) > 1: + # create_parallel_group(**args.train.model_parallel) + # device_num = device_num // shards + # shard_rank_id = rank_id // shards # FIXME: Improve seed setting set_seed(args.env.seed + shard_rank_id) # set different seeds per NPU for sampling different timesteps @@ -112,7 +111,7 @@ def main(args): ) args.model.in_channels = tae.out_channels else: - logger.info("TAE latent folder found. Skipping TAE initialization.") + logger.info("TAE latent folder provided. Skipping TAE initialization.") tae = None # 2.2 Llama 3 @@ -271,7 +270,6 @@ def main(args): create_dataloader, "dataloader", skip={"dataset", "transforms", "batch_transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") - parser.add_function_arguments(create_parallel_group, "train.model_parallel") parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"}) parser.add_class_arguments( ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False diff --git a/mindone/data/loader.py b/mindone/data/loader.py index 8ff9fd197e..be469e867d 100644 --- a/mindone/data/loader.py +++ b/mindone/data/loader.py @@ -20,7 +20,7 @@ def create_dataloader( drop_remainder: bool = True, python_multiprocessing: bool = True, prefetch_size: int = 16, - max_rowsize: int = 64, + max_rowsize: Optional[int] = None, device_num: int = 1, rank_id: int = 0, debug: bool = False, @@ -52,9 +52,14 @@ def create_dataloader( python_multiprocessing: Whether to use Python multiprocessing for data transformations. This option could be beneficial if the Python operation is computational heavy. Default is True. prefetch_size: The number of samples to prefetch (per device). Default is 16. - max_rowsize: (MindSpore 2.2 and lower only) Maximum size of row in MB that is used for shared memory allocation - to copy data between processes. This is only used if `python_multiprocessing` is set to `True`. - Default is 64. + max_rowsize: Maximum size of row in MB for shared memory allocation to copy data among processes. + This is only used if `python_multiprocessing` is set to `True`. + Values: + - `None` (default): + - For MindSpore 2.3 and above: Uses -1 (dynamic allocation). + - For MindSpore 2.2 and below: Uses 64MB. + - `-1`: (MindSpore 2.3+ only) Allocates memory dynamically. + - Positive integer: Sets a specific maximum row size in MB. device_num: The number of devices to distribute the dataset across. Default is 1. rank_id: The rank ID of the current device. Default is 0. debug: Whether to enable debug mode. Default is False. @@ -85,6 +90,12 @@ def create_dataloader( shuffle=shuffle, ) + if max_rowsize is None: + # MS 2.3 and above: allocate memory dynamically + max_rowsize = -1 if MS_VERSION >= "2.3" else 64 + if MS_VERSION < "2.3" and max_rowsize <= 0: + raise ValueError(f"`max_rowsize` must be a positive integer, got {max_rowsize}") + if transforms is not None: if isinstance(transforms, dict): transforms = [transforms] @@ -94,7 +105,7 @@ def create_dataloader( **transform, python_multiprocessing=python_multiprocessing, num_parallel_workers=num_workers, - max_rowsize=max_rowsize if MS_VERSION < "2.3" else -1, # MS 2.3 and above: allocate memory dynamically + max_rowsize=max_rowsize, ) if project_columns: @@ -122,7 +133,7 @@ def create_dataloader( **batch_transform, python_multiprocessing=python_multiprocessing, num_parallel_workers=num_workers, - max_rowsize=max_rowsize if MS_VERSION < "2.3" else -1, + max_rowsize=max_rowsize, ) return dataloader diff --git a/mindone/trainers/train_step.py b/mindone/trainers/train_step.py index 18d0ccc0b8..db86a98a1c 100644 --- a/mindone/trainers/train_step.py +++ b/mindone/trainers/train_step.py @@ -101,7 +101,7 @@ def __init__( if gradient_accumulation_steps > 1: self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros") - def set_train(self, mode=True): + def set_train(self, mode: bool = True): # Delegate the setting of training mode behavior to the network. self.network.set_train(mode) From a7d8a3718c680789b8f6cc9a7454ed088bcf6fae Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:49:50 +0800 Subject: [PATCH 083/122] Update docs --- examples/moviegen/README.md | 97 +++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 0bfead2150..66db6ca51c 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -9,13 +9,14 @@ Meta researchers found that scaling the training data, compute, and model parame Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained with [Flow Matching](https://arxiv.org/abs/2210.02747) yields high quality generative models for video or audio. -## Features: +### Features 1. :white_check_mark: Text-to-Video synthesis 2. \[Coming soon] Video personalization 3. \[Coming soon] Video editing -### TODO +
+TODO - [ ] Fix EMA. - [ ] Use ByT5 for encoding visual text only (i.e., text within quotes). @@ -24,16 +25,18 @@ Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained wit - [ ] Fix Model Parallel training. - [ ] Add FPS conditioning. -# Demo +
+ +## Demo Coming soon. -# Architecture +## Architecture
Architecture details -## Transformer Backbone +### Transformer Backbone The Movie Gen family of models contains the following variations: 1B, 5B, and 30B parameters. It uses the [LLaMa3](https://arxiv.org/abs/2407.21783) backbone architecture for the joint image-video generation model, @@ -49,11 +52,11 @@ There are three changes to the LLaMa3 Transformer block for the use case of vide ([DiT](https://arxiv.org/abs/2212.09748)). 3. Use full bidirectional attention instead of causal attention used in language modeling. -## TAE +### TAE [//]: # (TODO) -## Text Encoders +### Text Encoders Movie Gen uses a combination of [UL2](https://arxiv.org/abs/2205.05131), [ByT5](https://arxiv.org/abs/2105.13626), and Long-prompt [MetaCLIP](https://arxiv.org/abs/2309.16671) as text encoders to provide both semantic-level and @@ -69,7 +72,7 @@ character-level text understanding for the backbone:
-# Installation +## Installation | MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel | |:---------:|:-------------:|:-----------:|:-------------------:| @@ -84,7 +87,7 @@ character-level text understanding for the backbone: pip install -r requirements.txt ``` -# Model Weights +## Model Weights
TAE @@ -115,7 +118,7 @@ If you face an SSL certificate verification error, you can add `--disable_ssl_ve
-# Generating Text Embeddings +## Generating Text Embeddings Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating text embeddings online. Therefore, you need to prepare them in advance by running the following command: @@ -128,25 +131,28 @@ python inference_text_enc.py \ --model_max_length 512 ``` -> [!TIP] +> [!NOTE] > We use the sequence length of 512 tokens for UL2, 256 for MetaCLIP, and 100 for ByT5. -# Inference +## Inference -## Text-to-Video +For more detailed instructions, please run `python inference.py --help`. + +### Text-to-Image ```shell python inference.py \ --config configs/inference/moviegen_t2i_256x256.yaml \ ---model.name llama-5B +--model.name llama-5B \ --model.pretrained_model_path /path/to/llama-5B.ckpt \ --text_emb.ul2_dir /path/to/ul2_embeddings \ --text_emb.metaclip_dir /path/to/metaclip_embeddings \ --text_emb.byt5_dir /path/to/byt5_embeddings \ ---image_size 256 455 +--image_size 256 455 \ +--batch_size 2 ``` -## Text-to-Image +### Text-to-Video ```shell python inference.py \ @@ -162,9 +168,9 @@ python inference.py \ --save_format mp4 ``` -## TAE +### TAE -### Encoding video +#### Encoding Video ```python from mg.models.tae import TemporalAutoencoder @@ -184,7 +190,7 @@ z = (z - tae.shift_factor) * tae.scale_factor For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) -### Decoding video latent +#### Decoding Video Latent ```python # if z is scaled, you should unscale at first: @@ -197,7 +203,7 @@ x = tae.decode(z) x = tae.decode(z, num_target_frames=1) ``` -# Training +## Training Movie Gen is trained jointly on images and videos in 4 stages: @@ -210,20 +216,47 @@ Images are treated as single frame videos, enabling the use of the same model to Compared to video data, paired image-text datasets are easier to scale with diverse concepts and styles, and thus joint modeling of image and video leads to better generalization. -## Movie Gen - -To train Movie Gen, run the following command: +To train Movie Gen, run the following commands: ```shell scripts/stage1_train.sh # for stage 1 training scripts/stage2_train.sh # for stage 2 training ``` +### Dataset Preparation + +Paths to videos and their corresponding captions should be stored in a CSV file with two columns: `video` and `caption`. +For example: + +```text +video,caption +video_folder/part01/vid001.mp4,a cartoon character is walking through +video_folder/part01/vid002.mp4,a red and white ball with an angry look on its face +``` + +### Cache Video Embedding (Optional) + +If you have sufficient storage budget, you can cache the video embeddings to speed up training by using the following +command: + +```shell +python inference_tae_enc.py \ +--tae.pretrained=/path/to/tae.ckpt \ +--tae.dtype=bf16 \ +--data.folder=/path/to/folder/with/videos/ \ +--output_path=/path/to/output/directory/ \ +--data.size=256 \ +--data.crop_size=[256,455] +``` + ### Performance -| Model | Context | Jit level | Stage | Precision | Resolution | Batch size | NPUs | Time (s/step) | Config | -|-------|-------------------|-----------|---------|-----------|----------------|------------|------|---------------|--------------------------------------------------------------------| -| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](./configs/train/stage1_t2i_256x256.yaml) | +| Model | Context | Jit level | Stage | Precision | Resolution | TAE Cache | Batch size | NPUs | Time (s/step) | Config | +|:-----:|:-----------------:|:---------:|:---------:|:---------:|:----------------------------:|:---------:|:-----------------------:|:----:|:-------------:|:-----------------------------------------------------------------:| +| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | No | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 5B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | No | Image: 10
Video: 5 | 8 | 5.26 | [stage1_t2i_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| 1B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | Yes | 10 | 8 | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 1B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | Yes | Image: 10
Video: 10 | 8 | 2.08 | [stage1_t2i_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | ### Validation During Training @@ -240,9 +273,13 @@ Validation can be enabled by either setting parameters in the `valid` field of t --valid.dataset.text_emb_folder.byt5 /path/to/byt5_embeddings ``` -## TAE +## Evaluation -### Prepare datasets +Coming soon. + +## TAE Training & Evaluation + +### Dataset Preparation We need to prepare a csv annotation file listing the path to each input video related to the root folder, indicated by the `video_folder` argument. An example is @@ -304,7 +341,3 @@ Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode. | model name | cards | batch size | resolution | precision | jit level | graph compile | s/step | PSNR | SSIM | recipe | |:----------:|:-----:|:----------:|:----------:|:---------:|:---------:|:-------------:|:------:|:-----:|:----:|:-------------------------------------------------:| | TAE | 1 | 1 | 256x256x32 | bf16 | O0 | 2 min | 2.18 | 31.35 | 0.92 | [config](configs/tae/train/mixed_256x256x32.yaml) | - -# Evaluation - -Coming soon. From 0b77dfe14b20ae71b00766321f7f4bfb8eeaf0b5 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:29:52 +0800 Subject: [PATCH 084/122] Update docs --- examples/moviegen/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 66db6ca51c..88b8b9091d 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -251,12 +251,12 @@ python inference_tae_enc.py \ ### Performance -| Model | Context | Jit level | Stage | Precision | Resolution | TAE Cache | Batch size | NPUs | Time (s/step) | Config | -|:-----:|:-----------------:|:---------:|:---------:|:---------:|:----------------------------:|:---------:|:-----------------------:|:----:|:-------------:|:-----------------------------------------------------------------:| -| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | No | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 5B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | No | Image: 10
Video: 5 | 8 | 5.26 | [stage1_t2i_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | -| 1B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | Yes | 10 | 8 | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 1B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | Yes | Image: 10
Video: 10 | 8 | 2.08 | [stage1_t2i_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| Model | Context | Jit level | Stage | Precision | Resolution | TAE Cache | Batch size | NPUs | Time (s/step) | Config | +|:-----:|:-----------------:|:---------:|:---------:|:---------:|:----------------------------:|:---------:|:-----------------------:|:----:|:-------------:|:------------------------------------------------------------------:| +| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | No | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 5B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | No | Image: 10
Video: 5 | 8 | 5.26 | [stage1_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| 1B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | Yes | 10 | 8 | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 1B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | Yes | Image: 10
Video: 10 | 8 | 2.08 | [stage1_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | ### Validation During Training From e69f04d4ba5b2e9d07e6116f8f3178709711aff4 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:46:30 +0800 Subject: [PATCH 085/122] update docs and small fixes --- examples/moviegen/README.md | 7 ++++++- examples/moviegen/inference.py | 2 +- examples/moviegen/inference_text_enc.py | 2 +- examples/moviegen/train.py | 11 ++++++++--- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 88b8b9091d..af57c6615b 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -29,7 +29,12 @@ Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained wit ## Demo -Coming soon. +| 32x256x455 | 32x256x455 | +|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +|
-## Generating Text Embeddings +## Inference + +### Generating Text Embeddings Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating text embeddings online. Therefore, you need to prepare them in advance by running the following command: @@ -139,12 +96,10 @@ python inference_text_enc.py \ > [!NOTE] > We use the sequence length of 512 tokens for UL2, 256 for MetaCLIP, and 100 for ByT5. -## Inference +### Text-to-Image For more detailed instructions, please run `python inference.py --help`. -### Text-to-Image - ```shell python inference.py \ --config configs/inference/moviegen_t2i_256x256.yaml \ @@ -173,41 +128,6 @@ python inference.py \ --save_format mp4 ``` -### TAE - -#### Encoding Video - -```python -from mg.models.tae import TemporalAutoencoder - -# may set use_tile=True to save memory -tae = TemporalAutoencoder( - pretrained='/path/to/tae.ckpt', - use_tile=False, -) - -# x - a batch of videos, shape (b c t h w) -z, _, _ = tae.encode(x) - -# you may scale z by: -z = (z - tae.shift_factor) * tae.scale_factor -``` - -For detailed arguments, please refer to the docstring in [tae.py](mg/models/tae/tae.py) - -#### Decoding Video Latent - -```python -# if z is scaled, you should unscale at first: -z = z / tae.scale_factor + tae.shift_factor - -# z - a batch of video latent, shape (b c t h w) -x = tae.decode(z) - -# for image decoding, set num_target_frames to discard the spurious frames -x = tae.decode(z, num_target_frames=1) -``` - ## Training Movie Gen is trained jointly on images and videos in 4 stages: @@ -239,6 +159,12 @@ video_folder/part01/vid001.mp4,a cartoon character is walking through video_folder/part01/vid002.mp4,a red and white ball with an angry look on its face ``` +### Generating Text Embeddings + +Due to the large memory footprint of the text encoders, the inference and training pipelines don't support generating +text embeddings online. Please refer to the [Generating Text Embeddings](#generating-text-embeddings) section under the +Inference section for details. + ### Cache Video Embedding (Optional) If you have sufficient storage budget, you can cache the video embeddings to speed up training by using the following diff --git a/examples/opensora_hpcai/opensora/models/vae/vae.py b/examples/opensora_hpcai/opensora/models/vae/vae.py index b79a94dadf..d8154c0beb 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae.py @@ -13,7 +13,7 @@ from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD from .vae_temporal import VAE_Temporal_SD # noqa: F401 -__all__ = ["AutoencoderKL", "OpenSoraVAE_V1_2"] +__all__ = ["AutoencoderKL"] _logger = logging.getLogger(__name__) diff --git a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py index fc2c02e9b2..09d08dca4d 100644 --- a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py @@ -58,11 +58,12 @@ def __init__( self.use_cfg = False self.text_encoder = text_encoder + self.diffusion = create_diffusion(str(num_inference_steps)) if sampling.lower() == "ddim": - self.sampling_func = create_diffusion(str(num_inference_steps)).ddim_sample_loop + self.sampling_func = self.diffusion.ddim_sample_loop elif sampling.lower() == "ddpm": - self.sampling_func = create_diffusion(str(num_inference_steps)).p_sample_loop + self.sampling_func = self.diffusion.p_sample_loop elif sampling.lower() == "rflow": self.sampling_func = RFLOW(num_inference_steps, cfg_scale=guidance_rescale, use_timestep_transform=True) else: @@ -127,9 +128,9 @@ def data_prepare(self, inputs): # for token/text drop in caption embedder for condition-free guidance training. The null mask is the same as text mask. n = x.shape[0] # (n_tokens, dim_emb) -> (b n_tokens dim_emb) + null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) if self.use_cfg: - null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) y = ops.cat([text_emb, null_emb], axis=0) x_in = ops.concat([x] * 2, axis=0) assert y.shape[0] == x_in.shape[0], "shape mismatch!" diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index fddc0f03f7..68bc42196e 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -40,7 +40,7 @@ def parse_train_args(parser): "--caption_column", default="caption", type=str, help="name of column for captions saved in csv file" ) parser.add_argument("--video_folder", required=True, type=str, help="root dir for the video data") - parser.add_argument("--text_embed_folder", type=str, help="root dir for the text embedding data") + parser.add_argument("--text_embed_folder", type=str, help="root dir for the text embeding data") parser.add_argument("--vae_latent_folder", type=str, help="root dir for the vae latent data") parser.add_argument("--filter_data", default=False, type=str2bool, help="Filter non-existing videos.") parser.add_argument("--output_path", default="output/", type=str, help="output directory to save training results") @@ -49,7 +49,7 @@ def parse_train_args(parser): ) # model parser.add_argument( - "--model_version", default="v1", type=str, choices=["v1", "v1.1", "v1.2"], help="OpenSora model version." + "--model_version", default="v1", type=str, choices=["v1", "v1.1"], help="OpenSora model version." ) parser.add_argument( "--pretrained_model_path", diff --git a/examples/opensora_hpcai/scripts/inference_vae.py b/examples/opensora_hpcai/scripts/inference_vae.py index d90a7f9080..c2a4d30601 100644 --- a/examples/opensora_hpcai/scripts/inference_vae.py +++ b/examples/opensora_hpcai/scripts/inference_vae.py @@ -185,9 +185,6 @@ def main(args): mean_ssim += sum(ssim_cur) num_samples += x_rgb.shape[0] - logger.info(f"cur psnr: {psnr_cur[-1]:.4f}, mean psnr:{mean_psnr/num_samples:.4f}") - logger.info(f"cur ssim: {ssim_cur[-1]:.4f}, mean ssim:{mean_ssim/num_samples:.4f}") - if args.eval_loss: recon_loss = np.abs((x - recons).asnumpy()) lpips_loss = lpips_loss_fn(x, recons).asnumpy() diff --git a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh b/examples/opensora_hpcai/tools/mem_monitor/monitor.sh deleted file mode 100644 index 040480f68c..0000000000 --- a/examples/opensora_hpcai/tools/mem_monitor/monitor.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -# Check if a PID is provided -if [ "$#" -ne 2 ]; then - echo "Usage: $0 " - exit 1 -fi - -PID=$1 -LOG_FILE="memory_usage.log" - -# Check if the process with the given PID exists -if ! ps -p $PID > /dev/null; then - echo "Process with PID $PID does not exist." - exit 1 -fi - -# Initialize the log file -echo "Timestamp,Memory_Usage_Percentage" > "$LOG_FILE" - -# Monitor memory usage -echo "Monitoring memory usage for PID: $PID. Logging to $LOG_FILE" -echo "Press [CTRL+C] to stop." - -# Loop to continuously monitor memory usage -while true; do - # Get the total memory in KB - TOTAL_MEM=$(grep MemTotal /proc/meminfo | awk '{print $2}') - - # Get the RSS memory of the process in KB - MEMORY_INFO=$(pmap -x $PID | tail -n 1) - RSS_MEMORY=$(echo $MEMORY_INFO | awk '{print $3}') # Get the total RSS memory - - # Calculate memory usage percentage - if [ -n "$RSS_MEMORY" ]; then - MEMORY_USAGE_PERCENTAGE=$(echo "scale=2; ($RSS_MEMORY / $TOTAL_MEM) * 100" | bc) - TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S") - - # Log the timestamp and memory usage percentage - echo "$TIMESTAMP,$MEMORY_USAGE_PERCENTAGE" >> "$LOG_FILE" - - # Print the memory usage percentage to the console - echo "[$TIMESTAMP] Memory Usage: $MEMORY_USAGE_PERCENTAGE%" - else - echo "Unable to retrieve memory usage for PID $PID." - fi - - # Sleep for a specified interval (e.g., 1 second) - sleep 10 -done diff --git a/examples/opensora_hpcai/tools/mem_monitor/plot.py b/examples/opensora_hpcai/tools/mem_monitor/plot.py deleted file mode 100644 index 6b2e33ae07..0000000000 --- a/examples/opensora_hpcai/tools/mem_monitor/plot.py +++ /dev/null @@ -1,19 +0,0 @@ -import matplotlib.pyplot as plt -import pandas as pd - -# Read the log file -data = pd.read_csv("memory_usage.log", parse_dates=["Timestamp"]) - -# Plotting the memory usage -plt.figure(figsize=(10, 5)) -plt.plot(data["Timestamp"], data["Memory_Usage_Percentage"], label="Memory Usage (%)", color="blue") -plt.title("Memory Usage Percentage Over Time") -plt.xlabel("Time") -plt.ylabel("Memory Usage (%)") -plt.xticks(rotation=45) -plt.ylim(0, 100) # Set y-axis limits from 0 to 100% -plt.grid() -plt.legend() -plt.tight_layout() -plt.savefig("memory_usage_plot.png") -plt.show() diff --git a/mindone/data/loader.py b/mindone/data/loader.py index be469e867d..ce4912a1b3 100644 --- a/mindone/data/loader.py +++ b/mindone/data/loader.py @@ -93,8 +93,6 @@ def create_dataloader( if max_rowsize is None: # MS 2.3 and above: allocate memory dynamically max_rowsize = -1 if MS_VERSION >= "2.3" else 64 - if MS_VERSION < "2.3" and max_rowsize <= 0: - raise ValueError(f"`max_rowsize` must be a positive integer, got {max_rowsize}") if transforms is not None: if isinstance(transforms, dict): From ce10e7fa00be5cb6167c40fa499b06b442480adf Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:30:06 +0800 Subject: [PATCH 088/122] small inference fix --- examples/moviegen/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 043d124dbc..5f331b611e 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -90,7 +90,8 @@ def main(args): # 2.3 text embeddings prompt_prefix = [os.path.basename(emb)[:-4] for emb in ul2_emb] ul2_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in ul2_emb], dtype=ms.float32) - metaclip_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in metaclip_emb], dtype=ms.float32) + # metaclip_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in metaclip_emb], dtype=ms.float32) + metaclip_emb = ms.Tensor(np.ones((ul2_emb.shape[0], 300, 1280)), dtype=ms.float32) # FIXME: replace with actual byt5_emb = ms.Tensor([np.load(emb)["text_emb"] for emb in byt5_emb], dtype=ms.float32) num_prompts = ul2_emb.shape[0] From 3076725c6051b4dcfdf14b7d242d94204a2e22be Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:25:13 +0800 Subject: [PATCH 089/122] enable `lazy_inline` enable jit_level `O2` support --- examples/moviegen/mg/models/llama/network.py | 22 ++++----- .../moviegen/mg/schedulers/rectified_flow.py | 48 ++++++++----------- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index b6f3faa179..e067bad62a 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -4,11 +4,9 @@ import numpy as np -import mindspore as ms -import mindspore.mint as mint -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore import Parameter, Tensor, load_checkpoint, load_param_into_net +from mindspore import Parameter, Tensor +from mindspore import dtype as mstype +from mindspore import lazy_inline, load_checkpoint, load_param_into_net, mint, nn, ops from mindone.models.utils import normal_, zeros_ @@ -37,7 +35,7 @@ def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: class LlamaDecoderLayer(nn.Cell): - # @ms.lazy_inline(policy="front") + @lazy_inline(policy="front") def __init__( self, hidden_size: int = 4096, @@ -49,7 +47,7 @@ def __init__( attention_bias: bool = False, hidden_act: str = "silu", attn_implementation: Literal["eager", "flash_attention"] = "eager", - dtype: ms.Type = ms.float32, + dtype: mstype = mstype.float32, ) -> None: super().__init__() @@ -128,7 +126,7 @@ def __init__( patch_size: Tuple[int, int, int] = (1, 2, 2), out_channels: int = 8, rms_norm_eps: float = 1e-5, - dtype: ms.Type = ms.float32, + dtype: mstype = mstype.float32, ) -> None: super().__init__() self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -168,7 +166,7 @@ def __init__( use_linear_patch_embedder: bool = True, model_parallelism: bool = False, post_init_weight: bool = True, - dtype: ms.Type = ms.float32, + dtype: mstype.Type = mstype.float32, ) -> None: super().__init__() self.patch_size = patch_size @@ -281,9 +279,9 @@ def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: # assert nh < self.max_length[1] # assert nw < self.max_length[2] - t_inds = mint.arange(nt, dtype=ms.int64) - h_inds = mint.arange(nh, dtype=ms.int64) - w_inds = mint.arange(nw, dtype=ms.int64) + t_inds = mint.arange(nt, dtype=mstype.int64) + h_inds = mint.arange(nh, dtype=mstype.int64) + w_inds = mint.arange(nw, dtype=mstype.int64) position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") position_ids = ops.stack(position_ids, axis=-1) diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index 89b3834baf..a8a696b246 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -5,9 +5,9 @@ import numpy as np from tqdm import tqdm -import mindspore as ms -import mindspore.mint.nn.functional as F -from mindspore import Tensor, mint, nn, ops +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore import mint, nn, ops from ..models import LlamaModel @@ -17,25 +17,17 @@ class LogisticNormal(nn.Cell): - def __init__(self, loc: float = 0.0, scale: float = 1.0) -> None: + def __init__(self, loc: float = 0.0, scale: float = 1.0): super().__init__() + self.stdnormal = ops.StandardNormal() self.mean = loc self.std = scale - self._min = Tensor(np.finfo(np.float32).tiny, dtype=ms.float32) - self._max = Tensor(1.0 - np.finfo(np.float32).eps, dtype=ms.float32) - - def construct(self, shape: Tuple[int, ...]) -> Tensor: - # assert shape[-1] == 1 - x = mint.normal(mean=self.mean, std=self.std, size=shape) - offset = x.shape[-1] + 1 - mint.cumsum(mint.ones(x.shape[-1]), dim=-1) - z = self._clipped_sigmoid(x - mint.log(offset)) - z_cumprod = ops.cumprod((1 - z), dim=-1) - y = F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) - return y[:, 0] - - def _clipped_sigmoid(self, x: Tensor) -> Tensor: - x = mint.clamp(mint.sigmoid(x), min=self._min, max=self._max) - return x + self._min = Tensor(np.finfo(np.float32).tiny, dtype=mstype.float32) + self._max = Tensor(1.0 - np.finfo(np.float32).eps, dtype=mstype.float32) + + def construct(self, shape: Tuple[int]) -> Tensor: + x = self.mean + self.std * self.stdnormal(shape) + return ops.clamp(ops.sigmoid(x), self._min, self._max) class RFLOW: @@ -112,13 +104,13 @@ def __init__( self.mp_group = None def _discrete_sample(self, size: int) -> Tensor: - return ops.randint(0, self.num_timesteps, (size,), dtype=ms.int64) + return ops.randint(0, self.num_timesteps, (size,), dtype=mstype.int64) def _uniform_sample(self, size: int) -> Tensor: - return mint.rand((size,), dtype=ms.float32) * self.num_timesteps + return mint.rand((size,), dtype=mstype.float32) * self.num_timesteps def _logit_normal_sample(self, size: int) -> Tensor: - return self.distribution((size, 1)) * self.num_timesteps + return self.distribution((size,)) * self.num_timesteps def _broadcast(self, x: Tensor) -> Tensor: if self.mp_group is None: @@ -136,12 +128,12 @@ def construct( byt5_emb: (N, L3, 1472) ByT5 text embeddings timestep: (N,) tensor to indicate a denoising step """ - x = x.to(ms.float32) + x = x.to(mstype.float32) if timestep is None: timestep = self._broadcast(self._sample_func(x.shape[0])) - noise = self._broadcast(mint.normal(size=x.shape)) + noise = self._broadcast(ops.rand_like(x.shape)) x_t = self.add_noise(x, noise, timestep) model_output = self.model( @@ -150,7 +142,7 @@ def construct( ul2_emb.to(self.model.dtype), metaclip_emb.to(self.model.dtype), byt5_emb.to(self.model.dtype), - ).to(ms.float32) + ).to(mstype.float32) v_t = x - (1 - self.eps) * noise # 3.1.2 Eqa (2) @@ -163,7 +155,7 @@ def add_noise(self, x: Tensor, noise: Tensor, timesteps: Tensor) -> Tensor: noise: (N, T, C, H, W) tensor of white noise timesteps: (N,) tensor of timestamps with range [0, num_timesteps) """ - timesteps = 1 - timesteps.to(ms.float32) / self.num_timesteps + timesteps = 1 - timesteps.to(mstype.float32) / self.num_timesteps timesteps = timesteps[:, None, None, None, None] # 3.1.2 First Eqa. @@ -175,11 +167,11 @@ def __init__(self, network: RFlowLossWrapper, num_sampling_steps: int = 10): super().__init__() self.network = network self.timesteps = Tensor( - np.linspace(0, network.num_timesteps, num_sampling_steps + 2)[1:-1].reshape(-1, 1), dtype=ms.float32 + np.linspace(0, network.num_timesteps, num_sampling_steps + 2)[1:-1].reshape(-1, 1), dtype=mstype.float32 ) def construct(self, x: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor, **kwargs) -> Tensor: - loss = Tensor(0, dtype=ms.float32) + loss = Tensor(0, dtype=mstype.float32) timesteps = mint.tile(self.timesteps, (1, x.shape[0])) for t in timesteps: loss += self.network(x, ul2_emb, metaclip_emb, byt5_emb, t) From 8ba45df45ab63ff9144140dfba620c79263e0a91 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:43:17 +0800 Subject: [PATCH 090/122] small fix --- examples/moviegen/mg/schedulers/rectified_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index a8a696b246..47a2337540 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -133,7 +133,7 @@ def construct( if timestep is None: timestep = self._broadcast(self._sample_func(x.shape[0])) - noise = self._broadcast(ops.rand_like(x.shape)) + noise = self._broadcast(ops.rand_like(x)) x_t = self.add_noise(x, noise, timestep) model_output = self.model( From 8a9abfa9e697be87b721bb1c127ca997e20b42dc Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:44:58 +0800 Subject: [PATCH 091/122] small fix --- examples/moviegen/mg/schedulers/rectified_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index 47a2337540..6ed5a776cb 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -133,7 +133,7 @@ def construct( if timestep is None: timestep = self._broadcast(self._sample_func(x.shape[0])) - noise = self._broadcast(ops.rand_like(x)) + noise = self._broadcast(ops.randn_like(x)) x_t = self.add_noise(x, noise, timestep) model_output = self.model( From 0392703cdb382477f308f9d56931db07b6f0d847 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:55:49 +0800 Subject: [PATCH 092/122] enable flexible recompute --- .../moviegen/configs/train/stage1_t2i_256x256.yaml | 2 +- .../moviegen/configs/train/stage2_t2iv_256x256.yaml | 2 +- examples/moviegen/mg/models/llama/network.py | 10 +++++----- examples/moviegen/mg/utils/model_utils.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml index 1c1e5b1bf9..cad6a1bc95 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256x256.yaml @@ -9,7 +9,7 @@ model: name: llama-5B pretrained_model_path: enable_flash_attention: True - recompute: True + recompute_every_nth_block: 1 dtype: bf16 tae: diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml index add1eb20a4..844a0f1493 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml @@ -9,7 +9,7 @@ model: name: llama-5B pretrained_model_path: enable_flash_attention: True - recompute: True + recompute_every_nth_block: 1 dtype: bf16 tae: diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index e067bad62a..dc6b85b0b1 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -162,7 +162,7 @@ def __init__( patch_size: Tuple[int, int, int] = (1, 2, 2), max_length: Tuple[int, int, int] = (128, 64, 64), attn_implementation: Literal["eager", "flash_attention"] = "eager", - gradient_checkpointing: bool = False, + recompute_every_nth_block: Optional[int] = None, use_linear_patch_embedder: bool = True, model_parallelism: bool = False, post_init_weight: bool = True, @@ -232,10 +232,10 @@ def __init__( self.initializer_range = initializer_range self.init_weights() - # recompute - if gradient_checkpointing: - for layer in self.layers: # Explicitly recompute each block for PyNative - layer.recompute() + if recompute_every_nth_block is not None: + for i, layer in enumerate(self.layers): + if i % recompute_every_nth_block == 0: + layer.recompute() @property def dtype(self): diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py index a6ca1572b6..8c39adad2e 100644 --- a/examples/moviegen/mg/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -63,7 +63,7 @@ def init_model( pretrained_model_path: Optional[Path_fr] = None, enable_flash_attention: bool = True, model_parallelism: bool = False, - recompute: bool = False, + recompute_every_nth_block: Optional[int] = None, dtype: Literal["fp32", "fp16", "bf16"] = "fp32", ) -> LlamaModel: attn_implementation = "flash_attention" if enable_flash_attention else "eager" @@ -71,7 +71,7 @@ def init_model( in_channels=in_channels, attn_implementation=attn_implementation, model_parallelism=model_parallelism, - gradient_checkpointing=recompute, + recompute_every_nth_block=recompute_every_nth_block, dtype=MODEL_DTYPE[dtype], ) if pretrained_model_path: From 3c33e66e14c2716a780cf2479328fccb7395418e Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:07:11 +0800 Subject: [PATCH 093/122] enable flexible recompute --- examples/moviegen/mg/models/llama/network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index dc6b85b0b1..e33af9214b 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import Literal, Optional, Tuple, Union import numpy as np @@ -24,6 +25,8 @@ __all__ = ["LlamaModel", "llama3_1B", "llama3_5B", "llama3_30B"] +_logger = logging.getLogger(__name__) + Llama_ATTENTION_CLASSES = { "eager": LlamaAttention, "flash_attention": LlamaFlashAttention, @@ -233,6 +236,7 @@ def __init__( self.init_weights() if recompute_every_nth_block is not None: + _logger.info(f"Recomputing every {recompute_every_nth_block} block.") for i, layer in enumerate(self.layers): if i % recompute_every_nth_block == 0: layer.recompute() From dc711f9bfd0f8c07cd1de8089dbf2eb02993684e Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:57:36 +0800 Subject: [PATCH 094/122] - add train resume feature - preserve image / video orientation in transformations --- examples/moviegen/inference.py | 2 +- examples/moviegen/mg/dataset/transforms.py | 25 ++++++++++- examples/moviegen/mg/utils/callbacks.py | 9 ++-- examples/moviegen/mg/utils/model_utils.py | 51 +++++++++++++++++----- examples/moviegen/train.py | 18 +++++--- mindone/trainers/zero.py | 2 +- 6 files changed, 82 insertions(+), 25 deletions(-) diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 5f331b611e..472fef157f 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -169,7 +169,7 @@ def main(args): help="Path to load a config yaml file that describes the setting which will override the default arguments.", ) parser.add_function_arguments(init_train_env, "env") - parser.add_function_arguments(init_model, "model", skip={"in_channels"}) + parser.add_function_arguments(init_model, "model", skip={"in_channels", "resume"}) tae_group = parser.add_argument_group("TAE parameters") tae_group.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) tae_group.add_argument( diff --git a/examples/moviegen/mg/dataset/transforms.py b/examples/moviegen/mg/dataset/transforms.py index d69a01f2f2..27ab33255a 100644 --- a/examples/moviegen/mg/dataset/transforms.py +++ b/examples/moviegen/mg/dataset/transforms.py @@ -5,20 +5,41 @@ class ResizeCrop: - def __init__(self, size: Optional[Tuple[int, int]] = None, interpolation=cv2.INTER_LINEAR): + """ + Resize and center crop the input image or video to a target size while preserving the aspect ratio. + + Args: + size (Optional[Tuple[int, int]], optional): The target size. If None, the target size should be passed during the call. + interpolation (cv2.InterpolationFlags, optional): The interpolation method. Defaults to cv2.INTER_LINEAR. + preserve_orientation (bool, optional): Whether to preserve the orientation of the image/video. Defaults to True. + """ + + def __init__( + self, + size: Optional[Tuple[int, int]] = None, + interpolation: int = cv2.INTER_LINEAR, + preserve_orientation: bool = True, + ): self._size = size self._inter = interpolation + self._po = preserve_orientation def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np.ndarray: h, w = x.shape[-3:-1] # support images and videos th, tw = size or self._size + scale = max(th / h, tw / w) + if self._po: # preserve orientation + scale = min(scale, max(th / w, tw / h)) + if scale != 1: # resize if x.ndim == 3: # if image x = cv2.resize(x, None, fx=scale, fy=scale, interpolation=self._inter) else: # if video x = np.array([cv2.resize(i, None, fx=scale, fy=scale, interpolation=self._inter) for i in x]) - if x.shape[-3:-1] != (th, tw): # crop + + if x.shape[-3:-1] != (th, tw): # center crop i, j = round((x.shape[-3] - th) / 2.0), round((x.shape[-2] - tw) / 2.0) x = x[..., i : i + th, j : j + tw, :] + return x diff --git a/examples/moviegen/mg/utils/callbacks.py b/examples/moviegen/mg/utils/callbacks.py index 07cb1934e7..66f5ff6ca7 100644 --- a/examples/moviegen/mg/utils/callbacks.py +++ b/examples/moviegen/mg/utils/callbacks.py @@ -101,7 +101,6 @@ def __init__( file_name: str = "result.log", metric_names: List[str] = None, separator: str = "\t", - resume: bool = False, ): super().__init__() self._sep = separator @@ -110,10 +109,10 @@ def __init__( if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) self._log_file = os.path.join(save_dir, file_name) - if not resume: - header = separator.join([f"{'step':<7}", f"{'loss':<10}", "train_time(s)"] + self._metrics) - with open(self._log_file, "w", encoding="utf-8") as fp: - fp.write(header + "\n") + + header = separator.join([f"{'step':<7}", f"{'loss':<10}", "train_time(s)"] + self._metrics) + with open(self._log_file, "w", encoding="utf-8") as fp: + fp.write(header + "\n") def on_train_step_begin(self, run_context: RunContext): self._step_time = time.perf_counter() diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py index 8c39adad2e..86daab6ad9 100644 --- a/examples/moviegen/mg/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional, Tuple, Union from jsonargparse.typing import Path_fr from mg import LlamaModel, llama3_1B, llama3_5B, llama3_30B @@ -7,7 +7,10 @@ import mindspore as ms from mindspore import _no_grad, jit_class, nn -__all__ = ["MODEL_DTYPE", "load_ckpt_params", "no_grad", "init_model"] +from mindone.trainers.train_step import TrainOneStepWrapper +from mindone.utils.params import load_param_into_net_with_filter + +__all__ = ["MODEL_DTYPE", "no_grad", "init_model", "resume_train_net"] logger = logging.getLogger(__name__) @@ -20,7 +23,7 @@ } -def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: +def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> None: if isinstance(ckpt, str): logger.info(f"Loading {ckpt} params into network...") param_dict = ms.load_checkpoint(ckpt) @@ -29,13 +32,11 @@ def load_ckpt_params(model: nn.Cell, ckpt: Union[str, Dict]) -> nn.Cell: param_dict = ckpt param_not_load, ckpt_not_load = ms.load_param_into_net(model, param_dict) - if not (len(param_not_load) == len(ckpt_not_load) == 0): + if param_not_load or ckpt_not_load: logger.warning( - "Exist ckpt params not loaded: {} (total: {}), or net params not loaded: {} (total: {})".format( - ckpt_not_load, len(ckpt_not_load), param_not_load, len(param_not_load) - ) + f"Exist ckpt params not loaded: {ckpt_not_load} (total: {len(ckpt_not_load)}),\n" + f"or net params not loaded: {param_not_load} (total: {len(param_not_load)})" ) - return model @jit_class @@ -61,6 +62,7 @@ def init_model( name: Literal["llama-1B", "llama-5B", "llama-30B"], in_channels: int = 16, pretrained_model_path: Optional[Path_fr] = None, + resume: bool = False, enable_flash_attention: bool = True, model_parallelism: bool = False, recompute_every_nth_block: Optional[int] = None, @@ -74,8 +76,37 @@ def init_model( recompute_every_nth_block=recompute_every_nth_block, dtype=MODEL_DTYPE[dtype], ) - if pretrained_model_path: - model = load_ckpt_params(model, pretrained_model_path.absolute) + + if resume: + logger.info("Resume training checkpoint provided, skipping weight loading.") + elif pretrained_model_path: + load_ckpt_params(model, pretrained_model_path.absolute) else: logger.info(f"Initialize {name} model randomly.") return model + + +def resume_train_net( + train_net: TrainOneStepWrapper, resume_ckpt: Optional[Path_fr] = None +) -> Tuple[Union[int, None], Union[int, None]]: + if resume_ckpt is None: + return None, None + + state_dict = ms.load_checkpoint(resume_ckpt) + if "epoch_num" not in state_dict or "cur_step" not in state_dict or "loss_scale" not in state_dict: + raise ValueError("Resume training checkpoint is invalid. Please check the checkpoint file.") + + start_epoch = state_dict.pop("epoch_num").item() + global_step = state_dict.pop("cur_step").item() + logger.info(f"Resuming training of network from {resume_ckpt} at global step {global_step}") + + # FIXME: `EvalSaveCallback` renames `scale_sense` to `loss_scale` when saving the resume checkpoint + train_net.scale_sense = ms.Parameter(state_dict.pop("loss_scale"), name="scale_sense") + param_not_load, ckpt_not_load = load_param_into_net_with_filter(train_net, state_dict, filter=state_dict.keys()) + if param_not_load or ckpt_not_load: + logger.warning( + f"Exist ckpt params not loaded: {ckpt_not_load} (total: {len(ckpt_not_load)}),\n" + f"or net params not loaded: {param_not_load} (total: {len(param_not_load)})" + ) + + return start_epoch, global_step diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 9476b9c59e..1f79f2a373 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -20,7 +20,7 @@ from mg.models.tae import TemporalAutoencoder from mg.pipelines import DiffusionWithLoss from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper -from mg.utils import EMA, MODEL_DTYPE, init_model +from mg.utils import EMA, MODEL_DTYPE, init_model, resume_train_net from mg.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback from mindone.data import create_dataloader @@ -116,7 +116,7 @@ def main(args): # 2.2 Llama 3 logger.info("Transformer init") - network = init_model(**args.model) + network = init_model(resume=args.train.resume_ckpt is not None, **args.model) # 2.3 LossWrapper rflow_loss_wrapper = RFlowLossWrapper(network) @@ -152,6 +152,10 @@ def main(args): latent_diffusion_with_loss, optimizer=optimizer, scale_sense=loss_scaler, ema=ema, **args.train.settings ) + start_epoch, global_step = 0, 0 + if args.train.resume_ckpt is not None: + start_epoch, global_step = resume_train_net(net_with_grads, resume_ckpt=os.path.abspath(args.train.resume_ckpt)) + # TODO: validation graph? # if bucketing is used in Graph mode, activate dynamic inputs if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict): @@ -192,8 +196,9 @@ def main(args): ema=ema, step_mode=True, use_step_unit=True, - train_steps=args.train.steps, + start_epoch=start_epoch, resume_prefix_blacklist=("tae.", "swap."), + train_steps=args.train.steps, **args.train.save, ), PerfRecorderCallback( @@ -202,7 +207,7 @@ def main(args): ] ) - callbacks.append(StopAtStepCallback(train_steps=args.train.steps)) + callbacks.append(StopAtStepCallback(train_steps=args.train.steps, global_step=global_step)) # 5.5 print out key info and save config if rank_id == 0: @@ -246,7 +251,7 @@ def main(args): # 6. train logger.info("Start training...") # train() uses epochs, so the training will be terminated by the StopAtStepCallback - model.train(args.train.steps, dataloader, callbacks=callbacks) + model.train(args.train.steps, dataloader, callbacks=callbacks, initial_epoch=start_epoch) if __name__ == "__main__": @@ -258,7 +263,7 @@ def main(args): help="Path to load a config yaml file that describes the setting which will override the default arguments.", ) parser.add_function_arguments(init_train_env, "env") - parser.add_function_arguments(init_model, "model") + parser.add_function_arguments(init_model, "model", skip={"resume"}) parser.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) parser.add_argument( "--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="TAE model precision." @@ -290,6 +295,7 @@ def main(args): prepare_train_network, "train.settings", skip={"network", "optimizer", "scale_sense", "ema"} ) parser.add_subclass_arguments(EMA, "train.ema", skip={"network"}, required=False, instantiate=False) + parser.add_function_arguments(resume_train_net, "train", skip={"train_net"}) parser.add_argument( "--train.output_path", default="output/", diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 16e328499d..42bbc4e326 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -561,7 +561,7 @@ def prepare_train_network( dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, -): +) -> TrainOneStepWrapper: """ Prepare network and optimizer for distributed training. From 0776fdcc7511e7a3834b954f791e7e2318033284 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 19 Dec 2024 15:17:20 +0800 Subject: [PATCH 095/122] ResizeCrop fix --- examples/moviegen/mg/__init__.py | 1 - examples/moviegen/mg/dataset/transforms.py | 5 +++-- examples/moviegen/mg/utils/model_utils.py | 2 +- examples/moviegen/tests/ut/test_transforms.py | 16 ++++++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 examples/moviegen/tests/ut/test_transforms.py diff --git a/examples/moviegen/mg/__init__.py b/examples/moviegen/mg/__init__.py index aed4fa323c..e69de29bb2 100644 --- a/examples/moviegen/mg/__init__.py +++ b/examples/moviegen/mg/__init__.py @@ -1 +0,0 @@ -from .models import * diff --git a/examples/moviegen/mg/dataset/transforms.py b/examples/moviegen/mg/dataset/transforms.py index 27ab33255a..27a6263c81 100644 --- a/examples/moviegen/mg/dataset/transforms.py +++ b/examples/moviegen/mg/dataset/transforms.py @@ -29,8 +29,9 @@ def __call__(self, x: np.ndarray, size: Optional[Tuple[int, int]] = None) -> np. th, tw = size or self._size scale = max(th / h, tw / w) - if self._po: # preserve orientation - scale = min(scale, max(th / w, tw / h)) + if self._po and (new_scale := max(th / w, tw / h)) < scale: # preserve orientation + scale = new_scale + th, tw = tw, th if scale != 1: # resize if x.ndim == 3: # if image diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py index 86daab6ad9..767fe3104a 100644 --- a/examples/moviegen/mg/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -2,7 +2,7 @@ from typing import Dict, Literal, Optional, Tuple, Union from jsonargparse.typing import Path_fr -from mg import LlamaModel, llama3_1B, llama3_5B, llama3_30B +from mg.models import LlamaModel, llama3_1B, llama3_5B, llama3_30B import mindspore as ms from mindspore import _no_grad, jit_class, nn diff --git a/examples/moviegen/tests/ut/test_transforms.py b/examples/moviegen/tests/ut/test_transforms.py new file mode 100644 index 0000000000..d988d97bfd --- /dev/null +++ b/examples/moviegen/tests/ut/test_transforms.py @@ -0,0 +1,16 @@ +import numpy as np +from mg.dataset.transforms import ResizeCrop + + +def test_horizontal_image_crop(): + image = np.random.randint(0, 256, (150, 250, 3), dtype=np.uint8) + rc = ResizeCrop((100, 200)) + image = rc(image) + assert image.shape == (100, 200, 3) + + +def test_vertical_image_crop(): + image = np.random.randint(0, 256, (250, 150, 3), dtype=np.uint8) + rc = ResizeCrop((100, 200)) + image = rc(image) + assert image.shape == (200, 100, 3) From 333faa33d30f82aec1c7358419cb2072e05ad469 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:32:33 +0800 Subject: [PATCH 096/122] update docs --- examples/moviegen/README.md | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 76d9b0f57a..13bb66d73e 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -182,12 +182,18 @@ python inference_tae_enc.py \ ### Performance -| Model | Context | Jit level | Stage | Precision | Resolution | TAE Cache | Batch size | NPUs | Time (s/step) | Config | -|:-----:|:-----------------:|:---------:|:---------:|:---------:|:----------------------------:|:---------:|:-----------------------:|:----:|:-------------:|:------------------------------------------------------------------:| -| 5B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | No | 20 | 4 | 4.47 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 5B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | No | Image: 10
Video: 5 | 8 | 5.26 | [stage1_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | -| 1B | D910*-C18-MS2.3.1 | O1 | 1 (T2I) | BF16 | 256x455 (16:9) | Yes | 10 | 8 | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 1B | D910*-C18-MS2.3.1 | O0 | 2 (T2I/V) | BF16 | 256x455 (16:9)
32 frames | Yes | Image: 10
Video: 10 | 8 | 2.08 | [stage1_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in Graph mode. + +> [!NOTE] +> We trained all the models using BF16 precision. + +| Model | NPUs | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | TAE Cache | Time (s/step) | Config | +|:-----:|:----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:---------:|:-------------:|:------------------------------------------------------------------:| +| 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | Yes | 1.29 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | O1 | 6m | ON
(Every 2 blocks) | 5 | Yes | 5.09 | [stage2_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | O1 | 7m 30s | ON | 5 | Yes | 88.5 | [stage3_t2iv_768px.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| 1B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 2m 15s | ON | 1 | Yes | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | +| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | O0 | 1m 55s | ON | 1 | Yes | 2.07 | [stage2_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | ### Validation During Training From d237f8f7e8e4a12a2de2d652f7fe5b6e9d42957a Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 19 Dec 2024 18:02:48 +0800 Subject: [PATCH 097/122] support SP and change rms to ops.rms --- examples/moviegen/inference.py | 9 ++ examples/moviegen/mg/acceleration/__init__.py | 2 + .../mg/acceleration/communications.py | 71 +++++++++++ .../mg/acceleration/parallel_states.py | 37 ++++++ examples/moviegen/mg/models/llama/block.py | 33 ++++-- examples/moviegen/mg/models/llama/network.py | 78 ++++++------- .../moviegen/mg/schedulers/rectified_flow.py | 13 ++- examples/moviegen/mg/utils/model_utils.py | 2 - .../run_test_llama_sequence_parallel.sh | 18 +++ .../parallel/test_llama_sequence_parallel.py | 110 ++++++++++++++++++ examples/moviegen/train.py | 11 +- 11 files changed, 328 insertions(+), 56 deletions(-) create mode 100644 examples/moviegen/mg/acceleration/__init__.py create mode 100644 examples/moviegen/mg/acceleration/communications.py create mode 100644 examples/moviegen/mg/acceleration/parallel_states.py create mode 100755 examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh create mode 100644 examples/moviegen/tests/parallel/test_llama_sequence_parallel.py diff --git a/examples/moviegen/inference.py b/examples/moviegen/inference.py index 472fef157f..543d96e5e8 100644 --- a/examples/moviegen/inference.py +++ b/examples/moviegen/inference.py @@ -9,9 +9,11 @@ import numpy as np from jsonargparse import ActionConfigFile, ArgumentParser from jsonargparse.typing import path_type +from mg.acceleration import create_parallel_group import mindspore as ms from mindspore import amp, nn +from mindspore.communication import GlobalComm # TODO: remove in future when mindone is ready for install __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -62,6 +64,9 @@ def main(args): # 1. init env _, rank_id, device_num = init_train_env(**args.env) # TODO: rename as train and infer are identical? + if args.enable_sequence_paralell: + create_parallel_group(GlobalComm.WORLD_COMM_GROUP) + # 1.1 read caption embeddings ul2_emb, metaclip_emb, byt5_emb = prepare_captions(**args.text_emb, rank_id=rank_id, device_num=device_num) @@ -144,6 +149,9 @@ def main(args): f" sampling speed: {args.num_sampling_steps * (end_i - i) / batch_time:.2f} step/s" ) + if args.enable_sequence_paralell and rank_id > 1: + continue + # save result for j in range(0, end_i - i): fn = prompt_prefix[i + j] @@ -182,6 +190,7 @@ def main(args): infer_group.add_argument("--fps", type=int, default=16, help="FPS in the saved video") infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num"}) infer_group.add_argument("--batch_size", type=int, default=1) + infer_group.add_argument("--enable_sequence_paralell", type=bool, default=False, help="enable sequence parallel.") save_group = parser.add_argument_group("Saving options") save_group.add_argument( "--save_format", diff --git a/examples/moviegen/mg/acceleration/__init__.py b/examples/moviegen/mg/acceleration/__init__.py new file mode 100644 index 0000000000..51f2a4fd22 --- /dev/null +++ b/examples/moviegen/mg/acceleration/__init__.py @@ -0,0 +1,2 @@ +from .communications import * +from .parallel_states import * diff --git a/examples/moviegen/mg/acceleration/communications.py b/examples/moviegen/mg/acceleration/communications.py new file mode 100644 index 0000000000..127a8c3b5e --- /dev/null +++ b/examples/moviegen/mg/acceleration/communications.py @@ -0,0 +1,71 @@ +from typing import Callable, Literal, Tuple + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import GlobalComm, get_group_size, get_rank + +__all__ = ["SplitFowardGatherBackward", "GatherFowardSplitBackward"] + + +def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: + dim_size = x.shape[dim] + tensor_list = x.split(dim_size // world_size, axis=dim) + x = tensor_list[rank] + return x + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +class SplitFowardGatherBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _split(x, self.dim, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class GatherFowardSplitBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _split(dout, self.dim, self.rank, self.world_size) + return (dout,) diff --git a/examples/moviegen/mg/acceleration/parallel_states.py b/examples/moviegen/mg/acceleration/parallel_states.py new file mode 100644 index 0000000000..2fedbd47e2 --- /dev/null +++ b/examples/moviegen/mg/acceleration/parallel_states.py @@ -0,0 +1,37 @@ +from typing import Optional + +from mindspore.communication import create_group, get_group_size, get_rank + +__all__ = ["set_sequence_parallel_group", "get_sequence_parallel_group", "create_parallel_group"] + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_sequence_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["sequence"] = group + + +def get_sequence_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) + + +def create_parallel_group(sequence_parallel_shards: int) -> None: + if sequence_parallel_shards <= 1: + raise ValueError( + f"`sequence_parallel_shards` must be larger than 1 to enable sequence parallel, but get `{sequence_parallel_shards}`." + ) + + device_num = get_group_size() + if device_num % sequence_parallel_shards != 0: + raise ValueError( + f"Total number of devices `{device_num}` must be divisible by the number of sequence parallel shards `{sequence_parallel_shards}`." + ) + + rank_id = get_rank() + sp_group_id = rank_id // sequence_parallel_shards + sp_group_rank_ids = list( + range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) + ) + sp_group_name = f"sp_group_{sp_group_id}" + create_group(sp_group_name, sp_group_rank_ids) + set_sequence_parallel_group(sp_group_name) diff --git a/examples/moviegen/mg/models/llama/block.py b/examples/moviegen/mg/models/llama/block.py index b9645da307..1796380dee 100644 --- a/examples/moviegen/mg/models/llama/block.py +++ b/examples/moviegen/mg/models/llama/block.py @@ -1,4 +1,3 @@ -import logging from typing import Optional, Sequence, Tuple, Union import numpy as np @@ -9,12 +8,12 @@ import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor +from mindspore.communication import get_group_size from mindspore.ops.operations.nn_ops import FlashAttentionScore +from ...acceleration import get_sequence_parallel_group from .activation import ACT2FN -logger = logging.getLogger(__name__) - class LlamaRMSNorm(nn.Cell): def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: @@ -24,11 +23,10 @@ def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6, dt def construct(self, hidden_states: Tensor) -> Tensor: input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(ms.float32) - variance = mint.pow(hidden_states, 2) - variance = mint.mean(variance, dim=-1, keepdim=True) - hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + hidden_states, _ = ops.rms_norm( + hidden_states.to(ms.float32), self.weight.to(ms.float32), epsilon=self.variance_epsilon + ) + return hidden_states.to(input_dtype) class LlamaMLP(nn.Cell): @@ -94,6 +92,14 @@ def __init__( ) self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) + sp_group = get_sequence_parallel_group() + if sp_group is not None: + self.sp_group_size = get_group_size(sp_group) + self.alltoall = ops.AlltoAll(self.sp_group_size, 1, 2, group=sp_group) + else: + self.sp_group_size = None + self.alltoall = mint.nn.Identity() + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: bsz, q_len, _ = hidden_states.shape @@ -105,12 +111,15 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) query_states = mint.permute(query_states, (0, 2, 1, 3)) + query_states = self.alltoall(query_states) key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) key_states = mint.permute(key_states, (0, 2, 1, 3)) + key_states = self.alltoall(key_states) value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) value_states = mint.permute(value_states, (0, 2, 1, 3)) + value_states = self.alltoall(value_states) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -125,6 +134,7 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso attn_output = mint.matmul(attn_weights, value_states) attn_output = mint.permute(attn_output, (0, 2, 1, 3)) + attn_output = self.alltoall(attn_output) attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) attn_output = self.o_proj(attn_output) @@ -149,8 +159,9 @@ def __init__( attention_bias=attention_bias, dtype=dtype, ) + num_heads = self.num_heads // self.sp_group_size if self.sp_group_size is not None else self.num_heads self.flash_attention = FlashAttentionScore( - self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" ) def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: @@ -164,12 +175,15 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) query_states = mint.permute(query_states, (0, 2, 1, 3)) + query_states = self.alltoall(query_states) key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) key_states = mint.permute(key_states, (0, 2, 1, 3)) + key_states = self.alltoall(key_states) value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) value_states = mint.permute(value_states, (0, 2, 1, 3)) + value_states = self.alltoall(value_states) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -180,6 +194,7 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso value_states = mint.permute(value_states, (0, 2, 1, 3)) _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) + attn_output = self.alltoall(attn_output) attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) attn_output = self.o_proj(attn_output) diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index e33af9214b..78ce27b716 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -7,10 +7,11 @@ from mindspore import Parameter, Tensor from mindspore import dtype as mstype -from mindspore import lazy_inline, load_checkpoint, load_param_into_net, mint, nn, ops +from mindspore import lazy_inline, load_checkpoint, mint, nn, ops from mindone.models.utils import normal_, zeros_ +from ...acceleration import GatherFowardSplitBackward, SplitFowardGatherBackward, get_sequence_parallel_group from ..text_encoders import TextProjector from .activation import ACT2FN from .block import ( @@ -38,7 +39,7 @@ def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: class LlamaDecoderLayer(nn.Cell): - @lazy_inline(policy="front") + @lazy_inline def __init__( self, hidden_size: int = 4096, @@ -133,8 +134,8 @@ def __init__( ) -> None: super().__init__() self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) - self.proj = nn.Dense( - hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype + self.proj = mint.nn.Linear( + hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False, dtype=dtype ) self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size) / hidden_size**0.5, dtype=dtype)) @@ -167,7 +168,6 @@ def __init__( attn_implementation: Literal["eager", "flash_attention"] = "eager", recompute_every_nth_block: Optional[int] = None, use_linear_patch_embedder: bool = True, - model_parallelism: bool = False, post_init_weight: bool = True, dtype: mstype.Type = mstype.float32, ) -> None: @@ -180,29 +180,25 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.rms_norm_eps = rms_norm_eps self.max_length = max_length - self.model_parallelism = model_parallelism self._dtype = dtype - if self.model_parallelism: - raise NotImplementedError("Model parallelism is not supported yet.") - else: - self.layers = nn.CellList( - [ - LlamaDecoderLayer( - hidden_size=self.hidden_size, - intermediate_size=intermediate_size, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - rms_norm_eps=rms_norm_eps, - attention_dropout=attention_dropout, - attention_bias=attention_bias, - hidden_act=hidden_act, - attn_implementation=attn_implementation, - dtype=dtype, - ) - for _ in range(num_hidden_layers) - ] - ) + self.layers = nn.CellList( + [ + LlamaDecoderLayer( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) self.final_layer = LlamaFinalLayer( hidden_size=self.hidden_size, @@ -230,6 +226,16 @@ def __init__( out_features=self.hidden_size, layer_norm=LlamaRMSNorm, norm_eps=self.rms_norm_eps, dtype=dtype ) + # init sequence parallel + sp_group = get_sequence_parallel_group() + if sp_group is not None: + _logger.info(f"Initialize Llama model with sequence parallel group `{sp_group}`.") + self.split_forward_gather_backward = SplitFowardGatherBackward(dim=1, grad_scale="down", group=sp_group) + self.gather_forward_split_backward = GatherFowardSplitBackward(dim=1, grad_scale="up", group=sp_group) + else: + self.split_forward_gather_backward = mint.nn.Identity() + self.gather_forward_split_backward = mint.nn.Identity() + # post-init if post_init_weight: self.initializer_range = initializer_range @@ -340,12 +346,18 @@ def construct( # 3.1.4 text embedding text_embedding = self.text_projector(ul2_emb, metaclip_emb, byt5_emb) + # sequence parallel start + latent_embedding = self.split_forward_gather_backward(latent_embedding) + position_embedding = self.split_forward_gather_backward(position_embedding) + # main blocks hidden_states = latent_embedding - for decoder_layer in self.layers: hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding) + # sequence parallel end + hidden_states = self.gather_forward_split_backward(hidden_states) + # final block hidden_states = self.final_layer(hidden_states, timestep_embedding) @@ -372,18 +384,6 @@ def construct_with_cfg( model_out = mint.tile(model_out, (2, 1, 1, 1, 1)) return model_out - def load_weight_from_non_parallel_cell(self, target: LlamaModel): - param_dict = target.parameters_dict() - - # filter tensor-parallel block - names = ["gate_proj", "up_proj", "down_proj"] - param_dict = {k: v for k, v in param_dict.items() if not any([name in k for name in names])} - load_param_into_net(self, param_dict) - - # load tensor-parallel block - for layer, target_layer in zip(self.layers, target.layers): - layer.mlp.load_weight_from_non_parallel_cell(target_layer.mlp) - def llama3_1B(from_pretrained=None, **kwargs): model = LlamaModel( diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index 6ed5a776cb..e2be0009d1 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -8,7 +8,9 @@ from mindspore import Tensor from mindspore import dtype as mstype from mindspore import mint, nn, ops +from mindspore.communication import get_rank +from ..acceleration import get_sequence_parallel_group from ..models import LlamaModel logger = logging.getLogger(__name__) @@ -101,7 +103,14 @@ def __init__( self.model = model self.criteria = nn.MSELoss() - self.mp_group = None + self.sp_group = get_sequence_parallel_group() + if self.sp_group is not None: + logging.info( + f"Broadcasting all random variables from rank (0) to current rank ({get_rank(self.sp_group)}) in group `{self.sp_group}`." + ) + self.broadcast = ops.Broadcast(0, group=self.sp_group) + else: + self.broadcast = None def _discrete_sample(self, size: int) -> Tensor: return ops.randint(0, self.num_timesteps, (size,), dtype=mstype.int64) @@ -113,7 +122,7 @@ def _logit_normal_sample(self, size: int) -> Tensor: return self.distribution((size,)) * self.num_timesteps def _broadcast(self, x: Tensor) -> Tensor: - if self.mp_group is None: + if self.sp_group is None: return x return self.broadcast((x,))[0] diff --git a/examples/moviegen/mg/utils/model_utils.py b/examples/moviegen/mg/utils/model_utils.py index 86daab6ad9..c55af37a71 100644 --- a/examples/moviegen/mg/utils/model_utils.py +++ b/examples/moviegen/mg/utils/model_utils.py @@ -64,7 +64,6 @@ def init_model( pretrained_model_path: Optional[Path_fr] = None, resume: bool = False, enable_flash_attention: bool = True, - model_parallelism: bool = False, recompute_every_nth_block: Optional[int] = None, dtype: Literal["fp32", "fp16", "bf16"] = "fp32", ) -> LlamaModel: @@ -72,7 +71,6 @@ def init_model( model = MODEL_SPEC[name]( in_channels=in_channels, attn_implementation=attn_implementation, - model_parallelism=model_parallelism, recompute_every_nth_block=recompute_every_nth_block, dtype=MODEL_DTYPE[dtype], ) diff --git a/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh b/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh new file mode 100755 index 0000000000..e86f871cd8 --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama_sequence_parallel.sh @@ -0,0 +1,18 @@ +#!/bin/sh +set -e + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +echo "******** Graph Mode ********" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir="./log_test_sp_graph" --join True ${SCRIPT_DIR}/test_llama_sequence_parallel.py --mode 0 +echo "Done. Check the log at './log_test_sp_graph'." +echo "=========================================================================" + +echo "******** Pynative Mode ********" +msrun --master_port=1235 --worker_num=2 --local_worker_num=2 --log_dir="./log_test_sp_pynative" --join True ${SCRIPT_DIR}/test_llama_sequence_parallel.py --mode 1 +echo "Done. Check the log at './log_test_sp_pynative'." diff --git a/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py b/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py new file mode 100644 index 0000000000..afbb963779 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama_sequence_parallel.py @@ -0,0 +1,110 @@ +import argparse +from typing import Tuple + +import numpy as np +from mg.acceleration import create_parallel_group, get_sequence_parallel_group +from mg.models.llama.network import LlamaModel + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: + latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) + timestep = ms.Tensor([35], dtype=ms.int64) + ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) + metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) + byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) + return latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb + + +def get_network_config(): + config = dict(num_hidden_layers=1, attn_implementation="eager", post_init_weight=False) + return config + + +def run_network(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + ms.set_seed(1024) + data = get_sample_data(dtype=dtype) + + run_parallel_network(data, dtype=dtype) + + +def run_parallel_network(data: Tuple[Tensor, ...], dtype: ms.Type = ms.float32): + # non parallel network + ms.set_seed(1024) + non_parallel_network_cfg = get_network_config() + non_parallel_network = LlamaModel(**non_parallel_network_cfg, dtype=dtype) + + # parallel netowrk + ms.set_seed(1024) + create_parallel_group(sequence_parallel_shards=get_group_size()) + parallel_network_cfg = get_network_config() + parallel_network = LlamaModel(**parallel_network_cfg, dtype=dtype) + + # load weight + for (_, w0), (_, w1) in zip(non_parallel_network.parameters_and_names(), parallel_network.parameters_and_names()): + w1.set_data(w0) # FIXME: seed does not work + np.testing.assert_allclose(w0.value().asnumpy(), w1.value().asnumpy()) + + # test forward + non_parallel_out = non_parallel_network(*data).asnumpy() + parallel_out = parallel_network(*data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.", flush=True) + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_network) + parallel_mean_net = MeanNet(parallel_network) + + # check the parameter gradient + grad_fn = ms.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(*data) + + grad_fn = ms.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(*data) + + # take mean around different ranks + sp_group = get_sequence_parallel_group() + reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=sp_group) + num = get_group_size() + syn_parallel_grads = list() + for x in parallel_grads: + syn_parallel_grads.append(reduce(x) / num) + + pass_grads = [] + for grad_0, grad_1 in zip(non_parallel_grads, syn_parallel_grads): + is_passed = np.allclose(grad_0.asnumpy(), grad_1.asnumpy(), rtol=1.3e-6, atol=1e-5) + pass_grads.append(is_passed) + assert all(pass_grads), f"Pass rate ({sum(pass_grads)/len(pass_grads) * 100:.3f} %) is not 100 %" + + print("Test 2 (Backward: Parameter Gradient): Passed.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_network(mode=args.mode) diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 1f79f2a373..365045ee2c 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -16,6 +16,7 @@ mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) sys.path.append(mindone_lib_path) +from mg.acceleration import create_parallel_group from mg.dataset import ImageVideoDataset, bucket_split_function from mg.models.tae import TemporalAutoencoder from mg.pipelines import DiffusionWithLoss @@ -41,6 +42,7 @@ def initialize_dataset( ) dataloader_args = dataloader_args.as_dict() batch_size = dataloader_args.pop("batch_size") + logger.info(f"Initialize the dataloader with shard id `{shard_rank_id}` with total total shards `{device_num}`.") dataloader = create_dataloader( dataset, batch_size=batch_size if isinstance(batch_size, int) else 0, # Turn off batching if using buckets @@ -74,10 +76,10 @@ def main(args): # 1.1 init model parallel shard_rank_id = rank_id - # if (shards := args.train.model_parallel.model_parallel_shards) > 1: - # create_parallel_group(**args.train.model_parallel) - # device_num = device_num // shards - # shard_rank_id = rank_id // shards + if (shards := args.train.parallel.sequence_parallel_shards) > 1: + create_parallel_group(**args.train.parallel) + device_num = device_num // shards + shard_rank_id = rank_id // shards # FIXME: Improve seed setting set_seed(args.env.seed + shard_rank_id) # set different seeds per NPU for sampling different timesteps @@ -280,6 +282,7 @@ def main(args): "--dataloader.batch_size", default=1, type=Union[int, Dict[str, int]], help="Number of samples per batch" ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") + parser.add_function_arguments(create_parallel_group, "train.parallel") parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"}) parser.add_class_arguments( ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False From 302df959f7af0404bce22d86f070090d388cf8f3 Mon Sep 17 00:00:00 2001 From: Nguyen Truong Hai <47595486+itruonghai@users.noreply.github.com> Date: Fri, 20 Dec 2024 09:52:42 +0800 Subject: [PATCH 098/122] Gradio demo for MovieGen (#6) --- examples/moviegen/README.md | 25 ++++ examples/moviegen/gradio_demo.py | 228 +++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 examples/moviegen/gradio_demo.py diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 13bb66d73e..4047d8ef6a 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -128,6 +128,31 @@ python inference.py \ --save_format mp4 ``` +### Gradio Demo +To launch the web demo, follow these steps: + +1. Install Gradio: +```bash +pip install gradio +``` + +2. Run the demo script with the following configuration. The demo provides 80 pre-computed text prompts to choose from: + +```shell +python gradio_demo.py \ +--config configs/inference/moviegen_t2i_256x256.yaml \ +--model.name llama-5B \ +--model.pretrained_model_path /path/to/llama-5B.ckpt \ +--text_emb.ul2_dir /path/to/ul2-embedding.ckpt \ +--text_emb.metaclip_dir /path/to/metaclip-embedding.ckpt \ +--text_emb.byt5_dir /path/to/byt5-embedding.ckpt \ +--image_size 256 455 +--num_frames 32 +--save_format mp4 +``` +Note: Make sure to replace the `/path/to/` placeholders with your actual model and embedding paths. + + ## Training Movie Gen is trained jointly on images and videos in 4 stages: diff --git a/examples/moviegen/gradio_demo.py b/examples/moviegen/gradio_demo.py new file mode 100644 index 0000000000..99ad02db43 --- /dev/null +++ b/examples/moviegen/gradio_demo.py @@ -0,0 +1,228 @@ +import datetime +import glob +import logging +import os +import time +from typing import List, Tuple + +import gradio as gr +import numpy as np +from jsonargparse import ActionConfigFile, ArgumentParser +from jsonargparse.typing import path_type +from mg.models.tae import TemporalAutoencoder +from mg.pipelines import InferPipeline +from mg.utils import MODEL_DTYPE, init_model, to_numpy + +import mindspore as ms +from mindspore import amp, nn + +from mindone.utils import init_train_env, set_logger +from mindone.visualize import save_videos + +logger = logging.getLogger(__name__) + + +def prepare_captions( + ul2_dir: str, metaclip_dir: str, byt5_dir: str, rank_id: int = 0, device_num: int = 1 +) -> Tuple[List[str], List[str], List[str]]: + """Prepare caption embeddings from specified directories""" + ul2_emb = sorted(glob.glob(os.path.join(ul2_dir, "*.npz"))) + metaclip_emb = sorted(glob.glob(os.path.join(metaclip_dir, "*.npz"))) + byt5_emb = sorted(glob.glob(os.path.join(byt5_dir, "*.npz"))) + + if len(ul2_emb) != len(byt5_emb): + raise ValueError( + f"ul2_dir ({len(ul2_emb)}), metaclip_dir ({len(metaclip_emb)}), " + f" and byt5_dir ({len(byt5_emb)}) must contain the same number of files" + ) + + ul2_emb = ul2_emb[rank_id::device_num] + logger.info(f"Number of captions for rank {rank_id}: {len(ul2_emb)}") + return ul2_emb, metaclip_emb[rank_id::device_num], byt5_emb[rank_id::device_num] + + +def load_embeddings(selected_prompts: List[str], args) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]: + """Load embeddings for selected prompts matching original implementation""" + # Get full paths for selected prompts + # print(selected_prompts) + ul2_files = os.path.join(args.text_emb.ul2_dir, f"{selected_prompts}.npz") + byt5_files = os.path.join(args.text_emb.byt5_dir, f"{selected_prompts}.npz") + + # Load embeddings in batch + ul2_emb = ms.Tensor(np.load(ul2_files)["text_emb"], dtype=ms.float32) + byt5_emb = ms.Tensor(np.load(byt5_files)["text_emb"], dtype=ms.float32) + ul2_emb = ul2_emb.unsqueeze(0) + byt5_emb = byt5_emb.unsqueeze(0) + + # Create placeholder metaclip embedding matching batch size + metaclip_emb = ms.Tensor(np.ones((ul2_emb.shape[0], 300, 1280)), dtype=ms.float32) + return ul2_emb, metaclip_emb, byt5_emb + + +def init_models(args): + """Initialize MovieGen models with specified configurations""" + # Initialize TAE + logger.info("Initializing TAE...") + tae_args = args.tae.as_dict() + tae_dtype = tae_args.pop("dtype") + tae = TemporalAutoencoder(**tae_args).set_train(False) + + if tae_dtype != "fp32": + amp.custom_mixed_precision( + tae, + black_list=amp.get_black_list() + [nn.GroupNorm, nn.AvgPool2d, nn.Upsample], + dtype=MODEL_DTYPE[tae_dtype], + ) + + # Initialize Transformer model + logger.info("Initializing Transformer model...") + model = init_model(in_channels=tae.out_channels, **args.model).set_train(False) + + return model, tae + + +def create_pipeline(model, tae, args): + """Create MovieGen inference pipeline""" + img_h, img_w = args.image_size if isinstance(args.image_size, list) else (args.image_size, args.image_size) + latent_size = tae.get_latent_size((args.num_frames, img_h, img_w)) + + return InferPipeline( + model, + tae, + latent_size, + scale_factor=args.scale_factor, + guidance_scale=args.guidance_scale, + num_sampling_steps=args.num_sampling_steps, + sample_method=args.sample_method, + micro_batch_size=args.micro_batch_size, + ) + + +def generate_video(selected_prompts: List[str], args, pipeline, progress=gr.Progress()) -> List[str]: + """Generate videos for selected prompts""" + progress(0.1, "Loading embeddings...") + ul2_emb, metaclip_emb, byt5_emb = load_embeddings(selected_prompts, args) + + progress(0.2, "Generating videos...") + start_time = time.perf_counter() + sample, latent = pipeline( + ul2_emb=ul2_emb, + metaclip_emb=metaclip_emb, + byt5_emb=byt5_emb, + num_frames=args.num_frames, + ) + # import pdb + # pdb.set_trace() + generation_time = time.perf_counter() - start_time + + progress(0.8, "Saving videos...") + save_dir = os.path.join(args.output_path, "gradio_samples") + if args.append_timestamp: + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + save_dir = os.path.join(save_dir, time_str) + os.makedirs(save_dir, exist_ok=True) + + output_files = [] + # for i, prompt in enumerate(selected_prompts): + output_file = os.path.join(save_dir, f"{selected_prompts}.{args.save_format}") + save_videos(to_numpy(sample[0]), output_file, fps=args.fps) + output_files.append(output_file) + + logger.info( + f"Videos generated in {generation_time: .2f}s " + f"({args.num_sampling_steps * len(selected_prompts) / generation_time: .2f} steps/s)" + ) + + return output_files + + +def create_demo(args): + """Create and configure Gradio interface""" + # Initialize models and pipeline + model, tae = init_models(args) + pipeline = create_pipeline(model, tae, args) + + # Get available prompts + ul2_emb, _, _ = prepare_captions(**args.text_emb) + prompts = [os.path.basename(p)[:-4] for p in ul2_emb] + + # Create Gradio interface + with gr.Blocks() as demo: + gr.Markdown("# MovieGen Video Generation Demo") + gr.Markdown(f"Model: {args.model.name}") + + with gr.Row(): + with gr.Column(): + prompt = gr.Dropdown( + choices=prompts, + label="Select Pre-computed Prompt", + info="Choose from available pre-computed prompts", + ) + generate_btn = gr.Button("Generate Video", variant="primary") + + with gr.Column(): + video_output = gr.Video(label="Generated Video") + info_box = gr.Textbox(label="Generation Info", interactive=False) + + def generate_and_log(prompt_name): + print("Prompt name ", prompt_name) + output_file = generate_video(prompt_name, args, pipeline) + info = f"Successfully generated video for prompt: {prompt_name}" + return output_file[0], info + + generate_btn.click( + fn=generate_and_log, + inputs=[prompt], + outputs=[video_output, info_box], + ) + + return demo + + +if __name__ == "__main__": + parser = ArgumentParser(description="MovieGen Gradio demo") + parser.add_argument( + "-c", + "--config", + action=ActionConfigFile, + help="Path to MovieGen config file", + ) + + # Add all necessary arguments + parser.add_function_arguments(init_train_env, "env") + parser.add_function_arguments(init_model, "model", skip={"in_channels"}) + + # TAE parameters + tae_group = parser.add_argument_group("TAE parameters") + tae_group.add_class_arguments(TemporalAutoencoder, "tae", instantiate=False) + tae_group.add_argument("--tae.dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"]) + + # Inference parameters + infer_group = parser.add_argument_group("Inference parameters") + infer_group.add_class_arguments(InferPipeline, skip={"model", "tae", "latent_size"}, instantiate=False) + infer_group.add_argument("--image_size", type=int, nargs="+", default=[256, 455]) + infer_group.add_argument("--num_frames", type=int, default=32) + infer_group.add_argument("--fps", type=int, default=16) + infer_group.add_function_arguments(prepare_captions, "text_emb", skip={"rank_id", "device_num"}) + infer_group.add_argument("--batch_size", type=int, default=2) + + # Save options + save_group = parser.add_argument_group("Saving options") + save_group.add_argument("--save_format", default="mp4", choices=["gif", "mp4", "png"]) + save_group.add_argument("--output_path", default="output/", type=path_type("dcc")) + save_group.add_argument("--append_timestamp", type=bool, default=True) + save_group.add_argument( + "--save_latent", + type=bool, + default=False, + help="Save denoised video latent. If True, the denoised latents will be saved in $output_path/denoised_latents", + ) + args = parser.parse_args() + + # Set up logging + os.makedirs(os.path.join(args.output_path, "logs"), exist_ok=True) + set_logger(name="", output_dir=os.path.join(args.output_path, "logs")) + + # Create and launch demo + demo = create_demo(args) + demo.launch() From 7ee81181ee318a671970022eb362fd95adf2289a Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:26:04 +0800 Subject: [PATCH 099/122] update docs and add stage 3 configs --- examples/moviegen/README.md | 17 ++-- ...t2i_256x256.yaml => stage1_t2i_256px.yaml} | 4 +- ...iv_256x256.yaml => stage2_t2iv_256px.yaml} | 6 +- .../configs/train/stage3_t2iv_768px.yaml | 85 +++++++++++++++++++ examples/moviegen/scripts/stage1_train.sh | 4 +- examples/moviegen/scripts/stage2_train.sh | 6 +- examples/moviegen/scripts/stage3_train.sh | 25 ++++++ 7 files changed, 129 insertions(+), 18 deletions(-) rename examples/moviegen/configs/train/{stage1_t2i_256x256.yaml => stage1_t2i_256px.yaml} (94%) rename examples/moviegen/configs/train/{stage2_t2iv_256x256.yaml => stage2_t2iv_256px.yaml} (91%) create mode 100644 examples/moviegen/configs/train/stage3_t2iv_768px.yaml create mode 100644 examples/moviegen/scripts/stage3_train.sh diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 78077f6e44..acf9d11a78 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -173,6 +173,7 @@ To train Movie Gen, run the following commands: ```shell scripts/stage1_train.sh # for stage 1 training scripts/stage2_train.sh # for stage 2 training +scripts/stage3_train.sh # for stage 3 training (currently under verification) ``` ### Dataset Preparation @@ -214,18 +215,18 @@ Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in Graph mode. > [!NOTE] > We trained all the models using BF16 precision. -| Model | NPUs | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | TAE Cache | Time (s/step) | Config | -|:-----:|:----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:---------:|:-------------:|:------------------------------------------------------------------:| -| 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | Yes | 1.29 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | O1 | 6m | ON
(Every 2 blocks) | 5 | Yes | 5.09 | [stage2_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | -| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | O1 | 7m 30s | ON | 5 | Yes | 88.5 | [stage3_t2iv_768px.yaml](configs/train/stage2_t2iv_256x256.yaml) | -| 1B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 2m 15s | ON | 1 | Yes | 0.53 | [stage1_t2i_256x256.yaml](configs/train/stage1_t2i_256x256.yaml) | -| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | O0 | 1m 55s | ON | 1 | Yes | 2.07 | [stage2_t2iv_256x256.yaml](configs/train/stage2_t2iv_256x256.yaml) | +| Model | Cards | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | TAE Cache | Time (s/step) | Config | +|:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:---------:|:-------------:|:--------------------------------------------------------------:| +| 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | Yes | 1.29 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) | +| 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | O1 | 6m | ON
(Every 2 blocks) | 5 | Yes | 5.09 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) | +| 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | O1 | 7m 30s | ON | 5 | Yes | 88.5 | [stage3_t2iv_768px.yaml](configs/train/stage3_t2iv_768px.yaml) | +| 1B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 2m 15s | ON | 1 | Yes | 0.53 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) | +| 1B | 8 | 2 (T2I/V) | Image: 10
Video: 10 | 256x455
32 frames | O0 | 1m 55s | ON | 1 | Yes | 2.07 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) | ### Validation During Training Validation can be enabled by either setting parameters in the `valid` field of the configuration file -([example](configs/train/stage1_t2i_256x256.yaml)) or by supplying the following arguments to `train.py`: +([example](configs/train/stage1_t2i_256px.yaml)) or by supplying the following arguments to `train.py`: ```shell --valid.sampling_steps 10 \ diff --git a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml b/examples/moviegen/configs/train/stage1_t2i_256px.yaml similarity index 94% rename from examples/moviegen/configs/train/stage1_t2i_256x256.yaml rename to examples/moviegen/configs/train/stage1_t2i_256px.yaml index cad6a1bc95..4c140a7e2f 100644 --- a/examples/moviegen/configs/train/stage1_t2i_256x256.yaml +++ b/examples/moviegen/configs/train/stage1_t2i_256px.yaml @@ -27,7 +27,7 @@ dataset: ul2: EMPTY_TEXT_EMB byt5: EMPTY_TEXT_EMB text_drop_prob: 0.2 - target_size: [ 256, 256 ] + target_size: [ 256, 455 ] apply_transforms_dataset: True output_columns: ["video", "ul2_caption", "byt5_caption"] @@ -38,7 +38,7 @@ dataloader: train: steps: 30000 - output_path: ../../output/stage1_t2i_256x256 # the path is relative to this config + output_path: ../../output/stage1_t2i_256px # the path is relative to this config lr_scheduler: name: constant diff --git a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml b/examples/moviegen/configs/train/stage2_t2iv_256px.yaml similarity index 91% rename from examples/moviegen/configs/train/stage2_t2iv_256x256.yaml rename to examples/moviegen/configs/train/stage2_t2iv_256px.yaml index 844a0f1493..b07c32af2d 100644 --- a/examples/moviegen/configs/train/stage2_t2iv_256x256.yaml +++ b/examples/moviegen/configs/train/stage2_t2iv_256px.yaml @@ -27,21 +27,21 @@ dataset: ul2: EMPTY_TEXT_EMB byt5: EMPTY_TEXT_EMB text_drop_prob: 0.2 - target_size: [ 256, 256 ] + target_size: [ 256, 455 ] sample_n_frames: 256 # FIXME: add variable frames support. apply_transforms_dataset: True output_columns: ["video", "ul2_caption", "byt5_caption"] dataloader: batch_size: - image_batch_size: 70 + image_batch_size: 1 video_batch_size: 1 shuffle: True num_workers_dataset: 4 train: steps: 20000 - output_path: ../../output/stage2_t2iv_256x256 # the path is relative to this config + output_path: ../../output/stage2_t2iv_256px # the path is relative to this config lr_scheduler: name: constant diff --git a/examples/moviegen/configs/train/stage3_t2iv_768px.yaml b/examples/moviegen/configs/train/stage3_t2iv_768px.yaml new file mode 100644 index 0000000000..a3e12bbb15 --- /dev/null +++ b/examples/moviegen/configs/train/stage3_t2iv_768px.yaml @@ -0,0 +1,85 @@ +env: + mode: 0 + jit_level: O0 + seed: 42 + distributed: False + debug: False + +model: + name: llama-5B + pretrained_model_path: + enable_flash_attention: True + recompute_every_nth_block: 1 + dtype: bf16 + +tae: + pretrained: "" + use_tile: True + dtype: bf16 + +dataset: + csv_path: CSV_PATH + video_folder: VIDEO_FOLDER + text_emb_folder: + ul2: UL2_FOLDER + byt5: BYT5_FOLDER + empty_text_emb: + ul2: EMPTY_TEXT_EMB + byt5: EMPTY_TEXT_EMB + text_drop_prob: 0.2 + target_size: [ 576, 1024 ] + sample_n_frames: 256 # FIXME: add variable frames support. + apply_transforms_dataset: True + output_columns: ["video", "ul2_caption", "byt5_caption"] + +dataloader: + batch_size: + image_batch_size: 1 + video_batch_size: 1 + shuffle: True + num_workers_dataset: 4 + +train: + steps: 20000 + output_path: ../../output/stage2_t2iv_256px # the path is relative to this config + + lr_scheduler: + name: constant + lr: 6.0e-5 + warmup_steps: 1000 + + lr_reduce_on_plateau: + factor: 0.5 + patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps + mode: min + min_delta: 0.01 + min_lr: 1.0e-6 + + optimizer: + name: adamw_re + eps: 1e-15 + betas: [0.9, 0.999] + weight_decay: 0.1 + + loss_scaler: + class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16 + init_args: + loss_scale_value: 1 + + ema: + ema_decay: 0.9999 + offloading: True + + settings: + zero_stage: 0 + gradient_accumulation_steps: 1 + clip_grad: True + clip_norm: 1.0 + + save: + ckpt_save_policy: latest_k + ckpt_save_interval: &save_interval 100 + ckpt_max_keep: 10 + log_interval: 1 + save_ema_only: False + record_lr: False diff --git a/examples/moviegen/scripts/stage1_train.sh b/examples/moviegen/scripts/stage1_train.sh index ae1402acde..ebdc19901d 100644 --- a/examples/moviegen/scripts/stage1_train.sh +++ b/examples/moviegen/scripts/stage1_train.sh @@ -5,11 +5,11 @@ export MS_MEMORY_STATISTIC=0 # log level export GLOG_v=2 -output_dir=output/stage1_t2i_256x256/$(date +"%Y.%m.%d-%H.%M.%S") +output_dir=output/stage1_t2i_256px/$(date +"%Y.%m.%d-%H.%M.%S") msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ python train.py \ - --config configs/train/stage1_t2i_256x256.yaml \ + --config configs/train/stage1_t2i_256px.yaml \ --env.mode 0 \ --env.jit_level O1 \ --env.max_device_memory 59GB \ diff --git a/examples/moviegen/scripts/stage2_train.sh b/examples/moviegen/scripts/stage2_train.sh index 5182f302d2..5f047378fa 100644 --- a/examples/moviegen/scripts/stage2_train.sh +++ b/examples/moviegen/scripts/stage2_train.sh @@ -8,13 +8,13 @@ export MS_DEV_ENABLE_KERNEL_PACKET=on # log level export GLOG_v=2 -output_dir=output/stage2_t2iv_256x256/$(date +"%Y.%m.%d-%H.%M.%S") +output_dir=output/stage2_t2iv_256px/$(date +"%Y.%m.%d-%H.%M.%S") msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ python train.py \ - --config configs/train/stage2_t2iv_256x256.yaml \ + --config configs/train/stage2_t2iv_256px.yaml \ --env.mode 0 \ - --env.jit_level O0 \ + --env.jit_level O1 \ --env.max_device_memory 59GB \ --env.distributed True \ --train.settings.zero_stage 2 \ diff --git a/examples/moviegen/scripts/stage3_train.sh b/examples/moviegen/scripts/stage3_train.sh new file mode 100644 index 0000000000..de84c1c119 --- /dev/null +++ b/examples/moviegen/scripts/stage3_train.sh @@ -0,0 +1,25 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# plot memory usage, feature/model: 1 +export MS_MEMORY_STATISTIC=0 + +# operation/graph fusion for dynamic shape +export MS_DEV_ENABLE_KERNEL_PACKET=on + +# log level +export GLOG_v=2 + +output_dir=output/stage3_t2iv_768px/$(date +"%Y.%m.%d-%H.%M.%S") + +msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +python train.py \ + --config configs/train/stage3_t2iv_768px.yaml \ + --env.mode 0 \ + --env.jit_level O1 \ + --env.max_device_memory 59GB \ + --env.distributed True \ + --train.settings.zero_stage 2 \ + --dataset.csv_path CSV_PATH \ + --dataset.video_folder VIDEO_FOLDER \ + --dataset.text_emb_folder.ul2 UL2_FOLDER \ + --dataset.text_emb_folder.byt5 BYT5_FOLDER \ + --train.output_path "$output_dir" From 06ce9b4046103ce6639875270fb5c11f212fba1a Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 24 Dec 2024 09:54:00 +0800 Subject: [PATCH 100/122] add ZeRO-3 support to Movie Gen --- mindone/models/modules/parallel/__init__.py | 8 +++++--- mindone/models/modules/parallel/dense.py | 21 ++++++++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py index e3f77a9537..101c1a958a 100644 --- a/mindone/models/modules/parallel/__init__.py +++ b/mindone/models/modules/parallel/__init__.py @@ -1,7 +1,7 @@ -from mindspore import nn +from mindspore import mint, nn from .conv import Conv1d, Conv2d, Conv3d -from .dense import Dense +from .dense import Dense, Linear # {Original MindSpore Cell: New Cell in ZeRO3} PARALLEL_MODULES = { @@ -9,5 +9,7 @@ nn.Conv2d: Conv2d, nn.Conv3d: Conv3d, nn.Dense: Dense, + mint.nn.Linear: Linear, } -__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"] + +__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense", "Linear"] diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index 66ef7fef71..4db47ede8a 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -1,4 +1,8 @@ -from mindspore import nn, ops +from typing import Literal, Optional, Union + +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore import mint, nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode @@ -8,8 +12,14 @@ class Dense(nn.Cell): - def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): - super(Dense, self).__init__(auto_prefix=False) + def __init__( + self, + net: Union[nn.Dense, mint.nn.Linear], + zero_stage: Literal[0, 1, 2, 3] = 0, + op_group: str = GlobalComm.WORLD_COMM_GROUP, + cell_type: Optional[mstype] = None, + ): + super().__init__(auto_prefix=False) self.net = net self.set_param_wrapper(zero_stage, op_group, cell_type) @@ -43,3 +53,8 @@ def construct(self, x): out_shape = x_shape[:-1] + (x.shape[-1],) x = x.reshape(out_shape) return x + + +class Linear(Dense): + def construct(self, x: Tensor) -> Tensor: + return self.net.dense(x, self.param_wrapper_w(self.net.weight), self.param_wrapper_b(self.net.bias)) From fbe4e31095b9b42952ffbc837c8fa0a4a3444834 Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 24 Dec 2024 11:43:40 +0800 Subject: [PATCH 101/122] add Model Parallel --- examples/moviegen/README.md | 1 + examples/moviegen/mg/models/llama/block.py | 217 +++++++++- examples/moviegen/mg/models/llama/network.py | 160 ++++++- examples/moviegen/mg/parallel/__init__.py | 2 + examples/moviegen/mg/parallel/layers.py | 398 ++++++++++++++++++ .../moviegen/mg/parallel/parallel_states.py | 38 ++ .../moviegen/mg/schedulers/rectified_flow.py | 9 +- examples/moviegen/scripts/30B_stage2_train.sh | 30 ++ .../parallel/run_test_llama3_parallel.sh | 13 + .../run_test_llama3_parallel_block.sh | 13 + .../run_test_llama3_parallel_layer.sh | 13 + .../tests/parallel/run_test_rflow_parallel.sh | 13 + .../tests/parallel/test_llama3_parallel.py | 113 +++++ .../parallel/test_llama3_parallel_block.py | 107 +++++ .../parallel/test_llama3_parallel_layer.py | 125 ++++++ .../tests/parallel/test_rflow_parallel.py | 61 +++ examples/moviegen/tests/parallel/utils.py | 32 ++ examples/moviegen/train.py | 10 +- mindone/models/modules/parallel/dense.py | 2 +- 19 files changed, 1347 insertions(+), 10 deletions(-) create mode 100644 examples/moviegen/mg/parallel/__init__.py create mode 100644 examples/moviegen/mg/parallel/layers.py create mode 100644 examples/moviegen/mg/parallel/parallel_states.py create mode 100644 examples/moviegen/scripts/30B_stage2_train.sh create mode 100644 examples/moviegen/tests/parallel/run_test_llama3_parallel.sh create mode 100644 examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh create mode 100644 examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh create mode 100644 examples/moviegen/tests/parallel/run_test_rflow_parallel.sh create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel.py create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_block.py create mode 100644 examples/moviegen/tests/parallel/test_llama3_parallel_layer.py create mode 100644 examples/moviegen/tests/parallel/test_rflow_parallel.py create mode 100644 examples/moviegen/tests/parallel/utils.py diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index acf9d11a78..5a95d96caf 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -217,6 +217,7 @@ Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in Graph mode. | Model | Cards | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | TAE Cache | Time (s/step) | Config | |:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:---------:|:-------------:|:--------------------------------------------------------------:| +| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | O1 | 6m | ON | 1 | Yes | 23.8 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) | | 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | Yes | 1.29 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) | | 5B | 8 | 2 (T2I/V) | Image: 1
Video: 1 | 256x455
256 frames | O1 | 6m | ON
(Every 2 blocks) | 5 | Yes | 5.09 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) | | 5B | 8 | 3 (T2I/V) | Image: 1
Video: 1 | 576x1024
256 frames | O1 | 7m 30s | ON | 5 | Yes | 88.5 | [stage3_t2iv_768px.yaml](configs/train/stage3_t2iv_768px.yaml) | diff --git a/examples/moviegen/mg/models/llama/block.py b/examples/moviegen/mg/models/llama/block.py index b9645da307..703b4cef7d 100644 --- a/examples/moviegen/mg/models/llama/block.py +++ b/examples/moviegen/mg/models/llama/block.py @@ -1,7 +1,13 @@ -import logging from typing import Optional, Sequence, Tuple, Union import numpy as np +from mg.parallel import ( + ColumnParallelLinear, + FusedColumnParallelLinear, + FusedRowParallelLinear, + GatherForwardReduceScatterBackward, + RowParallelLinear, +) import mindspore as ms import mindspore.mint as mint @@ -9,12 +15,11 @@ import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor +from mindspore.communication import GlobalComm from mindspore.ops.operations.nn_ops import FlashAttentionScore from .activation import ACT2FN -logger = logging.getLogger(__name__) - class LlamaRMSNorm(nn.Cell): def __init__(self, hidden_size: Union[int, Sequence[int]], eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: @@ -51,6 +56,77 @@ def construct(self, hidden_state: Tensor) -> Tensor: return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) +class TensorParallelLlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 8192, + hidden_size: int = 3072, + hidden_act: str = "silu", + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, group=group, dtype=dtype + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False, input_is_parallel=True, group=group, dtype=dtype + ) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + def load_weight_from_non_parallel_cell(self, target: LlamaMLP): + self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) + self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) + self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) + + +class FusedTensorParallelLlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 8192, + hidden_size: int = 3072, + hidden_act: str = "silu", + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = FusedColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype + ) + self.up_proj = FusedColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False, gather_output=False, dim=dim, group=group, dtype=dtype + ) + self.down_proj = FusedRowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dim=dim, + group=group, + dtype=dtype, + ) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + def load_weight_from_non_parallel_cell(self, target: LlamaMLP): + self.gate_proj.load_weight_from_non_parallel_cell(target.gate_proj) + self.up_proj.load_weight_from_non_parallel_cell(target.up_proj) + self.down_proj.load_weight_from_non_parallel_cell(target.down_proj) + + def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: if n_rep == 1: return hidden_states @@ -131,6 +207,81 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso return attn_output +class ContextParallelLlamaAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype) + self.k_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.v_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype + ) + self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) + + self.gather_forward_reduce_scatter_backward = GatherForwardReduceScatterBackward(dim=1, group=group) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + key_states = self.gather_forward_reduce_scatter_backward(key_states) + value_states = self.gather_forward_reduce_scatter_backward(value_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = mint.permute(key_states, (0, 1, 3, 2)) + attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim)) + + # upcast attention to fp32 + attn_weights = attn_weights.to(ms.float32) + attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = mint.matmul(attn_weights, value_states) + + attn_output = mint.permute(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + class LlamaFlashAttention(LlamaAttention): def __init__( self, @@ -186,6 +337,66 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso return attn_output +class ContextParallelLlamaFlashAttention(ContextParallelLlamaAttention): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + attention_bias: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + group=group, + dtype=dtype, + ) + self.flash_attention = FlashAttentionScore( + self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + ) + + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) + + key_states = self.gather_forward_reduce_scatter_backward(key_states) + value_states = self.gather_forward_reduce_scatter_backward(value_states) + + query_states = ops.reshape(query_states, (bsz, -1, self.num_heads, self.head_dim)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, -1, self.num_key_value_heads, self.head_dim)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Reshape to the expected shape and dtype for Flash Attention + query_states = mint.permute(query_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) + + _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + class PatchEmbed3D(nn.Cell): def __init__( self, diff --git a/examples/moviegen/mg/models/llama/network.py b/examples/moviegen/mg/models/llama/network.py index e33af9214b..c732b8a5ed 100644 --- a/examples/moviegen/mg/models/llama/network.py +++ b/examples/moviegen/mg/models/llama/network.py @@ -4,22 +4,29 @@ from typing import Literal, Optional, Tuple, Union import numpy as np +from mg.parallel import GatherForwardSplitBackward, SplitForwardGatherBackward +from mg.parallel.parallel_states import get_model_parallel_group from mindspore import Parameter, Tensor from mindspore import dtype as mstype from mindspore import lazy_inline, load_checkpoint, load_param_into_net, mint, nn, ops +from mindspore.communication import GlobalComm, get_group_size from mindone.models.utils import normal_, zeros_ from ..text_encoders import TextProjector from .activation import ACT2FN from .block import ( + ContextParallelLlamaAttention, + ContextParallelLlamaFlashAttention, + FusedTensorParallelLlamaMLP, LinearPatchEmbed3D, LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D, + TensorParallelLlamaMLP, TimestepEmbedder, ) @@ -32,6 +39,11 @@ "flash_attention": LlamaFlashAttention, } +CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES = { + "eager": ContextParallelLlamaAttention, + "flash_attention": ContextParallelLlamaFlashAttention, +} + def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: return x * (1 + scale) + shift @@ -122,6 +134,117 @@ def construct( return hidden_states +class ModelParallelLlamaDecoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + attention_bias: bool = False, + hidden_act: str = "silu", + attn_implementation: Literal["eager", "flash_attention"] = "eager", + fused_tensor_parallel: bool = True, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: mstype.Type = mstype.float32, + ) -> None: + super().__init__() + + # 3.1.6 Context Parallelism + self.self_attn = CONTEXT_PARALLEL_Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + # 3.1.6 Tensor Parallelism + if fused_tensor_parallel: + self.mlp = FusedTensorParallelLlamaMLP( + intermediate_size=intermediate_size, + hidden_size=hidden_size, + hidden_act=hidden_act, + dim=1, + group=group, + dtype=dtype, + ) + else: + self.mlp = TensorParallelLlamaMLP( + intermediate_size=intermediate_size, + hidden_size=hidden_size, + hidden_act=hidden_act, + group=group, + dtype=dtype, + ) + + self.scale_shift_table = Parameter(Tensor(np.random.randn(1, 6, hidden_size) / hidden_size**0.5, dtype=dtype)) + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + if not fused_tensor_parallel: + self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=group) + self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=group) + else: + self.split_forward_gather_backward = nn.Identity() + self.gather_forward_split_backward = nn.Identity() + + def construct( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + modulation_parameters: Tensor, + position_embedding: Tensor, + ) -> Tensor: + B = hidden_states.shape[0] + + # 3.1.3 Positional Embedding + hidden_states = hidden_states + position_embedding + + # 3.1.3 Adaptive Layer Norm + modulation_parameters = self.scale_shift_table.to(hidden_states.dtype) + ops.reshape( + modulation_parameters, (B, 6, -1) + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk(modulation_parameters, 6, dim=1) + + # Self Attention (Bi-Directional Attention) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) + hidden_states = self.self_attn(hidden_states) + hidden_states = gate_msa * hidden_states + hidden_states = residual + hidden_states + + # 3.1.3 Cross Attention + residual = hidden_states + hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp) + hidden_states = self.gather_forward_split_backward(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.split_forward_gather_backward(hidden_states) + hidden_states = gate_mlp * hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + class LlamaFinalLayer(nn.Cell): def __init__( self, @@ -168,6 +291,7 @@ def __init__( recompute_every_nth_block: Optional[int] = None, use_linear_patch_embedder: bool = True, model_parallelism: bool = False, + fused_tensor_parallel: bool = True, post_init_weight: bool = True, dtype: mstype.Type = mstype.float32, ) -> None: @@ -182,9 +306,28 @@ def __init__( self.max_length = max_length self.model_parallelism = model_parallelism self._dtype = dtype + mp_group = get_model_parallel_group() if self.model_parallelism: - raise NotImplementedError("Model parallelism is not supported yet.") + self.layers = nn.CellList( + [ + ModelParallelLlamaDecoderLayer( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + fused_tensor_parallel=fused_tensor_parallel, + group=mp_group, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) else: self.layers = nn.CellList( [ @@ -230,6 +373,11 @@ def __init__( out_features=self.hidden_size, layer_norm=LlamaRMSNorm, norm_eps=self.rms_norm_eps, dtype=dtype ) + if self.model_parallelism: + self.group_size = get_group_size(mp_group) + self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, grad_scale="down", group=mp_group) + self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, grad_scale="up", group=mp_group) + # post-init if post_init_weight: self.initializer_range = initializer_range @@ -343,9 +491,19 @@ def construct( # main blocks hidden_states = latent_embedding + # 3.1.6 Sequence Parallelism Start + if self.model_parallelism: + # assert hidden_states.shape[1] % self.group_size == 0 + hidden_states = self.split_forward_gather_backward(hidden_states) + position_embedding = self.split_forward_gather_backward(position_embedding) + for decoder_layer in self.layers: hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding) + # 3.1.6 Sequence Parallelism End + if self.model_parallelism: + hidden_states = self.gather_forward_split_backward(hidden_states) + # final block hidden_states = self.final_layer(hidden_states, timestep_embedding) diff --git a/examples/moviegen/mg/parallel/__init__.py b/examples/moviegen/mg/parallel/__init__.py new file mode 100644 index 0000000000..de133abd08 --- /dev/null +++ b/examples/moviegen/mg/parallel/__init__.py @@ -0,0 +1,2 @@ +from .layers import * +from .parallel_states import * diff --git a/examples/moviegen/mg/parallel/layers.py b/examples/moviegen/mg/parallel/layers.py new file mode 100644 index 0000000000..d238d47391 --- /dev/null +++ b/examples/moviegen/mg/parallel/layers.py @@ -0,0 +1,398 @@ +import numbers +from typing import Callable, Literal, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.initializer import Initializer +from mindspore.communication import GlobalComm, get_group_size, get_rank + +__all__ = [ + "SplitForwardGatherBackward", + "GatherForwardSplitBackward", + "GatherForwardReduceScatterBackward", + "ColumnParallelLinear", + "RowParallelLinear", + "FusedColumnParallelLinear", + "FusedRowParallelLinear", +] + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: + dim_size = x.shape[dim] + tensor_list = x.split(dim_size // world_size, axis=dim) + x = tensor_list[rank] + return x + + +class _CopyToModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = self.reduce(dout) + return (dout,) + + +class _ReduceFromModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + return self.reduce(x) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + return (dout,) + + +class _ScatterToModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.gather = ops.AllGather(group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + + def construct(self, x: Tensor) -> Tensor: + return _split(x, -1, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _communicate_along_dim(dout, -1, self.gather) + return (dout,) + + +class _GatherFromModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.gather = ops.AllGather(group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, -1, self.gather) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _split(dout, -1, self.rank, self.world_size) + return (dout,) + + +class _GatherToModelParallelRegion(nn.Cell): + def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.dim = dim + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.scale = self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, self.dim, self.gather) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) + return (dout,) + + +class _ReduceScatterFromModelParallelRegion(nn.Cell): + def __init__(self, dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.dim = dim + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, self.dim, self.reduce_scatter) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class SplitForwardGatherBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _split(x, self.dim, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class GatherForwardSplitBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _split(dout, self.dim, self.rank, self.world_size) + return (dout,) + + +class GatherForwardReduceScatterBackward(nn.Cell): + def __init__(self, dim: int = 0, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + self.reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _communicate_along_dim(dout, self.dim, self.reduce_scatter) + return (dout,) + + +class ColumnParallelLinear(nn.Cell): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + gather_output: bool = True, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert out_features % self.world_size == 0 + self.out_features_per_partition = out_features // self.world_size + self.gather_output = gather_output + + self.copy_to_tensor_parallel_region = _CopyToModelParallelRegion(group=group) + self.linear = mint.nn.Linear( + in_features, + self.out_features_per_partition, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if self.gather_output: + self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + x = self.copy_to_tensor_parallel_region(x) + x = self.linear(x) + if self.gather_output: + x = self.gather_from_tensor_parallel_region(x) + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] + self.linear.bias.set_data(bias) + + +class FusedColumnParallelLinear(nn.Cell): + """For tensor parallel using sequence parallel input + It is a fused operation of gather_forward_split_backward & allreduce backward + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + gather_output: bool = True, + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert out_features % self.world_size == 0 + self.out_features_per_partition = out_features // self.world_size + self.gather_output = gather_output + + self.gather_to_tensor_parallel_region = _GatherToModelParallelRegion(dim=dim, group=group) + self.linear = mint.nn.Linear( + in_features, + self.out_features_per_partition, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if self.gather_output: + self.gather_from_tensor_parallel_region = _GatherFromModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + x = self.gather_to_tensor_parallel_region(x) + x = self.linear(x) + if self.gather_output: + x = self.gather_from_tensor_parallel_region(x) + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=0)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + bias = mint.chunk(target.bias, self.world_size, dim=0)[self.rank] + self.linear.bias.set_data(bias) + + +class RowParallelLinear(nn.Cell): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + input_is_parallel: bool = False, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert in_features % self.world_size == 0 + self.in_features_per_partition = in_features // self.world_size + self.input_is_parallel = input_is_parallel + + self.reduce_from_tensor_parallel_region = _ReduceFromModelParallelRegion(group=group) + self.linear = mint.nn.Linear( + self.in_features_per_partition, + out_features, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if not self.input_is_parallel: + self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + if not self.input_is_parallel: + x = self.scatter_to_tensor_parallel_region(x) + x = self.linear.dense(x, self.linear.weight) + x = self.reduce_from_tensor_parallel_region(x) + if self.linear.bias is not None: + x = x + self.linear.bias + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + self.linear.bias.set_data(target.bias) + + +class FusedRowParallelLinear(nn.Cell): + """For tensor parallel to sequence parallel output + It is a fused operation of split_forward_gather_backward & allreduce forward + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + input_is_parallel: bool = False, + dim: int = 1, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.rank = get_rank(group) + self.world_size = get_group_size(group) + assert in_features % self.world_size == 0 + self.in_features_per_partition = in_features // self.world_size + self.input_is_parallel = input_is_parallel + + self.reduce_from_tensor_parallel_region = _ReduceScatterFromModelParallelRegion(dim=dim, group=group) + self.linear = mint.nn.Linear( + self.in_features_per_partition, + out_features, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if not self.input_is_parallel: + self.scatter_to_tensor_parallel_region = _ScatterToModelParallelRegion(group=group) + + def construct(self, x: Tensor) -> Tensor: + if not self.input_is_parallel: + x = self.scatter_to_tensor_parallel_region(x) + x = self.linear.dense(x, self.linear.weight) + x = self.reduce_from_tensor_parallel_region(x) + if self.linear.bias is not None: + x = x + self.linear.bias + return x + + def load_weight_from_non_parallel_cell(self, target: mint.nn.Linear): + weight = mint.chunk(target.weight, self.world_size, dim=1)[self.rank] + self.linear.weight.set_data(weight) + + if target.bias is not None: + self.linear.bias.set_data(target.bias) diff --git a/examples/moviegen/mg/parallel/parallel_states.py b/examples/moviegen/mg/parallel/parallel_states.py new file mode 100644 index 0000000000..2a8d9c0a0c --- /dev/null +++ b/examples/moviegen/mg/parallel/parallel_states.py @@ -0,0 +1,38 @@ +from typing import Optional + +from mindspore.communication import create_group, get_group_size, get_rank + +__all__ = ["set_model_parallel_group", "get_model_parallel_group", "create_parallel_group"] + + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_model_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["model"] = group + + +def get_model_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("model", None) + + +def create_parallel_group(model_parallel_shards: int = 1) -> None: + if model_parallel_shards <= 1: + raise ValueError( + f"`model_parallel_shards` must be larger than 1 to enable model parallel, but get `{model_parallel_shards}`." + ) + + device_num = get_group_size() + if device_num % model_parallel_shards != 0: + raise ValueError( + f"Total number of devices ({device_num}) must be divisible by the number of model parallel shards ({model_parallel_shards})." + ) + + rank_id = get_rank() + + if model_parallel_shards > 1: + mp_group_id = rank_id // model_parallel_shards + mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) + mp_group_name = f"mp_group_{mp_group_id}" + create_group(mp_group_name, mp_group_rank_ids) + set_model_parallel_group(mp_group_name) diff --git a/examples/moviegen/mg/schedulers/rectified_flow.py b/examples/moviegen/mg/schedulers/rectified_flow.py index 6ed5a776cb..9b2457fd4f 100644 --- a/examples/moviegen/mg/schedulers/rectified_flow.py +++ b/examples/moviegen/mg/schedulers/rectified_flow.py @@ -8,8 +8,10 @@ from mindspore import Tensor from mindspore import dtype as mstype from mindspore import mint, nn, ops +from mindspore.communication import get_rank from ..models import LlamaModel +from ..parallel import get_model_parallel_group logger = logging.getLogger(__name__) @@ -101,7 +103,12 @@ def __init__( self.model = model self.criteria = nn.MSELoss() - self.mp_group = None + self.mp_group = get_model_parallel_group() + if self.mp_group is not None: + logging.info( + f"Broadcasting all random variables from rank (0) to current rank ({get_rank(self.mp_group)}) in group `{self.mp_group}`." + ) + self.broadcast = ops.Broadcast(0, group=self.mp_group) def _discrete_sample(self, size: int) -> Tensor: return ops.randint(0, self.num_timesteps, (size,), dtype=mstype.int64) diff --git a/examples/moviegen/scripts/30B_stage2_train.sh b/examples/moviegen/scripts/30B_stage2_train.sh new file mode 100644 index 0000000000..03575ae4f7 --- /dev/null +++ b/examples/moviegen/scripts/30B_stage2_train.sh @@ -0,0 +1,30 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# plot memory usage, feature/model: 1 +export MS_MEMORY_STATISTIC=0 + +# operation/graph fusion for dynamic shape +# export MS_DEV_ENABLE_KERNEL_PACKET=on # TODO: add dynamic shape support + +# log level +export GLOG_v=2 + +output_dir=output/stage2_t2iv_256px/$(date +"%Y.%m.%d-%H.%M.%S") + +msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \ +python train.py \ + --config configs/train/stage2_t2iv_256px.yaml \ + --env.mode 0 \ + --env.jit_level O1 \ + --env.max_device_memory 59GB \ + --env.distributed True \ + --model.name=llama-30B \ + --train.settings.zero_stage 3 \ + --train.model_parallel.model_parallel_shards=8 \ + --train.ema="" \ + --dataset.csv_path CSV_PATH \ + --dataset.video_folder VIDEO_FOLDER \ + --dataset.tae_latent_folder TAE_LATENT_FOLDER \ + --dataset.text_emb_folder.ul2 UL2_FOLDER \ + --dataset.text_emb_folder.byt5 BYT5_FOLDER \ + --dataloader.batch_size=1 \ + --train.output_path "$output_dir" diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh new file mode 100644 index 0000000000..b532dad534 --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh new file mode 100644 index 0000000000..603aac9fce --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel_block.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_block_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_block.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh b/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh new file mode 100644 index 0000000000..ecf23ff9a8 --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_llama3_parallel_layer.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_llama3_parallel_layer_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_llama3_parallel_layer.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh b/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh new file mode 100644 index 0000000000..88ad571cac --- /dev/null +++ b/examples/moviegen/tests/parallel/run_test_rflow_parallel.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname $(dirname "${SCRIPT_DIR}"))" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_rflow_parallel_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_rflow_parallel.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel.py b/examples/moviegen/tests/parallel/test_llama3_parallel.py new file mode 100644 index 0000000000..542e44d007 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel.py @@ -0,0 +1,113 @@ +import argparse +from typing import Tuple + +import numpy as np +from mg.models.llama.network import LlamaModel +from mg.parallel import create_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: + latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) + timestep = ms.Tensor([35], dtype=ms.int64) + ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) + metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) + byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) + return latent_embedding, timestep, ul2_emb, metaclip_emb, byt5_emb + + +def get_network_config(model_parallelism=False, fused_tensor_parallel=False): + config = dict( + num_hidden_layers=2, + attn_implementation="eager", + model_parallelism=model_parallelism, + fused_tensor_parallel=fused_tensor_parallel, + post_init_weight=False, + ) + return config + + +def run_network(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + print("Non-fused tensor parallel:", flush=True) + run_parallel_network(data, fused_tensor_parallel=False) + + print("Fused tensor parallel:", flush=True) + run_parallel_network(data, fused_tensor_parallel=True) + + +def run_parallel_network(data: Tuple[Tensor, ...], fused_tensor_parallel: bool = False, dtype: ms.Type = ms.float32): + # non parallel network + set_random_seed(1024) + non_parallel_network_cfg = get_network_config(model_parallelism=False, fused_tensor_parallel=fused_tensor_parallel) + non_parallel_network = LlamaModel(**non_parallel_network_cfg, dtype=dtype) + + # parallel netowrk + parallel_network_cfg = get_network_config(model_parallelism=True, fused_tensor_parallel=fused_tensor_parallel) + parallel_network = LlamaModel(**parallel_network_cfg, dtype=dtype) + + # load weight + parallel_network.load_weight_from_non_parallel_cell(non_parallel_network) + + # test forward + non_parallel_out = non_parallel_network(*data).asnumpy() + parallel_out = parallel_network(*data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.", flush=True) + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_network) + parallel_mean_net = MeanNet(parallel_network) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(*data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(*data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=2e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_network(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_block.py b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py new file mode 100644 index 0000000000..82141a1d31 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_block.py @@ -0,0 +1,107 @@ +import argparse + +import numpy as np +from mg.models.llama.block import LlamaMLP, TensorParallelLlamaMLP +from mg.parallel import create_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: + x = ops.rand([4, 64, 3072], dtype=dtype) # (N, T, H) + return x + + +def get_block_config(): + config = dict(intermediate_size=8192, hidden_size=3072, hidden_act="silu") + return config + + +def run_block(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + # non parallel block + set_random_seed(1024) + non_parallel_block_cfg = get_block_config() + non_parallel_block = LlamaMLP(**non_parallel_block_cfg, dtype=dtype) + + # parallel block + parallel_block_cfg = get_block_config() + parallel_block = TensorParallelLlamaMLP(**parallel_block_cfg, dtype=dtype) + + # load weight + parallel_block.load_weight_from_non_parallel_cell(non_parallel_block) + + # test forward + non_parallel_out = non_parallel_block(data).asnumpy() + parallel_out = parallel_block(data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.") + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_block) + parallel_mean_net = MeanNet(parallel_block) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.") + + # check the input gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=0) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 3 (Backward: Input Gradient): Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_block(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py new file mode 100644 index 0000000000..a4c5afb140 --- /dev/null +++ b/examples/moviegen/tests/parallel/test_llama3_parallel_layer.py @@ -0,0 +1,125 @@ +import argparse +from typing import Literal + +import numpy as np +from mg.parallel import ColumnParallelLinear, RowParallelLinear, create_parallel_group, get_model_parallel_group +from utils import gather_or_reduce_parallel_gradient + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() * 1024.0 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tensor: + x = ops.rand([4, 64, 256], dtype=dtype) # (N, T, H) + return x + + +def get_layer_config(bias: bool = False): + config = dict(in_features=256, out_features=32, bias=bias) + return config + + +def run_layer(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data(dtype=dtype) + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + print("Column Parallel Linear (Bias=True):") + run_parallel_linear(data, type="column_parallel", bias=True, dtype=dtype) + print("Column Parallel Linear (Bias=False):") + run_parallel_linear(data, type="column_parallel", bias=False, dtype=dtype) + print("Row Parallel Linear (Bias=True):") + run_parallel_linear(data, type="row_parallel", bias=True, dtype=dtype) + print("Row Parallel Linear (Bias=False):") + run_parallel_linear(data, type="row_parallel", bias=False, dtype=dtype) + + +def run_parallel_linear( + data: Tensor, type: Literal["column_parallel", "row_parallel"], bias: bool = False, dtype: ms.Type = ms.float32 +): + # non parallel layer + set_random_seed(1024) + non_parallel_layer_cfg = get_layer_config(bias=bias) + non_parallel_layer = mint.nn.Linear(**non_parallel_layer_cfg, dtype=dtype) + + # parallel layer + group = get_model_parallel_group() + parallel_layer_cfg = get_layer_config(bias=bias) + if type == "column_parallel": + parallel_layer = ColumnParallelLinear(**parallel_layer_cfg, gather_output=True, group=group, dtype=dtype) + else: + parallel_layer = RowParallelLinear(**parallel_layer_cfg, input_is_parallel=False, group=group, dtype=dtype) + + # load weight + parallel_layer.load_weight_from_non_parallel_cell(non_parallel_layer) + + # test forward + non_parallel_out = non_parallel_layer(data).asnumpy() + parallel_out = parallel_layer(data).asnumpy() + + assert np.count_nonzero(non_parallel_out) > 0 + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out, parallel_out, rtol=1.3e-6, atol=1e-5) + print("Test 1 (Forward): Passed.") + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_layer) + parallel_mean_net = MeanNet(parallel_layer) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_1 = gather_or_reduce_parallel_gradient(grad_1, grad_0.shape) + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.") + + # check the input gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=0) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + grad_0, grad_1 = grad_0.asnumpy(), grad_1.asnumpy() + assert np.count_nonzero(grad_0) > 0 + np.testing.assert_allclose(grad_0, grad_1, rtol=1.3e-6, atol=1e-5) + print("Test 3 (Backward: Input Gradient): Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_layer(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/test_rflow_parallel.py b/examples/moviegen/tests/parallel/test_rflow_parallel.py new file mode 100644 index 0000000000..a3d6302e3a --- /dev/null +++ b/examples/moviegen/tests/parallel/test_rflow_parallel.py @@ -0,0 +1,61 @@ +import argparse +from typing import Tuple + +from mg.parallel import create_parallel_group +from mg.schedulers import RFlowLossWrapper + +import mindspore as ms +from mindspore import Tensor, nn, ops +from mindspore.communication import get_group_size, init + +from mindone.utils.seed import set_random_seed + + +class SimpleNet(nn.Cell): + def construct( + self, x: Tensor, timestamp: Tensor, ul2_emb: Tensor, metaclip_emb: Tensor, byt5_emb: Tensor + ) -> Tensor: + return x.to(ms.float32) + + @property + def dtype(self): + return ms.float32 + + +def get_sample_data(dtype: ms.Type = ms.float32) -> Tuple[Tensor, ...]: + latent_embedding = ops.rand([1, 16, 8, 24, 44], dtype=dtype) + ul2_emb = ops.rand([1, 64, 4096], dtype=dtype) + metaclip_emb = ops.rand([1, 64, 1280], dtype=dtype) + byt5_emb = ops.rand([1, 64, 1472], dtype=dtype) + return latent_embedding, ul2_emb, metaclip_emb, byt5_emb + + +def run_network(mode: int = 0): + ms.set_context(mode=mode) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data() + + # prepare group + create_parallel_group(model_parallel_shards=get_group_size()) + + model = SimpleNet() + + # parallel netowrk + network = RFlowLossWrapper(model) + + loss = network(*data) + loss = ops.AllGather()(ops.unsqueeze(loss, 0)).asnumpy() + assert loss[0] == loss[1], f"expected two elements to be same, but get `{loss}`." + print("Test 1: Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_network(mode=args.mode) diff --git a/examples/moviegen/tests/parallel/utils.py b/examples/moviegen/tests/parallel/utils.py new file mode 100644 index 0000000000..2f8d19e2d5 --- /dev/null +++ b/examples/moviegen/tests/parallel/utils.py @@ -0,0 +1,32 @@ +from typing import Callable, Tuple + +import numpy as np + +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import GlobalComm, get_group_size + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +def gather_or_reduce_parallel_gradient( + parallel_gradient: Tensor, non_parallel_gradient_shape: Tuple[int, ...], group: str = GlobalComm.WORLD_COMM_GROUP +) -> Tensor: + if parallel_gradient.shape == non_parallel_gradient_shape: + # Sequence Parallel / Context Parallel + allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + parallel_gradient = allreduce(parallel_gradient) / get_group_size(group) + return parallel_gradient + + scales = np.array(non_parallel_gradient_shape) / np.array(parallel_gradient.shape) + assert np.count_nonzero(scales - 1) == 1 + assert np.prod(scales) == get_group_size(group) + dim = np.argmax(scales).item() + allgather = ops.AllGather(group=group) + parallel_gradient = _communicate_along_dim(parallel_gradient, dim, allgather) + return parallel_gradient diff --git a/examples/moviegen/train.py b/examples/moviegen/train.py index 1f79f2a373..f2c703fe0a 100644 --- a/examples/moviegen/train.py +++ b/examples/moviegen/train.py @@ -18,6 +18,7 @@ from mg.dataset import ImageVideoDataset, bucket_split_function from mg.models.tae import TemporalAutoencoder +from mg.parallel import create_parallel_group from mg.pipelines import DiffusionWithLoss from mg.schedulers import RFlowEvalLoss, RFlowLossWrapper from mg.utils import EMA, MODEL_DTYPE, init_model, resume_train_net @@ -74,10 +75,10 @@ def main(args): # 1.1 init model parallel shard_rank_id = rank_id - # if (shards := args.train.model_parallel.model_parallel_shards) > 1: - # create_parallel_group(**args.train.model_parallel) - # device_num = device_num // shards - # shard_rank_id = rank_id // shards + if (shards := args.train.model_parallel.model_parallel_shards) > 1: + create_parallel_group(**args.train.model_parallel) + device_num = device_num // shards + shard_rank_id = rank_id // shards # FIXME: Improve seed setting set_seed(args.env.seed + shard_rank_id) # set different seeds per NPU for sampling different timesteps @@ -280,6 +281,7 @@ def main(args): "--dataloader.batch_size", default=1, type=Union[int, Dict[str, int]], help="Number of samples per batch" ) parser.link_arguments("env.debug", "dataloader.debug", apply_on="parse") + parser.add_function_arguments(create_parallel_group, "train.model_parallel") parser.add_function_arguments(create_scheduler, "train.lr_scheduler", skip={"steps_per_epoch", "num_epochs"}) parser.add_class_arguments( ReduceLROnPlateauByStep, "train.lr_reduce_on_plateau", skip={"optimizer"}, instantiate=False diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index 4db47ede8a..8d31690fff 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -17,7 +17,7 @@ def __init__( net: Union[nn.Dense, mint.nn.Linear], zero_stage: Literal[0, 1, 2, 3] = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, - cell_type: Optional[mstype] = None, + cell_type: Optional[mstype.Type] = None, ): super().__init__(auto_prefix=False) self.net = net From 5aa1e4dc91d71934905319ba984704d4d4a62f8b Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:18:52 +0800 Subject: [PATCH 102/122] add technical report --- examples/moviegen/docs/report.md | 213 +++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 examples/moviegen/docs/report.md diff --git a/examples/moviegen/docs/report.md b/examples/moviegen/docs/report.md new file mode 100644 index 0000000000..d160df8b89 --- /dev/null +++ b/examples/moviegen/docs/report.md @@ -0,0 +1,213 @@ +# Movie Gen + +Movie Gen is a family of foundation models that can natively generate high-fidelity images and videos +while also possessing the abilities to edit and personalize the videos. + +Meta researchers found that scaling the training data, compute, and model parameters of a simple +Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained with +[Flow Matching](https://arxiv.org/abs/2210.02747) yields high-quality generative models for video or audio. + +Movie Gen supports the following features: + +1. Text-to-Video synthesis +2. Video personalization +3. Video editing + +# Detailed Technical Report + +## Architecture Overview + +[//]: # (TODO: Figure 3 Overall architecture) + +### TAE + +For improved training and inference efficiency, we perform generation in a spatio-temporally compressed latent space. + +### Transformer Backbone + +[//]: # (TODO: Figure 8) + +Movie Gen uses the [LLaMa3](https://arxiv.org/abs/2407.21783) backbone architecture for the joint image-video generation +model, enabling confident scaling of the model size while maintaining efficient training. It can directly generate video +at different aspect ratios (e.g., 1:1, 9:16, 16:9) and multiple lengths (4 – 16 seconds) at 768 px resolution. + +There are three changes to the LLaMa3 Transformer block for the use case of video generation using Flow Matching: + +1. Add a cross-attention module between the self-attention module and the feed forward network (FFN) + to each Transformer block to incorporate text conditioning based on the text prompt embedding **P**. + Multiple different text encoders are leveraged due to their complementary strengths + (see [Text Encoders](#text-encoders)). +2. Add adaptive layer norm blocks to incorporate the time-step $t$ to the Transformer, as used in prior work + ([DiT](https://arxiv.org/abs/2212.09748)). +3. Use full bidirectional attention instead of causal attention used in language modeling. + +#### Differences Among Models + +The Movie Gen family of models contains the following variations: 1B, 5B, and 30B parameters. + +| Model | Layers | Model Dimension | FFN Dimension | Attention Heads | +|:-----:|:------:|:---------------:|:-------------:|:---------------:| +| 1B | 24 | 1536 | 4096 | 16 | +| 5B | 32 | 3072 | 8192 | 24 | +| 30B | 48 | 6144 | 16384 | 48 | + +#### Patchifying Inputs + +To prepare inputs for the Transformer backbone, the video latent code ($T \times C \times H \times W$) is first +'patchified' using a 3D convolutional layer (as in [here](https://arxiv.org/abs/2010.11929)) and then flattened to yield +a 1D sequence. The 3D convolutional layer uses a kernel size of $k_t \times k_h \times k_w$ with a stride equal to the +kernel size and projects it into the same dimensions as needed by the Transformer backbone. Thus, the total number of +input tokens to the Transformer backbone is $THW/(k_tk_hk_w)$. We use $k_t=1$ and $k_h=k_w=2$, i.e., we produce +$2 \times 2$ spatial patches. + +#### Learnable Positional Embedding (PE) + +Movie Gen uses a factorized learnable positional embedding to enable arbitrary size, aspect ratio, and video +length (as in [NaViT](https://arxiv.org/abs/2307.06304)) inputs to the Transformer. +The 'patchified' tokens, i.e., output of the 3D convolutional layer, are converted into separate embeddings $\phi_h$, +$\phi_w$ and $\phi_t$ of spatial $h$, $w$, and temporal $t$ coordinates. The final positional embeddings are calculated +by adding all the factorized positional embeddings together. Finally, the final positional embeddings **are added +to the input for all the Transformer layers**. Compared with adding the positional embeddings to the first layer only, +adding to all layers can effectively reduce the distortion and morphing artifacts, especially in the temporal dimension. + +#### Model Parallelism + +Movie Gen employs 3D parallelism to support model-level scaling across three axes: number of parameters, input tokens, +and dataset size, while also allowing horizontal scale-out to more NPUs. It utilizes a combination of [fully sharded +data parallelism](https://arxiv.org/abs/2304.11277), [tensor parallelism](https://arxiv.org/abs/1909.08053), +[sequence parallelism](https://arxiv.org/abs/2105.13120), and context parallelism. + +Different parallelization strategies are depicted in the Transformer block figure. + +[//]: # (TODO: add reference to the figure.) + +- **Tensor-parallelism (TP)** shards the weights of linear layers either along columns or rows, and results in each NPU + involved in the sharding performing _tp-size_ less work (FLOPs) and generating _tp-size_ fewer activations for + column-parallel shards and consuming _tp-size_ fewer activations for row-parallel shards. The cost of performing such + a sharding is the addition of all-reduce communication overheads in both the forward (row-parallel) and backward + (column-parallel) passes. +- **Sequence-parallelism (SP)** builds upon TP to also allow the sharding of the input over the sequence dimension for + layers which are replicated and in which each sequence element can be treated independently. Such layers, e.g., + LayerNorm, would otherwise perform duplicate compute and generate identical (and thus replicated) activations across + the TP-group. +- **Context-parallelism (CP)** enables a partial sharding over the sequence dimension for the _sequence-dependent + softmax-attention operation_. CP leverages the insight that for any given (_source_ (_context_), _target_ (_query_)) + sequences pair, _softmax-attention is only sequence-dependent over the context and not the query_. Therefore, in the + case of self-attention where the input source and target sequences are identical, CP allows the attention computation + to be performed with only an all-gather for the $K$ and $V$ projections (instead of $Q$, $K$, and $V$) in the forward + pass, and a reduce-scatter for their associated gradients in the backward. +- **Fully sharded data parallel (FSDP)** shards the model, optimizer, and gradients across all data-parallel NPUs, + synchronously gathering and scattering parameters and gradients throughout each training step. + +### Text Encoders + +Movie Gen uses a combination of [UL2](https://arxiv.org/abs/2205.05131), [ByT5](https://arxiv.org/abs/2105.13626), and +Long-prompt [MetaCLIP](https://arxiv.org/abs/2309.16671) as text encoders to provide both semantic-level and +character-level text understanding for the backbone: + +- **UL2** is trained using massive text-only data and potentially provides strong text reasoning abilities in its + features. +- **Long-prompt MetaCLIP** provides text representations that are aligned with visual representations that are + beneficial for cross-modal generation. +- **ByT5** encoder is only used to encode visual text, i.e., the part of the text prompt that explicitly asks for a + character string to be generated in the output image / video. + +The text embeddings from the three text encoders are concatenated after adding separate linear projection and LayerNorm +layers to project them into the same 6144 dimension space and normalize the embeddings. + +## Training Details + +Movie Gen is trained jointly on images and videos. Images are treated as single frame videos, enabling the use of the +same model to generate both images and videos. Compared to video data, paired image-text datasets are easier to scale +with diverse concepts and styles, and thus joint modeling of image and video leads to better generalization. + +Training is performed in multiple stages for better efficiency: + +1. Pre-raining on low-resolution 256 px images only. + Meta researchers observed that directly training T2I/V models from scratch results in a slower convergence speed than + initializing them from a T2I model. +2. Joint training on low-resolution 256 px images and videos. + To enable the joint training, we double the spatial [PE](#learnable-positional-embedding-pe) layers to accommodate + various aspect ratios, add new temporal PE layers to support up to 32 latent frames, and initialize spatial PE layers + from the T2I model with 2x expansion. +3. Joint training at 768 px resolution. + For this stage, we expand the spatial PE layers by 3x. +4. Fine-tune the model on high-quality videos to improve the generations. + Improve the motion and aesthetic quality of the generated videos by fine-tuning the pre-trained model on a small + fine-tuning set of manually selected videos. During this stage, multiple models are trained and combined to form the + final model through a model averaging approach ([LLaMa3](https://arxiv.org/abs/2407.21783)). + +### Training Objective + +Movie Gen is trained with the [Flow Matching](https://arxiv.org/abs/2210.02747) framework, +i.e., it is trained to predict the velocity $V_t = \frac{dX_t}{dt}$ which teaches it to 'move' the sample $X_t$ +in the direction of the video sample $X_1$. +Movie Gen uses simple linear interpolation or the optimal transport path (Lipman et al., 2023), i.e., +$$X_t=tX_1+(1-(1-\sigma_{min})t)X_0$$ +Where $\sigma_{min}=10^{-5}$. Thus, the ground truth velocity can be derived as: +$$V_t = \frac{dX_t}{dt} = X_1 - (1-\sigma_{min})X_0$$ + +The model parameters are denoted by $\theta$, the embedding of text prompts by **P**, and the predicted velocity +by $u(X_t, P, t)$. +The model is trained by minimizing the mean squared error between the ground truth velocity and model prediction: +$$E_{t,X_0,X_1,P}\|u(X_t, P, t;\theta)-V_t\|^2$$ +As in prior work ([SD3](https://arxiv.org/abs/2403.03206)), $t$ is sampled from a logit-normal distribution where +the underlying Gaussian distribution has zero mean and unit standard deviation. + +### Signal-to-Noise Ratio (SNR) + +Choosing the right diffusion noise scheduler with a zero terminal signal-to-noise ratio is +[particularly important](https://arxiv.org/abs/2305.08891) for video generation. +Flow Matching implementation naturally ensures zero terminal SNR (i.e., at $t=0$). +This guarantees that, during training, the model receives pure Gaussian noise samples and is trained to predict the +velocity for them. +Thus, at inference, when the model receives pure Gaussian noise at $t = 0$, it can make a reasonable prediction. + +### Bucketization for Variable Duration and Size + +To accommodate diverse video lengths and aspect ratios, we bucketize the training data according to aspect ratio and +length. The videos in each bucket lead to the exact same latent shape which allows for easy batching of training data. + +### Controlling FPS + +The model is trained by pre-appending the sampling FPS value of each training video to the input text prompt +(e.g., “FPS-16”). + +### Validation During Training + +Meta researchers observed that the validation loss is well correlated with human evaluation results as the later +checkpoints with lower validation loss perform better in the human evaluations. This suggests that the Flow Matching +validation loss can serve as a useful proxy for evaluations during model development. Similar observation was made by +the authors of [SD3](https://arxiv.org/abs/2403.03206). For this reason, we maintain a validation set of unseen videos +and monitor the validation loss throughout training. + +### Learning Rate Reduction + +We decrease the learning rate by half whenever the validation loss plateaus to continue improving the model performance. + +### Personalization + +Enables the video generation model to condition on a text as well as an image of a person to generate a video featuring +the chosen person. +The generated personalized video maintains the identity of the person while following the text prompt. + +### Editing + +It allows users to effortlessly perform precise and imaginative edits on both real and generated videos using a textual +instruction. Since large-scale supervised video editing data is harder to obtain, the researchers show a novel approach +to training such a video editing model without supervised video editing data. + +## Inference + +Movie Gen uses a simple first-order Euler ODE solver with a unique t-schedule tailored to the model. Specifically, the +quality of an N-step video generation process can be closely approximated with merely 50 steps by implementing a +**linear-quadratic t-schedule**. + +[//]: # (TODO: Add figure 10 visualization) + +This approach follows the first 25 steps of an $N$-step linear schedule and then approximates the remaining $N-25$ steps +with 25 quadratically placed steps. The linear-quadratic strategy is predicated on the observation that the first +inference steps are pivotal in setting up the scene and motion of the video, since most changes occur in the first +solver steps (the left figure above). + +[//]: # (TODO: replace (the left figure above\) with the fig number) From de047db3689c01333f83a32c82e1936cb7cc56ce Mon Sep 17 00:00:00 2001 From: Rustam Khadipash <16683750+hadipash@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:52:27 +0800 Subject: [PATCH 103/122] update technical report --- examples/moviegen/README.md | 12 +- examples/moviegen/docs/report.md | 224 +++++++++++++++++++++++++++---- 2 files changed, 204 insertions(+), 32 deletions(-) diff --git a/examples/moviegen/README.md b/examples/moviegen/README.md index 5a95d96caf..50e6c19665 100644 --- a/examples/moviegen/README.md +++ b/examples/moviegen/README.md @@ -29,12 +29,12 @@ Transformer-based ([LLaMa3](https://arxiv.org/abs/2407.21783)) model trained wit ## Demo -| 32x256x455 | 32x256x455 | -|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| -|