Skip to content

Commit 03ffc82

Browse files
committed
[Quality] Fix flaky test
ghstack-source-id: 25d4a5e Pull-Request: #3211
1 parent 01d2801 commit 03ffc82

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

test/test_collector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(self, observation):
162162
output = self.linear(observation)
163163
if self.multiple_outputs:
164164
return output, output.sum(), output.min(), output.max()
165-
return self.linear(observation)
165+
return output
166166

167167

168168
class UnwrappablePolicy(nn.Module):
@@ -1512,6 +1512,7 @@ def create_env():
15121512
cudagraph_policy=cudagraph,
15131513
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
15141514
)
1515+
assert "policy" in collector._weight_senders, collector._weight_senders.keys()
15151516
try:
15161517
# collect state_dict
15171518
state_dict = collector.state_dict()

test/test_env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,8 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv):
38363836
finally:
38373837
env.close(raise_if_closed=False)
38383838
del env
3839+
time.sleep(0.1)
3840+
gc.collect()
38393841

38403842
class AddString(Transform):
38413843
def __init__(self):

torchrl/collectors/collectors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,19 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
307307
else None
308308
)
309309

310+
# If no weights were provided and a sync scheme exists, extract the latest
311+
# weights from the current model using the scheme strategy (state_dict or tensordict).
312+
# This ensures we don't return stale cached weights.
313+
if weights is None and scheme is not None:
314+
from torchrl.weight_update.weight_sync_schemes import (
315+
_resolve_model,
316+
WeightStrategy,
317+
)
318+
319+
strategy = WeightStrategy(extract_as=scheme.strategy)
320+
model = _resolve_model(self, model_id)
321+
return strategy.extract_weights(model)
322+
310323
if weights is None:
311324
if model_id == "policy" and hasattr(self, "policy_weights"):
312325
return self.policy_weights
@@ -462,6 +475,21 @@ def update_policy_weights_(
462475
# Apply to local policy
463476
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
464477
strategy.apply_weights(self.policy, weights)
478+
elif (
479+
hasattr(self, "_original_policy")
480+
and isinstance(self._original_policy, nn.Module)
481+
and hasattr(self, "policy")
482+
and isinstance(self.policy, nn.Module)
483+
):
484+
# If no weights were provided, mirror weights from the original (trainer) policy
485+
from torchrl.weight_update.weight_sync_schemes import WeightStrategy
486+
487+
strategy = WeightStrategy(extract_as="tensordict")
488+
weights = strategy.extract_weights(self._original_policy)
489+
# Cast weights to the policy device before applying
490+
if self.policy_device is not None:
491+
weights = weights.to(self.policy_device)
492+
strategy.apply_weights(self.policy, weights)
465493
# Otherwise, no action needed - policy is local and changes are immediately visible
466494

467495
def __iter__(self) -> Iterator[TensorDictBase]:

torchrl/envs/batched_envs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,14 +2489,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
24892489
# Make sure the root is updated
24902490
root_shared_tensordict.update_(env._step_mdp(input))
24912491

2492+
# Set event before sending non-tensor data so parent knows worker is done
2493+
# The recv() call itself will provide synchronization for the pipe
2494+
mp_event.set()
2495+
24922496
if _non_tensor_keys:
24932497
child_pipe.send(
24942498
("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
24952499
)
24962500

2497-
# Set event only after non-tensor data is sent to avoid race condition
2498-
mp_event.set()
2499-
25002501
del next_td
25012502

25022503
elif cmd == "step_and_maybe_reset":
@@ -2530,14 +2531,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
25302531
event.record()
25312532
event.synchronize()
25322533

2534+
# Set event before sending non-tensor data so parent knows worker is done
2535+
# The recv() call itself will provide synchronization for the pipe
2536+
mp_event.set()
2537+
25332538
if _non_tensor_keys:
25342539
ntd = root_next_td.select(*_non_tensor_keys)
25352540
ntd.set("next", td_next.select(*_non_tensor_keys))
25362541
child_pipe.send(("non_tensor", ntd))
25372542

2538-
# Set event only after non-tensor data is sent to avoid race condition
2539-
mp_event.set()
2540-
25412543
del td, root_next_td
25422544

25432545
elif cmd == "close":

0 commit comments

Comments
 (0)