Skip to content

Commit

Permalink
Added functionality to map batches directly to episode lists in an it…
Browse files Browse the repository at this point in the history
…erator that can prefetch.

Signed-off-by: Simon Zehnder <[email protected]>
  • Loading branch information
simonsays1980 committed May 27, 2024
1 parent da83264 commit 82ae5bd
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 58 deletions.
4 changes: 3 additions & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def training_step(self) -> ResultDict:
# for sampling. Only in online evaluation
# `RolloutWorker/EnvRunner` should be used.
episodes = self.offline_data.sample(
num_samples=self.config.train_batch_size
num_samples=self.config.train_batch_size,
num_shards=self.config.num_learners,
return_iterator=True if self.config.num_learners > 1 else False,
)

with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
Expand Down
85 changes: 50 additions & 35 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,42 +393,57 @@ def _learner_update(
# Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal
# in case of multi-agent).
else:
eps_shards = list(ShardEpisodesIterator(episodes, len(self._workers)))
# In the multi-agent case AND `minibatch_size` AND num_workers > 1, we
# compute a max iteration counter such that the different Learners will
# not go through a different number of iterations.
min_total_mini_batches = 0
if (
isinstance(episodes[0], MultiAgentEpisode)
and minibatch_size
and len(self._workers) > 1
):
# Find episode w/ the largest single-agent episode in it, then
# compute this single-agent episode's total number of mini batches
# (if we iterated over it num_sgd_iter times with the mini batch
# size).
longest_ts = 0
per_mod_ts = defaultdict(int)
for i, shard in enumerate(eps_shards):
for ma_episode in shard:
for sa_episode in ma_episode.agent_episodes.values():
key = (i, sa_episode.module_id)
per_mod_ts[key] += len(sa_episode)
if per_mod_ts[key] > longest_ts:
longest_ts = per_mod_ts[key]
min_total_mini_batches = self._compute_num_total_mini_batches(
batch_size=longest_ts,
mini_batch_size=minibatch_size,
num_iters=num_iters,
)
partials = [
partial(
_learner_update,
episodes_shard=eps_shard,
min_total_mini_batches=min_total_mini_batches,
from ray.data.iterator import DataIterator

if isinstance(episodes[0], DataIterator):
min_total_mini_batches = 0
partials = [
partial(
_learner_update,
episodes_shard=episodes_shard,
min_total_mini_batches=min_total_mini_batches,
)
for episodes_shard in episodes
]
else:
eps_shards = list(
ShardEpisodesIterator(episodes, len(self._workers))
)
for eps_shard in eps_shards
]
# In the multi-agent case AND `minibatch_size` AND num_workers
# > 1, we compute a max iteration counter such that the different
# Learners will not go through a different number of iterations.
min_total_mini_batches = 0
if (
isinstance(episodes[0], MultiAgentEpisode)
and minibatch_size
and len(self._workers) > 1
):
# Find episode w/ the largest single-agent episode in it, then
# compute this single-agent episode's total number of mini
# batches (if we iterated over it num_sgd_iter times with the
# mini batch size).
longest_ts = 0
per_mod_ts = defaultdict(int)
for i, shard in enumerate(eps_shards):
for ma_episode in shard:
for sa_episode in ma_episode.agent_episodes.values():
key = (i, sa_episode.module_id)
per_mod_ts[key] += len(sa_episode)
if per_mod_ts[key] > longest_ts:
longest_ts = per_mod_ts[key]
min_total_mini_batches = self._compute_num_total_mini_batches(
batch_size=longest_ts,
mini_batch_size=minibatch_size,
num_iters=num_iters,
)
partials = [
partial(
_learner_update,
episodes_shard=eps_shard,
min_total_mini_batches=min_total_mini_batches,
)
for eps_shard in eps_shards
]

if async_update:
# Retrieve all ready results (kicked off by prior calls to this method).
Expand Down
49 changes: 27 additions & 22 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import numpy as np
from pathlib import Path
Expand Down Expand Up @@ -63,41 +64,45 @@ def sample(
num_samples: int,
return_iterator: bool = False,
num_shards: int = 1,
as_episodes: bool = True,
):
if (
not return_iterator
or return_iterator
and num_shards <= 1
and not self.batch_iterator
):
self.batch_iterator = self.data.map_batches(
functools.partial(self._map_to_episodes, self.is_multi_agent)
).iter_batches(
batch_size=num_samples,
prefetch_batches=1,
local_shuffle_buffer_size=num_samples * 10,
)

if return_iterator:
if num_shards > 1:
return self.data.shards(num_shards)
return self.data.map_batches(
functools.partial(self._map_to_episodes, self.is_multi_agent)
).streaming_split(n=num_shards, equal=True)
else:
return self.data.iter_batches(
batch_size=num_samples,
batch_format="numpy",
local_shuffle_buffer_size=num_samples * 10,
)
return self.batch_iterator
else:
if not self.batch_iterator:
self.batch_iterator = self.data.iter_batches(
batch_size=num_samples,
batch_format="numpy",
local_shuffle_buffer_size=num_samples * 10,
)
# Return a single batch
if as_episodes:
return self._convert_to_episodes(next(iter(self.batch_iterator)))
else:
return self.data.take_batch(batch_size=num_samples)
return next(iter(self.batch_iterator))["episodes"]

def _convert_to_episodes(self, batch: Dict[str, np.ndarray]) -> List[EpisodeType]:
"""Converts a batch of data to episodes."""
@staticmethod
def _map_to_episodes(
is_multi_agent: bool, batch: Dict[str, np.ndarray]
) -> List[EpisodeType]:
"""Maps a batch of data to episodes."""

episodes = []
# TODO (simon): Give users possibility to provide a custom schema.
for i, obs in enumerate(batch["obs"]):

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if self.is_multi_agent:
if is_multi_agent:
agent_id = (
batch[Columns.AGENT_ID][i][0]
if Columns.AGENT_ID in batch
Expand All @@ -110,7 +115,7 @@ def _convert_to_episodes(self, batch: Dict[str, np.ndarray]) -> List[EpisodeType
else:
agent_id = None

if self.is_multi_agent:
if is_multi_agent:
# TODO (simon): Add support for multi-agent episodes.
pass
else:
Expand Down Expand Up @@ -144,4 +149,4 @@ def _convert_to_episodes(self, batch: Dict[str, np.ndarray]) -> List[EpisodeType
len_lookback_buffer=0,
)
episodes.append(episode)
return episodes
return {"episodes": episodes}

0 comments on commit 82ae5bd

Please sign in to comment.