-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_data.py
140 lines (100 loc) · 4.81 KB
/
batch_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
from glob import glob
import json
import matplotlib.pyplot as plt
import argparse
import torch
## This file generates batches of data as outlined in the PAML paper
## Train batch: 50 random frames of the subject
## Val batch: 10 random frames of the subject
## Helper Functions
def get_data_file(annotations_path, subject_id):
return glob(annotations_path + "/*subject"+str(subject_id)+"_data*")[0]
def get_coords_file(annotations_path, subject_id):
return glob(annotations_path + "/*subject"+str(subject_id)+"_joint_3d*")[0]
def get_subject_data_length(annotations_path, subject_id):
data_file = get_data_file(annotations_path, subject_id)
## Parse data file
with open(data_file,'r') as data_json_file:
data = json.load(data_json_file)
return len(data['images'])
def get_subject_action_boundaries(annotations_path, subject_id):
data_file = get_data_file(annotations_path, subject_id)
## Parse data file
with open(data_file,'r') as data_json_file:
data = json.load(data_json_file)
boundaries = [0]
prev_action = data['images'][0]['action_name']
for i, item in enumerate(data['images']):
if i == 0: continue
if item['action_name'] != prev_action:
prev_action = item['action_name']
boundaries.append(i)
boundaries.append(len(data['images']))
# print(boundaries)
return boundaries
def get_data(subject_id, random_index, annotations_path, timesteps=1):
data_file = get_data_file(annotations_path, subject_id)
coords_file = get_coords_file(annotations_path, subject_id)
## Parse data file
with open(data_file,'r') as data_json_file:
data = json.load(data_json_file)
## Parse coordinates file
with open(coords_file,'r') as coords_json_file:
coords = json.load(coords_json_file)
frames = np.arange(random_index, random_index + timesteps)
pose3d = []
for i, item in enumerate(frames):
meta_data = data['images'][item]
coord_data = np.array(coords[str(meta_data['action_idx'])][str(meta_data['subaction_idx'])][str(meta_data['frame_idx'])])
pose3d.append(coord_data)
pose3d = np.array(pose3d)
return data, meta_data, pose3d
def check_range_overlap(range1, range2):
return len(set(range1).intersection(set(range2))) != 0 ## Make sure to avoid overlapping with any boundaries
def generate_random_task(timesteps, timesteps_pred, num_support, num_query, annotations_path):
## Get Random Subject
subject_id = np.random.randint(1, 8)
while subject_id != 5: # Excluding 5 for test time
subject_id = np.random.randint(1, 8)
length = get_subject_data_length(annotations_path, subject_id)
boundaries = get_subject_action_boundaries(annotations_path, subject_id)
timesteps_total = timesteps+timesteps_pred
## Get Random Action
random_action = np.random.randint(len(boundaries[:-1]))
## Get Random Index
index = np.random.randint(boundaries[random_action], boundaries[random_action+1] - timesteps_total)
## Get Train and Query Data
data = []
label = []
total_range = list(range(index, index+timesteps_total))
for item in np.arange(num_support + num_query):
## Get Data
_, _, pose3d = get_data(subject_id, index, annotations_path, timesteps=timesteps_total)
data.append(pose3d[:timesteps].reshape(timesteps, 51)) # 51 = 17 x 3
label.append(pose3d[timesteps:].reshape(timesteps_pred, 51))
## Calculate next index
index = np.random.randint(boundaries[random_action], boundaries[random_action+1] - timesteps_total)
index_range = range(index, index + timesteps_total)
while check_range_overlap(total_range, index_range): ## Make sure to avoid overlapping between examples
index = np.random.randint(boundaries[random_action], boundaries[random_action+1] - timesteps_total)
total_range.extend(list(range(index, index+timesteps_total)))
## Return Batch
train_data = data[:num_support]
train_label = label[:num_support]
query_data = data[num_support:]
query_label = label[num_support:]
return train_data, train_label, query_data, query_label
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--subject_id", type=int, help="subject id", default=7)
parser.add_argument("--index", type=int, help="image index", default=0)
parser.add_argument("--timesteps", type=int, help="image index", default=50)
parser.add_argument("--timesteps_pred", type=int, help="image index", default=10)
parser.add_argument("--annotations_path", type=str, default="./annotations")
args = parser.parse_args()
train_data, train_label, query_data, query_label = generate_random_example(args.timesteps, args.timesteps_pred, args.annotations_path)
print("Train Data Example Shape", train_data.shape)
print("Train Label Example Shape", train_label.shape)
print("Query Data Example Shape", query_data.shape)
print("Query Label Example Shape", query_label.shape)