-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
319 lines (284 loc) · 12.9 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import gym
import numpy as np
import torch
import wandb
import argparse
import pickle
import random
import sys
from decision_transformer.evaluation.evaluate_episodes import evaluate_episode, evaluate_episode_rtg
from decision_transformer.models.decision_transformer import DecisionTransformer
from decision_transformer.models.mlp_bc import MLPBCModel
from decision_transformer.training.act_trainer import ActTrainer
from decision_transformer.training.seq_trainer import SequenceTrainer
from tqdm import tqdm
import random
def discount_cumsum(x, gamma):
discount_cumsum = np.zeros_like(x)
discount_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0]-1)):
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
return discount_cumsum
def experiment(
exp_prefix,
variant,
):
device = variant.get('device', 'cuda')
log_to_wandb = variant.get('log_to_wandb', False)
env_name, dataset = variant['env'], variant['dataset']
model_type = variant['model_type']
group_name = f'{exp_prefix}-{env_name}-{dataset}'
exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}'
if env_name == 'hopper':
env = gym.make('Hopper-v4')
max_ep_len = 1000
env_targets = [3600, 1800] # evaluation conditioning targets
scale = 1000. # normalization for rewards/returns
elif env_name == 'halfcheetah':
env = gym.make('HalfCheetah-v3')
max_ep_len = 1000
env_targets = [12000, 6000]
scale = 1000.
elif env_name == 'walker2d':
env = gym.make('Walker2d-v3')
max_ep_len = 1000
env_targets = [5000, 2500]
scale = 1000.
elif env_name == 'reacher2d':
from decision_transformer.envs.reacher_2d import Reacher2dEnv
env = Reacher2dEnv()
max_ep_len = 100
env_targets = [76, 40]
scale = 10.
elif env_name == 'pusher':
env = gym.make('Pusher-v4')
max_ep_len = 100
env_targets = [-20, -30]
scale = 10.
else:
raise NotImplementedError
if model_type == 'bc':
env_targets = env_targets[:1] # since BC ignores target, no need for different evaluations
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# load dataset
dataset_path = f'data/{env_name}-{dataset}-v2.pkl'
with open(dataset_path, 'rb') as f:
trajectories = pickle.load(f)
trajectories = random.sample(trajectories, 500)
# save all path information into separate lists
mode = variant.get('mode', 'normal')
states, traj_lens, returns = [], [], []
for path in trajectories:
#padding for pusher
#path['rewards'] += 0.9
if mode == 'delayed': # delayed: all rewards moved to end of trajectory
path['rewards'][-1] = path['rewards'].sum()
path['rewards'][:-1] = 0.
states.append(path['observations'])
traj_lens.append(len(path['observations']))
returns.append(path['rewards'].sum())
traj_lens, returns = np.array(traj_lens), np.array(returns)
# used for input normalization
states = np.concatenate(states, axis=0)
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
num_timesteps = sum(traj_lens)
print('=' * 50)
print(f'Starting new experiment: {env_name} {dataset}')
print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
print('=' * 50)
K = variant['K']
batch_size = variant['batch_size']
num_eval_episodes = variant['num_eval_episodes']
pct_traj = variant.get('pct_traj', 1.)
# only train on top pct_traj trajectories (for %BC experiment)
num_timesteps = max(int(pct_traj*num_timesteps), 1)
sorted_inds = np.argsort(returns) # lowest to highest
num_trajectories = 1
timesteps = traj_lens[sorted_inds[-1]]
ind = len(trajectories) - 2
while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
timesteps += traj_lens[sorted_inds[ind]]
num_trajectories += 1
ind -= 1
sorted_inds = sorted_inds[-num_trajectories:]
# used to reweight sampling so we sample according to timesteps instead of trajectories
p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])
def get_batch(batch_size=256, max_len=K):
batch_inds = np.random.choice(
np.arange(num_trajectories),
size=batch_size,
replace=True,
p=p_sample, # reweights so we sample according to timesteps
)
s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
for i in range(batch_size):
traj = trajectories[int(sorted_inds[batch_inds[i]])]
si = random.randint(0, traj['rewards'].shape[0] - 1)
# get sequences from dataset
s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
if 'terminals' in traj:
d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
else:
d.append(traj['dones'][si:si + max_len].reshape(1, -1))
timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1 # padding cutoff
rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
if rtg[-1].shape[1] <= s[-1].shape[1]:
rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)
# padding and state + reward normalization
tlen = s[-1].shape[1]
s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
s[-1] = (s[-1] - state_mean) / state_std
a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))
s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)
d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
return s, a, r, d, rtg, timesteps, mask
def eval_episodes(target_rew):
def fn(model):
returns, lengths = [], []
for _ in range(num_eval_episodes):
with torch.no_grad():
if model_type == 'dt':
ret, length = evaluate_episode_rtg(
env,
state_dim,
act_dim,
model,
max_ep_len=max_ep_len,
scale=scale,
target_return=target_rew/scale,
mode=mode,
state_mean=state_mean,
state_std=state_std,
device=device,
)
else:
ret, length = evaluate_episode(
env,
state_dim,
act_dim,
model,
max_ep_len=max_ep_len,
target_return=target_rew/scale,
mode=mode,
state_mean=state_mean,
state_std=state_std,
device=device,
)
returns.append(ret)
lengths.append(length)
return {
f'target_{target_rew}_return_mean': np.mean(returns),
f'target_{target_rew}_return_std': np.std(returns),
f'target_{target_rew}_length_mean': np.mean(lengths),
f'target_{target_rew}_length_std': np.std(lengths),
}
return fn
if model_type == 'dt':
model = DecisionTransformer(
state_dim=state_dim,
act_dim=act_dim,
max_length=K,
max_ep_len=max_ep_len,
hidden_size=variant['embed_dim'],
n_layer=variant['n_layer'],
n_head=variant['n_head'],
n_inner=4*variant['embed_dim'],
activation_function=variant['activation_function'],
n_positions=1024,
resid_pdrop=variant['dropout'],
attn_pdrop=variant['dropout'],
)
elif model_type == 'bc':
model = MLPBCModel(
state_dim=state_dim,
act_dim=act_dim,
max_length=K,
hidden_size=variant['embed_dim'],
n_layer=variant['n_layer'],
)
else:
raise NotImplementedError
model = model.to(device=device)
#model.load_state_dict(torch.load(f'models/{env_name}-model.pth'))
warmup_steps = variant['warmup_steps']
optimizer = torch.optim.AdamW(
model.parameters(),
lr=variant['learning_rate'],
weight_decay=variant['weight_decay'],
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda steps: min((steps+1)/warmup_steps, 1)
)
if model_type == 'dt':
trainer = SequenceTrainer(
model=model,
optimizer=optimizer,
batch_size=batch_size,
get_batch=get_batch,
scheduler=scheduler,
loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
eval_fns=[eval_episodes(tar) for tar in env_targets],
)
elif model_type == 'bc':
trainer = ActTrainer(
model=model,
optimizer=optimizer,
batch_size=batch_size,
get_batch=get_batch,
scheduler=scheduler,
loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
eval_fns=[eval_episodes(tar) for tar in env_targets],
)
if log_to_wandb:
wandb.init(
name=exp_prefix,
group=group_name,
project='decision-transformer',
config=variant
)
# wandb.watch(model) # wandb has some bug
for iter in tqdm(range(variant['max_iters'])):
outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True)
if log_to_wandb:
wandb.log(outputs)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--dataset', type=str, default='medium') # medium, medium-replay, medium-expert, expert
parser.add_argument('--mode', type=str, default='normal') # normal for standard setting, delayed for sparse
parser.add_argument('--K', type=int, default=20)
parser.add_argument('--pct_traj', type=float, default=1.)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--model_type', type=str, default='dt') # dt for decision transformer, bc for behavior cloning
parser.add_argument('--embed_dim', type=int, default=128)
parser.add_argument('--n_layer', type=int, default=3)
parser.add_argument('--n_head', type=int, default=1)
parser.add_argument('--activation_function', type=str, default='relu')
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
parser.add_argument('--warmup_steps', type=int, default=10000)
parser.add_argument('--num_eval_episodes', type=int, default=100)
parser.add_argument('--max_iters', type=int, default=10)
parser.add_argument('--num_steps_per_iter', type=int, default=10000)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--log_to_wandb', '-w', type=bool, default=False)
args = parser.parse_args()
experiment('gym-experiment', variant=vars(args))