Skip to content

Commit

Permalink
Merge pull request #159 from WilfChen/c51-range-error-fix
Browse files Browse the repository at this point in the history
fix c51 range type error
  • Loading branch information
WilfChen authored Jan 13, 2024
2 parents 4734caa + 848f03f commit 7fea931
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mindspore_rl/algorithm/c51/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,9 @@ def next_distribution(self, next_observation, batch_size):
next_action = self.get_max_index(next_target_q_values)[0]
next_qt_argmax = self.expand_dims(next_action, 1)
next_qt_argmax = self.cast(next_qt_argmax, ms.int32)
batch_indices = self.get_range(
Tensor(0, ms.int32), Tensor(batch_size, ms.int32), Tensor(1, ms.int32)
)
batch_indices = self.get_range(0, batch_size, 1)
batch_indices = self.expand_dims(batch_indices, 1).reshape(batch_size, 1)
batch_indices = ops.cast(batch_indices, ms.int32)
next_qt_index = self.concat((batch_indices, next_qt_argmax))
return self.gather_nd(next_target_probabilities, next_qt_index)

Expand Down

0 comments on commit 7fea931

Please sign in to comment.