-
Notifications
You must be signed in to change notification settings - Fork 0
/
preference_based_RL.py
179 lines (153 loc) · 5.75 KB
/
preference_based_RL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from argparse import ArgumentParser
import pickle
import torch as th
import numpy as np
import gym
from stable_baselines3 import PPO
from imitation.algorithms import preference_comparisons
from imitation.rewards.reward_nets import (
CnnRewardNet,
NormalizedRewardNet,
)
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
import minerl # noqa: F401
from openai_vpt.agent import MineRLAgent
import sb3_minerl_envs # noqa: F401
from sb3_policy_wrapper import MinecraftActorCriticPolicy
from gym_wrappers import ObservationToInfos
def load_model_parameters(path_to_model_file):
agent_parameters = pickle.load(open(path_to_model_file, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
return policy_kwargs, pi_head_kwargs
def preference_based_RL_train(env_str, in_model, in_weights, out_weights):
# Setup MineRL environment
minerl_env_str = "MineRLBasalt" + env_str
env = gym.make(minerl_env_str + "-v0")
# Setup MineRL agent
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
minerl_agent = MineRLAgent(
env,
device="cuda",
policy_kwargs=agent_policy_kwargs,
pi_head_kwargs=agent_pi_head_kwargs,
)
minerl_agent.load_weights(in_weights)
# Freeze most params if using small dataset
for param in minerl_agent.policy.parameters():
param.requires_grad = False
# Unfreeze final layers and policy and value head
for param in minerl_agent.policy.net.lastlayer.parameters():
param.requires_grad = True
for param in minerl_agent.policy.pi_head.parameters():
param.requires_grad = True
for param in minerl_agent.policy.value_head.parameters():
param.requires_grad = True
# Setup MineRL VecEnv
venv = make_vec_env(
minerl_env_str + "SB3-v0",
# Keep this at 1 since we are not keeping track of multiple hidden states
n_envs=1,
# This should be sufficiently high for the given task
max_episode_steps=20,
env_make_kwargs={"minerl_agent": minerl_agent},
)
# Setup preference-based reinforcement learning using the imitation package
# TODO In general, check whether algorithm hyperparams make sense and tune
# TODO reuse ImpalaCNN from the VPT models with a regression head
image_obs_space = gym.spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8)
reward_net = CnnRewardNet(image_obs_space, venv.action_space, use_action=False)
normalized_reward_net = NormalizedRewardNet(reward_net, RunningNorm)
preference_model = preference_comparisons.PreferenceModel(normalized_reward_net)
# TODO design more useful fragmenter for MineRL trajcetories,
# e.g. only compare last parts of episodes
fragmenter = preference_comparisons.MineRLFragmenter(warning_threshold=0, seed=0)
gatherer = preference_comparisons.PrefCollectGatherer(
pref_collect_address="http://127.0.0.1:8000",
video_output_dir="/home/aicrowd/pref-collect/videofiles/",
)
# TODO imitation also provides EnsembleTrainer
# (which requires a RewardEnsemble), which trainer should we use?
reward_trainer = preference_comparisons.BasicRewardTrainer(
model=normalized_reward_net,
loss=preference_comparisons.CrossEntropyRewardLoss(preference_model),
epochs=3,
)
agent = PPO(
policy=MinecraftActorCriticPolicy,
policy_kwargs={
"minerl_agent": minerl_agent,
"optimizer_class": th.optim.Adam,
},
env=venv,
seed=0,
n_steps=512 // venv.num_envs,
batch_size=32,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
)
trajectory_generator = preference_comparisons.MineRLAgentTrainer(
algorithm=agent,
reward_fn=reward_net,
venv=venv,
exploration_frac=0.0,
seed=0,
)
pref_comparisons = preference_comparisons.PreferenceComparisons(
trajectory_generator,
reward_net,
num_iterations=5,
fragmenter=fragmenter,
preference_gatherer=gatherer,
reward_trainer=reward_trainer,
fragment_length=10,
transition_oversampling=1,
initial_comparison_frac=0.1,
allow_variable_horizon=True,
seed=0,
initial_epoch_multiplier=1,
)
# Run training
pref_comparisons.train(
total_timesteps=500, # For good performance this should be 1_000_000
total_comparisons=10, # For good performance this should be 5_000
)
venv.close()
# Save finetuned weights
state_dict = minerl_agent.policy.state_dict()
th.save(state_dict, out_weights)
print("Finished")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--env",
type=str,
help="Environment name from [FindCave, MakeWaterfall, \
CreateVillageAnimalPen, BuildVillageHouse]",
default="FindCave",
)
parser.add_argument(
"--in-model",
type=str,
help="Path to the .model file to be finetuned",
default="data/VPT-models/foundation-model-1x.model",
)
parser.add_argument(
"--in-weights",
type=str,
help="Path to the .weights file to be finetuned",
default="data/VPT-models/foundation-model-1x.weights",
)
parser.add_argument(
"--out-weights",
type=str,
help="Path where finetuned weights will be saved",
default="train/PrefRLFinetuned.weights",
)
args = parser.parse_args()
preference_based_RL_train(
args.env, args.in_model, args.in_weights, args.out_weights
)