-
Notifications
You must be signed in to change notification settings - Fork 0
/
behavioural_cloning.py
206 lines (167 loc) · 7.71 KB
/
behavioural_cloning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Basic behavioural cloning
# Note: this uses gradient accumulation in batches of ones
# to perform training.
# This will fit inside even smaller GPUs (tested on 8GB one),
# but is slow.
import pickle
import time
import wandb
import gym
import minerl
import torch as th
import numpy as np
import os
from argparse import ArgumentParser
from openai_vpt.agent import PI_HEAD_KWARGS, MineRLAgent
from data_loader import DataLoader
from openai_vpt.lib.tree_util import tree_map
from utils.logs import Logging
# Originally this code was designed for a small dataset of ~20 demonstrations per task.
# The settings might not be the best for the full BASALT dataset (thousands of demonstrations).
# Use this flag to switch between the two settings
USING_FULL_DATASET = True
EPOCHS = int(os.getenv("EPOCHS", 1))
# Needs to be <= number of videos
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = int(os.getenv("N_WORKERS", 50))
DEVICE = "cuda"
LOSS_REPORT_RATE = 100
SAVE_WEIGHTS = False
# Tuned with bit of trial and error
LEARNING_RATE = float(os.getenv("LEARNING_RATE", 0.000181))
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = float(os.getenv("WEIGHT_DECAY", 0.0))
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = float(os.getenv("KL_LOSS_WEIGHT", 1.0))
MAX_GRAD_NORM = float(os.getenv("MAX_GRAD_NORM", 5.0))
MAX_BATCHES = int(os.getenv("MAX_BATCHES", 2700))
def load_model_parameters(path_to_model_file):
agent_parameters = pickle.load(open(path_to_model_file, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
return policy_kwargs, pi_head_kwargs
def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
# save config
wandb.config.model = in_model
wandb.config.epochs = EPOCHS
wandb.config.batch_size = BATCH_SIZE
wandb.config.n_workers = N_WORKERS
wandb.config.learning_rate = LEARNING_RATE
wandb.config.weight_decay = WEIGHT_DECAY
wandb.config.kl_loss_weight = KL_LOSS_WEIGHT
wandb.config.max_grad_norm = MAX_GRAD_NORM
wandb.config.max_batches = MAX_BATCHES
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
# To create model with the right environment.
# All basalt environments have the same settings, so any of them works here
env = gym.make("MineRLBasaltFindCave-v0")
agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
agent.load_weights(in_weights)
# Create a copy which will have the original parameters
original_agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
original_agent.load_weights(in_weights)
env.close()
policy = agent.policy
original_policy = original_agent.policy
# Freeze most params if using small dataset
for param in policy.parameters():
param.requires_grad = False
# Unfreeze final layers
trainable_parameters = []
for param in policy.net.lastlayer.parameters():
param.requires_grad = True
trainable_parameters.append(param)
for param in policy.pi_head.parameters():
param.requires_grad = True
trainable_parameters.append(param)
# Parameters taken from the OpenAI VPT paper
optimizer = th.optim.Adam(
trainable_parameters,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
data_loader = DataLoader(
dataset_dir=data_dir,
n_workers=N_WORKERS,
batch_size=BATCH_SIZE,
n_epochs=EPOCHS,
)
start_time = time.time()
# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)
loss_sum = 0
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
batch_loss = 0
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
if image is None and action is None:
# A work-item was done. Remove hidden state
if episode_id in episode_hidden_states:
removed_hidden_state = episode_hidden_states.pop(episode_id)
del removed_hidden_state
continue
agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
if agent_action is None:
# Action was null
continue
agent_obs = agent._env_obs_to_agent({"pov": image})
if episode_id not in episode_hidden_states:
episode_hidden_states[episode_id] = policy.initial_state(1)
agent_state = episode_hidden_states[episode_id]
pi_distribution, _, new_agent_state = policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
with th.no_grad():
original_pi_distribution, _, _ = original_policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)
kl_div = policy.get_kl_of_action_dists(pi_distribution, original_pi_distribution)
# Make sure we do not try to backprop through sequence
# (fails with current accumulation)
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
episode_hidden_states[episode_id] = new_agent_state
# Finally, update the agent to increase the probability of the
# taken action.
# Remember to take mean over batch losses
loss = (-log_prob + KL_LOSS_WEIGHT * kl_div) / BATCH_SIZE
batch_loss += loss.item()
loss.backward()
th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
optimizer.step()
optimizer.zero_grad()
loss_sum += batch_loss
wandb.log({'loss': batch_loss})
if batch_i % LOSS_REPORT_RATE == 0:
time_since_start = time.time() - start_time
Logging.info(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")
loss_sum = 0
if batch_i > MAX_BATCHES:
break
if SAVE_WEIGHTS and batch_i % 100 == 0:
Logging.info(f"Save weights to .tmp.{batch_i}")
state_dict = policy.state_dict()
th.save(state_dict, out_weights + f".tmp.{batch_i}")
state_dict = policy.state_dict()
th.save(state_dict, out_weights)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on")
parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned")
parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned")
parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved")
args = parser.parse_args()
behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights)