-
Notifications
You must be signed in to change notification settings - Fork 64
/
train.py
122 lines (98 loc) · 3.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import numpy as np
import torch
import pdb
import trajectory.utils as utils
import trajectory.datasets as datasets
from trajectory.models.transformers import GPT
class Parser(utils.Parser):
dataset: str = 'halfcheetah-medium-expert-v2'
config: str = 'config.offline'
#######################
######## setup ########
#######################
args = Parser().parse_args('train')
#######################
####### dataset #######
#######################
env = datasets.load_environment(args.dataset)
sequence_length = args.subsampled_sequence_length * args.step
dataset_config = utils.Config(
datasets.DiscretizedDataset,
savepath=(args.savepath, 'data_config.pkl'),
env=args.dataset,
N=args.N,
penalty=args.termination_penalty,
sequence_length=sequence_length,
step=args.step,
discount=args.discount,
discretizer=args.discretizer,
)
dataset = dataset_config()
obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
transition_dim = dataset.joined_dim
#######################
######## model ########
#######################
block_size = args.subsampled_sequence_length * transition_dim - 1
print(
f'Dataset size: {len(dataset)} | '
f'Joined dim: {transition_dim} '
f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}'
)
model_config = utils.Config(
GPT,
savepath=(args.savepath, 'model_config.pkl'),
## discretization
vocab_size=args.N, block_size=block_size,
## architecture
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
## dimensions
observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
## loss weighting
action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
## dropout probabilities
embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
)
model = model_config()
model.to(args.device)
#######################
####### trainer #######
#######################
warmup_tokens = len(dataset) * block_size ## number of tokens seen per epoch
final_tokens = 20 * warmup_tokens
trainer_config = utils.Config(
utils.Trainer,
savepath=(args.savepath, 'trainer_config.pkl'),
# optimization parameters
batch_size=args.batch_size,
learning_rate=args.learning_rate,
betas=(0.9, 0.95),
grad_norm_clip=1.0,
weight_decay=0.1, # only applied on matmul weights
# learning rate decay: linear warmup followed by cosine decay to 10% of original
lr_decay=args.lr_decay,
warmup_tokens=warmup_tokens,
final_tokens=final_tokens,
## dataloader
num_workers=0,
device=args.device,
)
trainer = trainer_config()
#######################
###### main loop ######
#######################
## scale number of epochs to keep number of updates constant
n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref)
save_freq = int(n_epochs // args.n_saves)
for epoch in range(n_epochs):
print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}')
trainer.train(model, dataset)
## get greatest multiple of `save_freq` less than or equal to `save_epoch`
save_epoch = (epoch + 1) // save_freq * save_freq
statepath = os.path.join(args.savepath, f'state_{save_epoch}.pt')
print(f'Saving model to {statepath}')
## save state to disk
state = model.state_dict()
torch.save(state, statepath)