Skip to content

Commit

Permalink
fix non-nstep
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 28, 2023
1 parent 641c1de commit e624c4a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions q_transformer/q_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def is_divisible(num, den):
return (num % den) == 0

Expand Down Expand Up @@ -266,7 +269,6 @@ def q_learn(
# the max Q value is taken as the optimal action is implicitly the one with the highest Q score

q_next = self.ema_model(next_states, instructions).amax(dim = -1)

q_next = q_next.clamp(min = default(monte_carlo_return, 1e4))

# Bellman's equation. most important line of code, hopefully done correctly
Expand All @@ -277,10 +279,10 @@ def q_learn(

loss = F.mse_loss(q_pred, q_target)

# that's it. 4 loc for the heart of q-learning
# that's it. ~5 loc for the heart of q-learning
# return loss and some of the intermediates for logging

return loss, QIntermediates(q_pred, q_next, q_target)
return loss, QIntermediates(q_pred_all_actions, q_pred, q_next, q_target)

def n_step_q_learn(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'q-transformer',
packages = find_packages(exclude=[]),
version = '0.0.15',
version = '0.0.16',
license='MIT',
description = 'Q-Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit e624c4a

Please sign in to comment.