Skip to content

Commit

Permalink
finally approaching the main contribution of the paper. to be finishe…
Browse files Browse the repository at this point in the history
…d tomorrow morning
  • Loading branch information
lucidrains committed Nov 29, 2023
1 parent 8bfb82c commit 4669c7c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
37 changes: 36 additions & 1 deletion q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
):
super().__init__()

self.is_multiple_actions = model.num_actions > 1

# q-learning related hyperparameters

self.discount_factor_gamma = discount_factor_gamma
Expand Down Expand Up @@ -358,6 +360,34 @@ def n_step_q_learn(

return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

def autoregressive_q_learn(
self,
instructions: Tuple[str],
states: TensorType['b', 't', 'c', 'f', 'h', 'w', float],
actions: TensorType['b', 't', 'n', int],
next_states: TensorType['b', 'c', 'f', 'h', 'w', float],
rewards: TensorType['b', 't', float],
dones: TensorType['b', 't', bool],
*,
monte_carlo_return = None

) -> Tuple[Tensor, QIntermediates]:
"""
einops
b - batch
c - channels
f - frames
h - height
w - width
t - timesteps
n - number of actions
a - action bins
q - q values
"""

raise NotImplementedError

def learn(
self,
*args,
Expand All @@ -374,9 +404,14 @@ def learn(

# main q-learning loss, whether single or n-step

if self.n_step_q_learning:
if self.is_multiple_actions:
td_loss, q_intermediates = self.autoregressive_q_learn(*args, **q_learn_kwargs)
num_timesteps = actions.shape[1]

elif self.n_step_q_learning:
td_loss, q_intermediates = self.n_step_q_learn(*args, **q_learn_kwargs)
num_timesteps = actions.shape[1]

else:
td_loss, q_intermediates = self.q_learn(*args, **q_learn_kwargs)
num_timesteps = 1
Expand Down
1 change: 1 addition & 0 deletions q_transformer/robotic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def __init__(

assert num_actions >= 1

self.num_actions = num_actions
self.is_single_action = num_actions == 1
self.action_bins = action_bins

Expand Down

0 comments on commit 4669c7c

Please sign in to comment.