You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/reference/collectors.rst
+242Lines changed: 242 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -169,6 +169,248 @@ transformed, and applied, ensuring seamless integration with their existing infr
169
169
RPCWeightUpdater
170
170
DistributedWeightUpdater
171
171
172
+
Weight Synchronization API
173
+
~~~~~~~~~~~~~~~~~~~~~~~~~~
174
+
175
+
The weight synchronization API provides a simple, modular approach to updating model weights across
176
+
distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple
177
+
models may need to be synchronized independently.
178
+
179
+
Overview
180
+
^^^^^^^^
181
+
182
+
In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference
183
+
policies synchronized with the latest trained weights. The API addresses this challenge through a clean
184
+
separation of concerns, where four classes are involved:
185
+
186
+
- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is
187
+
your main entrypoint to configure the weight synchronization.
188
+
- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers.
189
+
- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes.
190
+
- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC)
191
+
192
+
The following diagram shows the different classes involved in the weight synchronization process:
193
+
194
+
.. aafig::
195
+
:aspect: 60
196
+
:scale: 130
197
+
:proportional:
198
+
199
+
INITIALIZATION PHASE
200
+
====================
201
+
202
+
WeightSyncScheme
203
+
+------------------+
204
+
||
205
+
| Configuration: |
206
+
| - strategy |
207
+
| - transport_type |
208
+
||
209
+
+--------+---------+
210
+
|
211
+
+------------+-------------+
212
+
||
213
+
creates creates
214
+
||
215
+
v v
216
+
Main Process Worker Process
217
+
+--------------++---------------+
218
+
| WeightSender || WeightReceiver|
219
+
||||
220
+
| - strategy || - strategy |
221
+
| - transports || - transport |
222
+
| - model_ref || - model_ref |
223
+
||||
224
+
| Registers: || Registers: |
225
+
| - model || - model |
226
+
| - workers || - transport |
227
+
+--------------++---------------+
228
+
||
229
+
| Transport Layer |
230
+
|+----------------+|
231
+
+-->+ MPTransport |<------+
232
+
|| (pipes) ||
233
+
|+----------------+|
234
+
|+----------------+|
235
+
+-->+ SharedMemTrans |<------+
236
+
|| (shared mem) ||
237
+
|+----------------+|
238
+
|+----------------+|
239
+
+-->+ RayTransport |<------+
240
+
| (Ray store) |
241
+
+----------------+
242
+
243
+
244
+
SYNCHRONIZATION PHASE
245
+
=====================
246
+
247
+
Main Process Worker Process
248
+
249
+
+-------------------+ +-------------------+
250
+
|WeightSender | | WeightReceiver |
251
+
|| | |
252
+
|1. Extract | | 4. Poll transport |
253
+
|weights from | | for weights |
254
+
|model using | | |
255
+
|strategy | | |
256
+
|| 2. Send via | |
257
+
|+-------------+ | Transport | +--------------+ |
258
+
|| Strategy | | +------------+ | | Strategy | |
259
+
|| extract() | | | | | | apply() | |
260
+
|+-------------+ +----+ Transport +-------->+ +--------------+ |
0 commit comments