forked from real-stanford/decentralized-multiarm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark_dynamic.py
114 lines (105 loc) · 3.54 KB
/
benchmark_dynamic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import ray
from environment.rrt import RRTWrapper
from environment import utils
from environment import RealTimeEnv
from utils import (
parse_args,
load_config,
create_policies,
exit_handler
)
from environment import TaskLoader
import pickle
from signal import signal, SIGINT
from numpy import mean
from distribute import Pool
from os.path import exists
from tqdm import tqdm
if __name__ == "__main__":
args = parse_args()
args.gui = True
config = load_config(args.config)
env_conf = config['environment']
training_conf = config['training']
env_conf['min_ur5s_count'] = 1
env_conf['max_ur5s_count'] = 10
env_conf['task']['type'] = 'dynamic'
ray.init()
signal(SIGINT, lambda sig, frame: exit())
output_path = 'rrt_dynamic_benchmark_score.pkl'
if args.load:
output_path = 'policy_dynamic_benchmark_score.pkl'
benchmark_results = []
continue_benchmark = False
if exists(output_path):
# continue benchmark
benchmark_results = pickle.load(open(output_path, 'rb'))
continue_benchmark = True
finished_task_paths = [r['task']['task_path']
for r in benchmark_results]
task_loader = TaskLoader(
root_dir=args.tasks_path,
shuffle=True,
repeat=False)
training_conf['task_loader'] = task_loader
# set up policy if loaded
if args.load:
obs_dim = utils.get_observation_dimensions(
training_conf['observations'])
action_dim = 6
policy_manager = create_policies(
args=args,
training_config=config['training'],
action_dim=action_dim,
actor_obs_dim=obs_dim,
critic_obs_dim=obs_dim,
training=args.mode == 'train',
logger=None,
device='cpu')
policy = policy_manager.get_inference_nodes()[
'multiarm_motion_planner']
policy.policy.to('cpu')
training_conf['policy'] = policy
env = RealTimeEnv(
env_config=env_conf,
training_config=training_conf,
gui=args.gui,
logger=None)
env.set_memory_cluster_map(policy_manager.memory_map)
else:
RealTimeEnv = ray.remote(RealTimeEnv)
envs = [RealTimeEnv.remote(
env_config=env_conf,
training_config=training_conf,
gui=args.gui,
logger=None)
for _ in range(args.num_processes)]
env_pool = Pool(envs)
def callback(result):
benchmark_results.append(result)
if len(benchmark_results) % 100 == 0\
and len(benchmark_results) > 0:
print('Saving benchmark scores to ',
output_path)
with open(output_path, 'wb') as f:
pickle.dump(benchmark_results, f)
def pbar_update(pbar):
pbar.set_description(
'Average Success Rate : {:.04f}'.format(
mean([r['success_rate']
for r in benchmark_results])))
tasks = [t for t in task_loader
if not continue_benchmark
or t.task_path not in finished_task_paths]
if args.load:
with tqdm(tasks, dynamic_ncols=True, smoothing=0.01) as pbar:
for task in pbar:
callback(env.solve_task(task))
pbar_update(pbar)
else:
benchmark_results = env_pool.map(
exec_fn=lambda env, task: env.solve_task.remote(task),
iterable=tasks,
pbar_update=pbar_update,
callback_fn=callback
)