Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 18, 2024
1 parent e15b1da commit 49f77f6
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,18 +663,22 @@ def _make_none(key, val):

# call vjp to get gradients
grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)
# assert grad_action.device == ctx.env.device

# reshape batch size
grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
grad_action = _tree_reshape(grad_action, ctx.env.batch_size)
# assert grad_action.device == ctx.env.device

# convert ndarrays to tensors
grad_state_qp = _object_to_tensordict(
grad_state.pipeline_state,
device=ctx.env.device,
batch_size=ctx.env.batch_size,
)
grad_action = _ndarray_to_tensor(grad_action)
grad_action_np = _ndarray_to_tensor(grad_action)
assert grad_action.device == ctx.env.device, grad_action
grad_action = grad_action_np
grad_state_qp = {
key: val if key not in none_keys else None
for key, val in grad_state_qp.items()
Expand Down

0 comments on commit 49f77f6

Please sign in to comment.