Skip to content

Commit

Permalink
Allow passing HumanML3D base_path
Browse files Browse the repository at this point in the history
  • Loading branch information
UuuNyaa committed Nov 17, 2022
1 parent e70fa57 commit 7c5a825
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions data_loaders/humanml/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,25 +719,23 @@ def __getitem__(self, item):

# A wrapper class for t2m original dataset for MDM purposes
class HumanML3D(data.Dataset):
def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", **kwargs):
def __init__(self, mode, datapath='./dataset/humanml_opt.txt', base_path='.', split="train", **kwargs):
self.mode = mode

self.dataset_name = 't2m'
self.dataname = 't2m'

# Configurations of T2M dataset and KIT dataset is almost the same
abs_base_path = f'.'
dataset_opt_path = pjoin(abs_base_path, datapath)
dataset_opt_path = pjoin(base_path, datapath)
device = None # torch.device('cuda:4') # This param is not in use in this context
opt = get_opt(dataset_opt_path, device)
opt.meta_dir = pjoin(abs_base_path, opt.meta_dir)
opt.motion_dir = pjoin(abs_base_path, opt.motion_dir)
opt.text_dir = pjoin(abs_base_path, opt.text_dir)
opt.model_dir = pjoin(abs_base_path, opt.model_dir)
opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir)
opt.data_root = pjoin(abs_base_path, opt.data_root)
opt.save_root = pjoin(abs_base_path, opt.save_root)
opt.meta_dir = './dataset'
opt.meta_dir = os.path.dirname(dataset_opt_path)
opt.motion_dir = pjoin(base_path, opt.motion_dir)
opt.text_dir = pjoin(base_path, opt.text_dir)
opt.model_dir = pjoin(base_path, opt.model_dir)
opt.checkpoints_dir = pjoin(base_path, opt.checkpoints_dir)
opt.data_root = pjoin(base_path, opt.data_root)
opt.save_root = pjoin(base_path, opt.save_root)
self.opt = opt
print('Loading dataset %s ...' % opt.dataset_name)

Expand All @@ -760,7 +758,7 @@ def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", **
if mode == 'text_only':
self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file)
else:
self.w_vectorizer = WordVectorizer(pjoin(abs_base_path, 'glove'), 'our_vab')
self.w_vectorizer = WordVectorizer(pjoin(base_path, 'glove'), 'our_vab')
self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer)
self.num_actions = 1 # dummy placeholder

Expand Down

0 comments on commit 7c5a825

Please sign in to comment.