-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathtrainer.py
113 lines (96 loc) · 3.1 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the CC BY-NC license found in the
LICENSE.md file in the root directory of this source tree.
"""
import numpy as np
import torch
import time
class SequenceTrainer:
def __init__(
self,
model,
optimizer,
log_temperature_optimizer,
scheduler=None,
device="cuda",
):
self.model = model
self.optimizer = optimizer
self.log_temperature_optimizer = log_temperature_optimizer
self.scheduler = scheduler
self.device = device
self.start_time = time.time()
def train_iteration(
self,
loss_fn,
dataloader,
):
losses, nlls, entropies = [], [], []
logs = dict()
train_start = time.time()
self.model.train()
for _, trajs in enumerate(dataloader):
loss, nll, entropy = self.train_step_stochastic(loss_fn, trajs)
losses.append(loss)
nlls.append(nll)
entropies.append(entropy)
logs["time/training"] = time.time() - train_start
logs["training/train_loss_mean"] = np.mean(losses)
logs["training/train_loss_std"] = np.std(losses)
logs["training/nll"] = nlls[-1]
logs["training/entropy"] = entropies[-1]
logs["training/temp_value"] = self.model.temperature().detach().cpu().item()
return logs
def train_step_stochastic(self, loss_fn, trajs):
(
states,
actions,
rewards,
dones,
rtg,
timesteps,
ordering,
padding_mask,
) = trajs
states = states.to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
dones = dones.to(self.device)
rtg = rtg.to(self.device)
timesteps = timesteps.to(self.device)
ordering = ordering.to(self.device)
padding_mask = padding_mask.to(self.device)
action_target = torch.clone(actions)
_, action_preds, _ = self.model.forward(
states,
actions,
rewards,
rtg[:, :-1],
timesteps,
ordering,
padding_mask=padding_mask,
)
loss, nll, entropy = loss_fn(
action_preds, # a_hat_dist
action_target,
padding_mask,
self.model.temperature().detach(), # no gradient taken here
)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
self.optimizer.step()
self.log_temperature_optimizer.zero_grad()
temperature_loss = (
self.model.temperature() * (entropy - self.model.target_entropy).detach()
)
temperature_loss.backward()
self.log_temperature_optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
return (
loss.detach().cpu().item(),
nll.detach().cpu().item(),
entropy.detach().cpu().item(),
)