-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_PointEnv.py
138 lines (119 loc) · 5.86 KB
/
run_PointEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from pud.dependencies import *
from pud.utils import set_global_seed, set_env_seed, AttrDict
cfg_file = sys.argv[-1]
cfg = AttrDict(**eval(open(cfg_file, 'r').read()))
print(cfg)
set_global_seed(cfg.seed)
from pud.envs.simple_navigation_env import env_load_fn
env = env_load_fn(cfg.env.env_name, cfg.env.max_episode_steps,
resize_factor=cfg.env.resize_factor,
terminate_on_timeout=False,
thin=cfg.env.thin)
set_env_seed(env, cfg.seed + 1)
eval_env = env_load_fn(cfg.env.env_name, cfg.env.max_episode_steps,
resize_factor=cfg.env.resize_factor,
terminate_on_timeout=True,
thin=cfg.env.thin)
set_env_seed(eval_env, cfg.seed + 2)
obs_dim = env.observation_space['observation'].shape[0]
goal_dim = obs_dim
state_dim = obs_dim + goal_dim
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
print(f'obs dim: {obs_dim}, goal dim: {goal_dim}, state dim: {state_dim}, action dim: {action_dim}, max action: {max_action}')
from pud.ddpg import UVFDDPG
agent = UVFDDPG(
state_dim, # concatenating obs and goal
action_dim,
max_action,
**cfg.agent,
)
print(agent)
from pud.buffer import ReplayBuffer
replay_buffer = ReplayBuffer(obs_dim, goal_dim, action_dim, **cfg.replay_buffer)
if False:
from pud.policies import GaussianPolicy
policy = GaussianPolicy(agent)
from pud.runner import train_eval, eval_pointenv_dists
train_eval(policy,
agent,
replay_buffer,
env,
eval_env,
eval_func=eval_pointenv_dists,
**cfg.runner,
)
torch.save(agent.state_dict(), os.path.join(cfg.ckpt_dir, 'agent.pth'))
elif True:
ckpt_file = os.path.join(cfg.ckpt_dir, 'agent.pth')
agent.load_state_dict(torch.load(ckpt_file))
agent.eval()
# from pud.visualize import visualize_trajectory
# eval_env.duration = 100 # We'll give the agent lots of time to try to find the goal.
# visualize_trajectory(agent, eval_env, difficulty=0.5)
# We now will implement the search policy, which automatically finds these waypoints via graph search.
# The first step is to fill the replay buffer with random data.
#
from pud.collector import Collector
env.set_sample_goal_args(prob_constraint=0.0, min_dist=0, max_dist=np.inf)
rb_vec = Collector.sample_initial_states(eval_env, replay_buffer.max_size)
# from pud.visualize import visualize_buffer
# visualize_buffer(rb_vec, eval_env)
pdist = agent.get_pairwise_dist(rb_vec, aggregate=None)
# from scipy.spatial import distance
# euclidean_dists = distance.pdist(rb_vec)
# As a sanity check, we'll plot the pairwise distances between all
# observations in the replay buffer. We expect to see a range of values
# from 1 to 20. Distributional RL implicitly caps the maximum predicted
# distance by the largest bin. We've used 20 bins, so the critic
# predicts 20 for all states that are at least 20 steps away from one another.
#
# from pud.visualize import visualize_pairwise_dists
# visualize_pairwise_dists(pdist)
# With these distances, we can construct a graph. Nodes in the graph are
# observations in our replay buffer. We connect observations with edges
# whose lengths are equal to the predicted distance between those observations.
# Since it is hard to visualize the edge lengths, we included a slider that
# allows you to only show edges whose predicted length is less than some threshold.
# ---
# Our method learns a collection of critics, each of which makes an independent
# prediction for the distance between two states. Because each network may make
# bad predictions for pairs of states it hasn't seen before, we act in
# a *risk-averse* manner by using the maximum predicted distance across our
# ensemble. That is, we act pessimistically, only adding an edge
# if *all* critics think that this pair of states is nearby.
#
# from pud.visualize import visualize_graph
# visualize_graph(rb_vec, eval_env, pdist)
# We can also visualize the predictions from each critic.
# Note that while each critic may make incorrect decisions
# for distant states, their predictions in aggregate are correct.
#
# from pud.visualize import visualize_graph_ensemble
# visualize_graph_ensemble(rb_vec, eval_env, pdist)
# from pud.policies import SearchPolicy
# search_policy = SearchPolicy(agent, rb_vec, pdist=pdist, open_loop=True)
# eval_env.duration = 300 # We'll give the agent lots of time to try to find the goal.
# Sparse graphical memory
#
from pud.policies import SparseSearchPolicy
search_policy = SparseSearchPolicy(agent, rb_vec, pdist=pdist, cache_pdist=True, max_search_steps=10)
eval_env.duration = 300
#
from pud.runner import cleanup_and_eval_search_policy
(initial_g, initial_rb), (filtered_g, filtered_rb), (cleaned_g, cleaned_rb) = cleanup_and_eval_search_policy(search_policy, eval_env)
#
from pud.visualize import visualize_full_graph
visualize_full_graph(cleaned_g, cleaned_rb, eval_env)
# Plot the search path found by the search policy
#
# from pud.visualize import visualize_search_path
# visualize_search_path(search_policy, eval_env, difficulty=0.9)
# Now, we'll use that path to guide the agent towards the goal.
# On the left, we plot rollouts from the baseline goal-conditioned policy.
# On the right, we use that same policy to reach each of the waypoints
# leading to the goal. As before, the slider allows you to change the
# distance to the goal. Note that only the search policy is able to reach distant goals.
#
# from pud.visualize import visualize_compare_search
# visualize_compare_search(agent, search_policy, eval_env, difficulty=0.9)