Skip to content

Commit

Permalink
fix inflate
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Dec 11, 2024
1 parent 21e06bf commit 77af428
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions examples/movie_gen/tools/inflate_vae_to_tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from safetensors import safe_open

import mindspore as ms
from mg.models.tae.sd3_vae import SD3d5_CONFIG,SD3d5_VAE


def get_shape_from_str(shape):
Expand Down Expand Up @@ -32,9 +33,7 @@ def load_torch_ckpt(ckpt_path):


def plot_ms_vae2d5():
from mg.models.tae.tae import SD3d5_CONFIG, TemporalAutoencoder

tae = TemporalAutoencoder(config=SD3d5_CONFIG)
tae = SD3d5_VAE(config=SD3d5_CONFIG)

sd = tae.parameters_dict()
pnames = list(sd.keys())
Expand Down

0 comments on commit 77af428

Please sign in to comment.