Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated train_on_episode_end #320

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions howto/work_with_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The hyper-parameters that refer to the *policy steps* are:
* `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms.
* `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `truncated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$.
* `learning_starts`: how many policy steps the agent has to perform before starting the training. During the first `learning_starts` steps the buffer is pre-filled with random actions sampled by the environment.
* `train_on_episode_end`: If set to `true` training occurs only at the end of episodes rather than after every policy step. This configuration is particularly beneficial in scenarios where maintaining a high step rate (steps per second) is crucial, such as in real-time or physical simulations. It is important to note that in distributed training this feature is disabled automatically to avoid conflicts between parallel processes.

## Gradient steps
A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed.
Expand Down
64 changes: 34 additions & 30 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,36 +645,40 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
if aggregator:
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
if aggregator:
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
76 changes: 40 additions & 36 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,42 +680,46 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()):
tcp.data.copy_(cp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
actions_dim,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()):
tcp.data.copy_(cp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
actions_dim,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
82 changes: 43 additions & 39 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,45 +658,49 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
is_continuous,
actions_dim,
moments,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(per_rank_gradient_steps):
if (
cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq
== 0
):
tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data)
batch = {k: v[i].float() for k, v in local_data.items()}
train(
fabric,
world_model,
actor,
critic,
target_critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
is_continuous,
actions_dim,
moments,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
82 changes: 43 additions & 39 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,46 +669,50 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor_task,
critic_task,
world_optimizer,
actor_task_optimizer,
critic_task_optimizer,
batch,
aggregator,
cfg,
ensembles=ensembles,
ensemble_optimizer=ensemble_optimizer,
actor_exploration=actor_exploration,
critic_exploration=critic_exploration,
actor_exploration_optimizer=actor_exploration_optimizer,
critic_exploration_optimizer=critic_exploration_optimizer,
is_distributed = fabric.world_size > 1
if (
cfg.algo.train_on_episode_end and reset_envs > 0 and not is_distributed
) or not cfg.algo.train_on_episode_end:
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor_task,
critic_task,
world_optimizer,
actor_task_optimizer,
critic_task_optimizer,
batch,
aggregator,
cfg,
ensembles=ensembles,
ensemble_optimizer=ensemble_optimizer,
actor_exploration=actor_exploration,
critic_exploration=critic_exploration,
actor_exploration_optimizer=actor_exploration_optimizer,
critic_exploration_optimizer=critic_exploration_optimizer,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

if aggregator and not aggregator.disabled:
aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step))
aggregator.update(
"Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step)
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

if aggregator and not aggregator.disabled:
aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step))
aggregator.update(
"Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step)
)

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
Expand Down
Loading