forked from benjiebob/SMPL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
72 lines (56 loc) · 1.99 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import smpl_tf
import smpl_np
from smpl_torch import SMPLModel
import numpy as np
import tensorflow as tf
import torch
import os
def compute_diff(a, b):
"""
Compute the max relative difference between ndarray a and b element-wisely.
Parameters:
----------
a, b: ndarrays to be compared of same shape.
Return:
------
The max relative difference.
"""
return np.max(np.abs(a - b) / np.minimum(a, b))
def pytorch_wrapper(beta, pose, trans):
device = torch.device('cuda')
pose = torch.from_numpy(pose).type(torch.float64).to(device)
beta = torch.from_numpy(beta).type(torch.float64).to(device)
trans = torch.from_numpy(trans).type(torch.float64).to(device)
model = SMPLModel(device=device)
with torch.no_grad():
result = model(beta, pose, trans)
return result.cpu().numpy()
def tf_wrapper(beta, pose, trans):
beta = tf.constant(beta, dtype=tf.float64)
trans = tf.constant(trans, dtype=tf.float64)
pose = tf.constant(pose, dtype=tf.float64)
output, _ = smpl_tf.smpl_model('./model.pkl', beta, pose, trans)
with tf.Session() as sess:
result = sess.run(output)
return result
def np_wrapper(beta, pose, trans):
smpl = smpl_np.SMPLModel('./model.pkl')
result = smpl.set_params(pose=pose, beta=beta, trans=trans)
return result
if __name__ == '__main__':
pose_size = 72
beta_size = 10
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
np.random.seed(9608)
pose = (np.random.rand(pose_size) - 0.5) * 0.4
beta = (np.random.rand(beta_size) - 0.5) * 0.06
trans = np.zeros(3)
tf_result = tf_wrapper(beta, pose, trans)
np_result = np_wrapper(beta, pose, trans)
torch_result = pytorch_wrapper(beta, pose, trans)
if np.allclose(np_result, tf_result) and np.allclose(np_result, torch_result):
print('Bingo!')
else:
print('Failed')
print('tf - np: ', compute_diff(tf_result, np_result))
print('torch - np: ', compute_diff(torch_result, np_result))