Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
markub3327 committed Nov 17, 2023
1 parent c8ce6f0 commit 0f4f7ca
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions rl_toolkit/networks/models/dueling.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def _compute_n_step_rewards(
n = tf.shape(rewards)[1]

# Create a discount factor tensor
discounts = tf.constant(
[discount_factor ** float(j) for j in tf.range(n)], dtype=rewards.dtype
)
discounts = discount_factor ** tf.range(n + 1, dtype=rewards.dtype)
print(f"discounts: {discounts}")

# Pad the rewards tensor to ensure proper handling of the last elements in each sequence
Expand All @@ -213,7 +211,7 @@ def _compute_n_step_rewards(
windows = tf.TensorArray(
dtype=rewards.dtype,
size=n,
element_shape=(tf.shape(rewards)[0], tf.shape(rewards)[1]),
# element_shape=(tf.shape(rewards)[0], tf.shape(rewards)[1]),
)

for i in tf.range(n):
Expand All @@ -226,7 +224,7 @@ def _compute_n_step_rewards(
print(f"rewards_windows: {rewards_windows}")

# Multiply each window by the corresponding discount factor
discounted_windows = rewards_windows * discounts
discounted_windows = rewards_windows * discounts[:-1]
print(f"discounted_windows: {discounted_windows}")

# Sum along the time axis to get the n-step rewards
Expand All @@ -235,9 +233,10 @@ def _compute_n_step_rewards(

# Add the next state value with discount
n_step_rewards += (
(1.0 - is_terminal)
* discount_factor ** float(n)
* next_state_value[:, tf.newaxis]
(1.0 - is_terminal[:, tf.newaxis])
* tf.reverse(discounts[1:], axis=[0])[
tf.newaxis, :
] * next_state_value[:, tf.newaxis]
)

return n_step_rewards
Expand Down

0 comments on commit 0f4f7ca

Please sign in to comment.