-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecision_transformer.py
85 lines (65 loc) · 3.28 KB
/
decision_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import numpy as np
from transformer.llama import LlamaModel
class DecisionTransformer(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim, action_bins, seq_len, n_layers=3, n_heads=4, max_ep_len=1000,
dueling=False, device=None):
super().__init__()
mask_size = action_dim + seq_len - 1
self.transformer = LlamaModel(hidden_dim, num_heads=n_heads, max_position=mask_size,
num_layers=n_layers) # Transformer(hidden_dim, n_heads, n_layers)
self.action_dim = action_dim
self.state_emb = nn.Linear(state_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, action_bins)
self.action_emb = nn.Embedding(action_bins, hidden_dim)
self.return_emb = nn.Embedding(1, hidden_dim)
self.reward_emb = nn.Embedding(1, hidden_dim)
self.time_emb = nn.Embedding(max_ep_len, hidden_dim)
tri_mask = np.tril(np.ones((mask_size, mask_size)))
self.attn_mask = torch.from_numpy(np.logical_not(tri_mask))
self.attn_mask = self.attn_mask * torch.finfo(torch.float32).min
self.attn_mask = self.attn_mask.unsqueeze(0).unsqueeze(0)
self.device = device
self.pos_ids = torch.arange(0, mask_size).long().unsqueeze(0)
if device:
self.attn_mask = self.attn_mask.to(device)
self.pos_ids = self.pos_ids.to(device)
# self.pos_enc = nn.Parameter(torch.randn((1, mask_size, hidden_dim))*0.02, requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, states, actions, rewards, returns, timesteps):
time = self.time_emb(timesteps)
state_token = self.state_emb(states) + time
if self.action_dim > 1:
a_token = self.action_emb(actions[:, :-1])
token = torch.cat((state_token, a_token), dim=1)
else:
token = state_token
# token = token + self.pos_enc
out = self.transformer(token, attention_mask=self.attn_mask, position_ids=self.pos_ids)
action_token = out[:, -(self.action_dim):]
q_values = self.out(action_token)
return q_values
def predict_action(self, states, actions, rewards, returns, timesteps):
actions = torch.zeros((states.shape[0], self.action_dim)).int().to(self.device)
states = states.to(self.device)
timesteps = timesteps.to(self.device)
with torch.no_grad():
for i in range(self.action_dim):
q_values = self.forward(states, actions, timesteps)
indices = torch.argmax(q_values[:, -self.action_dim + i], dim=-1)
actions[:, i] = indices
return actions.cpu().numpy()