diff --git a/README.md b/README.md index 6e3be95e..14b519e7 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanml3d)](https://paperswithcode.com/sota/motion-synthesis-on-humanml3d?p=human-motion-diffusion-model) [![arXiv](https://img.shields.io/badge/arXiv-<2209.14916>-.svg)](https://arxiv.org/abs/2209.14916) + + The official PyTorch implementation of the paper [**"Human Motion Diffusion Model"**](https://arxiv.org/abs/2209.14916). Please visit our [**webpage**](https://guytevet.github.io/mdm-page/) for more details. diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 00000000..3d5da633 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,38 @@ +build: + gpu: true + cuda: "11.3" + python_version: 3.8 + system_packages: + - libgl1-mesa-glx + - libglib2.0-0 + + python_packages: + - imageio==2.22.2 + - matplotlib==3.1.3 + - spacy==3.3.1 + - smplx==0.1.28 + - chumpy==0.70 + - blis==0.7.8 + - click==8.1.3 + - confection==0.0.2 + - ftfy==6.1.1 + - importlib-metadata==5.0.0 + - lxml==4.9.1 + - murmurhash==1.0.8 + - preshed==3.0.7 + - pycryptodomex==3.15.0 + - regex==2022.9.13 + - srsly==2.4.4 + - thinc==8.0.17 + - typing-extensions==4.1.1 + - urllib3==1.26.12 + - wasabi==0.10.1 + - wcwidth==0.2.5 + + run: + - apt update -y && apt-get install ffmpeg -y +# - python -m spacy download en_core_web_sm + - git clone https://github.com/openai/CLIP.git sub_modules/CLIP + - pip install -e sub_modules/CLIP + +predict: "sample/predict.py:Predictor" diff --git a/data_loaders/humanml/data/dataset.py b/data_loaders/humanml/data/dataset.py index 4f0fbbaf..df8ee956 100644 --- a/data_loaders/humanml/data/dataset.py +++ b/data_loaders/humanml/data/dataset.py @@ -719,25 +719,23 @@ def __getitem__(self, item): # A wrapper class for t2m original dataset for MDM purposes class HumanML3D(data.Dataset): - def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", **kwargs): + def __init__(self, mode, datapath='./dataset/humanml_opt.txt', base_path='.', split="train", **kwargs): self.mode = mode self.dataset_name = 't2m' self.dataname = 't2m' # Configurations of T2M dataset and KIT dataset is almost the same - abs_base_path = f'.' - dataset_opt_path = pjoin(abs_base_path, datapath) + dataset_opt_path = pjoin(base_path, datapath) device = None # torch.device('cuda:4') # This param is not in use in this context opt = get_opt(dataset_opt_path, device) - opt.meta_dir = pjoin(abs_base_path, opt.meta_dir) - opt.motion_dir = pjoin(abs_base_path, opt.motion_dir) - opt.text_dir = pjoin(abs_base_path, opt.text_dir) - opt.model_dir = pjoin(abs_base_path, opt.model_dir) - opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir) - opt.data_root = pjoin(abs_base_path, opt.data_root) - opt.save_root = pjoin(abs_base_path, opt.save_root) - opt.meta_dir = './dataset' + opt.meta_dir = os.path.dirname(dataset_opt_path) + opt.motion_dir = pjoin(base_path, opt.motion_dir) + opt.text_dir = pjoin(base_path, opt.text_dir) + opt.model_dir = pjoin(base_path, opt.model_dir) + opt.checkpoints_dir = pjoin(base_path, opt.checkpoints_dir) + opt.data_root = pjoin(base_path, opt.data_root) + opt.save_root = pjoin(base_path, opt.save_root) self.opt = opt print('Loading dataset %s ...' % opt.dataset_name) @@ -760,7 +758,7 @@ def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", ** if mode == 'text_only': self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file) else: - self.w_vectorizer = WordVectorizer(pjoin(abs_base_path, 'glove'), 'our_vab') + self.w_vectorizer = WordVectorizer(pjoin(base_path, 'glove'), 'our_vab') self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer) self.num_actions = 1 # dummy placeholder diff --git a/diffusion/respace.py b/diffusion/respace.py index 13a3c066..949cde22 100644 --- a/diffusion/respace.py +++ b/diffusion/respace.py @@ -29,13 +29,12 @@ def space_timesteps(num_timesteps, section_counts): """ if isinstance(section_counts, str): if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) - for i in range(1, num_timesteps): - if len(range(0, num_timesteps, i)) == desired_count: - return set(range(0, num_timesteps, i)) - raise ValueError( - f"cannot create exactly {num_timesteps} steps with an integer stride" - ) + desired_count = int(section_counts[len("ddim"):]) + if desired_count > 1000: + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + return set(np.rint(np.arange(0, num_timesteps, num_timesteps / desired_count)).astype(int)) section_counts = [int(x) for x in section_counts.split(",")] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) diff --git a/model/mdm.py b/model/mdm.py index 14fd5bda..18e27655 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -11,7 +11,7 @@ class MDM(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, - arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): + arch='trans_enc', emb_trans_dec=False, clip_version=None, clip_download_root=None, smpl_model_path=None, joint_regressor_train_extra_path=None, **kargs): super().__init__() self.legacy = legacy @@ -38,6 +38,7 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re self.activation = activation self.clip_dim = clip_dim self.action_emb = kargs.get('action_emb', None) + self.device = kargs.get('device', None if torch.cuda.is_available() else 'cpu') self.input_feats = self.njoints * self.nfeats @@ -85,7 +86,7 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re print('EMBED TEXT') print('Loading CLIP...') self.clip_version = clip_version - self.clip_model = self.load_and_freeze_clip(clip_version) + self.clip_model = self.load_and_freeze_clip(clip_version, clip_download_root) if 'action' in self.cond_mode: self.embed_action = EmbedAction(self.num_actions, self.latent_dim) print('EMBED ACTION') @@ -93,16 +94,21 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, self.nfeats) - self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) + self.rot2xyz = Rotation2xyz( + device='cpu', dataset=self.dataset, + smpl_model_path=smpl_model_path, + joint_regressor_train_extra_path=joint_regressor_train_extra_path + ) def parameters_wo_clip(self): return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] - def load_and_freeze_clip(self, clip_version): - clip_model, clip_preprocess = clip.load(clip_version, device='cpu', + def load_and_freeze_clip(self, clip_version, clip_download_root): + clip_model, clip_preprocess = clip.load(clip_version, device='cpu', download_root=clip_download_root, jit=False) # Must set jit=False for training - clip.model.convert_weights( - clip_model) # Actually this line is unnecessary since clip by default already on float16 + if str(self.device) != 'cpu': + clip.model.convert_weights( + clip_model) # Actually this line is unnecessary since clip by default already on float16 # Freeze CLIP weights clip_model.eval() diff --git a/model/rotation2xyz.py b/model/rotation2xyz.py index 9746c7d7..3f738834 100644 --- a/model/rotation2xyz.py +++ b/model/rotation2xyz.py @@ -9,10 +9,10 @@ class Rotation2xyz: - def __init__(self, device, dataset='amass'): + def __init__(self, device, dataset='amass', smpl_model_path=None, joint_regressor_train_extra_path=None): self.device = device self.dataset = dataset - self.smpl_model = SMPL().eval().to(device) + self.smpl_model = SMPL(model_path=smpl_model_path, joint_regressor_train_extra_path=joint_regressor_train_extra_path).eval().to(device) def __call__(self, x, mask, pose_rep, translation, glob, jointstype, vertstrans, betas=None, beta=0, diff --git a/model/smpl.py b/model/smpl.py index 587f5419..e1f2e134 100644 --- a/model/smpl.py +++ b/model/smpl.py @@ -64,14 +64,14 @@ class SMPL(_SMPLLayer): """ Extension of the official SMPL implementation to support more joints """ - def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): - kwargs["model_path"] = model_path + def __init__(self, model_path=None, joint_regressor_train_extra_path=None, **kwargs): + kwargs["model_path"] = model_path or SMPL_MODEL_PATH # remove the verbosity for the 10-shapes beta parameters with contextlib.redirect_stdout(None): super(SMPL, self).__init__(**kwargs) - J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) + J_regressor_extra = np.load(joint_regressor_train_extra_path or JOINT_REGRESSOR_TRAIN_EXTRA) self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) a2m_indexes = vibe_indexes[action2motion_joints] diff --git a/sample/predict.py b/sample/predict.py new file mode 100644 index 00000000..d99688fe --- /dev/null +++ b/sample/predict.py @@ -0,0 +1,151 @@ +import os +import subprocess +import typing +from argparse import Namespace + +import torch +from cog import BasePredictor, Input, Path + +import data_loaders.humanml.utils.paramUtil as paramUtil +from data_loaders.get_data import get_dataset_loader +from data_loaders.humanml.scripts.motion_process import recover_from_ric +from data_loaders.humanml.utils.plot_script import plot_3d_motion +from data_loaders.tensors import collate +from model.cfg_sampler import ClassifierFreeSampleModel +from utils import dist_util +from utils.model_util import create_model_and_diffusion, load_model_wo_clip +from sample.generate import construct_template_variables + +""" +In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as +suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6 +""" + + +def get_args(): + args = Namespace() + args.fps = 20 + args.model_path = './save/humanml_trans_enc_512/model000200000.pt' + args.guidance_param = 2.5 + args.unconstrained = False + args.dataset = 'humanml' + + args.cond_mask_prob = 1 + args.emb_trans_dec = False + args.latent_dim = 512 + args.layers = 8 + args.arch = 'trans_enc' + + args.noise_schedule = 'cosine' + args.sigma_small = True + args.lambda_vel = 0.0 + args.lambda_rcxyz = 0.0 + args.lambda_fc = 0.0 + return args + + +class Predictor(BasePredictor): + def setup(self): + subprocess.run(["mkdir", "/root/.cache/clip"]) + subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"]) + + self.args = get_args() + self.num_frames = self.args.fps * 6 + print('Loading dataset...') + + # temporary data + self.data = get_dataset_loader(name=self.args.dataset, + batch_size=1, + num_frames=196, + split='test', + hml_mode='text_only') + + self.data.fixed_length = float(self.num_frames) + + print("Creating model and diffusion...") + self.model, self.diffusion = create_model_and_diffusion(self.args, self.data) + + print(f"Loading checkpoints from...") + state_dict = torch.load(self.args.model_path, map_location='cpu') + load_model_wo_clip(self.model, state_dict) + + if self.args.guidance_param != 1: + self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler + self.model.to(dist_util.dev()) + self.model.eval() # disable random masking + + def predict( + self, + prompt: str = Input(default="the person walked forward and is picking up his toolbox."), + num_repetitions: int = Input(default=3, description="How many"), + + ) -> typing.List[Path]: + args = self.args + args.num_repetitions = int(num_repetitions) + + self.data = get_dataset_loader(name=self.args.dataset, + batch_size=args.num_repetitions, + num_frames=self.num_frames, + split='test', + hml_mode='text_only') + + collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}] + _, model_kwargs = collate(collate_args) + + # add CFG scale to batch + if args.guidance_param != 1: + model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param + + sample_fn = self.diffusion.p_sample_loop + sample = sample_fn( + self.model, + (args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames), + clip_denoised=False, + model_kwargs=model_kwargs, + skip_timesteps=0, # 0 is the default value - i.e. don't skip any step + init_image=None, + progress=True, + dump_steps=None, + noise=None, + const_noise=False, + ) + + # Recover XYZ *positions* from HumanML3D vector representation + if self.model.data_rep == 'hml_vec': + n_joints = 22 if sample.shape[1] == 263 else 21 + sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() + sample = recover_from_ric(sample, n_joints) + sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) + + rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep + rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions, + self.num_frames).bool() + sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, + jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, + get_rotations_back=False) + + all_motions = sample.cpu().numpy() + + caption = str(prompt) + + skeleton = paramUtil.t2m_kinematic_chain + + + sample_print_template, row_print_template, all_print_template, \ + sample_file_template, row_file_template, all_file_template = construct_template_variables( + args.unconstrained) + + rep_files = [] + replicate_fnames = [] + for rep_i in range(args.num_repetitions): + motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames] + save_file = sample_file_template.format(1, rep_i) + print(sample_print_template.format(caption, 1, rep_i, save_file)) + plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps) + # Credit for visualization: https://github.com/EricGuo5513/text-to-motion + rep_files.append(save_file) + + replicate_fnames.append(Path(save_file)) + + return replicate_fnames + diff --git a/utils/model_util.py b/utils/model_util.py index fd697b07..57f673b5 100644 --- a/utils/model_util.py +++ b/utils/model_util.py @@ -56,9 +56,9 @@ def get_model_args(args, data): def create_gaussian_diffusion(args): # default params predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! - steps = 1000 + steps = args.diffusion_steps scale_beta = 1. # no scaling - timestep_respacing = '' # can be used for ddim sampling, we don't use it. + timestep_respacing = f'ddim{args.diffusion_sampling_steps}' # can be used for ddim sampling learn_sigma = False rescale_timesteps = False diff --git a/utils/parser_util.py b/utils/parser_util.py index 3e5c27c3..1aff0d65 100644 --- a/utils/parser_util.py +++ b/utils/parser_util.py @@ -69,6 +69,8 @@ def add_diffusion_options(parser): help="Noise schedule type") group.add_argument("--diffusion_steps", default=1000, type=int, help="Number of diffusion steps (denoted T in the paper)") + group.add_argument("--diffusion_sampling_steps", default=1000, type=int, + help="Number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])") group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.")