Skip to content

Commit 182f9c2

Browse files
committed
[Feature] Documentation
ghstack-source-id: d0d0cf5 Pull-Request: #3192
1 parent f7b32c1 commit 182f9c2

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

docs/source/reference/collectors.rst

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
169169
RPCWeightUpdater
170170
DistributedWeightUpdater
171171

172+
Weight Synchronization API
173+
~~~~~~~~~~~~~~~~~~~~~~~~~~
174+
175+
The weight synchronization API provides a simple, modular approach to updating model weights across
176+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
177+
models may need to be synchronized independently.
178+
179+
Overview
180+
^^^^^^^^
181+
182+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
183+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
184+
separation of concerns, where four classes are involved:
185+
186+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
187+
your main entrypoint to configure the weight synchronization.
188+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
189+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
190+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
191+
192+
The following diagram shows the different classes involved in the weight synchronization process:
193+
194+
.. aafig::
195+
:aspect: 60
196+
:scale: 130
197+
:proportional:
198+
199+
INITIALIZATION PHASE
200+
====================
201+
202+
WeightSyncScheme
203+
+------------------+
204+
| |
205+
| Configuration: |
206+
| - strategy |
207+
| - transport_type |
208+
| |
209+
+--------+---------+
210+
|
211+
+------------+-------------+
212+
| |
213+
creates creates
214+
| |
215+
v v
216+
Main Process Worker Process
217+
+--------------+ +---------------+
218+
| WeightSender | | WeightReceiver|
219+
| | | |
220+
| - strategy | | - strategy |
221+
| - transports | | - transport |
222+
| - model_ref | | - model_ref |
223+
| | | |
224+
| Registers: | | Registers: |
225+
| - model | | - model |
226+
| - workers | | - transport |
227+
+--------------+ +---------------+
228+
| |
229+
| Transport Layer |
230+
| +----------------+ |
231+
+-->+ MPTransport |<------+
232+
| | (pipes) | |
233+
| +----------------+ |
234+
| +----------------+ |
235+
+-->+ SharedMemTrans |<------+
236+
| | (shared mem) | |
237+
| +----------------+ |
238+
| +----------------+ |
239+
+-->+ RayTransport |<------+
240+
| (Ray store) |
241+
+----------------+
242+
243+
244+
SYNCHRONIZATION PHASE
245+
=====================
246+
247+
Main Process Worker Process
248+
249+
+-------------------+ +-------------------+
250+
| WeightSender | | WeightReceiver |
251+
| | | |
252+
| 1. Extract | | 4. Poll transport |
253+
| weights from | | for weights |
254+
| model using | | |
255+
| strategy | | |
256+
| | 2. Send via | |
257+
| +-------------+ | Transport | +--------------+ |
258+
| | Strategy | | +------------+ | | Strategy | |
259+
| | extract() | | | | | | apply() | |
260+
| +-------------+ +----+ Transport +-------->+ +--------------+ |
261+
| | | | | | | |
262+
| v | +------------+ | v |
263+
| +-------------+ | | +--------------+ |
264+
| | Model | | | | Model | |
265+
| | (source) | | 3. Ack (optional) | | (dest) | |
266+
| +-------------+ | <-----------------------+ | +--------------+ |
267+
| | | |
268+
+-------------------+ | 5. Apply weights |
269+
| to model using |
270+
| strategy |
271+
+-------------------+
272+
273+
Key Challenges Addressed
274+
^^^^^^^^^^^^^^^^^^^^^^^^^
275+
276+
Modern RL training often involves multiple models that need independent synchronization:
277+
278+
1. **Multiple Models Per Collector**: A collector might need to update:
279+
280+
- The main policy network
281+
- A value network in a Ray actor within the replay buffer
282+
- Models embedded in the environment itself
283+
- Separate world models or auxiliary networks
284+
285+
2. **Different Update Strategies**: Each model may require different synchronization approaches:
286+
287+
- Full state_dict transfer vs. TensorDict-based updates
288+
- Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
289+
- Varied update frequencies
290+
291+
3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to
292+
specific worker indices, requiring a more flexible update mechanism.
293+
294+
Architecture
295+
^^^^^^^^^^^^
296+
297+
The API follows a scheme-based design where users specify synchronization requirements upfront,
298+
and the collector handles the orchestration transparently:
299+
300+
.. aafig::
301+
:aspect: 60
302+
:scale: 130
303+
:proportional:
304+
305+
Main Process Worker Process 1 Worker Process 2
306+
307+
+-----------------+ +---------------+ +---------------+
308+
| Collector | | Collector | | Collector |
309+
| | | | | |
310+
| Models: | | Models: | | Models: |
311+
| +----------+ | | +--------+ | | +--------+ |
312+
| | Policy A | | | |Policy A| | | |Policy A| |
313+
| +----------+ | | +--------+ | | +--------+ |
314+
| +----------+ | | +--------+ | | +--------+ |
315+
| | Model B | | | |Model B| | | |Model B| |
316+
| +----------+ | | +--------+ | | +--------+ |
317+
| | | | | |
318+
| Weight Senders: | | Weight | | Weight |
319+
| +----------+ | | Receivers: | | Receivers: |
320+
| | Sender A +---+------------+->Receiver A | | Receiver A |
321+
| +----------+ | | | | |
322+
| +----------+ | | +--------+ | | +--------+ |
323+
| | Sender B +---+------------+->Receiver B | | Receiver B |
324+
| +----------+ | Pipes | | Pipes | |
325+
+-----------------+ +-------+-------+ +-------+-------+
326+
^ ^ ^
327+
| | |
328+
| update_policy_weights_() | Apply weights |
329+
| | |
330+
+------+-------+ | |
331+
| User Code | | |
332+
| (Training) | | |
333+
+--------------+ +------------------------+
334+
335+
The weight synchronization flow:
336+
337+
1. **Initialization**: User creates ``weight_sync_schemes`` dict mapping model IDs to schemes
338+
2. **Registration**: Collector creates ``WeightSender`` for each model in the main process
339+
3. **Worker Setup**: Each worker creates corresponding ``WeightReceiver`` instances
340+
4. **Synchronization**: Calling ``update_policy_weights_()`` triggers all senders to push weights
341+
5. **Application**: Receivers automatically apply weights to their registered models
342+
343+
Available Classes
344+
^^^^^^^^^^^^^^^^^
345+
346+
**Synchronization Schemes** (User-Facing Configuration):
347+
348+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme`: Base class for schemes
349+
- :class:`~torchrl.weight_update.weight_sync_schemes.MultiProcessWeightSyncScheme`: For multiprocessing with pipes
350+
- :class:`~torchrl.weight_update.weight_sync_schemes.SharedMemWeightSyncScheme`: For shared memory synchronization
351+
- :class:`~torchrl.weight_update.weight_sync_schemes.RayWeightSyncScheme`: For Ray-based distribution
352+
- :class:`~torchrl.weight_update.weight_sync_schemes.NoWeightSyncScheme`: Dummy scheme for no synchronization
353+
354+
**Internal Classes** (Automatically Managed):
355+
356+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender`: Sends weights to all workers for one model
357+
- :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver`: Receives and applies weights in worker
358+
- :class:`~torchrl.weight_update.weight_sync_schemes.TransportBackend`: Communication layer abstraction
359+
360+
Usage Example
361+
^^^^^^^^^^^^^
362+
363+
.. code-block:: python
364+
365+
from torchrl.collectors import MultiSyncDataCollector
366+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
367+
368+
# Define synchronization for multiple models
369+
weight_sync_schemes = {
370+
"policy": MultiProcessWeightSyncScheme(strategy="tensordict"),
371+
"value_net": MultiProcessWeightSyncScheme(strategy="state_dict"),
372+
}
373+
374+
collector = MultiSyncDataCollector(
375+
create_env_fn=[make_env] * 4,
376+
policy=policy,
377+
frames_per_batch=1000,
378+
weight_sync_schemes=weight_sync_schemes, # Pass schemes dict
379+
)
380+
381+
# Single call updates all registered models across all workers
382+
for i, batch in enumerate(collector):
383+
# Training step
384+
loss = train(batch)
385+
386+
# Sync all models with one call
387+
collector.update_policy_weights_(policy)
388+
389+
The collector automatically:
390+
391+
- Creates ``WeightSender`` instances in the main process for each model
392+
- Creates ``WeightReceiver`` instances in each worker process
393+
- Resolves models by ID (e.g., ``"policy"`` → ``collector.policy``)
394+
- Handles transport setup and communication
395+
- Applies weights using the appropriate strategy (state_dict vs tensordict)
396+
397+
API Reference
398+
^^^^^^^^^^^^^
399+
400+
.. currentmodule:: torchrl.weight_update.weight_sync_schemes
401+
402+
.. autosummary::
403+
:toctree: generated/
404+
:template: rl_template.rst
405+
406+
WeightSyncScheme
407+
MultiProcessWeightSyncScheme
408+
SharedMemWeightSyncScheme
409+
RayWeightSyncScheme
410+
NoWeightSyncScheme
411+
WeightSender
412+
WeightReceiver
413+
172414
Collectors and replay buffers interoperability
173415
----------------------------------------------
174416

0 commit comments

Comments
 (0)