diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 91da9e27..000c688f 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -7,6 +7,7 @@ from typing import Dict, Sequence import gymnasium as gym +import jsonlines import numpy as np import torch import torch.nn.functional as F @@ -37,8 +38,7 @@ from sheeprl.utils.registry import register_algorithm from sheeprl.utils.utils import compute_lambda_values, polynomial_decay -# Decomment the following two lines if you are using MineDojo on an headless machine -# os.environ["MINEDOJO_HEADLESS"] = "1" +os.environ["MINEDOJO_HEADLESS"] = "1" def train( @@ -729,6 +729,12 @@ def main(): max_decay_steps=max_step_expl_decay, ) + # Stats file info + if "minedojo" in args.env_id: + stats_dir = os.path.join(log_dir, "stats") + os.makedirs(stats_dir, exist_ok=True) + stats_filename = os.path.join(stats_dir, "stats.jsonl") + # Get the first environment observation and start the optimization episode_steps = [] o, infos = env.reset(seed=args.seed) @@ -785,6 +791,18 @@ def main(): else: real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) + # Save stats + if "minedojo" in args.env_id: + with jsonlines.open(stats_filename, mode="a") as writer: + writer.write( + { + "life_stats": infos["life_stats"], + "location_stats": infos["location_stats"], + "action": real_actions.tolist(), + "biomeid": infos["biomeid"], + } + ) + step_data["is_first"] = copy.deepcopy(step_data["dones"]) o, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape)) dones = np.logical_or(dones, truncated) diff --git a/sheeprl/algos/ppo/ppo_atari.py b/sheeprl/algos/ppo/ppo_atari.py deleted file mode 100644 index ac9eb104..00000000 --- a/sheeprl/algos/ppo/ppo_atari.py +++ /dev/null @@ -1,499 +0,0 @@ -import copy -import os -import time -import warnings -from dataclasses import asdict -from datetime import datetime - -import gymnasium as gym -import torch -from gymnasium.vector import SyncVectorEnv -from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing -from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli -from lightning.fabric.loggers import TensorBoardLogger -from lightning.fabric.plugins.collectives import TorchCollective -from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.strategies import DDPStrategy -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase, make_tensordict -from torch import Tensor -from torch.distributed.algorithms.join import Join -from torch.distributions import Categorical -from torch.optim import Adam -from torch.utils.data import BatchSampler, RandomSampler -from torchmetrics import MeanMetric - -from sheeprl.algos.ppo.args import PPOArgs -from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import test -from sheeprl.data import ReplayBuffer -from sheeprl.models.models import MLP, NatureCNN -from sheeprl.utils.imports import _IS_ATARI_AVAILABLE, _IS_ATARI_ROMS_AVAILABLE -from sheeprl.utils.metric import MetricAggregator -from sheeprl.utils.parser import HfArgumentParser -from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay - -if not _IS_ATARI_AVAILABLE: - raise ModuleNotFoundError(str(_IS_ATARI_AVAILABLE)) - -if not _IS_ATARI_ROMS_AVAILABLE: - raise ModuleNotFoundError(str(_IS_ATARI_ROMS_AVAILABLE)) - - -def make_env(env_id, seed, idx, capture_video, run_name, prefix: str = "", vector_env_idx: int = 0): - def thunk(): - env = gym.make(env_id, render_mode="rgb_array") - env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if vector_env_idx == 0 and idx == 0: - env = gym.experimental.wrappers.RecordVideoV0( - env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True - ) - env = AtariPreprocessing(env, grayscale_obs=True, grayscale_newaxis=False, scale_obs=True) - env = gym.wrappers.FrameStack(env, 4) - env.action_space.seed(seed) - env.observation_space.seed(seed) - return env - - return thunk - - -# Simple wrapper to let torch.distributed.algorithms.join.Join -# correctly injects fake communication hooks when we are -# working with uneven inputs -class Agent(torch.nn.Module): - def __init__(self, feature_extractor, actor, critic) -> None: - super().__init__() - self.feature_extractor = feature_extractor - self.actor = actor - self.critic = critic - - def forward(self, obs: Tensor) -> Tensor: - feat = self.feature_extractor(obs) - return self.actor(feat), self.critic(feat) - - -@torch.no_grad() -def player(args: PPOArgs, world_collective: TorchCollective, player_trainer_collective: TorchCollective): - run_name = f"{args.env_id}_{args.exp_name}_{args.seed}" - - logger = TensorBoardLogger( - root_dir=os.path.join("logs", "ppo_atari", datetime.today().strftime("%Y-%m-%d_%H-%M-%S")), name=run_name - ) - logger.log_hyperparams(asdict(args)) - - # Initialize Fabric object - fabric = Fabric(loggers=logger) - if not _is_using_cli(): - fabric.launch() - device = fabric.device - fabric.seed_everything(args.seed) - torch.backends.cudnn.deterministic = args.torch_deterministic - - # Environment setup - envs = SyncVectorEnv( - [ - make_env(args.env_id, args.seed + i, 0, args.capture_video, logger.log_dir, "train", vector_env_idx=i) - for i in range(args.num_envs) - ] - ) - assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - - # Create the actor and critic models - features_dim = 512 - feature_extractor = NatureCNN(in_channels=4, features_dim=features_dim).to(device) - actor = MLP( - input_dims=features_dim, - output_dim=envs.single_action_space.n, - hidden_sizes=(), - activation=torch.nn.ReLU, - ).to(device) - critic = MLP( - input_dims=features_dim, - output_dim=1, - hidden_sizes=(), - activation=torch.nn.ReLU, - ).to(device) - - flattened_parameters = torch.empty_like( - torch.nn.utils.convert_parameters.parameters_to_vector( - list(feature_extractor.parameters()) + list(actor.parameters()) + list(critic.parameters()) - ), - device=device, - ) - - # Receive the first weights from the rank-1, a.k.a. the first of the trainers - # In this way we are sure that before the first iteration everyone starts with the same parameters - player_trainer_collective.broadcast(flattened_parameters, src=1) - torch.nn.utils.convert_parameters.vector_to_parameters( - flattened_parameters, - list(feature_extractor.parameters()) + list(actor.parameters()) + list(critic.parameters()), - ) - - # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) - - # Local data - rb = ReplayBuffer(args.rollout_steps, args.num_envs, device=device) - step_data = TensorDict({}, batch_size=[args.num_envs], device=device) - - # Global variables - global_step = 0 - start_time = time.time() - single_global_step = int(args.num_envs * args.rollout_steps) - num_updates = args.total_steps // single_global_step if not args.dry_run else 1 - if single_global_step < world_collective.world_size - 1: - raise RuntimeError( - "The number of trainers ({}) is greater than the available collected data ({}). ".format( - world_collective.world_size - 1, single_global_step - ) - + "Consider to lower the number of trainers at least to the size of available collected data" - ) - chunks_sizes = [ - len(chunk) for chunk in torch.tensor_split(torch.arange(single_global_step), world_collective.world_size - 1) - ] - - # Broadcast num_updates to all the world - update_t = torch.tensor([num_updates], device=device, dtype=torch.float32) - world_collective.broadcast(update_t, src=0) - - with device: - # Get the first environment observation and start the optimization - next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device) - next_done = torch.zeros(args.num_envs, 1).to(device) - - for update in range(1, num_updates + 1): - for _ in range(0, args.rollout_steps): - global_step += args.num_envs - - # Sample an action given the observation received by the environment - features = feature_extractor(next_obs) - actions_logits = actor(features) - dist = Categorical(logits=actions_logits.unsqueeze(-2)) - action = dist.sample() - logprob = dist.log_prob(action) - - # Compute the value of the current observation - value = critic(features) - - # Single environment step - obs, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) - - with device: - obs = torch.tensor(obs) # [N_envs, N_obs] - rewards = torch.tensor(reward).view(args.num_envs, -1) # [N_envs, 1] - done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) # [N_envs, 1] - done = done.view(args.num_envs, -1).float() - - # Update the step data - step_data["dones"] = next_done - step_data["values"] = value - step_data["actions"] = action - step_data["logprobs"] = logprob - step_data["rewards"] = rewards - step_data["observations"] = next_obs - - # Append data to buffer - rb.add(step_data.unsqueeze(0)) - - # Update the observation and done - next_obs = obs - next_done = done - - if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) - - # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - next_features = feature_extractor(next_obs) - next_values = critic(next_features) - returns, advantages = gae( - rb["rewards"], - rb["values"], - rb["dones"], - next_values, - next_done, - args.rollout_steps, - args.gamma, - args.gae_lambda, - ) - - # Add returns and advantages to the buffer - rb["returns"] = returns.float() - rb["advantages"] = advantages.float() - - # Flatten the batch - local_data = rb.buffer.view(-1) - - # Send data to the training agents - # Split data in an even way, when possible - perm = torch.randperm(local_data.shape[0], device=device) - chunks = local_data[perm].split(chunks_sizes) - world_collective.scatter_object_list([None], [None] + chunks, src=0) - - # Gather metrics from the trainers to be plotted - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) - - # Wait the trainers to finish - player_trainer_collective.broadcast(flattened_parameters, src=1) - - # Convert back the parameters - torch.nn.utils.convert_parameters.vector_to_parameters( - flattened_parameters, - list(feature_extractor.parameters()) + list(actor.parameters()) + list(critic.parameters()), - ) - - # Log metrics - aggregator.update("Time/step_per_second", int(global_step / (time.time() - start_time))) - fabric.log_dict(metrics[0], global_step) - fabric.log_dict(aggregator.compute(), global_step) - aggregator.reset() - - # Checkpoint model on rank-0: send it everything - if (args.checkpoint_every > 0 and update % args.checkpoint_every == 0) or args.dry_run: - state = [None] - player_trainer_collective.broadcast_object_list(state, src=1) - ckpt_path = fabric.logger.log_dir + f"/checkpoint/ckpt_{update}_{fabric.global_rank}.ckpt" - fabric.save(ckpt_path, state[0]) - - world_collective.scatter_object_list([None], [None] + [-1] * (world_collective.world_size - 1), src=0) - envs.close() - if fabric.is_global_zero: - test(torch.nn.Sequential(feature_extractor, actor), envs, fabric, args) - - -def trainer( - args: PPOArgs, - world_collective: TorchCollective, - player_trainer_collective: TorchCollective, - optimization_pg: CollectibleGroup, -): - global_rank = world_collective.rank - - # Initialize Fabric - fabric = Fabric(strategy=DDPStrategy(process_group=optimization_pg)) - if not _is_using_cli(): - fabric.launch() - device = fabric.device - fabric.seed_everything(args.seed) - torch.backends.cudnn.deterministic = args.torch_deterministic - - # Environment setup - envs = SyncVectorEnv([make_env(args.env_id, 0, 0, False, None)]) - assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - - # Create the actor and critic models - features_dim = 512 - feature_extractor = NatureCNN(in_channels=4, features_dim=features_dim) - actor = MLP( - input_dims=features_dim, - output_dim=envs.single_action_space.n, - hidden_sizes=(), - activation=torch.nn.ReLU, - ) - critic = MLP(input_dims=features_dim, output_dim=1, hidden_sizes=(), activation=torch.nn.ReLU) - agent = Agent(feature_extractor, actor, critic) - - # Define the agent and the optimizer and setup them with Fabric - optimizer = Adam(agent.parameters(), lr=args.lr, eps=1e-4) - agent = fabric.setup_module(agent) - optimizer = fabric.setup_optimizers(optimizer) - - # Send weights to rank-0, a.k.a. the player - if global_rank == 1: - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), src=1 - ) - - # Receive maximum number of updates from the player - num_updates = torch.zeros(1, device=device) - world_collective.broadcast(num_updates, src=0) - num_updates = num_updates.item() - - # Linear learning rate scheduler - if args.anneal_lr: - from torch.optim.lr_scheduler import PolynomialLR - - scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) - - # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric(process_group=optimization_pg), - "Loss/policy_loss": MeanMetric(process_group=optimization_pg), - "Loss/entropy_loss": MeanMetric(process_group=optimization_pg), - } - ) - - # Start training - update = 0 - initial_ent_coef = copy.deepcopy(args.ent_coef) - initial_clip_coef = copy.deepcopy(args.clip_coef) - while True: - # Wait for data - data = [None] - world_collective.scatter_object_list(data, [None for _ in range(world_collective.world_size)], src=0) - data = data[0] - if not isinstance(data, TensorDictBase) and data == -1: - return - data = make_tensordict(data, device=device) - update += 1 - - # Prepare sampler - indexes = list(range(data.shape[0])) - sampler = BatchSampler(RandomSampler(indexes), batch_size=args.per_rank_batch_size, drop_last=False) - - # The Join context is needed because there can be the possibility - # that some ranks receive less data - with Join([agent._forward_module]): - for _ in range(args.update_epochs): - for batch_idxes in sampler: - batch = data[batch_idxes] - actions_logits, new_values = agent(batch["observations"]) - - dist = Categorical(logits=actions_logits.unsqueeze(-2)) - if args.normalize_advantages: - batch["advantages"] = normalize_tensor(batch["advantages"]) - - # Policy loss - pg_loss = policy_loss( - dist.log_prob(batch["actions"]), - batch["logprobs"], - batch["advantages"], - args.clip_coef, - args.loss_reduction, - ) - - # Value loss - v_loss = value_loss( - new_values, - batch["values"], - batch["returns"], - args.clip_coef, - args.clip_vloss, - args.loss_reduction, - ) - - # Entropy loss - entropy = entropy_loss(dist.entropy(), reduction=args.loss_reduction) - - # Equation (9) in the paper - loss = pg_loss + args.vf_coef * v_loss + args.ent_coef * entropy - optimizer.zero_grad(set_to_none=True) - fabric.backward(loss) - if args.max_grad_norm > 0.0: - fabric.clip_gradients(agent, optimizer, max_norm=args.max_grad_norm) - optimizer.step() - - # Update metrics - aggregator.update("Loss/policy_loss", pg_loss.detach()) - aggregator.update("Loss/value_loss", v_loss.detach()) - aggregator.update("Loss/entropy_loss", entropy.detach()) - - # Send updated weights to the player - metrics = aggregator.compute() - aggregator.reset() - if global_rank == 1: - if args.anneal_lr: - metrics["Info/learning_rate"] = scheduler.get_last_lr()[0] - else: - metrics["Info/learning_rate"] = args.lr - metrics["Info/clip_coef"] = args.clip_coef - metrics["Info/ent_coef"] = args.ent_coef - player_trainer_collective.broadcast_object_list( - [metrics], src=1 - ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector( - list(feature_extractor.parameters()) + list(actor.parameters()) + list(critic.parameters()) - ), - src=1, - ) - - if args.anneal_lr: - scheduler.step() - - if args.anneal_clip_coef: - args.clip_coef = polynomial_decay( - update, initial=initial_clip_coef, final=0.0, max_decay_steps=num_updates, power=1.0 - ) - - if args.anneal_ent_coef: - args.ent_coef = polynomial_decay( - update, initial=initial_ent_coef, final=0.0, max_decay_steps=num_updates, power=1.0 - ) - - # Checkpoint Model - if (args.checkpoint_every > 0 and update % args.checkpoint_every == 0) or args.dry_run: - if global_rank == 1: - state = { - "agent": agent.state_dict(), - "optimizer": optimizer.state_dict(), - "args": asdict(args), - "update_step": update, - "scheduler": scheduler.state_dict() if args.anneal_lr else None, - } - player_trainer_collective.broadcast_object_list([state], src=1) - # Fake save for the other ranks - fabric.barrier() - - -@register_algorithm(decoupled=True) -def main(): - devices = os.environ.get("LT_DEVICES", None) - if devices is None or devices == "1": - raise RuntimeError( - "Please run the script with the number of devices greater than 1: " - "`lightning run model --devices=2 main.py ...`" - ) - - parser = HfArgumentParser(PPOArgs) - args: PPOArgs = parser.parse_args_into_dataclasses()[0] - - if args.share_data: - warnings.warn( - "You have called the script with `--share_data=True`: " - "decoupled scripts splits collected data in an almost-even way between the number of trainers" - ) - - world_collective = TorchCollective() - player_trainer_collective = TorchCollective() - world_collective.setup(backend="nccl" if os.environ.get("LT_ACCELERATOR", None) in ("gpu", "cuda") else "gloo") - - # Create a global group, assigning it to the collective: used by the player to exchange - # collected experiences with the trainers - world_collective.create_group() - global_rank = world_collective.rank - - # Create a group between rank-0 (player) and rank-1 (trainer), assigning it to the collective: - # used by rank-1 to send metrics to be tracked by the rank-0 at the end of a training episode - player_trainer_collective.create_group(ranks=[0, 1]) - - # Create a new group, without assigning it to the collective: in this way the trainers can - # still communicate with the player through the global group, but they can optimize the agent - # between themselves - optimization_pg = world_collective.new_group(ranks=list(range(1, world_collective.world_size))) - if global_rank == 0: - player(args, world_collective, player_trainer_collective) - else: - trainer(args, world_collective, player_trainer_collective, optimization_pg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 8bf962d7..d261babd 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -272,6 +272,7 @@ def reset( "food": float(obs["life_stats"]["food"].item()), }, "location_stats": copy.deepcopy(self._pos), + "biomeid": float(obs["location_stats"]["biome_id"].item()), } def close(self): diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 03df5e5e..b8319b3a 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -534,6 +534,66 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): shutil.rmtree("pytest_" + start_time) +@pytest.mark.timeout(60) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +@pytest.mark.parametrize("checkpoint_buffer", [True, False]) +def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): + task = importlib.import_module("sheeprl.algos.p2e_dv2.p2e_dv2") + root_dir = os.path.join("pytest_" + start_time, "p2e_dv2", os.environ["LT_DEVICES"]) + run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" + ckpt_path = os.path.join(root_dir, run_name) + version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) + ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") + args = standard_args + [ + "--per_rank_batch_size=2", + "--per_rank_sequence_length=2", + f"--buffer_size={int(os.environ['LT_DEVICES'])}", + "--learning_starts=0", + "--gradient_steps=1", + "--horizon=2", + "--env_id=" + env_id, + "--root_dir=" + root_dir, + "--run_name=" + run_name, + "--dense_units=8", + "--cnn_channels_multiplier=2", + "--recurrent_state_size=8", + "--hidden_size=8", + "--cnn_keys=rgb", + "--pretrain_steps=1", + ] + if checkpoint_buffer: + args.append("--checkpoint_buffer") + + with mock.patch.object(sys, "argv", [task.__file__] + args): + for command in task.__all__: + if command == "main": + task.__dict__[command]() + + with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_DEVICES": str(1)}): + keys = { + "world_model", + "actor_task", + "critic_task", + "ensembles", + "world_optimizer", + "actor_task_optimizer", + "critic_task_optimizer", + "ensemble_optimizer", + "expl_decay_steps", + "args", + "global_step", + "batch_size", + "actor_exploration", + "critic_exploration", + "actor_exploration_optimizer", + "critic_exploration_optimizer", + } + if checkpoint_buffer: + keys.add("rb") + check_checkpoint(ckpt_path, keys, checkpoint_buffer) + shutil.rmtree("pytest_" + start_time) + + @pytest.mark.timeout(60) @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False])