diff --git a/model/mdm.py b/model/mdm.py index 14fd5bda..6548dda6 100644 --- a/model/mdm.py +++ b/model/mdm.py @@ -11,7 +11,7 @@ class MDM(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, - arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): + arch='trans_enc', emb_trans_dec=False, clip_version=None, smpl_model_path=None, joint_regressor_train_extra_path=None, **kargs): super().__init__() self.legacy = legacy @@ -93,7 +93,11 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, self.nfeats) - self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) + self.rot2xyz = Rotation2xyz( + device='cpu', dataset=self.dataset, + smpl_model_path=smpl_model_path, + joint_regressor_train_extra_path=joint_regressor_train_extra_path + ) def parameters_wo_clip(self): return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] diff --git a/model/rotation2xyz.py b/model/rotation2xyz.py index 9746c7d7..3f738834 100644 --- a/model/rotation2xyz.py +++ b/model/rotation2xyz.py @@ -9,10 +9,10 @@ class Rotation2xyz: - def __init__(self, device, dataset='amass'): + def __init__(self, device, dataset='amass', smpl_model_path=None, joint_regressor_train_extra_path=None): self.device = device self.dataset = dataset - self.smpl_model = SMPL().eval().to(device) + self.smpl_model = SMPL(model_path=smpl_model_path, joint_regressor_train_extra_path=joint_regressor_train_extra_path).eval().to(device) def __call__(self, x, mask, pose_rep, translation, glob, jointstype, vertstrans, betas=None, beta=0, diff --git a/model/smpl.py b/model/smpl.py index 587f5419..e1f2e134 100644 --- a/model/smpl.py +++ b/model/smpl.py @@ -64,14 +64,14 @@ class SMPL(_SMPLLayer): """ Extension of the official SMPL implementation to support more joints """ - def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): - kwargs["model_path"] = model_path + def __init__(self, model_path=None, joint_regressor_train_extra_path=None, **kwargs): + kwargs["model_path"] = model_path or SMPL_MODEL_PATH # remove the verbosity for the 10-shapes beta parameters with contextlib.redirect_stdout(None): super(SMPL, self).__init__(**kwargs) - J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) + J_regressor_extra = np.load(joint_regressor_train_extra_path or JOINT_REGRESSOR_TRAIN_EXTRA) self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) a2m_indexes = vibe_indexes[action2motion_joints]