diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 7dc0328f7..7189fcf04 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -9,4 +9,14 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment + +from .train_muzero_multitask_segment_noddp import train_muzero_multitask_segment_noddp +from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp + + +from .train_unizero_multitask_serial import train_unizero_multitask_serial +from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp +from .train_unizero_multitask_segment_serial import train_unizero_multitask_segment_serial + +from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval from .utils import * diff --git a/lzero/entry/compute_task_weight.py b/lzero/entry/compute_task_weight.py new file mode 100644 index 000000000..84204a9a2 --- /dev/null +++ b/lzero/entry/compute_task_weight.py @@ -0,0 +1,80 @@ + + + +import numpy as np +import torch + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 归一化,减少目标值的幅度差异。 + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 的逆操作,用于恢复原始值。 + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +def compute_task_weights( + task_rewards: dict, + epsilon: float = 1e-6, + min_weight: float = 0.1, + max_weight: float = 0.5, + temperature: float = 1.0, + use_symlog: bool = True, +) -> dict: + """ + 改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。 + + Args: + task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 + epsilon (float): 避免分母为零的小值。 + min_weight (float): 权重的最小值,用于裁剪。 + max_weight (float): 权重的最大值,用于裁剪。 + temperature (float): 控制权重分布的温度系数。 + use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。 + """ + # Step 1: 矫正奖励值(可选,使用 symlog) + if use_symlog: + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正 + task_rewards = dict(zip(task_rewards.keys(), corrected_rewards)) + + # Step 2: 计算初始权重(反比例关系) + raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()} + + # Step 3: 温度缩放 + scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()} + + # Step 4: 归一化权重 + total_weight = sum(scaled_weights.values()) + normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()} + + # Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内 + clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()} + + final_weights = clipped_weights + return final_weights + +task_rewards_list = [ + {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, + {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, + {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, +] + +for i, task_rewards in enumerate(task_rewards_list, start=1): + print(f"Case {i}: Original Rewards: {task_rewards}") + print("Original Weights:") + print(compute_task_weights(task_rewards, use_symlog=False)) + print("Improved Weights with Symlog:") + print(compute_task_weights(task_rewards, use_symlog=True)) + print() \ No newline at end of file diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py new file mode 100644 index 000000000..b717b710e --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -0,0 +1,579 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.distributed as dist + +import concurrent.futures + +# ========== 超时时间设置 ========== +TIMEOUT = 3600 # 例如,60分钟 + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + 安全地执行评估操作,防止因超时导致训练过程阻塞。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的排名。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: + - stop (Optional[bool]): 评估是否停止的标志。 + - reward (Optional[float]): 评估得到的奖励。 + """ + print(f"=========评估前 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超过 {TIMEOUT} 秒超时。") + return None, None + + print(f"======评估后 Rank {rank}/{world_size}======") + return stop, reward + + +def allocate_batch_size( + cfgs: List, + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的 num_of_collected_episodes 反比分配 batch_size, + 并动态调整 batch_size 限制范围以提高训练的稳定性和效率。 + + Args: + cfgs (List): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的 replay_buffer 实例列表。 + alpha (float): 控制反比程度的超参数 (默认为1.0)。 + clip_scale (int): 动态调整的缩放因子 (默认为1)。 + + Returns: + List[int]: 分配后的 batch_size 列表。 + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object( + all_task_num_of_collected_episodes, + buffer_num_of_collected_episodes + ) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.max_batch_size 的和) + max_batch_size = cfgs[0].policy.max_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = max_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = max_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for multi-task MuZero, adapted from UniZero's multi-task training. + This script aims to enhance the planning capabilities of reinforcement learning agents + by leveraging multi-task learning to address diverse environments. + + Args: + input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): + Configurations for different tasks as a list of tuples containing task ID and configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Predefined model instance. If provided, it will be used instead of creating a new one. + model_path (Optional[str]): + Path to the pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of training iterations. Defaults to 1e10. + max_env_step (Optional[int]): + Maximum number of environment interaction steps. Defaults to 1e10. + + Returns: + Policy: + The trained policy instance. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任何任务,继续运行但无任务处理。") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + return + + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # 使用第一个任务的配置来创建共享的 policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + # 设置每个任务的随机种子和任务编号 + for config in tasks_for_this_rank: + config[1][0].policy.task_num = len(tasks_for_this_rank) + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config( + cfg, + seed=seed, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + # 创建共享的 policy + policy = create_policy( + cfg.policy, + model=model, + enable_field=['learn', 'collect', 'eval'] + ) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'开始加载模型来自 {model_path}...') + policy.learn_mode.load_state_dict( + torch.load(model_path, map_location=cfg.policy.device) + ) + logging.info(f'完成加载模型来自 {model_path}.') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config( + cfg, + seed=seed + task_id, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + torch.cuda.empty_cache() + + if cfg.policy.allocated_batch_sizes: + # TODO========== + # 线性变化的 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size( + cfgs, + game_buffers, + alpha=1.0, + clip_scale=clip_scale + ) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + cfg.policy.batch_size = allocated_batch_sizes[idx] + policy._cfg.batch_size[idx] = allocated_batch_sizes[idx] + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + + log_buffer_memory_usage( + learner.train_iter, + replay_buffer, + tb_logger, + cfg.policy.task_id + ) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval( + evaluator, + learner, + collector, + rank, + world_size + ) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估期间遇到问题。继续训练中...") + else: + print(f"评估成功: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} 收集 task_id: {cfg.policy.task_id}...') + + # 收集数据 + new_data = collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if ( + train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.max_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练前的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(cfgs, collectors, game_buffers) + ): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if ( + i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Replay buffer 中的数据不足以采样一个 mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate( + zip(cfgs, game_buffers) + ): + # 更新任务特定的 replay buffer 的优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 运行均值的平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # 如果不存在,则初始化运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # 更新运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = ( + current_priorities - running_mean_priority + ) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回 replay buffer + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 如果设置了 print_task_priority_logs 标志,则记录统计信息 + if cfg.policy.print_task_priority_logs: + print( + f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}" + ) + + train_epoch += 1 + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # local_envsteps 不再需要填充 + local_envsteps = [collector.envstep for collector in collectors] + + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) + for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any( + torch.stack(all_train_iters) >= max_train_iter + ) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 满足终止条件') + dist.barrier() # 确保所有进程同步 + break + else: + pass + + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_muzero_multitask_segment_noddp.py b/lzero/entry/train_muzero_multitask_segment_noddp.py new file mode 100644 index 000000000..bbeecb227 --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_noddp.py @@ -0,0 +1,270 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from lzero.mcts import MuZeroGameBuffer as GameBuffer # 根据不同策略选择合适的 GameBuffer +from .utils import random_collect + +from ding.utils import EasyTimer +timer = EasyTimer() +from line_profiler import line_profiler + + +def train_muzero_multitask_segment_noddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + 多任务训练入口,基于 MuZero 的多任务版本,支持多任务环境的训练。 + 参考论文 UniZero: Generalized and Efficient Planning with Scalable Latent World Models。 + Arguments: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): 不同任务的配置列表。 + - seed (int): 随机种子。 + - model (Optional[torch.nn.Module]): torch.nn.Module 的实例。 + - model_path (Optional[str]): 预训练模型路径,指向预训练模型的 ckpt 文件。 + - max_train_iter (Optional[int]): 最大训练迭代次数。 + - max_env_step (Optional[int]): 最大环境交互步数。 + Returns: + - policy (Policy): 收敛后的策略。 + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['muzero_multitask'], "train_muzero entry now only supports 'muzero'" + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy for all tasks + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create SummaryWriter for TensorBoard logging + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + # Create shared learner for all tasks + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # TODO task_id = 0: + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 初始化多任务配置 + for task_id, input_cfg in input_cfg_list: + + if task_id > 0: + # Get the configuration for each task + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # ===== NOTE: Create different game buffer, collector, evaluator for each task ==== + # TODO: share replay buffer for all tasks + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + torch.cuda.empty_cache() + + # 遍历每个任务进行数据收集和评估 + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 评估策略性能 + if learner.train_iter ==0 or evaluator.should_eval(learner.train_iter): + logging.info(f'========== 评估任务 {task_id} ==========') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # 收集数据 + logging.info(f'========== 收集任务 {task_id} 数据 ==========') + # collector.reset() + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 确定每次收集后的更新次数 + if update_per_collect is None: + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # 更新回放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 定期重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + # 检查是否有足够的数据进行训练 + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + if not not_enough_data: + # 进行训练 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + + + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(task_id) # 添加 task_id + train_data_multi_task.append(train_data) + else: + logging.warning( + f'回放缓冲区数据不足以采样 mini-batch: ' + f'batch_size: {batch_size}, 回放缓冲区: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + alpha = 0.1 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + # 清除位置嵌入缓存 + train_epoch += 1 + + # 检查是否达到训练结束条件 + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + # 调用学习器的 after_run 钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cd7ff7605..c1687749b 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -119,6 +119,10 @@ def train_unizero( batch_size = policy._cfg.batch_size + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + # import sys; sys.exit(0) + while True: # Log buffer memory usage log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py new file mode 100644 index 000000000..3886efe0f --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -0,0 +1,711 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F + +import torch.distributed as dist + +import concurrent.futures + + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely执行评估任务,避免超时。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的rank。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交评估任务 + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 + + Args: + cfgs (List[dict]): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 + alpha (float, optional): 控制反比程度的超参数。默认为1.0。 + clip_scale (int, optional): 动态调整的clip比例。默认为1。 + + Returns: + List[int]: 分配后的batch_size列表。 + """ + # 提取每个任务的 collected episodes 数量 + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 collected episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 collected episodes 合并为一个大列表 + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + +import numpy as np + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 归一化,减少目标值的幅度差异。 + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 的逆操作,用于恢复原始值。 + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + +# 全局最大值和最小值(用于 "run-max-min") +GLOBAL_MAX = -float('inf') +GLOBAL_MIN = float('inf') + +def compute_task_weights( + task_rewards: dict, + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, # 是否使用 Softmax + reverse: bool = False, # 正比 (False) 或反比 (True) + clip_min: float = 1e-2, # 权重的最小值 + clip_max: float = 1.0, # 权重的最大值 +) -> dict: + """ + 改进后的任务权重计算函数,支持多种标准化方式、Softmax 和正反比权重计算,并增加权重范围裁剪功能。 + + Args: + task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励或损失。 + option (str): 标准化方式,可选值为 "symlog", "max-min", "run-max-min", "rank", "none"。 + epsilon (float): 避免分母为零的小值。 + temperature (float): 控制权重分布的温度系数。 + use_softmax (bool): 是否使用 Softmax 进行权重分配。 + reverse (bool): 若为 True,权重与值反比;若为 False,权重与值正比。 + clip_min (float): 权重的最小值,用于裁剪。 + clip_max (float): 权重的最大值,用于裁剪。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化后的权重。 + """ + import torch + import torch.nn.functional as F + + global GLOBAL_MAX, GLOBAL_MIN + + # 如果输入为空字典,直接返回空结果 + if not task_rewards: + return {} + + # Step 1: 对 task_rewards 的值构造张量 + task_ids = list(task_rewards.keys()) + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + + if option == "symlog": + # 使用 symlog 标准化 + scaled_rewards = symlog(rewards_tensor) + elif option == "max-min": + # 使用最大最小值归一化 + max_reward = rewards_tensor.max().item() + min_reward = rewards_tensor.min().item() + scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + # 使用全局最大最小值归一化 + GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item()) + scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + # 使用 rank 标准化 + # Rank 是基于值大小的排名,1 表示最小值,越大排名越高 + sorted_indices = torch.argsort(rewards_tensor) + scaled_rewards = torch.empty_like(rewards_tensor) + rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) # 1 到 N + scaled_rewards[sorted_indices] = rank_values + elif option == "none": + # 不进行标准化 + scaled_rewards = rewards_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + # Step 2: 根据 reverse 确定权重是正比还是反比 + if not reverse: + # 正比:权重与值正相关 + raw_weights = scaled_rewards + else: + # 反比:权重与值负相关 + # 避免 scaled_rewards 为负数或零 + scaled_rewards = torch.clamp(scaled_rewards, min=epsilon) + raw_weights = 1.0 / scaled_rewards + + # Step 3: 根据是否使用 Softmax 进行权重计算 + if use_softmax: + # 使用 Softmax 进行权重分配 + beta = 1.0 / max(temperature, epsilon) # 确保 temperature 不为零 + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + # 不使用 Softmax,直接计算权重 + # 温度缩放 + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # 确保温度不为零 + + # 归一化权重 + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / total_weight + + # 转换为字典 + weights = dict(zip(task_ids, normalized_weights.numpy())) + + # Step 4: Clip 权重范围 + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 + 详细信息请参阅 https://arxiv.org/abs/2406.10667。 + + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 + - seed (:obj:`int`): 随机种子。 + - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 + - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 + - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + + Returns: + - policy (:obj:`Policy`): 收敛的策略。 + """ + # 初始化温度调度器 + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) # 训练步数达到 10k 时,温度降至 1.0 + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' # 或 'exponential' + ) + + # 获取当前进程的rank和总进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任务,继续执行。") + # 初始化空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # 使用第一个任务的配置创建共享的policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 加载预训练模型(如果提供) + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 创建TensorBoard日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # 处理当前进程分配到的每个任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建不同的game buffer、collector和evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # 调用learner的before_run钩子 + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_complexity_weight = cfg.policy.task_complexity_weight + use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight + task_exploitation_weight = None + + # 创建任务奖励字典 + task_rewards = {} # {task_id: reward} + + while True: + # 动态调整batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # 记录缓冲区内存使用情况 + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的epsilon值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 判断是否需要进行评估 + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + # if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # =========TODO========= + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # 执行安全评估 + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + task_rewards[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值 + else: + # 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值 + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"提取评估奖励时发生错误: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值 + + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新重放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 获取当前温度 + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + # collector._policy._task_weight_temperature = current_temperature_task_weight + # policy.collect_mode.get_attribute('task_weight_temperature') = current_temperature_task_weight + + # 计算任务权重 + try: + # 汇聚任务奖励 + dist.barrier() + if task_complexity_weight: + all_task_rewards = [None for _ in range(world_size)] + dist.all_gather_object(all_task_rewards, task_rewards) + # 合并任务奖励 + merged_task_rewards = {} + for rewards in all_task_rewards: + if rewards: + merged_task_rewards.update(rewards) + # 计算全局任务权重 + task_weights = compute_task_weights(merged_task_rewards, temperature=current_temperature_task_weight) + # 同步任务权重 + dist.broadcast_object_list([task_weights], src=0) + print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}") + else: + task_weights = None + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + break + + + # 学习策略 + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # learn_kwargs = {'task_exploitation_weight':task_exploitation_weight, 'task_weights':task_weights, } + learn_kwargs = {'task_weights':task_exploitation_weight} + + # 在训练时,DDP会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # 判断是否需要计算task_exploitation_weight + if i == 0: + # 计算任务权重 + try: + dist.barrier() # 等待所有进程同步 + if use_task_exploitation_weight: + # 收集所有任务的 obs_loss + all_obs_loss = [None for _ in range(world_size)] + # 构建当前进程的任务 obs_loss 数据 + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + # 汇聚所有进程的 obs_loss 数据 + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + # 合并所有进程的 obs_loss 数据 + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + # 计算全局任务权重 + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + # temperature=current_temperature_task_weight # TODO + temperature=1, + ) + # 广播任务权重到所有进程 + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。") + task_exploitation_weight = None + else: + task_exploitation_weight = None + # 更新训练参数,使其包含计算后的任务权重 + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + raise e # 保留异常抛出,便于外部捕获和分析 + + + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # 更新任务特定的重放缓冲区优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回重放缓冲区 + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 记录优先级统计信息 + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 检查是否需要终止训练 + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + # 调用learner的after_run钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_ddp_bkp20250106.py b/lzero/entry/train_unizero_multitask_segment_ddp_bkp20250106.py new file mode 100644 index 000000000..15c199bdc --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp_bkp20250106.py @@ -0,0 +1,492 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.distributed as dist + +import concurrent.futures + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely执行评估任务,避免超时。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的rank。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交评估任务 + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 + + Args: + cfgs (List[dict]): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 + alpha (float, optional): 控制反比程度的超参数。默认为1.0。 + clip_scale (int, optional): 动态调整的clip比例。默认为1。 + + Returns: + List[int]: 分配后的batch_size列表。 + """ + # 提取每个任务的 collected episodes 数量 + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 collected episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 collected episodes 合并为一个大列表 + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 + 详细信息请参阅 https://arxiv.org/abs/2406.10667。 + + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 + - seed (:obj:`int`): 随机种子。 + - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 + - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 + - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + + Returns: + - policy (:obj:`Policy`): 收敛的策略。 + """ + # 获取当前进程的rank和总进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任务,继续执行。") + # 初始化空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # 使用第一个任务的配置创建共享的policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 加载预训练模型(如果提供) + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 创建TensorBoard日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # 处理当前进程分配到的每个任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建不同的game buffer、collector和evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # 调用learner的before_run钩子 + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 动态调整batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # 记录缓冲区内存使用情况 + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的epsilon值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 判断是否需要进行评估 + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + # if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # =========TODO========= + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # 执行安全评估 + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + else: + print(f"评估成功: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新重放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有rank的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练前的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 学习策略 + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # 更新任务特定的重放缓冲区优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回重放缓冲区 + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 记录优先级统计信息 + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 检查是否需要终止训练 + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + # 调用learner的after_run钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py new file mode 100644 index 000000000..f98e4c41b --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -0,0 +1,480 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + +import torch.distributed as dist +import concurrent.futures + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely evaluates the policy using the evaluator with a timeout. + + Args: + evaluator (Evaluator): The evaluator instance. + learner (BaseLearner): The learner instance. + collector (Collector): The collector instance. + rank (int): The rank of the current process. + world_size (int): Total number of processes. + + Returns: + Tuple[Optional[bool], Optional[float]]: A tuple containing the stop flag and reward. + """ + try: + print(f"=========before eval Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.") + return None, None + + print(f"======after eval Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[Any], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Allocates batch sizes inversely proportional to the number of collected episodes for each task. + Dynamically adjusts batch size within a specified range to enhance training stability and efficiency. + + Args: + cfgs (List[Any]): List of configurations for each task. + game_buffers (List[GameBuffer]): List of replay buffer instances for each task. + alpha (float): The hyperparameter controlling the degree of inverse proportionality. Default is 1.0. + clip_scale (int): The scaling factor to clip the batch size. Default is 1. + + Returns: + List[int]: A list of allocated batch sizes for each task. + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = get_world_size() + rank = get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_unizero_multitask_segment_eval( + input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models". + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + + Args: + input_cfg_list (List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]]): + List of configurations for different tasks. Each item is a tuple containing a task ID and a tuple of configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Instance of torch.nn.Module representing the model. If None, a new model will be created. + model_path (Optional[str]): + Path to a pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of policy update iterations during training. Default is a very large number. + max_env_step (Optional[int]): + Maximum number of environment interaction steps to collect. Default is a very large number. + + Returns: + 'Policy': + The converged policy after training. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: No tasks assigned, continuing without tasks.") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs: List[Any] = [] + game_buffers: List[GameBuffer] = [] + collectors: List[Collector] = [] + evaluators: List[Evaluator] = [] + + # 使用本rank的第一个任务的配置来创建共享的 policy + task_id, (cfg, create_cfg) = tasks_for_this_rank[0] + + # 设置每个任务的 task_num 以用于 learner_log + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的 policy 类型是支持的 + assert create_cfg.policy.type in [ + 'unizero_multitask'], "train_unizero entry now only supports 'unizero_multitask'" + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的 policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, (cfg, create_cfg)) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 预先计算位置嵌入矩阵(如果需要) + # policy._collect_model.world_model.precompute_pos_emb_diff_kv() + # policy._target_model.world_model.precompute_pos_emb_diff_kv() + + if cfg.policy.allocated_batch_sizes: + # 动态调整 clip_scale 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for cfg, _collector, _evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluates task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation. Continuing training...") + else: + print(f"Evaluation successful: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} collects task_id: {cfg.policy.task_id}...') + + # NOTE: 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if (train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier before training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for cfg, collector, replay_buffer in zip(cfgs, collectors, game_buffers): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if (i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier during training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # TODO: 可选:终止进程 + import sys + sys.exit(0) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier after training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # 收集本地的 envsteps + local_envsteps = [collector.envstep for collector in collectors] + + # 收集所有进程的 envsteps + total_envsteps: List[Optional[int]] = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起进行检查 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: Termination condition met') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: Termination check failed with error {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_serial.py b/lzero/entry/train_unizero_multitask_segment_serial.py new file mode 100644 index 000000000..adf9fd8f8 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_serial.py @@ -0,0 +1,299 @@ +import logging +import os +from functools import partial +from typing import List, Optional, Tuple + +import numpy as np +import torch +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + + +timer = EasyTimer() + + +def train_unizero_multitask_segment_serial( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + 概述: + UniZero的训练入口,基于论文《UniZero: Generalized and Efficient Planning with Scalable Latent World Models》提出。 + UniZero旨在通过解决MuZero风格算法在需要捕捉长期依赖的环境中的局限性,增强强化学习代理的规划能力。 + 详细内容可参考 https://arxiv.org/abs/2406.10667。 + + 参数: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): 不同任务的配置列表。 + - seed (int): 随机种子。 + - model (Optional[torch.nn.Module]): torch.nn.Module的实例。 + - model_path (Optional[str]): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (Optional[int]): 训练中的最大策略更新迭代次数。 + - max_env_step (Optional[int]): 收集环境交互步骤的最大数量。 + + 返回: + - policy (Policy): 收敛的策略对象。 + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # 获取第一个任务的配置 + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 为所有任务创建共享策略 + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 如果指定了预训练模型路径,加载预训练模型 + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 为TensorBoard日志创建SummaryWriter + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + # 为所有任务创建共享学习器 + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 遍历所有任务的配置 + for task_id, input_cfg in input_cfg_list: + if task_id > 0: + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + # 更新收集和评估模式的配置 + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境管理器 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建各任务专属的游戏缓存、收集器和评估器 + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 预计算收集和评估时的位置嵌入矩阵(非训练阶段) + # policy._collect_model.world_model.precompute_pos_emb_diff_kv() + # policy._target_model.world_model.precompute_pos_emb_diff_kv() + + # 为每个任务收集数据 + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认epsilon值 + } + + # 如果启用了epsilon-greedy探索,计算当前epsilon值 + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 评估阶段 + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'开始评估任务 id: {task_id}...') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + print('=' * 20) + print(f'开始收集任务 id: {task_id}...') + + # 在每次收集前重置初始数据,对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True, task_id=task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 确定每次收集后的更新次数 + if update_per_collect is None: + # 如果未设置update_per_collect,则根据收集的转换数量和重放比例计算 + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0] + ) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # 更新重放缓存 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 定期重新分析重放缓存 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 一个训练epoch内重新分析buffer的次数 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每隔一定数量的训练epoch重新分析buffer + if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + # 每次重新分析处理reanalyze_batch_size个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'重放缓存重新分析次数: {buffer_reanalyze_count}') + logging.info(f'重放缓存重新分析时间: {timer.value}') + + # 检查是否有重放缓存数据不足 + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + # 从收集的数据中学习策略 + if not not_enough_data: + # 学习器将在一次迭代中进行update_per_collect次训练 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练epoch内按照频率重新分析buffer + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'重放缓存重新分析次数: {buffer_reanalyze_count}') + logging.info(f'重放缓存重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 将task_id附加到训练数据 + train_data.append(task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓存中的数据不足以采样一个小批量: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + # 如果使用优先级重放,更新各任务的优先级 + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + # 更新任务特定重放缓存的优先级 + replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}']) + + # 获取当前任务的更新后优先级 + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + # 计算优先级的均值和标准差 + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + # 使用指数移动平均计算运行中的均值 + alpha = 0.1 # 平滑因子,可根据需要调整 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # 初始化运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # 更新运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 计算归一化优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 记录统计信息 + if cfg.policy.print_task_priority_logs: + print( + f"任务 {task_id} - 优先级均值: {mean_priority:.8f}, " + f"运行均值优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}" + ) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 检查是否达到训练或环境步数的最大限制 + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_serial.py b/lzero/entry/train_unizero_multitask_serial.py new file mode 100644 index 000000000..0a5aaae25 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_serial.py @@ -0,0 +1,256 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator +from lzero.mcts import UniZeroGameBuffer as GameBuffer + +from line_profiler import line_profiler + +#@profile +def train_unizero_multitask_serial( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): List of configurations for different tasks. + - seed (int): Random seed. + - model (Optional[torch.nn.Module]): Instance of torch.nn.Module. + - model_path (Optional[str]): The pretrained model path, which should point to the ckpt file of the pretrained model. + - max_train_iter (Optional[int]): Maximum policy update iterations in training. + - max_env_step (Optional[int]): Maximum collected environment interaction steps. + Returns: + - policy (Policy): Converged policy. + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry now only supports 'unizero'" + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy for all tasks + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create SummaryWriter for TensorBoard logging + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + # Create shared learner for all tasks + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # TODO task_id = 0: + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + for task_id, input_cfg in input_cfg_list: + if task_id > 0: + # Get the configuration for each task + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # ===== NOTE: Create different game buffer, collector, evaluator for each task ==== + # TODO: share replay buffer for all tasks + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + update_per_collect = cfg.policy.update_per_collect + + while True: + # Precompute positional embedding matrices for collect/eval (not training) + policy._collect_model.world_model.precompute_pos_emb_diff_kv() + policy._target_model.world_model.precompute_pos_emb_diff_kv() + + # Collect data for each task + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'evaluate task_id: {task_id}...') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + print('=' * 20) + print(f'collect task_id: {task_id}...') + + # Reset initial data before each collection + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + if update_per_collect is None: + collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + # Learn policy from collected data. + if not not_enough_data: + # Learner will train ``update_per_collect`` times in one iteration. + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + policy.recompute_pos_emb_diff_and_clear_cache() + # Append task_id to train_data + train_data.append(task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + # Update the priority for the task-specific replay buffer. + replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}']) + + # Retrieve the updated priorities for the current task. + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + # Calculate statistics: mean, running mean, standard deviation for the priorities. + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + # Using exponential moving average for running mean (alpha is the smoothing factor). + alpha = 0.1 # You can adjust this smoothing factor as needed. + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # Initialize running mean if it does not exist. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # Update running mean. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # Calculate the normalized priority using the running mean. + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # Store the normalized priorities back to the replay buffer (if needed). + # replay_buffer.update_priority(train_data_multi_task[task_id], normalized_priorities) + + # Log the statistics if the print_task_priority_logs flag is set. + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean Priority: {mean_priority:.8f}, " + f"Running Mean Priority: {running_mean_priority:.8f}, " + f"Standard Deviation: {std_priority:.8f}") + + + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index d2c23f930..22bb16618 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -6,6 +6,49 @@ from tensorboardX import SummaryWriter from typing import Optional, Callable import torch +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +class TemperatureScheduler: + def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): + """ + 温度调度器,用于根据当前训练步数逐渐调整温度。 + + Args: + initial_temp (float): 初始温度值。 + final_temp (float): 最终温度值。 + threshold_steps (int): 温度衰减到最终温度所需的训练步数。 + mode (str): 衰减方式,可选 'linear' 或 'exponential'。默认 'linear'。 + """ + self.initial_temp = initial_temp + self.final_temp = final_temp + self.threshold_steps = threshold_steps + assert mode in ['linear', 'exponential'], "Mode must be 'linear' or 'exponential'." + self.mode = mode + + def get_temperature(self, current_step: int) -> float: + """ + 根据当前步数计算温度。 + + Args: + current_step (int): 当前的训练步数。 + + Returns: + float: 当前温度值。 + """ + if current_step >= self.threshold_steps: + return self.final_temp + progress = current_step / self.threshold_steps + if self.mode == 'linear': + temp = self.initial_temp - (self.initial_temp - self.final_temp) * progress + elif self.mode == 'exponential': + # 指数衰减,确保温度逐渐接近 final_temp + decay_rate = np.log(self.final_temp / self.initial_temp) / self.threshold_steps + temp = self.initial_temp * np.exp(decay_rate * current_step) + temp = max(temp, self.final_temp) + return temp def initialize_zeros_batch(observation_shape, batch_size, device): @@ -63,7 +106,7 @@ def random_collect( collector.reset_policy(policy.collect_mode) -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: +def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter, task_id=0) -> None: """ Overview: Log the memory usage of the buffer and the current process to TensorBoard. @@ -74,9 +117,9 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa """ # "writer is None" means we are in a slave process in the DDP setup. if writer is not None: - writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + writer.add_scalar(f'Buffer/num_of_all_collected_episodes_{task_id}', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'Buffer/num_of_game_segments_{task_id}', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'Buffer/num_of_transitions_{task_id}', len(buffer.game_segment_game_pos_look_up), train_iter) game_segment_buffer = buffer.game_segment_buffer @@ -87,7 +130,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/game_segment_buffer_{task_id}', buffer_memory_usage_mb, train_iter) # Get the amount of memory currently used by the process (in bytes). process = psutil.Process(os.getpid()) @@ -97,7 +140,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa process_memory_usage_mb = process_memory_usage / (1024 * 1024) # Record the memory usage of the process to TensorBoard. - writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/process_{task_id}', process_memory_usage_mb, train_iter) def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 2678066e9..3383dd2ef 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -151,14 +153,10 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # Indices exceeding `game_segment_length` are padded with the next segment and are not updated # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within # [0, game_segment_length - num_unroll_steps] to avoid padded data. - # TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency. - # if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() - # NOTE: Sample the init position from the whole segment, but not from the padded part - if pos_in_game_segment >= self._cfg.game_segment_length: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) @@ -166,6 +164,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: make_time = [time.time() for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: @@ -585,7 +589,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -597,8 +602,15 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) + if isinstance(self._cfg.batch_size, int): + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) + else: + try: + if total_transition >= self._cfg.batch_size[0]: + self._remove(index + 1) + except Exception as e: + print(e) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 2ba8180de..d062bd865 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -61,6 +61,16 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + def reset_runtime_metrics(self): """ Overview: @@ -146,7 +156,7 @@ def sample( self.compute_target_re_time += self._compute_target_timer.value batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -448,7 +458,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + if not model.training: # if not in training, obtain the scalars of the value/reward @@ -573,7 +587,10 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) - m_output = model.initial_inference(m_obs) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) if not model.training: # if not in training, obtain the scalars of the value/reward @@ -591,7 +608,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -603,7 +620,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model with self._origin_search_timer: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree @@ -613,7 +634,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -629,7 +654,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -638,7 +663,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: # Update the data in game segment: @@ -655,7 +680,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -684,7 +709,7 @@ def _compute_target_policy_non_reanalyzed( - game_segment_lens - action_mask_segment - to_play_segment - - policy_shape: self._cfg.model.action_space_size + - policy_shape: self.action_space_size Returns: - batch_target_policies_non_re """ @@ -707,7 +732,7 @@ def _compute_target_policy_non_reanalyzed( ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -757,6 +782,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index abb7c92a8..2f03e2a3a 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -48,9 +48,19 @@ def __init__(self, cfg: dict): self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + + def reanalyze_buffer( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -112,21 +122,22 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -135,7 +146,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -272,18 +283,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -292,7 +303,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -316,7 +327,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: if self._cfg.model.continuous_action_space: # pad random action bootstrap_action_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp)) ] bootstrap_action_list.append(bootstrap_action_tmp) @@ -429,7 +440,7 @@ def _prepare_policy_reanalyzed_context( ] return policy_re_context - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, batch_action) -> np.ndarray: """ Overview: prepare policy targets from the reanalyzed context of policies @@ -474,9 +485,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # calculate the target value - # action_batch.shape (32, 10) + # batch_action.shape (32, 10) # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 - m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= if not model.training: @@ -502,18 +519,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # cpp mcts_tree # roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, + transition_batch_size, legal_actions, self.action_space_size, self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space ) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -592,8 +615,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: return batch_target_policies_re, root_sampled_actions - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action) -> Tuple[ Any, Any]: """ Overview: @@ -617,7 +639,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ @@ -635,18 +657,24 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the target value # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(batch_obs, action_batch) + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action) # ====================================================================== - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + # print(f'model.training:{model.training}') + # model.training = False + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + network_output.append(m_output) if self._cfg.use_root_value: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index fcd47d851..3ad7c9ae7 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -48,6 +49,17 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -77,7 +89,7 @@ def sample( # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -94,6 +106,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -136,6 +149,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # TODO: original buffer mask + # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] + # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) @@ -391,11 +408,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -411,7 +428,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= if not model.training: @@ -430,7 +452,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -438,13 +460,19 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -455,7 +483,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -464,7 +492,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: if self._cfg.env_type == 'not_board_games': @@ -474,7 +502,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -519,7 +547,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action) + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action) + # ====================================================================== if not model.training: diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index a509f1360..004e664ed 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -31,7 +31,7 @@ class GameSegment: - store_search_stats """ - def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: + def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. @@ -45,19 +45,27 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor - self.action_space_size = config.model.action_space_size + # self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder - if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape - elif len(config.model.observation_shape) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if task_id is None: + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + else: + if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape_list[task_id] + elif len(config.model.observation_shape_list[task_id]) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp index 7c5d11dd2..83f50e2da 100644 --- a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp @@ -22,6 +22,7 @@ #include #include + #ifdef _WIN32 #include "..\..\common_lib\utils.cpp" #else diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index dd58e8682..50d4b0927 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -15,6 +15,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -71,10 +72,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -144,7 +145,12 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # for UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -225,10 +231,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -298,6 +304,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) @@ -495,7 +508,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index 5f6f74740..2abd75b83 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -82,7 +82,7 @@ def roots( # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -156,8 +156,12 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # for Sampled UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, - latent_state_index_in_search_path) + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) diff --git a/lzero/model/common.py b/lzero/model/common.py index 22afa95fe..7d8005fc1 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -185,6 +185,8 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, super().__init__() assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + assert num_resblocks == 1, "num_resblocks must be 1 in DownSample" + self.observation_shape = observation_shape self.conv1 = nn.Conv2d( observation_shape[0], @@ -231,7 +233,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, [ ResBlock( in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(1) + ) for _ in range(num_resblocks) ] ) self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) @@ -318,6 +320,7 @@ def __init__( num_channels, activation=activation, norm_type=norm_type, + num_resblocks=1, ) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) @@ -343,10 +346,10 @@ def __init__( self.embedding_dim = embedding_dim if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) self.sim_norm = SimNorm(simnorm_dim=group_size) @@ -817,9 +820,9 @@ def __init__( self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) if observation_shape[1] == 96: - latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + latent_shape = (observation_shape[1] // 16, observation_shape[2] // 16) elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + latent_shape = (observation_shape[1] // 8, observation_shape[2] // 8) if norm_type == 'BN': self.norm_value = nn.BatchNorm2d(value_head_channels) diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py new file mode 100644 index 000000000..6d7326152 --- /dev/null +++ b/lzero/model/muzero_model_multitask.py @@ -0,0 +1,389 @@ +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroMTModel') +class MuZeroMTModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + analysis_sim_norm: bool = False, + task_num: int = 1, # 任务数量 + *args, + **kwargs + ): + """ + 多任务MuZero模型的定义,继承自MuZeroModel。 + 增加了多任务相关的处理,如任务数量和动作空间大小调整。 + """ + super(MuZeroMTModel, self).__init__() + + print(f'==========MuZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}, task_num:{task_num}===========') + + if discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.task_num = task_num + self.action_space_size = 18 # 假设每个任务的动作空间相同 + + self.categorical_distribution = categorical_distribution + + self.discrete_action_encoding_type = 'one_hot' + + # 共享表示网络 + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + activation=activation, + norm_type=norm_type + ) + + # ====== for analysis ====== + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + # 共享动态网络 + self.dynamics_network = DynamicsNetwork( + observation_shape, + action_encoding_dim=self.action_encoding_dim, + num_res_blocks=num_res_blocks, + num_channels=num_channels + self.action_encoding_dim, + reward_head_channels=reward_head_channels, + fc_reward_layers=fc_reward_layers, + output_support_size=reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * self._get_latent_size(observation_shape, downsample), + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + # 独立的预测网络,每个任务一个 + # 计算flatten_output_size + value_flatten_size = int(value_head_channels * self._get_latent_size(observation_shape, downsample)) + policy_flatten_size = int(policy_head_channels * self._get_latent_size(observation_shape, downsample)) + + self.prediction_networks = nn.ModuleList([ + PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + fc_value_layers, + fc_policy_layers, + self.value_support_size, + flatten_output_size_for_value_head=value_flatten_size, + flatten_output_size_for_policy_head=policy_flatten_size, + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) for _ in range(task_num) + ]) + + # 共享投影和预测头(如果使用自监督学习损失) + if self_supervised_learning_loss: + self.projection_network = nn.Sequential( + nn.Linear(num_channels * self._get_latent_size(observation_shape, downsample), proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_out), + nn.BatchNorm1d(proj_out) + ) + + self.prediction_head = nn.Sequential( + nn.Linear(proj_out, pred_hid), + nn.BatchNorm1d(pred_hid), + activation, + nn.Linear(pred_hid, pred_out), + ) + + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + + def _get_latent_size(self, observation_shape: SequenceType, downsample: bool) -> int: + """ + 辅助函数,根据观测形状和下采样选项计算潜在状态的大小。 + """ + if downsample: + return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) + else: + return observation_shape[-2] * observation_shape[-1] + + def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务初始推理,基于任务ID选择对应的预测网络。 + """ + batch_size = obs.size(0) + latent_state = self.representation_network(obs) + if self.state_norm: + latent_state = renormalize(latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(latent_state) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务递归推理,根据任务ID选择对应的预测网络。 + """ + next_latent_state, reward = self._dynamics(latent_state, action) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(next_latent_state) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + and ``reward``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action. + # The final action_encoding shape is (batch_size, action_space_size, latent_state[2], latent_state[3]), e.g. (8, 2, 4, 1). + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + + elif self.discrete_action_encoding_type == 'not_one_hot': + # Stack latent_state with the normalized encoded action. + # The final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). + if len(action.shape) == 2: + # (batch_size, action_dim=1) -> (batch_size, 1, 1, 1) + # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + elif len(action.shape) == 1: + # (batch_size,) -> (batch_size, 1, 1, 1) + # e.g., -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + + action_encoding = action.expand( + latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] + ) / self.action_space_size + + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or + # (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + 多任务投影方法,当前实现为共享投影网络。 + """ + if not self.self_supervised_learning_loss: + raise NotImplementedError("Self-supervised learning loss is not enabled for this model.") + + latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection_network(latent_state) + if with_grad: + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + flatten_output_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + DynamicsNetwork定义,适用于多任务共享。 + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + + self.action_encoding_dim = action_encoding_dim + self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + elif norm_type == 'LN': + if downsample: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) + + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + + self.fc_reward_head = MLP( + self.flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + layer_num=len(fc_reward_layers) + 1, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + DynamicsNetwork的前向传播,预测下一个潜在状态和奖励。 + """ + # 提取状态编码(去除动作编码部分) + state_encoding = state_action_encoding[:, :-self.action_encoding_dim, :, :] + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # 残差连接 + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.norm_reward(x) + x = self.activation(x) + x = x.view(x.shape[0], -1) + + # 使用全连接层预测奖励 + reward = self.fc_reward_head(x) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + return get_reward_mean(self) \ No newline at end of file diff --git a/lzero/model/sampled_unizero_model_multitask.py b/lzero/model/sampled_unizero_model_multitask.py new file mode 100644 index 000000000..be7cd0e7a --- /dev/null +++ b/lzero/model/sampled_unizero_model_multitask.py @@ -0,0 +1,311 @@ +from typing import Optional, List + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, LatentDecoder, \ + FeatureAndGradientHook, SimNorm +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +class RepresentationNetworkMLPMT(nn.Module): + def __init__( + self, + observation_shape_list: List[int], # List of observation shapes for each task + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + group_size: int = 8, + use_shared_projection: bool = False, # 控制是否启用共享投影层 + shared_projection_dim: Optional[int] = None, # 共享投影层的维度 + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ + with Multi-Layer Perceptron (MLP), optionally followed by a shared projection layer. + Arguments: + - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - layer_num (:obj:`int`): The number of layers in the MLP. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. + - group_size (:obj:`int`): The group size used in SimNorm. + - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer, defaults to False. + - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. \ + If None, defaults to `hidden_channels`. + """ + super().__init__() + self.env_num = len(observation_shape_list) + self.use_shared_projection = use_shared_projection + self.hidden_channels = hidden_channels + self.shared_projection_dim = shared_projection_dim or hidden_channels + + # Task-specific representation networks + self.fc_representation = nn.ModuleList([ + MLP( + in_channels=obs_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + for obs_shape in observation_shape_list + ]) + + # Shared projection layer + if self.use_shared_projection: + self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) + # self.projection_norm = nn.LayerNorm(self.shared_projection_dim) # Optional normalization for shared space + self.projection_norm = SimNorm(simnorm_dim=group_size) # Optional normalization for shared space + + # SimNorm for task-specific outputs + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - task_id (:obj:`int`): The ID of the current task. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)` if shared projection is not used, \ + otherwise :math:`(B, shared_projection_dim)`. + """ + # Task-specific representation + x = self.fc_representation[task_id](x) + x = self.sim_norm(x) + + # Shared projection layer (if enabled) + if self.use_shared_projection: + x = self.shared_projection(x) + x = self.projection_norm(x) # Optional normalization + return x + + +# class RepresentationNetworkMLPMT(nn.Module): +# def __init__( +# self, +# observation_shape_list: List[int], # List of observation shapes for each task +# hidden_channels: int = 64, +# layer_num: int = 2, +# activation: nn.Module = nn.GELU(approximate='tanh'), +# norm_type: Optional[str] = 'BN', +# group_size: int = 8, +# ) -> torch.Tensor: +# """ +# Overview: +# Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ +# with Multi-Layer Perceptron (MLP). +# Arguments: +# - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. +# - hidden_channels (:obj:`int`): The channel of output hidden state. +# - layer_num (:obj:`int`): The number of layers in the MLP. +# - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). +# - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. +# - group_size (:obj:`int`): The group size used in SimNorm. +# """ +# super().__init__() +# self.env_num = len(observation_shape_list) +# self.fc_representation = nn.ModuleList([ +# MLP( +# in_channels=obs_shape, +# hidden_channels=hidden_channels, +# out_channels=hidden_channels, +# layer_num=layer_num, +# activation=activation, +# norm_type=norm_type, +# # don't use activation and norm in the last layer of representation network is important for convergence. +# output_activation=False, +# output_norm=False, +# # last_linear_layer_init_zero=True is beneficial for convergence speed. +# last_linear_layer_init_zero=True, +# ) +# for obs_shape in observation_shape_list +# ]) +# self.sim_norm = SimNorm(simnorm_dim=group_size) + +# def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: +# """ +# Shapes: +# - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. +# - task_id (:obj:`int`): The ID of the current task. +# - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. +# """ +# x = self.fc_representation[task_id](x) +# x = self.sim_norm(x) +# return x + + +@MODEL_REGISTRY.register('SampledUniZeroMTModel') +class SampledUniZeroMTModel(nn.Module): + def __init__( + self, + observation_shape_list: List[SequenceType], # List of observation shapes for each task + action_space_size_list: List[int], # List of action space sizes for each task + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'LN', + # world_model_cfgs: List[EasyDict] = None, # List of world model configs for each task + world_model_cfg: List[EasyDict] = None, # List of world model configs for each task + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape_list (:obj:`List[SequenceType]`): List of observation space shapes for each task, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size_list (:obj:`List[int]`): List of action space sizes for each task. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj=`str`): The type of normalization in networks. Defaults to 'LN'. + - world_model_cfgs (:obj=`List[EasyDict]`): The list of world model configurations for each task. + """ + super(SampledUniZeroMTModel, self).__init__() + self.task_num = len(observation_shape_list) + self.activation = activation + self.downsample = downsample + + # Initialize environment-specific networks and models + self.representation_networks = nn.ModuleList() + # self.decoder_networks = nn.ModuleList() + # self.world_models = nn.ModuleList() + + if world_model_cfg.task_embed_option == "concat_task_embed": + obs_act_embed_dim = world_model_cfg.embed_dim - 96 + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + for task_id in range(self.task_num): + # world_model_cfg = world_model_cfgs[task_id] + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLPMT( + observation_shape_list=observation_shape_list, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + norm_type=norm_type, + group_size=world_model_cfg.group_size, + use_shared_projection=world_model_cfg.use_shared_projection, + ) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # TODO: N independent encoder + for task_id in range(1): # TODO: one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape_list[task_id], + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + )) + # TODO: we should change the output_shape to the real observation shape + # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + + # Print model parameters for debugging + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + - task_id (:obj:`int`): The ID of the current task. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. + - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. + - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=0) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model. To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - state_action_history (:obj:`torch.Tensor`): The history of states and actions. + - task_id (:obj:`int`): The ID of the current task. + - simulation_index (:obj=`int`): The index of the current simulation. + - latent_state_index_in_search_path (:obj=`List[int]`): The indices of latent states in the search path. + Returns (MZNetworkOutput): + - value (:obj=`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj=`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj=`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj=`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj=`torch.Tensor`): :math=`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj=`torch.Tensor`): :math=`(B, )`, where B is batch_size. + - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. + - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. + - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + - next_latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..8f2e9f300 --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,238 @@ +from typing import Optional + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +from line_profiler import line_profiler + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + + #@profile + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - world_model_cfg (:obj:`EasyDict`): The configuration of the world model, including the following keys: + - obs_type (:obj:`str`): The type of observation, which can be 'image', 'vector', or 'image_memory'. + - embed_dim (:obj:`int`): The dimension of the embedding. + - group_size (:obj:`int`): The group size of the transformer. + - max_blocks (:obj:`int`): The maximum number of blocks in the transformer. + - max_tokens (:obj:`int`): The maximum number of tokens in the transformer. + - context_length (:obj:`int`): The context length of the transformer. + - device (:obj:`str`): The device of the model, which can be 'cuda' or 'cpu'. + - action_space_size (:obj:`int`): The shape of the action. + - num_layers (:obj:`int`): The number of layers in the transformer. + - num_heads (:obj:`int`): The number of heads in the transformer. + - policy_entropy_weight (:obj:`float`): The weight of the policy entropy. + - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. + """ + super(UniZeroMTModel, self).__init__() + + print(f'==========UniZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}===========') + + self.action_space_size = action_space_size + + # for multi-task + self.action_space_size = 18 + self.task_num = task_num + + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLP( + observation_shape, + hidden_channels=world_model_cfg.embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: only for MemoryEnv now + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=self.decoder_network, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # TODO: N independent encoder + for task_id in range(1): # TODO: one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + )) + # TODO: we should change the output_shape to the real observation shape + # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + # ====== for analysis ====== + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image_memory': + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + + if world_model_cfg.analysis_sim_norm: + # ====== for analysis ====== + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, + decoder_network=self.decoder_network) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) + + #@profile + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + print('=here 5='*20) + import ipdb; ipdb.set_trace() + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + #@profile + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index c6ee6426c..7abd5c062 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -22,11 +22,13 @@ def __init__(self, use_dropout: bool = True): self.chns = [64, 128, 256, 512, 512] # vg16 features # Comment out the following line if you don't need perceptual loss # self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + + # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # Comment out the following line if you don't need perceptual loss # self.load_from_pretrained() # for param in self.parameters(): diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..159afd69e --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,49 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 +class MultiplicationFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + + self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # if len(self.experts) == 1: + # # 只有一个专家时,直接使用该专家 + # return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + # batch_idx, nth_expert = torch.where(selected_experts == i) + # results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py new file mode 100644 index 000000000..6ab93cc16 --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe.py @@ -0,0 +1,107 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# 定义MoeArgs数据类,用于存储MoE的配置参数 +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + +# 定义Mixture of Experts(MoE)层 +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if len(self.experts) == 1: + # 只有一个专家时,直接使用该专家 + return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results + +# 定义一个简单的Transformer块 +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('使用MoE在Transformer的feed_forward中') + print("="*20) + else: + self.feed_forward = self.mlp + + def forward(self, x): + return self.feed_forward(x) + +# 定义配置类 +class Config: + def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): + self.embed_dim = embed_dim + self.resid_pdrop = resid_pdrop + self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer + self.moe_in_transformer = moe_in_transformer + +# 测试代码 +def test_transformer_block(): + # 初始化配置 + embed_dim = 64 + resid_pdrop = 0.1 + num_experts_of_moe_in_transformer = 1 + + # 创建输入数据 + inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) + + # 初始化两个输出变量 + outputs_true = None + outputs_false = None + + # 对于moe_in_transformer为True和False分别进行测试 + for moe_in_transformer in [True, False]: + config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) + transformer_block = TransformerBlock(config) + + outputs = transformer_block(inputs) + print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + + if moe_in_transformer: + outputs_true = outputs + else: + outputs_false = outputs + + # 计算输出的差异 + mse_difference = None + if outputs_true is not None and outputs_false is not None: + mse_difference = F.mse_loss(outputs_true, outputs_false).item() + + print(f"输出差异的均方误差(MSE): {mse_difference}") + +if __name__ == "__main__": + test_transformer_block() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bd066ccec..d0f5e0483 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -54,35 +54,59 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False) self.encoder = encoder self.decoder_network = decoder_network - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Tensor: """ Encode observations to embeddings. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). + - x (torch.Tensor): Input tensor of shape (B, ...). Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: Encoded embeddings of shape (B, 1, E). """ shape = x.shape + # TODO: ====== + if task_id is None: + # for compatibility with multitask setting + task_id = 0 + else: + # task_id = 0 # one share encoder + task_id = task_id # TODO: one encoder per task + # print(f'='*20) + # print(f'x.shape:{x.shape}') + # print(f'self.encoder:{self.encoder}') + # Process input tensor based on its dimensionality if len(shape) == 2: # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) + # obs_embeddings = self.encoder[task_id](x) + obs_embeddings = self.encoder(x, task_id) # TODO: + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 3: # Case when input is 3D (B, T, E) x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) + # obs_embeddings = self.encoder[task_id](x) + obs_embeddings = self.encoder(x,task_id) # TODO: + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) + try: + obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask + # obs_embeddings = self.encoder[task_id](x) + except Exception as e: + print(e) + obs_embeddings = self.encoder(x) # TODO: for memory env + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 5: # Case when input is 5D (B, T, C, H, W) x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) + try: + obs_embeddings = self.encoder[task_id](x) + except Exception as e: + obs_embeddings = self.encoder(x) # TODO: for memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') else: raise ValueError(f"Invalid input shape: {shape}") diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 62536c892..f7169c355 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -13,6 +13,8 @@ from torch.nn import functional as F from .kv_caching import KeysValues +from .moe import MoeLayer, MultiplicationFeedForward +from line_profiler import line_profiler @dataclass @@ -67,8 +69,13 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: - KeysValues: An object containing empty keys and values. """ device = self.ln_f.weight.device # Assumption: All submodules are on the same device - return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + if self.config.task_embed_option == "concat_task_embed": + return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + else: + return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + + #@profile def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -91,6 +98,8 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues return x + + class Block(nn.Module): """ Transformer block class. @@ -121,12 +130,48 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + if config.moe_in_transformer: + # 创Create multiple independent MLP instances + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + elif config.multiplication_moe_in_transformer: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + else: + self.feed_forward = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -144,10 +189,10 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.mlp(self.ln2(x)) + x = x + self.feed_forward(self.ln2(x)) return x @@ -188,6 +233,7 @@ def __init__(self, config: TransformerConfig) -> None: causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) self.register_buffer('mask', causal_mask) + #@profile def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -205,7 +251,10 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, B, T, C = x.size() if kv_cache is not None: b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + try: + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + except Exception as e: + print('debug') else: L = 0 diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..0a0c9dd51 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -215,8 +215,14 @@ def init_weights(module, norm_type='BN'): module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.bias.data.zero_() + except Exception as e: + print(e) + try: + module.weight.data.fill_(1.0) + except Exception as e: + print(e) elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -294,7 +300,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 37d4cd3ec..f98bc0033 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -45,8 +45,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.transformer = Transformer(self.config) + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.transformer = Transformer(self.config) + self.task_num = 1 if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -82,7 +84,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Head modules self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.config.embed_dim, self.sim_norm) # NOTE: we add a sim_norm to the head for observations if self.continuous_action_space: self.sigma_type = self.config.sigma_type @@ -259,7 +261,6 @@ def _initialize_config_parameters(self) -> None: self.max_cache_size = self.config.max_cache_size self.env_num = self.config.env_num self.num_layers = self.config.num_layers - self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) def _initialize_patterns(self) -> None: @@ -335,7 +336,13 @@ def _initialize_projection_input_dim(self) -> None: if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - self.projection_input_dim = self.obs_per_embdding_dim + # self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - 96 + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" @@ -392,6 +399,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -413,6 +421,7 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads ).transpose(1, 2).detach() + #@profile def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], past_keys_values: Optional[torch.Tensor] = None, kvcache_independent: bool = False, is_init_infer: bool = True, @@ -484,6 +493,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -512,6 +522,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -548,6 +559,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -577,6 +589,7 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ Pass sequences through the transformer. @@ -597,6 +610,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: """ @@ -631,6 +645,7 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch. return outputs_wm, self.latent_state + #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, @@ -724,7 +739,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # ================ calculate the target value in Train phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, - self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + self.config.embed_dim) # (BL, K) for unroll_step=1 last_obs_embeddings = last_obs_embeddings[:, :-1, :] batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) @@ -754,6 +769,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict): """ @@ -771,6 +787,7 @@ def forward_initial_inference(self, obs_act_dict): return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, latent_state_index_in_search_path=[]): @@ -856,6 +873,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -908,6 +926,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None): """ @@ -1049,6 +1068,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.past_kv_cache_recurrent_infer[cache_key] = cache_index + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0) -> list: """ @@ -1450,7 +1470,7 @@ def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the policy loss for continuous actions. @@ -1465,9 +1485,12 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso - mu (:obj:`torch.Tensor`): The mean of the normal distribution. - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. """ - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + if task_id is None: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ 0], self.config.num_unroll_steps, self.config.action_space_size - + else: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size_list[task_id] policy_logits_all = outputs.logits_policy mask_batch = batch['mask_padding'] child_sampled_actions_batch = batch['child_sampled_actions'] @@ -1509,6 +1532,8 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -1558,6 +1583,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1567,6 +1593,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1586,6 +1613,8 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..c6b0c0e13 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,1961 @@ +import collections +import logging +from typing import Any, Tuple +from typing import Optional +from typing import Union, Dict + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.world_model import WorldModel +from lzero.model.utils import cal_dormant_ratio +from .moe import MoeLayer, MultiplicationFeedForward +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights +from .utils import WorldModelOutput, hash_state + +logging.getLogger().setLevel(logging.DEBUG) +from ding.utils import get_rank +import torch.distributed as dist +from sklearn.manifold import TSNE +import os +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import torch + + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + #@profile + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + + - task_embed_option (str): Strategy for incorporating task embeddings. Options: + - "add_task_embed": Adds task embeddings to observation embeddings (default). + - "concat_task_embed": Concatenates task embeddings with observation embeddings. + - "register_task_embed": Uses task embeddings as additional input tokens. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + print(f"self.device: {self.device}") + # Position embedding + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + # Task embedding setup + self.use_task_embed = config.use_task_embed + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.task_num = config.task_num + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + self.register_token_length = config.register_token_length if hasattr(config, "register_token_length") else 4 + + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + if self.task_embed_option == "concat_task_embed": + # TODO:目前在 "concat_task_embed"下面,self.pos_emb需要设置为固定的0 + self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TODO + # self.task_emb.weight = self.sim_norm(self.task_emb.weight) + + self.obs_act_embed_dim = config.embed_dim - 96 + elif self.task_embed_option == "register_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.obs_act_embed_dim = config.embed_dim + # self.register_task_tokens = nn.Parameter( + # torch.zeros(self.register_token_length, config.embed_dim) + # ) + # nn.init.normal_(self.register_task_tokens, mean=0.0, std=0.02) + elif self.task_embed_option == "add_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.obs_act_embed_dim = config.embed_dim + + + + self.transformer = Transformer(self.config) + + # TODO ======== + self.analysis_tsne = self.config.get('analysis_tsne', False) + + if self.analysis_tsne: + self.env_id_list = self.config.env_id_list + # 自动生成 self.env_short_names + self.env_short_names = {} + + # 遍历 env_id_list,提取短名称 + for env_id in self.config.env_id_list: + # 提取 'NoFrameskip-v4' 之前的部分作为短名称 + short_name = env_id.replace('NoFrameskip-v4', '') + self.env_short_names[env_id] = short_name + # 映射环境 ID 到简写名称 + # self.env_short_names = { + # 'PongNoFrameskip-v4': 'Pong', + # 'MsPacmanNoFrameskip-v4': 'MsPacman', + # 'SeaquestNoFrameskip-v4': 'Seaquest', + # 'BoxingNoFrameskip-v4': 'Boxing', + # 'AlienNoFrameskip-v4': 'Alien', + # 'ChopperCommandNoFrameskip-v4': 'Chopper', + # 'HeroNoFrameskip-v4': 'Hero', + # 'RoadRunnerNoFrameskip-v4': 'RoadRunner' + # } + # 颜色映射,确保每个任务有固定的颜色 + self.num_tasks = len(self.env_id_list) + + # 生成足够多的颜色 + self.colors = self._generate_colors(len(self.env_id_list)) + + + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + # self.act_embedding_table = nn.Sequential( + # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + # SimNorm(simnorm_dim=self.group_size)) + # print(f'config.action_space_size_list:{config.action_space_size_list}') + self.act_embedding_table = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size) + ) + for task_id in range(self.task_num) + ]) + else: + # for discrete action space + self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + # if self.num_experts_in_moe_head == -1: + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + print('We use normal head') + # TODO: Normal Head + for task_id in range(self.task_num): + if self.continuous_action_space: + # TODO + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) # TODO + else: + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.config.embed_dim, + self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations_multi_task.append(self.head_observations) + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store SoftMoE instances + self.soft_moe_instances = {} + + # Create softmoe head modules + self.create_head_modules_softmoe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store moe instances + self.moe_instances = {} + + # Create moe head modules + self.create_head_modules_moe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + + # Apply weight initialization, the order is important + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.shared_pool_size = int(50*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _generate_colors(self, num_colors): + """ + 生成足够多的独特颜色,适用于大量分类。 + + 参数: + - num_colors: 所需颜色数量。 + + 返回: + - colors: 颜色列表。 + """ + # 使用多个matplotlib离散色图拼接 + color_maps = ['tab20', 'tab20b', 'tab20c'] + colors = [] + for cmap_name in color_maps: + cmap = plt.get_cmap(cmap_name) + colors.extend([cmap(i) for i in range(cmap.N)]) + if len(colors) >= num_colors: + break + if len(colors) < num_colors: + # 生成额外的颜色,如果需要 + additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) + colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) + return colors[:num_colors] + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.device = self.config.device + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: + """Create moe head modules for the transformer.""" + modules = [ + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + def get_moe(self, name): + """Get or create a MoE instance""" + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + + self.moe_instances[name] = MoeLayer( + experts=self.experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + return self.moe_instances[name] + + def create_head_modules_moe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_moe( + self.act_tokens_pattern, + self.support_size, + moe=self.get_moe("rewards_moe") + ) + + # Observations head + self.head_observations = self._create_head_moe( + self.all_but_last_latent_state_pattern, + self.embdding_dim, + norm_layer=self.sim_norm, # NOTE + moe=self.get_moe("observations_moe") + ) + + # Policy head + self.head_policy = self._create_head_moe( + self.value_policy_tokens_pattern, + self.action_space_size, + moe=self.get_moe("policy_moe") + ) + + # Value head + self.head_value = self._create_head_moe( + self.value_policy_tokens_pattern, + self.support_size, + moe=self.get_moe("value_moe") + ) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: + """Create softmoe head modules for the transformer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name): + """Get or create a SoftMoE instance""" + # from soft_moe_pytorch import SoftMoE + # if name not in self.soft_moe_instances: + # self.soft_moe_instances[name] = SoftMoE( + # dim=self.embed_dim, + # seq_len=20, # TODO + # num_experts=self.num_experts_in_moe_head, + # ) + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu = True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_softmoe( + self.act_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("rewards_soft_moe") + ) + + # Observations head + self.head_observations = self._create_head_softmoe( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + norm_layer=self.sim_norm, # NOTE + soft_moe=self.get_soft_moe("observations_soft_moe") + ) + + # Policy head + self.head_policy = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.action_space_size, + soft_moe=self.get_soft_moe("policy_soft_moe") + ) + + # Value head + self.head_value = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("value_soft_moe") + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True + print(f'world_model_mt.py:self.task_num:{self.task_num}') + if last_linear_layer_init_zero: + if self.continuous_action_space: + module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] + else: + module_to_initialize = [self.head_policy, self.head_value, self.head_rewards, self.head_observations] + + # TODO: multitask + if self.task_num == 1: + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + elif self.task_num > 1: + if self.continuous_action_space: + module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + else: + module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + self.past_kv_cache_recurrent_infer = collections.OrderedDict() + self.past_kv_cache_init_infer = collections.OrderedDict() + self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + # self.projection_input_dim = self.config.embed_dim + + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - 96 + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + #@profile + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + #@profile + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + #@profile + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + # TODO: detach() ========== + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + #@profile + def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + """ + Forward pass for the model. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + Returns: + - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + else: + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= + + # Determine previous steps based on key-value caching method + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device) + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid_context_lengths during initial inference + if is_init_infer: + valid_context_lengths = None + + # Process observation embeddings + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + + # TODO: multitask + if self.task_embed_option == "add_task_embed": + obs_embeddings = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'obs_embeddings.shape:{obs_embeddings.shape}') + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + if is_init_infer: + # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中以及含义task embeddings的信息了 + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + + if is_init_infer: + if self.task_embed_option == "register_task_embed": + # Register task embeddings as input tokens + task_tokens = self.task_embeddings.expand(obs_embeddings.shape[0], self.register_token_length, -1) + obs_embeddings = torch.cat([task_tokens, obs_embeddings], dim=1) + + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + + # Process action tokens + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + if self.task_num >= 1: + act_embeddings = self.act_embedding_table[task_id](act_tokens) + else: + act_embeddings = self.act_embedding_table(act_tokens) + + if self.task_embed_option == "add_task_embed": + # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 + # obs_embeddings = obs_embeddings + self.task_embeddings + pass + elif self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'act_embeddings.shape:{act_embeddings.shape}') + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1) + act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1) + + + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + # Process combined observation embeddings and action tokens + else: + + # "add_task_embed"在self._process_obs_act_combined_cont方法内部处理 + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + + # Pass sequences through transformer + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths) + + # Generate logits + + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + # TODO: one head or moe head + if self.use_moe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # TODO: in total N head, one head per task + logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + + # logits_ends is None + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + + #@profile + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + #@profile + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + if self.continuous_action_space: + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: # TODO + act_tokens = act_tokens.unsqueeze(-1) + + # B, L, E + act_embeddings = self.act_embedding_table[task_id](act_tokens) + + B, L, K, E = obs_embeddings.size() + + if self.task_embed_option == "concat_task_embed": + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + else: + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + + if self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + if self.task_embed_option == "add_task_embed": + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + elif self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'obs_embeddings.shape:{obs_embeddings.shape}') + # print(f'=='*20) + obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1) + else: + obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + + act = act_embeddings[:, i, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'act_embeddings.shape:{act_embeddings.shape}') + # print(f'=='*20) + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + # print(f'obs_act.shape:{obs_act.shape}') + + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + if self.task_embed_option == "register_task_embed": + # =====TODO===== + # Register task embeddings as input tokens + task_tokens = self.task_embeddings.expand(B, self.register_token_length, -1) + obs_act_embeddings = torch.cat([task_tokens, obs_act_embeddings], dim=1) + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + + #@profile + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table[task_id](act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + # obs = obs_embeddings[:, i, :, :] + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + #@profile + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + #@profile + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, task_id = 0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # ================ calculate the target value in Train phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_act_embed_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + if not self.continuous_action_space: + token = action.reshape(-1, 1) + else: + token = action.reshape(-1, self.config.action_space_size_list[task_id]) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # Store the latest key-value cache for initial inference + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + # Store the latest key-value cache for recurrent inference + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + + def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task=5, save_dir='tsne_plots_26games'): + """ + 生成 t-SNE 可视化图,并在图中为每个任务随机标注指定数量的观测样本图像。 + + 参数: + - tsne_results: t-SNE 降维结果 (N x 2 的数组) + - task_ids: 环境任务 ID,用于着色 (N 的数组) + - observations: 对应的观测样本 (N x C x H x W 的张量或数组) + - samples_per_task: 每个任务选择的样本数量,默认 5 + - save_dir: 保存路径,默认 'tsne_plots_26games' + """ + + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + + # 创建 t-SNE 图 + print("[INFO] 开始绘制 t-SNE 散点图...") + plt.figure(figsize=(18, 10)) # 增大图像宽度以适应右侧图例 + + # 散点图 + scatter = plt.scatter( + tsne_results[:, 0], + tsne_results[:, 1], + c=[self.colors[tid] for tid in task_ids], + alpha=0.6, + edgecolor='w', + linewidth=0.5 + ) + + # 创建自定义图例 + legend_elements = [] + for idx, env_id in enumerate(self.env_id_list): + short_name = self.env_short_names.get(env_id, env_id) + color = self.colors[idx] + legend_elements.append( + Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + ) + + # 将图例放在图像右侧,并且每个图例项占一行 + plt.legend( + handles=legend_elements, + title="Environment IDs", + loc='center left', + bbox_to_anchor=(1, 0.5), # 图例在图像右侧中央 + fontsize=10, + title_fontsize=12, + ncol=1, + frameon=False # 去除图例边框,增强美观 + ) + + # 设置标题和轴标签 + plt.title("t-SNE of Latent States across Environments", fontsize=16) + plt.xlabel("t-SNE Dimension 1", fontsize=14) + plt.ylabel("t-SNE Dimension 2", fontsize=14) + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.grid(True, linestyle='--', alpha=0.5) + print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + + # 为每个任务选择指定数量的样本进行图像标注 + print(f"[INFO] 开始为每个任务选择 {samples_per_task} 个样本进行图像标注...") + for task_id in range(len(self.env_id_list)): + # 找到当前任务的所有索引 + task_indices = np.where(task_ids == task_id)[0] + if len(task_indices) == 0: + print(f"[WARNING] 任务 ID {task_id} 没有对应的样本。") + continue + # 如果样本数量少于所需,全部选取 + if len(task_indices) < samples_per_task: + selected_indices = task_indices + print(f"[INFO] 任务 ID {task_id} 的样本数量 ({len(task_indices)}) 少于 {samples_per_task},选取全部。") + else: + selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + print(f"[INFO] 任务 ID {task_id} 随机选取 {samples_per_task} 个样本进行标注。") + + for idx in selected_indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.shape[0] == 1 or img.shape[0] == 3: # 处理灰度图或 RGB 图 + img = np.transpose(img, (1, 2, 0)) + else: + raise ValueError(f"Unsupported image shape: {img.shape}") + + # 标准化图像到 [0,1] 范围 + img_min, img_max = img.min(), img.max() + if img_max - img_min > 1e-5: + img = (img - img_min) / (img_max - img_min) + else: + img = np.zeros_like(img) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox( + imagebox, + (tsne_results[idx, 0], tsne_results[idx, 1]), + frameon=False, + pad=0.3 + ) + plt.gca().add_artist(ab) + print(f"[INFO] 已添加图像标注: 任务 ID {task_id}, 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # 调整布局以适应图例 + plt.tight_layout(rect=[0, 0, 0.9, 1]) # 为右侧的图例预留空间 + + # 保存图像,使用高分辨率 + save_path_png = os.path.join(save_dir, 'tsne_plot.png') + save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + print(f"[INFO] t-SNE 可视化图已保存至: {save_path_png} 和 {save_path_pdf}") + plt.close() + + @torch.no_grad() + def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # 准备接收来自所有进程的CUDA张量 + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # 准备接收来自所有进程的CPU对象 + observations_list = [None for _ in range(world_size)] + + try: + # 收集CUDA张量:embeddings和task_ids + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # 收集CPU对象:observations + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # 拼接所有embeddings和task_ids + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # 拼接所有observations + all_observations = [] + for obs in observations_list: + all_observations.extend(obs) + all_observations = np.array(all_observations) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # 执行t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # 绘制并保存图像 + self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_tsne: + # =========== tsne analysis =========== + # 确保embeddings在CUDA设备上且为稠密张量 + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + + # 保存当前进程的 embeddings 和 task_id + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + + # 将observations移到CPU并转换为numpy + local_observations = batch['observations'].detach().cpu() + + # 进行数据收集和可视化 + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + if self.use_task_embed and self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'labels_observations.shape:{labels_observations.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + task_emb_expanded = self.task_embeddings.expand(labels_observations.shape[0], -1) + # print(f'task_emb_expanded:{task_emb_expanded}') + # print(f"task_emb_expanded.shape: {task_emb_expanded.shape}") + # print(f"task_emb_expanded (min, max, mean): {task_emb_expanded.min()}, {task_emb_expanded.max()}, {task_emb_expanded.mean()}") + # assert not torch.isnan(task_emb_expanded).any(), "task_emb_expanded 存在 NaN 值" + # print(f"logits_observations.shape: {logits_observations.shape}") + labels_observations = torch.cat([labels_observations, task_emb_expanded.detach()], dim=-1) # NOTE: detach() + # print(f"labels_observations.shape: {labels_observations.shape}") + # assert logits_observations.shape == labels_observations.shape, "logits 和 labels 的形状不匹配" + + + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # assert not torch.isnan(logits_reshaped).any(), "logits_reshaped contains NaN values" + # assert not torch.isnan(labels_reshaped).any(), "labels_reshaped contains NaN values" + # print('loss_obs:', loss_obs.mean()) + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # logits_grad = torch.autograd.grad(loss_obs.mean(), logits_observations, retain_graph=True)[0] + # print(f"logits_grad (min, max, mean): {logits_grad.min()}, {logits_grad.max()}, {logits_grad.mean()}") + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple( + outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + outputs, batch, task_id=task_id) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/unizero_world_models/world_model_multitask_tsne-v0.py b/lzero/model/unizero_world_models/world_model_multitask_tsne-v0.py new file mode 100644 index 000000000..4fb1f07cc --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask_tsne-v0.py @@ -0,0 +1,1613 @@ +import collections +import copy +import logging +from typing import Any, Tuple +from typing import Optional +from typing import Union, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lzero.model.common import SimNorm +from lzero.model.utils import cal_dormant_ratio +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache +from .utils import WorldModelOutput, hash_state + +from .moe import MoeLayer, MultiplicationFeedForward +from lzero.model.unizero_world_models.world_model import WorldModel +logging.getLogger().setLevel(logging.DEBUG) +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size + +from line_profiler import line_profiler +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import numpy as np +import os + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + #@profile + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + self.transformer = Transformer(self.config) + + self.analysis_mode = self.config.get('analysis_mode', False) + + # TODO: multitask + self.task_num = config.task_num + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + print(f"self.device: {self.device}") + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Position embedding + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + self.act_embedding_table = nn.Sequential( + nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size)) + else: + # for discrete action space + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + # if self.num_experts_in_moe_head == -1: + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + print('We use normal head') + # TODO: Normal Head + for task_id in range(self.task_num): # TODO + action_space_size = self.action_space_size # TODO:====================== + # action_space_size=18 # TODO:====================== + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations_multi_task.append(self.head_observations) + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store SoftMoE instances + self.soft_moe_instances = {} + + # Create softmoe head modules + self.create_head_modules_softmoe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store moe instances + self.moe_instances = {} + + # Create moe head modules + self.create_head_modules_moe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + + # Apply weight initialization, the order is important + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.shared_pool_size = int(50*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: 过多会导致检索到错误的kvcache吗 + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.device = self.config.device + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: + """Create moe head modules for the transformer.""" + modules = [ + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + def get_moe(self, name): + """Get or create a MoE instance""" + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + + self.moe_instances[name] = MoeLayer( + experts=self.experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + return self.moe_instances[name] + + def create_head_modules_moe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_moe( + self.act_tokens_pattern, + self.support_size, + moe=self.get_moe("rewards_moe") + ) + + # Observations head + self.head_observations = self._create_head_moe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + moe=self.get_moe("observations_moe") + ) + + # Policy head + self.head_policy = self._create_head_moe( + self.value_policy_tokens_pattern, + self.action_space_size, + moe=self.get_moe("policy_moe") + ) + + # Value head + self.head_value = self._create_head_moe( + self.value_policy_tokens_pattern, + self.support_size, + moe=self.get_moe("value_moe") + ) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: + """Create softmoe head modules for the transformer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name): + """Get or create a SoftMoE instance""" + # from soft_moe_pytorch import SoftMoE + # if name not in self.soft_moe_instances: + # self.soft_moe_instances[name] = SoftMoE( + # dim=self.embed_dim, + # seq_len=20, # TODO + # num_experts=self.num_experts_in_moe_head, + # ) + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu = True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_softmoe( + self.act_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("rewards_soft_moe") + ) + + # Observations head + self.head_observations = self._create_head_softmoe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + soft_moe=self.get_soft_moe("observations_soft_moe") + ) + + # Policy head + self.head_policy = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.action_space_size, + soft_moe=self.get_soft_moe("policy_soft_moe") + ) + + # Value head + self.head_value = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("value_soft_moe") + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True + if last_linear_layer_init_zero: + # TODO: multitask + if self.task_num == 1: + for head in [self.head_policy, self.head_value, self.head_rewards, self.head_observations]: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + elif self.task_num > 1: + for head in self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + self.past_kv_cache_recurrent_infer = collections.OrderedDict() + self.past_kv_cache_init_infer = collections.OrderedDict() + self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + self.projection_input_dim = self.obs_per_embdding_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + #@profile + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + #@profile + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + #@profile + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + #@profile + def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + """ + Forward pass for the model. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + Returns: + - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + """ + # task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # NOTE:TODO no task_embeddings ============= + + # Determine previous steps based on key-value caching method + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device) + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid_context_lengths during initial inference + if is_init_infer: + valid_context_lengths = None + + # Process observation embeddings + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + # TODO: multitask + sequences = sequences + self.task_embeddings + + # Process action tokens + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + act_embeddings = self.act_embedding_table(act_tokens) + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + # TODO: multitask + # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + sequences = sequences + self.task_embeddings + + # Process combined observation embeddings and action tokens + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + # Pass sequences through transformer + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths) + + # Generate logits + + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + # TODO: one head or moe head + if self.use_moe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # TODO: N head + logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + + # logits_ends is None + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + + #@profile + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + #@profile + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + # obs = obs_embeddings[:, i, :, :] + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + #@profile + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + #@profile + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, task_id=0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # ================ calculate the target value in Train phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id=0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + elif batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id=0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id=0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + token = action.reshape(-1, 1) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # Store the latest key-value cache for initial inference + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + # Store the latest key-value cache for recurrent inference + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id=0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + def plot_embeddings(self, tsne_results, task_ids, observations, save_dir='tsne_plots'): + """ + 生成 t-SNE 可视化图,并在图中随机标注对应的观测样本图像。 + + 参数: + - tsne_results: t-SNE 降维结果 (N x 2 的数组) + - task_ids: 环境任务 ID,用于着色 (N 的数组) + - observations: 对应的观测样本 (N x C x H x W 的张量或数组) + - save_dir: 保存路径,默认 'tsne_plots' + """ + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + + # 创建 t-SNE 图 + print("[INFO] 开始绘制 t-SNE 散点图...") + plt.figure(figsize=(16, 10)) + scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=task_ids, cmap='tab10', alpha=0.6) + plt.legend(*scatter.legend_elements(), title="Env IDs") + plt.title("t-SNE of Observations Embeddings across Environments") + print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + + # 添加典型点的图像标注(随机选择 10 个点) + num_images = 10 + if len(tsne_results) > num_images: + print(f"[INFO] 数据点数量 ({len(tsne_results)}) 大于 {num_images},随机选择其中的 {num_images} 个点进行标注...") + indices = np.random.choice(range(len(tsne_results)), size=num_images, replace=False) + else: + print(f"[INFO] 数据点数量 ({len(tsne_results)}) 小于或等于 {num_images},将全部点用于标注...") + indices = range(len(tsne_results)) + + for idx in indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.shape[0] in [1, 3]: # 处理灰度图或 RGB 图 + img = np.transpose(img, (1, 2, 0)) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox(imagebox, (tsne_results[idx, 0], tsne_results[idx, 1]), frameon=False, pad=0.5) + plt.gca().add_artist(ab) + print(f"[INFO] 已添加图像标注: 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # 保存图像 + save_path = os.path.join(save_dir, 'tsne_plot.png') + plt.savefig(save_path) + print(f"[INFO] t-SNE 可视化图已保存至: {save_path}") + plt.close() + + @torch.no_grad() + def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # 准备接收来自所有进程的CUDA张量 + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # 准备接收来自所有进程的CPU对象 + observations_list = [None for _ in range(world_size)] + + + try: + # 收集CUDA张量:embeddings和task_ids + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # 收集CPU对象:observations + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # 拼接所有embeddings和task_ids + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # 拼接所有observations + all_observations = [] + for obs in observations_list: + all_observations.extend(obs) + all_observations = np.array(all_observations) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # 执行t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # 绘制并保存图像 + self.plot_embeddings(tsne_results, all_task_ids, all_observations) + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id=0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_mode: + # 确保embeddings在CUDA设备上且为稠密张量 + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + + # 保存当前进程的 embeddings 和 task_id + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + + # 将observations移到CPU并转换为numpy + local_observations = batch['observations'].detach().cpu() + + # 进行数据收集和可视化 + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # print('loss_obs:', loss_obs.mean()) + # assert not torch.isnan(loss_obs).any(), "loss_obs contains NaN values" + # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" + # for name, param in self.tokenizer.representation_network.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, + element='policy') + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + # self.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py new file mode 100644 index 000000000..933a61f34 --- /dev/null +++ b/lzero/policy/muzero_multitask.py @@ -0,0 +1,859 @@ +import copy +from typing import List, Dict, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + cross_entropy_loss, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + negative_cosine_similarity, + prepare_obs, +) +from lzero.policy.muzero import MuZeroPolicy + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'loss_task{}' + :param task_id: 任务起始ID + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + +@POLICY_REGISTRY.register('muzero_multitask') +class MuZeroMTPolicy(MuZeroPolicy): + """ + 概述: + MuZero 的多任务策略类,扩展自 MuZeroPolicy。支持同时训练多个任务,通过分离每个任务的损失并进行优化。 + """ + + # MuZeroMTPolicy 的默认配置 + config = dict( + type='muzero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(4, 96, 96), # example shape + self_supervised_learning_loss=False, + categorical_distribution=True, + image_channel=1, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=300, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + ), + # ****** common ****** + use_rnd_model=False, + multi_gpu=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=200, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + + # ****** observation ****** + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + # ******* learn ****** + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='SGD', + learning_rate=0.2, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=5, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + lr_piecewise_constant_decay=True, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + + # ****** UCB ****** + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + + # ****** 多任务相关 ****** + task_num=2, # 任务数量,根据实际需求调整 + task_id=0, # 当前任务的起始ID + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + 概述: + 返回该算法的默认模型设置。 + 返回: + - model_info (:obj:`Tuple[str, List[str]]`): 模型名称和模型导入路径列表。 + """ + return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] + + def _init_learn(self) -> None: + """ + 概述: + 学习模式初始化方法。初始化学习模型、优化器和MCTS工具。 + """ + super()._init_learn() + + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # harmonydream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + # 初始化多任务相关参数 + self.task_num_for_current_rank = self._cfg.task_num + self.task_id = self._cfg.task_id + + def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: + """ + 概述: + 学习模式的前向函数,是学习过程的核心。数据从重放缓冲区采样,计算损失并反向传播更新模型。 + 参数: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): 每个任务的数据元组列表, + 每个元组包含 (current_batch, target_batch, task_id)。 + 返回: + - info_dict (:obj:`Dict[str, Union[float, int]]`): 用于记录的信息字典,包含当前学习损失和学习统计信息。 + """ + self._learn_model.train() + self._target_model.train() + + # 初始化多任务损失列表 + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + consistency_loss_multi_task = [] + policy_entropy_multi_task = [] + lambd_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 # 初始化为0 + losses_list = [] # 用于存储每个任务的损失 + + for task_idx, (current_batch, target_batch, task_id) in enumerate(data): + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # 数据增强 + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # 准备动作批次并转换为张量 + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( + data_list, self._cfg.device + ) + + target_reward = target_reward.view(self._cfg.batch_size[task_idx], -1) + target_value = target_value.view(self._cfg.batch_size[task_idx], -1) + + assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) + + # 变换奖励和价值到缩放形式 + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # 转换为类别分布 + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # 初始推理 + network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) + + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio 和 L2 Norm + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio( + self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold + ) + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 初始化预测值和策略 + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # 计算优先级 + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # 计算第一个步骤的策略和价值损失 + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + target_policy_entropy = 0 + + # 循环进行多个unroll步骤 + for step_k in range(self._cfg.num_unroll_steps): + # 使用动态函数进行递归推理 + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + action_tmp = action_batch[:, step_k] + if len(action_tmp.shape) == 1: + action_tmp = action_tmp.unsqueeze(-1) + # 转换动作为独热编码 + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio( + self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold + ) + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 计算一致性损失 + if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # 计算策略和价值损失 + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + # 计算策略熵损失 + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss += -entropy + + # 计算目标策略熵(仅用于调试) + target_normalized_visit_count = target_policy[:, step_k + 1] + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * + torch.log(target_normalized_visit_count_masked + 1e-6) + ).sum(-1).mean() + else: + target_policy_entropy += torch.log( + torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) + ) + + + # 记录预测值和奖励(如果监控额外统计) + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat( + (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) + ) + + # 核心学习模型更新步骤 + weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ + self._cfg.value_loss_weight * value_loss + \ + self._cfg.reward_loss_weight * reward_loss + \ + self._cfg.ssl_loss_weight * consistency_loss + \ + self._cfg.policy_entropy_weight * policy_entropy_loss + + # 将多个任务的损失累加 + weighted_total_loss += weighted_loss.mean() + + # 保留每个任务的损失用于日志记录 + reward_loss_multi_task.append(reward_loss.mean().item()) + policy_loss_multi_task.append(policy_loss.mean().item()) + value_loss_multi_task.append(value_loss.mean().item()) + consistency_loss_multi_task.append(consistency_loss.mean().item()) + policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) # TODO: 如果使用梯度校正,可以在这里调整 + value_priority_multi_task.append(value_priority.mean().item()) + value_priority_mean_multi_task.append(value_priority.mean().item()) + losses_list.append(weighted_loss.mean().item()) + + # 清零优化器的梯度 + self._optimizer.zero_grad() + + # 反向传播 + weighted_total_loss.backward() + + # 梯度裁剪 + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), + self._cfg.grad_clip_value + ) + + # 多GPU训练时同步梯度 + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # 更新优化器 + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # 更新目标模型 + self._target_model.update(self._learn_model.state_dict()) + + # 获取GPU内存使用情况 + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0.0 + max_memory_allocated_gb = 0.0 + + # 构建返回的损失字典 + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # print(f'self.task_id:{self.task_id}') + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd_multi_task, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 返回最终的损失字典 + return return_loss_dict + + + def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: + """ + 概述: + 注册学习模式中需要监控的变量。注册的变量将根据 `_forward_learn` 的返回值记录到tensorboard。 + 如果提供了 `num_tasks`,则为每个任务生成监控变量。 + 参数: + - num_tasks (:obj:`int`, 可选): 任务数量。 + 返回: + - monitored_vars (:obj:`List[str]`): 需要监控的变量列表。 + """ + # 基本监控变量 + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # 任务特定的监控变量 + task_specific_vars = [ + 'noreduce_consistency_loss', + 'noreduce_reward_loss', + 'noreduce_policy_loss', + 'noreduce_value_loss', + 'noreduce_policy_entropy', + 'noreduce_lambd', + 'noreduce_value_priority', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([8, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(8)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, + data, task_id=task_id) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + else: + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type in ['conv', 'conv_context']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(3)] + # elif self._cfg.model.model_type == 'mlp_context': + # self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape]).to(self._cfg.device) + # self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py new file mode 100644 index 000000000..d1c3eb5d1 --- /dev/null +++ b/lzero/policy/sampled_unizero_multitask.py @@ -0,0 +1,1003 @@ +# /Users/puyuan/code/LightZero/lzero/policy/sample_unizero_multitask.py + +import copy +import logging +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + prepare_obs, + prepare_obs_stack4_for_unizero +) +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt +import sys +sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/') +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :param task_id: 基础任务 ID + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else float(task_loss) + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +@POLICY_REGISTRY.register('sampled_unizero_multitask') +class SampledUniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for Sampled UniZero Multitask, combining multi-task learning with sampled-based MCTS. + This implementation extends the UniZeroPolicy to handle multiple tasks simultaneously while utilizing + sampled MCTS for action selection. It ensures scalability and correctness in multi-task environments. + """ + + # The default_config for Sampled UniZero Multitask policy. + config = dict( + type='sampled_unizero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(3, 64, 64), + self_supervised_learning_loss=True, + categorical_distribution=True, + image_channel=3, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=50, + bias=True, + res_connection_in_dynamics=True, + norm_type='LN', + analysis_sim_norm=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + tokens_per_block=2, + max_blocks=10, + max_tokens=20, + context_length=8, + gru_gating=False, + device='cpu', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + action_space_size=6, + group_size=8, + attention='causal', + num_layers=2, + num_heads=8, + embed_dim=768, + embed_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + support_size=101, + max_cache_size=5000, + env_num=8, + latent_recon_loss_weight=0., + perceptual_loss_weight=0., + policy_entropy_weight=5e-3, + predict_latent_loss_type='group_kl', + obs_type='image', + gamma=1, + dormant_threshold=0.025, + policy_loss_type='kl', + ), + ), + use_rnd_model=False, + multi_gpu=True, + sampled_algo=True, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=400, + analysis_sim_norm=False, + collect_with_pure_policy=False, + eval_freq=int(5e3), + sample_type='transition', + + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='AdamW', + learning_rate=0.0001, + init_w=3e-3, + target_update_freq=100, + target_update_theta=0.05, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=5, + n_episode=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=10, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + train_start_after_envsteps=0, + + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + random_collect_episode_num=0, + + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Return this algorithm's default model setting for demonstration. + """ + return 'SampledUniZeroMTModel', ['lzero.model.sampled_unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Learn mode init method. Initialize the learn model, optimizer, and MCTS utils. + """ + # Configure optimizer for world model + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + if self._cfg.cos_lr_scheduler: + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 + ) + elif self._cfg.piecewise_decay_lr_scheduler: + # Example step scheduler, adjust milestones and gamma as needed + self.lr_scheduler = StepLR( + self._optimizer_world_model, step_size=int(5e4), gamma=0.1 + ) + + if self._cfg.model.continuous_action_space: + # Weight Init for the last output layer of gaussian policy head in prediction network. + init_w = self._cfg.init_w + self._model.world_model.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) + try: + self._model.world_model.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) + except Exception as exception: + logging.warning(exception) + + # Initialize target model + self._target_model = copy.deepcopy(self._model) + # Ensure torch version >= 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # Soft target update + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + # if self._cfg.use_augmentation: + # self.image_transforms = ImageTransforms( + # self._cfg.augmentation, + # image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + # ) + + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + if self._cfg.use_moco: + # 创建 WrappedModel 实例,仅矫正部分参数,保持可扩展性 + # wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer.encoder[0], # 假设只有一个编码器 + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # head 没有矫正梯度 + wrapped_model = WrappedModelV2( + self._learn_model.world_model.tokenizer.encoder, # TODO: one or N encoder inside + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # TODO + # 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad) + # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device) # only compatiable with for 1GPU training + + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: + """ + Forward function for learning policy in learn mode, handling multiple tasks. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize multi-task loss lists + task_weight_multi_task = [] + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 + losses_list = [] # 存储每个任务的损失 + + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg, task_id) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + if self._cfg.model.continuous_action_space: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) + else: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float32'), + target_value.astype('float32'), + target_policy, + weights + ] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for GPT model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape_list[task_id], int) or len(self._cfg.model.observation_shape_list[task_id]) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape_list[task_id]) + elif len(self._cfg.model.observation_shape_list[task_id]) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape_list[task_id]) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['child_sampled_actions'] = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)[:, :-1] + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, + self._target_model.world_model.tokenizer, + self.inverse_scalar_transform_handle, + task_id=task_id + ) + if task_weights is not None: + weighted_total_loss += losses.loss_total * task_weights[task_id] + losses_list.append(losses.loss_total * task_weights[task_id]) + + task_weight_multi_task.append(task_weights[task_id]) + else: + weighted_total_loss += losses.loss_total + losses_list.append(losses.loss_total) + + task_weight_multi_task.append(1) + + + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + # print(f'{loss_name}: {loss_value.sum()}') + # print(f'{loss_name}: {loss_value[0][0]}') + + # print(f"=== 全局任务权重 (按 task_id 排列): {task_weights}") + assert not torch.isnan(losses.loss_total).any(), f"Loss contains NaN values, losses.loss_total:{losses.loss_total}, losses:{losses}" + assert not torch.isinf(losses.loss_total).any(), f"Loss contains Inf values, losses.loss_total:{losses.loss_total}, losses:{losses}" + + # Collect losses per task + obs_loss = self.intermediate_losses.get('loss_obs', 0.0) or 0.0 + reward_loss = self.intermediate_losses.get('loss_rewards', 0.0) or 0.0 + policy_loss = self.intermediate_losses.get('loss_policy', 0.0) or 0.0 + orig_policy_loss = self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0 + policy_entropy = self.intermediate_losses.get('policy_entropy', 0.0) or 0.0 + value_loss = self.intermediate_losses.get('loss_value', 0.0) or 0.0 + latent_recon_loss = self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0 + perceptual_loss = self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0 + latent_state_l2_norms = self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0 + value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder, adjust as needed + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + average_target_policy_entropy_multi_task.append(average_target_policy_entropy) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + if self._cfg.use_moco: + # 这里可以集成 MoCo 或 CAGrad 等梯度校正方法, 1gpu 需要知道所有task对应的梯度 + lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + else: + # 不使用梯度校正的情况 + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + # 获取GPU内存使用情况 + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 构建损失字典 + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + # if task_weights is None: + # task_weights = {self.task_id+i: 1 for i in range(self.task_num_for_current_rank)} + # else: + # print(f'task_weights:{task_weights}') + # from ding.utils import EasyTimer, set_pkg_seed, get_rank + + # print(f'rank:{get_rank()}, task_id:{self.task_id}') + + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', task_id=self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # print(f'multi_task_loss_dicts:{ multi_task_loss_dicts}') + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 如果需要,可以将损失字典记录到日志或其他地方 + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_loss_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_loss_dict + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # rank = get_rank() + task_specific_vars = [ + 'noreduce_task_weight', + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def monitor_weights_and_grads(self, model): + """ + Monitor and print the weights and gradients of the model. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Collect mode init method. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._task_weight_temperature = 10. + + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros( + [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros( + [self.collector_env_num, self._cfg.model.observation_shape_list[0]] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Forward function for collecting data in collect mode, handling multiple tasks. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference( + self.last_batch_obs, + self.last_batch_action, + data, + task_id=task_id + ) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) + ] if not self._cfg.model.continuous_action_space else [ + [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] + for _ in range(active_collect_env_num) + ] + + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.world_model_cfg.num_of_sampled_actions)) + .astype(np.float32).tolist() for _ in range(active_collect_env_num) + ] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + active_collect_env_num, + legal_actions, + self._cfg.model.world_model_cfg.action_space_size, + self._cfg.model.world_model_cfg.num_of_sampled_actions, + self._cfg.model.continuous_action_space + ) + else: + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([ + getattr(action, 'value', action) for action in roots_sampled_actions[i] + ]) + + # 选择动作 + action, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + + # 获取采样动作 + action = root_sampled_actions[action] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # 检查并重置采集器 + if active_collect_env_num < self.collector_env_num: + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + + return output + + def _init_eval(self) -> None: + """ + Evaluate mode init method. Initialize the eval model and MCTS utils. + """ + from ding.utils import EasyTimer, set_pkg_seed, get_rank + + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + self.task_id_for_eval = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs_eval = torch.zeros( + [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs_eval = torch.zeros( + [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] # TODO + ).to(self._cfg.device) + print(f'rank {get_rank()} last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Forward function for evaluating the current policy in eval mode, handling multiple tasks. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference( + self.last_batch_obs_eval, + self.last_batch_action, + data, + task_id=task_id + ) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + # TODO:======== + # self._eval_model.training = False + # if not self._eval_model.training: + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num) + ] if not self._cfg.model.continuous_action_space else [ + [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] + for _ in range(active_eval_env_num) + ] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + active_eval_env_num, + legal_actions, + self._cfg.model.world_model_cfg.action_space_size, + self._cfg.model.world_model_cfg.num_of_sampled_actions, + self._cfg.model.continuous_action_space + ) + else: + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # print(f'type(policy_logits): {type(policy_logits)}') + # print(f'policy_logits.shape: {policy_logits.shape}') + # print(f'policy_logits: {policy_logits}') + + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([ + getattr(action, 'value', action) for action in roots_sampled_actions[i] + ]) + + # 选择动作(确定性) + action, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + + # 获取采样动作 + action = root_sampled_actions[action] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Reset the collection process for a specific environment. + """ + if reset_init_data: + if task_id is not None: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.collector_env_num, + self._cfg.device + ) + else: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + logging.info(f'collector: last_batch_obs, last_batch_action reset() {self.last_batch_obs.shape}') + + if env_id is None or isinstance(env_id, list): + return + + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + if current_steps % clear_interval == 0: + logging.info(f'clear_interval: {clear_interval}') + + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + + logging.info('collector: collect_model clear()') + logging.info(f'eps_steps_lst[{env_id}]: {current_steps}') + + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Reset the target model's caches. + """ + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + logging.info('collector: target_model past_kv_cache.clear()') + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Return the state_dict of learn mode, including model, target_model, and optimizer. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== TODO: original version: load all parameters ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # """ + # # 定义需要排除的参数前缀 + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # 定义需要排除的具体参数(如果有特殊情况) + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # # 添加其他需要排除的具体参数名 + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # 过滤掉需要排除的参数。 + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") # 调试用 + # continue + # filtered[k] = v + # return filtered + + # # 过滤并加载 'model' 部分 + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # 过滤并加载 'target_model' 部分 + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数 + # if 'optimizer_world_model' in state_dict: + # optimizer_state_dict = state_dict['optimizer_world_model'] + # try: + # self._optimizer_world_model.load_state_dict(optimizer_state_dict) + # except Exception as e: + # print(f"Error loading optimizer state_dict: {e}") + # else: + # print("No 'optimizer_world_model' key found in the state_dict.") + + # # 如果需要,还可以加载其他部分,例如 scheduler 等 \ No newline at end of file diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index cf95c46d9..6183c7d3c 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -155,7 +155,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -563,7 +563,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None + ready_env_id: np.array = None, + task_id: int = None, ) -> Dict: """ Overview: @@ -575,6 +576,7 @@ def _forward_collect( - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -697,7 +699,7 @@ def _init_eval(self) -> None: self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None) -> Dict: + ready_env_id: np.array = None, task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -707,6 +709,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -725,7 +728,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data) + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) if not self._eval_model.training: @@ -775,12 +778,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 } batch_action.append(action) - self.last_batch_obs = data + self.last_batch_obs_eval = data self.last_batch_action = batch_action return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -824,7 +827,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in print('collector: collect_model clear()') print(f'eps_steps_lst[{env_id}]: {current_steps}') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -837,11 +840,22 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device - ) + if task_id is not None: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.evaluator_env_num, + self._cfg.device + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + else: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] # Return immediately if env_id is None or a list diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..880ab67b7 --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,1242 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack4_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt + + +# sys.path.append('/Users/puyuan/code/LibMTL/') +# from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect + +# from LibMTL.weighting.abstract_weighting import AbsWeighting + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + + +class WrappedModel: + def __init__(self, world_model): + self.world_model = world_model + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return self.world_model.parameters() + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + def __init__(self, transformer, pos_emb, task_emb, act_embedding_table): + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + # self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: TODO + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + # NOTE: multi-task model + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # 创建 WrappedModel 实例 + # 所有参数都共享,即所有参数都需要进行矫正 + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + # head 没有矫正梯度 + wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer, # TODO: + self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # head 和 tokenizer.encoder 没有矫正梯度 + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # 将 wrapped_model 作为 share_model 传递给 GradCorrect + # ========= 初始化 MoCo CAGrad 参数 ========= + # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + # self.grad_correct.init_param() + # self.grad_correct.rep_grad = False + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + + losses_list = [] # 用于存储每个任务的损失 + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + # current_batch, target_batch, _ = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + + # rank = get_rank() + # print(f'Rank {rank}: cfg.policy.task_id : {self._cfg.task_id}, self._cfg.batch_size {self._cfg.batch_size}') + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for a transformer-based world model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + weighted_total_loss += losses.loss_total # TODO + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + losses_list.append(losses.loss_total) # TODO: for moco + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + # value_priority = intermediate_losses['value_priority'] + # logits_value = intermediate_losses['logits_value'] + + # print(f'logits_value:" {logits_value}') + # print(f'logits_value.shape:" {logits_value.shape}') + # print(f"batch_for_gpt['observations'].shape: {batch_for_gpt['observations'].shape}") + + # ============ for value priority ============ + # transform the categorical representation of the scaled value to its original value + # original_value = self.inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # calculate the new priorities for each transition. + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) # TODO: mix of mean and sum + # value_priority = value_priority.data.cpu().numpy() + 1e-6 # TODO: log-reduce not support array now + value_priority = torch.tensor(0., device=self._cfg.device) + # ============ for value priority ============ + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + reward_loss_multi_task.append(reward_loss) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 + # ============= for CAGrad and MoCo ============= + # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # ============= TODO: 不使用梯度矫正的情况 ============= + lambd = torch.tensor([0. for i in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + # ========== for debugging ========== + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if self._cfg.multi_gpu: + # Very important to sync gradients before updating the model + # rank = get_rank() + # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') + self.sync_gradients(self._learn_model) + # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') + + self._optimizer_world_model.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 然后,在您的代码中,使用这个函数来构建损失字典: + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + # 'policy_entropy': policy_entropy, + # 'target_policy_entropy': average_target_policy_entropy, + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + # print(f'return_loss_dict:{return_loss_dict}') + + # 返回最终的损失字典 + return return_loss_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # rank = get_rank() + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + # TODO: check its correctness ========= + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + This method resets the target model. It clears caches and memory, ensuring optimal performance. + Arguments: + - None + """ + + # Clear various caches in the target_model + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + print('collector: target_model past_kv_cache.clear()') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + if task_id is not None: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.evaluator_env_num, + self._cfg.device + ) + print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + else: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + # world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # NOTE: Clear caches and precompute positional embedding matrices both for the collect and target models + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== TODO: original version: load all parameters ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + """ + # 定义需要排除的参数前缀 + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # 定义需要排除的具体参数(如果有特殊情况) + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # 添加其他需要排除的具体参数名 + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + 过滤掉需要排除的参数。 + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # 调试用 + continue + filtered[k] = v + return filtered + + # 过滤并加载 'model' 部分 + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # 过滤并加载 'target_model' 部分 + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数 + if 'optimizer_world_model' in state_dict: + optimizer_state_dict = state_dict['optimizer_world_model'] + try: + self._optimizer_world_model.load_state_dict(optimizer_state_dict) + except Exception as e: + print(f"Error loading optimizer state_dict: {e}") + else: + print("No 'optimizer_world_model' key found in the state_dict.") + + # 如果需要,还可以加载其他部分,例如 scheduler 等 \ No newline at end of file diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 198bd86ba..e002562c3 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -359,7 +359,7 @@ def prepare_obs_stack4_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> return obs_batch, obs_target_batch -def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict, task_id = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Prepare the observations for the model by converting the original batch of observations @@ -382,9 +382,12 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Calculate the dimension size to slice based on the model configuration. # For convolutional models ('conv'), use the number of frames to stack times the number of channels. # For multi-layer perceptron models ('mlp'), use the number of frames to stack times the size of the observation space. - stack_dim = cfg.model.frame_stack_num * ( + if task_id is None: + stack_dim = cfg.model.frame_stack_num * ( cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape) - + else: + stack_dim = cfg.model.frame_stack_num * ( + cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id]) # Slice the original observation tensor to obtain the batch for the initial inference. obs_batch = obs_batch_ori[:, :stack_dim] @@ -395,7 +398,10 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Determine the starting dimension to exclude based on the model type. # For 'conv', exclude the first 'image_channel' dimensions. # For 'mlp', exclude the first 'observation_shape' dimensions. - exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + if task_id is None: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + else: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id] # Slice the original observation tensor to obtain the batch for consistency loss calculation. obs_target_batch = obs_batch_ori[:, exclude_dim:] @@ -550,7 +556,11 @@ def concat_output_value(output_lst: List) -> np.ndarray: # concat the values of the model output list value_lst = [] for output in output_lst: - value_lst.append(output.value) + value_lst.append(output.value) # TODO:cpu + + # print(f'value_lst:{value_lst}') + # print(f'value_lst[0]:{value_lst[0]}') + # print(f'value_lst[0].shape:{value_lst[0].shape}') value_lst = np.concatenate(value_lst) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index eff413df6..9b2f164e0 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -41,6 +41,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -53,7 +54,9 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -267,6 +270,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm end_index = beg_index + self.unroll_plus_td_steps - 1 pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_lst = game_segments[i].chance_segment[beg_index:end_index] @@ -293,7 +297,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm game_segment element shape: obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 + action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 child_visits: game_segment_length + num_unroll_steps -> 20 +5 to_play: game_segment_length -> 20 @@ -434,8 +438,13 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) - + # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + else: + # multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} @@ -554,9 +563,9 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'unizero_multitask', 'sampled_unizero_multitask']: + # TODO: only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False total_transitions += 1 @@ -774,10 +783,16 @@ def _output_log(self, train_iter: int) -> None: for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, self._total_envstep_count) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index f7cc39047..3908d0e4a 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -15,6 +15,7 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +import threading class MuZeroEvaluator(ISerialEvaluator): @@ -56,6 +57,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -70,7 +72,10 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - instance_name (:obj:`str`): Name of this evaluator instance. - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.stop_event = threading.Event() # Add stop event to handle timeouts + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name @@ -88,7 +93,19 @@ def __init__( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name ) else: - self._logger, self._tb_logger = None, None # for close elegantly + # self._logger, self._tb_logger = None, None # for close elegantly + # ========== TODO: unizero_multitask ddp_v2 ======== + if tb_logger is not None: + self._logger, _ = build_logger( + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + + + self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self.reset(policy, env) @@ -101,6 +118,9 @@ def __init__( # ============================================================== self.policy_config = policy_config + # def stop(self): + # self.stop_event.set() + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: @@ -129,7 +149,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -210,10 +230,20 @@ def eval( - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ + if torch.cuda.is_available(): + print(f"=========in eval() Rank {get_rank()} ===========") + device = torch.cuda.current_device() + print(f"当前默认的 GPU 设备编号: {device}") + torch.cuda.set_device(get_rank()) + print(f"set device后的 GPU 设备编号: {get_rank()}") + # the evaluator only works on rank0 episode_info = None stop_flag = False - if get_rank() == 0: + # ======== TODO: unizero_multitask ddp_v2 ======== + # if get_rank() == 0: + if get_rank() >= 0: + if n_episode is None: n_episode = self._default_n_episode assert n_episode is not None, "please indicate eval n_episode" @@ -222,7 +252,7 @@ def eval( env_nums = self._env.env_num self._env.reset() - self._policy.reset() + self._policy.reset(task_id=self.task_id) # initializations init_obs = self._env.ready_obs @@ -250,7 +280,8 @@ def eval( GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] for i in range(env_nums): @@ -263,6 +294,12 @@ def eval( eps_steps_lst = np.zeros(env_nums) with self._timer: while not eval_monitor.is_finished(): + + # Check if stop_event is set (timeout occurred) + if self.stop_event.is_set(): + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + # Get current ready env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) @@ -284,7 +321,13 @@ def eval( # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + # policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, task_id=self.task_id) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} @@ -328,7 +371,7 @@ def eval( eps_steps_lst[env_id] += 1 if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], @@ -389,7 +432,8 @@ def eval( game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset( @@ -426,14 +470,23 @@ def eval( episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self._logger.info(self._logger.get_tabulate_vars_hor(info)) for k, v in info.items(): if k in ['train_iter', 'ckpt_name', 'each_reward']: continue if not np.isscalar(v): continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + envstep) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) @@ -451,12 +504,16 @@ def eval( ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) - if get_world_size() > 1: - objects = [stop_flag, episode_info] - broadcast_object_list(objects, src=0) - stop_flag, episode_info = objects + # ========== TODO: unizero_multitask ddp_v2 ======== + # if get_world_size() > 1: + # objects = [stop_flag, episode_info] + # print(f'rank {self._rank}, self.task_id: {self.task_id}') + # print('before broadcast_object_list') + # broadcast_object_list(objects, src=0) + # print('evaluator after broadcast_object_list') + # stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: episode_info['trajectory'] = game_segments - return stop_flag, episode_info + return stop_flag, episode_info \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 137d8ec89..ab33a85c7 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -46,19 +46,22 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: - Initialize the MuZeroSegmentCollector with the given parameters. + Initialize the MuZeroCollector with the given parameters. Arguments: - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): Namedtuple of the collection mode policy API. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. """ + self.task_id = task_id + self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -66,6 +69,10 @@ def __init__( self._end_flag = False self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + + self._world_size = get_world_size() if self._rank == 0: if tb_logger is not None: @@ -83,7 +90,9 @@ def __init__( self._logger, _ = build_logger( path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) - self._tb_logger = None + # =========== TODO: for unizero_multitask ddp_v2 ======== + self._tb_logger = tb_logger + self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -124,7 +133,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._logger.debug( 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -384,7 +393,8 @@ def collect(self, GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] # stacked observation windows in reset stage for init game_segments @@ -442,6 +452,8 @@ def collect(self, # ready_env_id = set(obs.keys()) stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + + stack_obs = list(stack_obs.values()) self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} @@ -460,8 +472,14 @@ def collect(self, # ============================================================== # Key policy forward step # ============================================================== - # logging.info(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + # print(f'ready_env_id:{ready_env_id}') + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) + # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -624,7 +642,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -686,7 +705,7 @@ def collect(self, # Env reset is done by env_manager automatically # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ - self._policy.reset([env_id]) + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) ready_env_id.remove(env_id) @@ -695,7 +714,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -716,11 +736,13 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) + # TODO: for atari multitask new ddp pipeline # reduce data when enables DDP - if self._world_size > 1: - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration @@ -736,8 +758,9 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): Current training iteration number for logging context. """ - if self._rank != 0: - return + # TODO: for atari multitask new ddp pipeline + # if self._rank != 0: + # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) @@ -770,11 +793,20 @@ def _output_log(self, train_iter: int) -> None: if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() + print(f'collector output_log: rank {self._rank}, self.task_id: {self.task_id}') self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + self._total_envstep_count) diff --git a/requirements.txt b/requirements.txt index 831ae67c5..beec0c5ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ moviepy pytest line_profiler xxhash +simple_parsing einops diff --git a/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py b/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py new file mode 100644 index 000000000..ce486a050 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py @@ -0,0 +1,260 @@ +from easydict import EasyDict + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== TODO: only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # =========== TODO: debug =========== + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_mt_8games/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + # import sys + # import os + # # 添加项目根目录到 PYTHONPATH + # sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + from lzero.entry import train_muzero_multitask_segment_noddp + import argparse + + parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + seed = args.seed + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + # Start training + train_muzero_multitask_segment_noddp(configs, seed=seed, max_env_step=int(5e5)) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py new file mode 100644 index 000000000..c11790337 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py @@ -0,0 +1,275 @@ +# zoo/atari/config/atari_muzero_multitask_segment_8games_config.py + +from easydict import EasyDict +from copy import deepcopy +from atari_env_action_space_map import atari_env_action_space_map + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + multi_gpu=True, # ======== Very important for ddp ============= + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + allocated_batch_sizes=False, + # max_batch_size=max_batch_size, + max_batch_size=512,# TODO + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_mt_8games_ddp_8gpu_1129/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + + # parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + # parser.add_argument('--seed', type=int, default=0, help='Random seed') + # args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + + # seed = args.seed + seed = 0 + + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = 5e5 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + 或者使用 torchrun: + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + with DDPContext(): + train_muzero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index c7787831b..91517afd5 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -18,6 +18,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 5 +# max_env_step = int(2e5) +# reanalyze_ratio = 0.1 +# batch_size = 64 +# num_unroll_steps = 10 +# replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -32,6 +43,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py index 82f64f141..26ecff41c 100644 --- a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py @@ -55,13 +55,20 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, num_layers=2, num_heads=8, embed_dim=768, obs_type='image', env_num=max(collector_env_num, evaluator_env_num), + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. diff --git a/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py new file mode 100644 index 000000000..f78e9e33f --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py @@ -0,0 +1,158 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=1000, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): + configs = [] + exp_name_prefix = f'data_unizero_mt/{len(env_id_list)}games_1-encoder-{norm_type}_26-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_serial + + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + action_space_size = 18 + seed = 0 + collector_env_num = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0.0 + max_batch_size = 1000 + batch_size = [int(max_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) + + # train_unizero_multitask_serial(configs[:4], seed=seed, max_env_step=max_env_step) # multitask learning on first four tasks + train_unizero_multitask_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py new file mode 100644 index 000000000..20a85207e --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py @@ -0,0 +1,171 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu-26game_1201/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + # Hyperparameters + action_space_size = 18 + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + num_segments = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = int(5e5) + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in env_id_list] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + for seed in [0]: # Seed for reproducibility + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + + # Training with distributed data parallel (DDP) + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py new file mode 100644 index 000000000..0bcca387b --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py @@ -0,0 +1,194 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Create the configuration for a specific environment. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), # Input observation dimensions + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + # ===== TODO: only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, # Encoder configuration + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, # Transformer layers + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # Update steps per collection + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Generate configurations for all environments in `env_id_list`. + """ + configs = [] + exp_name_prefix = f'data_unizero_mt_segcollect_1107/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer4-nh8_maxbs-640_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + """ + Create the environment manager configuration. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_segment_serial + + # Define environments + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + # Define hyperparameters + action_space_size = 18 + seed = 0 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + max_batch_size = 640 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4, 4, 4, 4] + + # Generate configurations + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + + # Train using the generated configurations + train_unizero_multitask_segment_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py new file mode 100644 index 000000000..cdddabc14 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py @@ -0,0 +1,170 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(30), + # eval_max_episode_steps=int(30), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + num_layers=12, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + os.environ["NCCL_TIMEOUT"] = "3600000000" + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) # train on the first four tasks diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py new file mode 100644 index 000000000..790e06a37 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py @@ -0,0 +1,165 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Create a configuration for a specific environment. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=160, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Generate configurations for all tasks in the environment list. + """ + configs = [] + exp_name_prefix = f'data_unizero_mt_serial/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer4-nh8-lsd768_mbs-320_upc160_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + """ + Create the environment manager configuration. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_segment_serial + + # Define the environment list + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + + # Define key parameters + action_space_size = 18 + seed = 0 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0. + max_batch_size = 320 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in env_id_list] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # Generate configurations and start training + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + train_unizero_multitask_segment_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py new file mode 100644 index 000000000..6bb9972c9 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -0,0 +1,166 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, + MoCo_rho=0, calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + env_id_list=env_id_list, + analysis_tsne=True, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, # Transformer layers + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu_eval-latent_state_tsne/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer8-nh24-lsd768_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This program is designed to obtain the t-SNE of the latent states in 8games multi-task learning. + + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_eval + from ding.utils import DDPContext + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + + for seed in [0]: + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = int(4*len(env_id_list)) + batch_size = [4 for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1/50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, + evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, seed, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + + # Pretrained model paths + # 8games + pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + # 26games + # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu-26game_1127/26games_brf0.02_nlayer8-nhead24_seed0/26games_brf0.02_1-encoder-LN-res2-channel256_gsl20_26-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed0/Pong_unizero-mt_seed0/ckpt/iteration_150000.pth.tar' + + with DDPContext(): + train_unizero_multitask_segment_eval(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py new file mode 100644 index 000000000..aa11a8120 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -0,0 +1,169 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( # Gradient correction parameters + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-2gpu_1201/finetune_pong/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments, + total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + from easydict import EasyDict + + env_id_list = ['PongNoFrameskip-v4'] # Debug setup + action_space_size = 18 + + # NCCL environment setup + import os + os.environ["NCCL_TIMEOUT"] = "3600000000" + + for seed in [0, 1, 2]: + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size) + + pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 67a27c5f0..d7d0174c6 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -24,6 +24,8 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. @@ -156,6 +158,7 @@ def step(self, action: int) -> BaseEnvTimestep: observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return + print(f'one episode of {self.cfg.env_id} done') return BaseEnvTimestep(observation, self.reward, done, info) diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..265ef31ac 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -93,9 +93,9 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gym.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') + env = gym.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) assert 'NoFrameskip' in env.spec.id if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: diff --git a/zoo/box2d/box2d_suz_multitask.py b/zoo/box2d/box2d_suz_multitask.py new file mode 100644 index 000000000..cf87e189d --- /dev/null +++ b/zoo/box2d/box2d_suz_multitask.py @@ -0,0 +1,179 @@ +from easydict import EasyDict +from copy import deepcopy +import torch +def create_config(env_id, observation_shapes, action_space_sizes, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + continuous=True, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000,),),), # default is 10000 + grad_correct_params=dict( + # for MoCo + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + # for CAGrad + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shapes=observation_shapes, + action_space_size=4, + action_space_sizes=action_space_sizes, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + obs_type='vector', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=1e-4, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + norm_type=norm_type, + bound_type=None, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda' if torch.cuda.is_available() else 'cpu', + action_space_size=action_space_sizes, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, # NOTE + moe_in_transformer=False, # NOTE + multiplication_moe_in_transformer=False, # NOTE + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=True, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=1e-4, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): + configs = [] + exp_name_prefix = f'data_unizero_mt_box2d/{len(env_id_list)}games_cont_action_seed{seed}/' + + for task_id, (env_id, observation_shape, action_space_size) in enumerate(zip(env_id_list, observation_shapes, action_space_sizes)): + config = create_config( + env_id, + observation_shapes, # TODO + action_space_sizes, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('-v')[0]}_unizero_mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager(env_name=env_id)]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='box2d', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env', 'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + +def create_env_manager(env_name: str): + if env_name == 'LunarLanderContinuous-v2': + return EasyDict(dict( + env=dict( + type='lunarlander', + import_names=[f'zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + elif env_name == 'BipedalWalker-v3': + return EasyDict(dict( + env=dict( + type='bipedalwalker', + import_names=[f'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask + + env_id_list = [ + 'LunarLanderContinuous-v2', + 'BipedalWalker-v3', + ] + + observation_shapes = [ + 8, # LunarLanderContinuous-v2 + 24, # BipedalWalker-v3 + ] + + action_space_sizes = [ + 2, # LunarLanderContinuous-v2 + 4, # BipedalWalker-v3 + ] + + seed = 0 + collector_env_num = 6 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + max_batch_size = 1000 + batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + configs = generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) + + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py new file mode 100644 index 000000000..4f5ca5bda --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py @@ -0,0 +1,132 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +env_id = 'cartpole-swingup' # You can specify any DMC task here +action_space_size = dmc_state_env_action_space_map[env_id] +obs_space_size = dmc_state_env_obs_space_map[env_id] +print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + +domain_name = env_id.split('-')[0] +task_name = env_id.split('-')[1] + +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e6) +reanalyze_ratio = 0 +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'LN' +seed = 0 + +# for debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 2 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +dmc2gym_pixels_cont_sampled_unizero_config = dict( + exp_name=f'data_sampled_unizero_0901/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}', + env=dict( + env_id='dmc2gym-v0', + continuous=True, + domain_name=domain_name, + task_name=task_name, + from_pixels=True, # pixel/image obs + frame_skip=2, + warp_frame=True, + scale=True, + frame_stack_num=1, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 84, 84), + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + world_model_cfg=dict( + obs_type='image', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=5e-3, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + fixed_sigma_value=0.3, + bound_type=None, + model_type='conv', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + # device='cpu', + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + game_segment_length=100, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +dmc2gym_pixels_cont_sampled_unizero_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_config) +main_config = dmc2gym_pixels_cont_sampled_unizero_config + +dmc2gym_pixels_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), +) +dmc2gym_pixels_cont_sampled_unizero_create_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_create_config) +create_config = dmc2gym_pixels_cont_sampled_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + + train_unizero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_config.py similarity index 100% rename from zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py rename to zoo/dmc2gym/config/dmc2gym_state_suz_config.py diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py new file mode 100644 index 000000000..a9d951c9b --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py @@ -0,0 +1,351 @@ +from easydict import EasyDict +from typing import List + +def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + # ===== TODO: only for debug ===== + # frame_skip=10, # 10 + frame_skip=2, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, # As per single-task config + # ===== only for debug ===== + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # TODO: enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + # Example gradient correction parameters, adjust as needed + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + # use_moco=True, # ==============TODO============== + use_moco=False, # ==============TODO============== + task_num=len(env_id_list), + task_id=0, # To be set per task + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + # use_shared_projection=True, # TODO + use_shared_projection=False, + task_embed_option='concat_task_embed', # ==============TODO============== + use_task_embed=True, # TODO + # use_task_embed=False, # ==============TODO============== + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', # TODO + # num_layers=2, + # num_layers=4, # TODO + + num_layers=8, # TODO + num_heads=8, + + # num_layers=12, # TODO + # num_heads=12, + + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + # use_task_exploitation_weight=True, # TODO + use_task_exploitation_weight=False, # TODO + # task_complexity_weight=True, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + # train_start_after_envsteps=int(2e3), + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: 80 + update_per_collect=200, # TODO: 8*100*0.25=200 + # update_per_collect=80, # TODO: 8*100*0.1=80 + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(5e3), + eval_freq=int(4e3), + # eval_freq=int(2e3), + # eval_freq=int(1e3), # TODO: task_weight=========== + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list: List[str], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int): + configs = [] + # TODO: debug + exp_name_prefix = f'data_suz_mt_20250113/ddp_7gpu_nlayer8_upc200_no-taskweight-obsloss-temp1_concat-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + # exp_name_prefix = f'data_suz_mt_20250113_debug/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + # exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + config = create_config( + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=7 --master_port=29503 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py + + python -m torch.distributed.launch --nproc_per_node=1 --master_port=29503 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # os.environ["NCCL_TIMEOUT"] = "3600000000" + + # 定义环境列表 + env_id_list = [ + # 'acrobot-swingup', # 6 1 + 'cartpole-swingup', # 5 1 + ] + + + + # env_id_list = [ + # # 'acrobot-swingup', + # # 'cartpole-balance', + # # 'cartpole-balance_sparse', + # # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # # "ball_in_cup-catch", + # "finger-spin", + # ] + + # env_id_list = [ + # # 'acrobot-swingup', + # # 'cartpole-balance', + # # 'cartpole-balance_sparse', + # # 'cartpole-swingup', + # # 'cartpole-swingup_sparse', + # # 'cheetah-run', + # # "ball_in_cup-catch", + # "finger-spin", + # ] + + # DMC 8games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + ] + + # DMC 18games + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # "finger-turn_easy", + # "finger-turn_hard", + # 'hopper-hop', + # 'hopper-stand', + # 'pendulum-swingup', + # 'reacher-easy', + # 'reacher-hard', + # 'walker-run', + # 'walker-stand', + # 'walker-walk', + # ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + # max_env_step = int(5e5) + max_env_step = int(1e6) + reanalyze_ratio = 0.0 + + # nlayer=4 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + # nlayer=8/12 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 5 + # batch_size = [10 for _ in range(len(env_id_list))] + # ======================================= + + seed = 0 # You can iterate over multiple seeds if needed + + configs = generate_configs( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # 如果只想训练部分任务,可以修改 configs,例如: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config_cprofile.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config_cprofile.py new file mode 100644 index 000000000..361dcc947 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config_cprofile.py @@ -0,0 +1,336 @@ +from easydict import EasyDict +from typing import List + +def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + # ===== only for debug ===== + # frame_skip=10, # 100 + frame_skip=2, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, # As per single-task config + # ===== only for debug ===== + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # TODO: nable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + # Example gradient correction parameters, adjust as needed + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, # To be set per task + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + # use_shared_projection=True, # TODO + use_shared_projection=False, + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', # TODO + # num_layers=2, + # num_layers=4, # TODO + + num_layers=8, # TODO + num_heads=8, + + # num_layers=12, # TODO + # num_heads=12, + + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + # task_complexity_weight=True, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + # train_start_after_envsteps=int(2e3), + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: 80 + update_per_collect=200, # TODO: 8*100*0.25=200 + # update_per_collect=80, # TODO: 8*100*0.1=80 + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + # eval_freq=int(5e3), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list: List[str], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int): + configs = [] + # exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_upc80_notusp_taskweight_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20250107_cprofile/ddp_8gpu_nlayer8_upc200_notusp_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + config = create_config( + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config_cprofile.py + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # os.environ["NCCL_TIMEOUT"] = "3600000000" + + # 定义环境列表 + env_id_list = [ + 'acrobot-swingup', # 6 1 + 'cartpole-swingup', # 5 1 + ] + + # DMC 8games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + ] + + # DMC 18games + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # "finger-turn_easy", + # "finger-turn_hard", + # 'hopper-hop', + # 'hopper-stand', + # 'pendulum-swingup', + # 'reacher-easy', + # 'reacher-hard', + # 'walker-run', + # 'walker-stand', + # 'walker-walk', + # ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + # nlayer=4 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + # nlayer=8/12 + total_batch_size = 256 + batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4 for _ in range(len(env_id_list))] + # ======================================= + + seed = 0 # You can iterate over multiple seeds if needed + + configs = generate_configs( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + # with DDPContext(): + # train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # # 如果只想训练部分任务,可以修改 configs,例如: + # # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) + + + def run(max_env_step: int): + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # 如果只想训练部分任务,可以修改 configs,例如: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) + + import cProfile + from ding.utils import EasyTimer, set_pkg_seed, get_rank + + def profile_main_process_only(max_env_step: int): + # 初始化分布式环境 + rank = get_rank() + + # 仅在主进程运行 cProfile + if rank == 0: + cProfile.run(f"run({max_env_step})", filename="ddp_main_process.prof", sort="cumulative") + else: + run(max_env_step) # 其他进程正常运行 + + max_env_step = 1000 + profile_main_process_only(max_env_step) + + # import cProfile + # cProfile.run(f"run({1000})", filename="ddp_8gpu_nlayer8_upc200_cprofile_1k_envstep", sort="cumulative") \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py new file mode 100644 index 000000000..74f713ebc --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py @@ -0,0 +1,334 @@ +from easydict import EasyDict +from typing import List + +def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + # ===== only for debug ===== + # frame_skip=100, # 100 + frame_skip=2, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, # As per single-task config + # ===== only for debug ===== + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # TODO: nable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + # Example gradient correction parameters, adjust as needed + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_moco=True, # ==============TODO============== + task_num=len(env_id_list), + task_id=0, # To be set per task + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + # use_shared_projection=True, # TODO + use_shared_projection=False, + # use_task_embed=True, # TODO + use_task_embed=False, # ==============TODO============== + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', # TODO + # num_layers=2, + # num_layers=4, # TODO + + num_layers=8, # TODO + num_heads=8, + + # num_layers=12, # TODO + # num_heads=12, + + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + # task_complexity_weight=True, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + # train_start_after_envsteps=int(2e3), + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: 80 + update_per_collect=200, # TODO: 8*100*0.25=200 + # update_per_collect=80, # TODO: 8*100*0.1=80 + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + # eval_freq=int(5e3), + eval_freq=int(4e3), + # eval_freq=int(2e3), + # eval_freq=int(1e3), # TODO =========== + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list: List[str], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int): + configs = [] + # TODO: debug + # exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + exp_name_prefix = f'data_suz_mt_20250113/ddp_1gpu-moco_nlayer8_upc80_notaskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + # exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/' + + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + config = create_config( + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # os.environ["NCCL_TIMEOUT"] = "3600000000" + + # 定义环境列表 + env_id_list = [ + 'acrobot-swingup', # 6 1 + 'cartpole-swingup', # 5 1 + ] + + + + # env_id_list = [ + # # 'acrobot-swingup', + # # 'cartpole-balance', + # # 'cartpole-balance_sparse', + # # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # # "ball_in_cup-catch", + # "finger-spin", + # ] + + # DMC 8games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + ] + + # DMC 18games + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # "finger-turn_easy", + # "finger-turn_hard", + # 'hopper-hop', + # 'hopper-stand', + # 'pendulum-swingup', + # 'reacher-easy', + # 'reacher-hard', + # 'walker-run', + # 'walker-stand', + # 'walker-walk', + # ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + # max_env_step = int(5e5) + max_env_step = int(1e6) + + reanalyze_ratio = 0.0 + + # nlayer=4 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + # nlayer=8/12 + total_batch_size = 256 + batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # batch_size = [4 for _ in range(len(env_id_list))] + # ======================================= + + seed = 0 # You can iterate over multiple seeds if needed + + configs = generate_configs( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # 如果只想训练部分任务,可以修改 configs,例如: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py new file mode 100644 index 000000000..12cade98c --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -0,0 +1,294 @@ +from easydict import EasyDict +from typing import List + +def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=2, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, # As per single-task config + # ===== only for debug ===== + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # TODO: nable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + # Example gradient correction parameters, adjust as needed + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, # To be set per task + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', # TODO + # num_layers=2, + # num_layers=4, # TODO + num_layers=8, # TODO + num_heads=8, + # num_layers=12, # TODO + # num_heads=12, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + # train_start_after_envsteps=int(2e3), + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: 80 + update_per_collect=200, # TODO: 8*100*0.25=200 + # update_per_collect=80, # TODO: 8*100*0.1=80 + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + # eval_freq=int(5e3), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list: List[str], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int): + configs = [] + exp_name_prefix = f'data_suz_mt_20250102/ddp_8gpu_nlayer8_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs256_seed{seed}/' + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + config = create_config( + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + os.environ["NCCL_TIMEOUT"] = "3600000000" + + + # DMC 8games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + ] + + # DMC 18games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + "finger-turn_easy", + "finger-turn_hard", + 'hopper-hop', + 'hopper-stand', + 'pendulum-swingup', + 'reacher-easy', + 'reacher-hard', + 'walker-run', + 'walker-stand', + 'walker-walk', + ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + # nlayer=4 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + # nlayer=8/12 + total_batch_size = 256 + batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4 for _ in range(len(env_id_list))] + # ======================================= + + seed = 0 # You can iterate over multiple seeds if needed + + configs = generate_configs( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # 如果只想训练部分任务,可以修改 configs,例如: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py new file mode 100644 index 000000000..7ac3e6d12 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py @@ -0,0 +1,263 @@ +from dizoo.classic_control.pendulum.config.pendulum_ibc_config import multi_gpu +from easydict import EasyDict +from copy import deepcopy + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + +def create_config(env_id, action_space_size_list, observation_shape_list, collector_env_num, evaluator_env_num, + n_episode, num_simulations, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, seed, update_per_collect): + """ + Create a multi-task configuration for DMC environments. + """ + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=2, + continuous=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + save_replay_gif=False, + replay_path_gif='./replay_gif', + ), + policy=dict( + multi_gpu=False, # TODO: nable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # 每个时间步有2个token: obs 和 action + context_length=2 * infer_context_length, + device='cuda', + # device='cpu', # TODO + # num_layers=2, + num_layers=4, # TODO + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=100, + update_per_collect=update_per_collect, # TODO + replay_ratio=0.25, # TODO + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + n_episode=n_episode, + replay_buffer_size=int(1e6), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + seed=seed, + )) + + +def generate_configs(env_id_list, seed, collector_env_num, evaluator_env_num, n_episode, num_simulations, + batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, update_per_collect): + """ + Generate configurations for all DMC tasks in the environment list. + """ + configs = [] + exp_name_prefix = f'data_suz_mt_20241224/{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}/' + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size_list, + observation_shape_list, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments, + seed, + update_per_collect + ) + + # 设置多任务相关的配置 + config.policy.task_num = len(env_id_list) + config.policy.task_id = task_id + + # 生成实验名称前缀 + config.exp_name = exp_name_prefix + f'mt_unizero_seed{seed}' + configs.append([task_id, [config, create_env_manager()]]) + return configs + # return [[i, [deepcopy(config), create_env_manager()]] for i in range(len(env_id_list))] + + +def create_env_manager(): + """ + Create the environment manager configuration. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_segment_serial + + import argparse + + parser = argparse.ArgumentParser(description='Train multi-task DMC Unizero model.') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + args = parser.parse_args() + + # 定义环境列表 + env_id_list = [ + 'cartpole-swingup', + 'cartpole-balance', + # 'cheetah-run', + # 'walker-walk', + # 'hopper-hop', + # 'humanoid-walk', + # 'quadruped-run', + # 'finger-spin', + ] + + # DMC 18games + env_id_list = [ + 'acrobot-swingup', + 'cartpole-balance', + 'cartpole-balance_sparse', + 'cartpole-swingup', + 'cartpole-swingup_sparse', + 'cheetah-run', + "ball_in_cup-catch", + "finger-spin", + "finger-turn_easy", + "finger-turn_hard", + 'hopper-hop', + 'hopper-stand', + 'pendulum-swingup', + # 'quadruped-run', + # 'quadruped-walk', + 'reacher-easy', + 'reacher-hard', + 'walker-run', + 'walker-stand', + 'walker-walk', + # 'humanoid-run', + ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + # 定义关键参数 + seed = args.seed + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + batch_size = [64 for _ in range(len(env_id_list))] + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + max_env_step = int(5e5) + update_per_collect = 100 + + # ========== TODO: debug config ============ + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 2 + # batch_size = [4,4] # 可以根据需要调整或者设置为列表 + # update_per_collect = 1 + + # 生成配置 + configs = generate_configs( + env_id_list=env_id_list, + seed=seed, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + update_per_collect=update_per_collect + ) + + # 启动多任务训练 + train_unizero_multitask_segment_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config_2.py b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config_2.py new file mode 100644 index 000000000..c786ff4c5 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config_2.py @@ -0,0 +1,228 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +def main(env_id, seed): + action_space_size = dmc_state_env_action_space_map[env_id] + obs_space_size = dmc_state_env_obs_space_map[env_id] + print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + + continuous_action_space = True + K = 20 # num_of_sampled_actions + # K = 16 # num_of_sampled_actions + + collector_env_num = 8 + n_episode = 8 + num_segments = 8 + game_segment_length = 125 + # game_segment_length = 500 + + # collector_env_num = 16 + # n_episode = 16 + # num_segments = 16 + # game_segment_length = 125 + + # collector_env_num = 16 + # n_episode = 16 + # num_segments = 16 + # game_segment_length = 125 + + + evaluator_env_num = 3 + num_simulations = 50 # TODO + + # max_env_step = int(5e5) + max_env_step = int(1e6) + # max_env_step = int(3e6) # TODO + + reanalyze_ratio = 0 + batch_size = 64 + num_layers = 2 + # num_layers = 4 + + num_unroll_steps = 5 + # num_unroll_steps = 10 + infer_context_length = 2 + + # replay_ratio = 0.25 + # num_unroll_steps = 10 + # infer_context_length = 4 + + norm_type = 'LN' + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/1000000000 # TODO + # replay_ratio = 0.1 + replay_ratio = 0.25 + + + # buffer_reanalyze_freq = 1/10 + # replay_ratio = 0.1 + + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=0.75 + + # for debug + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 3 + # batch_size = 3 + # reanalyze_batch_size = 1 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + dmc2gym_state_cont_sampled_unizero_config = dict( + env=dict( + env_id='dmc2gym-v0', + domain_name=domain_name, + task_name=task_name, + from_pixels=False, # vector/state obs + # from_pixels=True, # vector/state obs + # frame_skip=2, + frame_skip=8, + continuous=True, + save_replay_gif=False, + replay_path_gif='./replay_gif', + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000,),),), # default is 10000 + model=dict( + observation_shape=obs_space_size, + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + model_type='mlp', + world_model_cfg=dict( + num_simulations=num_simulations, + policy_loss_type='kl', # 'simple' + # policy_loss_type='simple', # 'simple' + obs_type='vector', + num_unroll_steps=num_unroll_steps, + # policy_entropy_weight=0, + # policy_entropy_weight=5e-3, + policy_entropy_weight=5e-2, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + # sigma_type='fixed', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, # original + # embed_dim=512, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_root_value=False, + use_augmentation=False, + use_priority=False, + env_type='not_board_games', + replay_ratio=replay_ratio, + batch_size=batch_size, + discount_factor=0.99, + # discount_factor=1, + # td_steps=5, + # td_steps=10, + td_steps=game_segment_length, # TODO + + lr_piecewise_constant_decay=False, + learning_rate=1e-4, + grad_clip_value=5, + # grad_clip_value=0.3, # TODO + # manual_temperature_decay=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + # replay_buffer_size=int(5e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for ReZero ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, # 1 means reanalyze one times per epoch, 2 means reanalyze one times each two epoch + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + ) + + dmc2gym_state_cont_sampled_unizero_config = EasyDict(dmc2gym_state_cont_sampled_unizero_config) + main_config = dmc2gym_state_cont_sampled_unizero_config + + dmc2gym_state_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), + ) + dmc2gym_state_cont_sampled_unizero_create_config = EasyDict(dmc2gym_state_cont_sampled_unizero_create_config) + create_config = dmc2gym_state_cont_sampled_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_unizero_segment + main_config.exp_name=f'data_suz_1216/dmc2gym_{env_id}_state_cont_suz_fs8_act-simnorm_td{game_segment_length}_dc099_learn-sigma_gcv5_rbs1e6_no-corlr_embed768_temp2.5e4_pew5e-2_19prior1flatten_obs10value01_clamp4_brf{buffer_reanalyze_freq}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}' + + train_unizero_segment([main_config, create_config], model_path=main_config.policy.model_path, seed=seed, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + + parser.add_argument('--env', type=str, help='The environment to use', default='cartpole-swingup') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + + args = parser.parse_args() + + # args.env = 'cheetah-run' + # args.env = 'walker-walk' + # args.env = 'finger-spin' + # args.env = 'pendulum-swingup' + + # args.env = 'hopper-hop' + # args.env = 'acrobot-swingup' + + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py index 6a6d669a0..2be2ee5cc 100644 --- a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py +++ b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py @@ -19,6 +19,8 @@ from gym.spaces import Box from matplotlib import animation import imageio +import logging + def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: @@ -391,6 +393,7 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: self._frames.append(image_obs) if done: + logging.info(f'one episode done! episode return: {self._eval_episode_return}') info['eval_episode_return'] = self._eval_episode_return if self._save_replay_gif: @@ -399,7 +402,8 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', + self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -464,7 +468,7 @@ def __repr__(self) -> str: String representation of the environment. """ return "LightZero DMC2Gym Env({}:{})".format(self._cfg["domain_name"], self._cfg["task_name"]) - + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') @@ -479,4 +483,4 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps cfg.is_eval = True - return [cfg for _ in range(evaluator_env_num)] + return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_config_debug.py b/zoo/jericho/configs/jericho_unizero_config_debug.py new file mode 100644 index 000000000..cb7a62d86 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_config_debug.py @@ -0,0 +1,174 @@ +import os +from easydict import EasyDict +# import os +# os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub" + +def main(env_id='detective.z5', seed=0): + # action_space_size = 50 + action_space_size = 10 + max_steps = 50 + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + # collector_env_num = 8 + # n_episode = 8 + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(10e6) + + # batch_size = 8 + # num_unroll_steps = 10 + # infer_context_length = 4 + + batch_size = 16 + num_unroll_steps = 5 + infer_context_length = 2 + + num_layers = 2 + replay_ratio = 0.25 + update_per_collect = None # NOTE: very important for ddp + embed_dim = 768 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + model_name = 'BAAI/bge-base-en-v1.5' + # model_name = 'google-bert/bert-base-uncased' + # =========== TODO: only for debug =========== + collector_env_num = 2 + num_segments = 2 + game_segment_length = 20 + evaluator_env_num = 2 + max_env_step = int(5e5) + batch_size = 10 + num_simulations = 5 + num_unroll_steps = 5 + infer_context_length = 2 + max_steps = 10 + num_layers = 1 + replay_ratio = 0.05 + embed_dim = 768 + # TODO: MCTS内部的action_space受限于root节点的legal action + + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + # game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + multi_gpu=False, # ======== Very important for ddp ============= + # multi_gpu=True, # ======== Very important for ddp ============= + # default is 10000 + use_wandb=False, + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + # encoder_url='google-bert/bert-base-uncased', + # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', + continuous_action_space=False, + world_model_cfg=dict( + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='text', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + update_per_collect=update_per_collect, + action_type='varied_action_space', + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + cos_lr_scheduler=True, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + num_simulations=num_simulations, + # num_segments=num_segments, + n_episode=n_episode, + train_start_after_envsteps=0, # TODO + # game_segment_length=game_segment_length, + # replay_buffer_size=int(1e6), + replay_buffer_size=int(1e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero_detective_debug/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_uz_nlayer{num_layers}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, + help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5' + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main(args.env, args.seed) diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py new file mode 100644 index 000000000..4a2e46e48 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -0,0 +1,196 @@ +import os +from easydict import EasyDict +# import os +# os.environ["HF_HOME"] = "/root/.cache/huggingface/hub" + +def main(env_id='detective.z5', seed=0): + # action_space_size = 50 + action_space_size = 20 + + max_steps = 60 # for detective + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + # collector_env_num = 4 + # num_segments = 4 + # n_episode = 4 + # batch_size = int(64//8) # TODO: for serve model + # num_unroll_steps = 10 + # infer_context_length = 4 + + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + # batch_size = int(16//8) # TODO: for local model + # batch_size = 4 # TODO: for each gpu + batch_size = 8 # TODO: for each gpu + num_unroll_steps = 10 + infer_context_length = 4 + + # num_unroll_steps = 5 + # infer_context_length = 2 + + update_per_collect = 20 # NOTE: very important for ddp + game_segment_length = 20 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(10e6) + + num_layers = 2 + replay_ratio = 0.25 + embed_dim = 768 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/10000000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + model_name = 'BAAI/bge-base-en-v1.5' + # model_name = 'google-bert/bert-base-uncased' + # model_name = "/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594" + # =========== TODO: only for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # game_segment_length = 20 + # evaluator_env_num = 2 + # max_env_step = int(5e5) + # batch_size = 4 + # num_simulations = 5 + # num_unroll_steps = 5 + # infer_context_length = 2 + # max_steps = 20 + # num_layers = 2 + # replay_ratio = 0.05 + # embed_dim = 768 + # update_per_collect = 2 # NOTE: very important for ddp + + # TODO: MCTS内部的action_space受限于root节点的legal action + + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + # tokenizer_path="google-bert/bert-base-uncased", + # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + # game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + multi_gpu=True, # ======== Very important for ddp ============= + # default is 10000 + use_wandb=False, + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + # encoder_url='google-bert/bert-base-uncased', + # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', + continuous_action_space=False, + world_model_cfg=dict( + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='text', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + update_per_collect=update_per_collect, + action_type='varied_action_space', + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + n_episode=n_episode, + train_start_after_envsteps=0, # TODO + game_segment_length=game_segment_length, + # replay_buffer_size=int(1e6), + replay_buffer_size=int(1e5), + eval_freq=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero_detective_4gpu_{model_name}_20250102/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + # import argparse + # parser = argparse.ArgumentParser(description='Process some environment.') + # parser.add_argument('--env', type=str, + # help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5' + # parser.add_argument('--seed', type=int, help='The seed to use', default=0) + # args = parser.parse_args() + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + export CUDA_VISIBLE_DEVICES=0,1,2,4,5,6,7 + python -m torch.distributed.launch --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + torchrun --nproc_per_node=2 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + + """ + # TODO ========== + import os + from ding.utils import DDPContext + with DDPContext(): + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main('detective.z5', 0) diff --git a/zoo/jericho/envs/jericho_env_noserve.py b/zoo/jericho/envs/jericho_env_noserve.py new file mode 100644 index 000000000..52a147469 --- /dev/null +++ b/zoo/jericho/envs/jericho_env_noserve.py @@ -0,0 +1,157 @@ +import copy +from typing import List + +import gym +import numpy as np +from transformers import AutoTokenizer +from ding.utils import ENV_REGISTRY +from ding.envs import BaseEnv, BaseEnvTimestep +from jericho import FrotzEnv + + +@ENV_REGISTRY.register('jericho') +class JerichoEnv(BaseEnv): + """ + Overview: + The environment for Jericho games. For more details about the game, please refer to the \ + `Jericho `. + """ + tokenizer = None + + def __init__(self, cfg): + self.cfg = cfg + self.max_steps = cfg.max_steps + self.game_path = cfg.game_path + self.max_action_num = cfg.max_action_num + self.max_seq_len = cfg.max_seq_len + + if JerichoEnv.tokenizer is None: + JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) + + self._env = FrotzEnv(self.game_path) + self._action_list = None + self.finished = False + self._init_flag = False + self.episode_return = 0 + self.env_step = 0 + + self.observation_space = gym.spaces.Dict() + self.action_space = gym.spaces.Discrete(self.max_action_num) + self.reward_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + + def prepare_obs(self, obs, return_str: bool = False): + if self._action_list is None: + self._action_list = self._env.get_valid_actions() + full_obs = obs + "\nValid actions: " + str(self._action_list) + if not return_str: + full_obs = JerichoEnv.tokenizer( + [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) + obs_attn_mask = full_obs['attention_mask'] + full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask + if len(self._action_list) <= self.max_action_num: + action_mask = [1] * len(self._action_list) + [0] * \ + (self.max_action_num - len(self._action_list)) + else: + action_mask = [1] * len(self._action_list) + + action_mask = np.array(action_mask, dtype=np.int8) + # return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} + return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask,'action_mask': action_mask, 'to_play': -1} + + def reset(self, return_str: bool = False): + initial_observation, info = self._env.reset() + self.finished = False + self._init_flag = True + self._action_list = None + self.episode_return = 0 + self.env_step = 0 + + return self.prepare_obs(initial_observation, return_str) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment. + """ + self._seed = seed + self._env.seed(seed) + + def close(self) -> None: + self._init_flag = False + + def __repr__(self) -> str: + return "LightZero Jericho Env" + + def step(self, action: int, return_str: bool = False): + try: + action_str = self._action_list[action] + except Exception as e: + # TODO: why exits illegal action + print('='*20) + print(e, 'action is illegal now we randomly choose a legal action!') + action = np.random.choice(len(self._action_list)) + action_str = self._action_list[action] + + observation, reward, done, info = self._env.step(action_str) + self.env_step += 1 + self.episode_return += reward + self._action_list = None + observation = self.prepare_obs(observation, return_str) + + # print(f'observation:{observation}, action:{action}, reward:{reward}') + # print(f'self._action_list:{self._action_list}, action:{action}, reward:{reward}') + + if self.env_step >= self.max_steps: + print('='*20) + print('one episode done!') + print(f'self._action_list:{self._action_list}, action:{action}, reward:{reward}') + done = True + + if done: + self.finished = True + info['eval_episode_return'] = self.episode_return + + return BaseEnvTimestep(observation, reward, done, info) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + # when in collect phase, sometimes we need to normalize the reward + # reward_normalize is determined by the config. + cfg.is_collect = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + # when in evaluate phase, we don't need to normalize the reward. + cfg.reward_normalize = False + cfg.is_collect = False + return [cfg for _ in range(evaluator_env_num)] + + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg = EasyDict( + dict( + max_steps=100, + # game_path="z-machine-games-master/jericho-game-suite/zork1.z5", + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5", + # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/905.z5", + max_action_num=50, + max_env_step=100, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512 + ) + ) + env = JerichoEnv(env_cfg) + obs = env.reset(return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + while True: + action_id = int(input('Please input the action id:')) + obs, reward, done, info = env.step(action_id, return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + if done: + break diff --git a/zoo/jericho/envs/jericho_env_output_str.py b/zoo/jericho/envs/jericho_env_output_str.py new file mode 100644 index 000000000..5b60e10e9 --- /dev/null +++ b/zoo/jericho/envs/jericho_env_output_str.py @@ -0,0 +1,162 @@ +""" +env返回的obs不是id 是string +""" +import copy +from typing import List + +import gym +import numpy as np +from transformers import AutoTokenizer +from ding.utils import ENV_REGISTRY +from ding.envs import BaseEnv, BaseEnvTimestep +from jericho import FrotzEnv + + +@ENV_REGISTRY.register('jericho') +class JerichoEnv(BaseEnv): + """ + Overview: + The environment for Jericho games. For more details about the game, please refer to the \ + `Jericho `. + """ + + def __init__(self, cfg): + self.cfg = cfg + self.max_steps = cfg.max_steps + self.game_path = cfg.game_path + self.max_action_num = cfg.max_action_num + self.max_seq_len = cfg.max_seq_len + + # 初始化分词器以供其他用途(如动作提示),不用于观察 + self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) + + self._env = FrotzEnv(self.game_path) + self._action_list = None + self.finished = False + self._init_flag = False + self.episode_return = 0 + self.env_step = 0 + + self.observation_space = gym.spaces.Dict({ + 'observation': gym.spaces.Text(), + 'action_mask': gym.spaces.Box(low=0, high=1, shape=(self.max_action_num,), dtype=np.int8), + 'to_play': gym.spaces.Discrete(1) + }) + self.action_space = gym.spaces.Discrete(self.max_action_num) + self.reward_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) + + def prepare_obs(self, obs, return_str: bool = True): + if self._action_list is None: + self._action_list = self._env.get_valid_actions() + full_obs = obs + "\nValid actions: " + str(self._action_list) + + # 始终返回字符串形式的观察 + if return_str: + return {'observation': full_obs, 'action_mask': self._create_action_mask(), 'to_play': -1} + else: + raise ValueError("Current configuration only supports string observations.") + + def _create_action_mask(self): + if len(self._action_list) <= self.max_action_num: + action_mask = [1] * len(self._action_list) + [0] * (self.max_action_num - len(self._action_list)) + else: + action_mask = [1] * self.max_action_num + return np.array(action_mask, dtype=np.int8) + + def reset(self): + initial_observation, info = self._env.reset() + self.finished = False + self._init_flag = True + self._action_list = None + self.episode_return = 0 + self.env_step = 0 + + return self.prepare_obs(initial_observation) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment. + """ + self._seed = seed + self._env.seed(seed) + + def close(self) -> None: + self._init_flag = False + + def __repr__(self) -> str: + return "LightZero Jericho Env" + + def step(self, action: int): + try: + action_str = self._action_list[action] + except IndexError as e: + # 处理非法动作 + print('='*20) + print(e, 'Action is illegal. Randomly choosing a legal action!') + action = np.random.choice(len(self._action_list)) + action_str = self._action_list[action] + + observation, reward, done, info = self._env.step(action_str) + self.env_step += 1 + self.episode_return += reward + self._action_list = None + + observation = self.prepare_obs(observation) + + if self.env_step >= self.max_steps: + print('='*20) + print('One episode done!') + done = True + + if done: + self.finished = True + info['eval_episode_return'] = self.episode_return + + return BaseEnvTimestep(observation, reward, done, info) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + cfg.is_collect = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + cfg.reward_normalize = False + cfg.is_collect = False + return [cfg for _ in range(evaluator_env_num)] + + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg = EasyDict( + dict( + max_steps=100, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5", + max_action_num=50, + max_seq_len=512, + tokenizer_path="google-bert/bert-base-uncased", + ) + ) + env = JerichoEnv(env_cfg) + obs = env.reset() + print(f'[OBS]:\n{obs["observation"]}') + while True: + try: + action_id = int(input('Please input the action id:')) + timestep = env.step(action_id) + obs = timestep.obs + reward = timestep.reward + done = timestep.done + info = timestep.info + print(f'[OBS]:\n{obs["observation"]}') + print(f'Reward: {reward}') + if done: + print('Episode finished.') + break + except Exception as e: + print(f'Error: {e}. Please try again.') \ No newline at end of file diff --git a/zoo/memory/config/memory_unizero_config_2.py b/zoo/memory/config/memory_unizero_config_2.py new file mode 100644 index 000000000..d2975599a --- /dev/null +++ b/zoo/memory/config/memory_unizero_config_2.py @@ -0,0 +1,161 @@ +from easydict import EasyDict +env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door' + +# memory_length = 2 # DEBUG +memory_length = 500 +# memory_length = 100 + +# max_env_step = int(1e6) # for visual_match [2, 60, 100] +max_env_step = int(3e6) # for visual_match [250,500] + +# embed_dim=256 +# num_layers=2 +# num_heads=2 + +# embed_dim=256 +# num_layers=8 +# num_heads=8 + + + +embed_dim=128 +num_layers=12 +num_heads=8 + +# memory_length = 500 +# max_env_step = int(3e6) # for visual_match [100,250,500] +# embed_dim=256 # for visual_match [100,250,500] +# num_layers=8 +# num_heads=8 +# ============================================================== +# begin of the most frequently changed config specified by the user, +# you should change the following configs to adapt to your own task +# ============================================================== +# for key_to_door +# num_unroll_steps = 30+memory_length +# game_segment_length = 30+memory_length # TODO: for "explore": 15 + +# for visual_match +num_unroll_steps = 16 + memory_length +game_segment_length = 16 + memory_length # TODO: for "explore": 1 +seed = 0 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 10 + +num_simulations = 50 +# update_per_collect = None +# update_per_collect = 10 +update_per_collect = 50 + +replay_ratio = 0.1 +# batch_size = 160 # 32*5 = 160 +batch_size = 64 # 32*5 = 160 +reanalyze_ratio = 0 +# td_steps = 10 +# td_steps = 5 +td_steps = game_segment_length + +# eps_greedy_exploration_in_collect = True + +# ========= only for debug =========== +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# num_simulations = 3 +# update_per_collect = None +# replay_ratio = 0.25 +# batch_size = 4 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +memory_unizero_config = dict( + # exp_name=f'data_{env_id}_1025_clean/{env_id}_memlen-{memory_length}_unizero_H{num_unroll_steps}_bs{batch_size}_seed{seed}', + env=dict( + stop_value=int(1e6), + env_id=env_id, + flatten_observation=False, # Whether to flatten the observation + max_frames={ + # ================ Maximum frames per phase ============= + # "explore": 15, # TODO: for key_to_door + "explore": 1, # for visual_match + "distractor": memory_length, + "reward": 15 + }, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000,),),), # default is 10000 + sample_type='episode', # NOTE: very important for memory env + model=dict( + observation_shape=(3, 5, 5), + action_space_size=4, + world_model_cfg=dict( + # In order to preserve the observation data of the first frame in a memory environment, + # we must ensure that we do not exceed the episode_length during the MCTS of the last frame. + # Therefore, we set a longer context_length than during training to ensure that the observation data of the first frame is not lost. + max_blocks=num_unroll_steps + 5, + max_tokens=2 * (num_unroll_steps + 5), + context_length=2 * (num_unroll_steps + 5), + # device='cpu', + device='cuda', + action_space_size=4, + num_layers=num_layers, + num_heads=num_heads, + embed_dim=embed_dim, + env_num=max(collector_env_num, evaluator_env_num), + obs_type='image_memory', + policy_entropy_weight=5e-3, + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + td_steps=td_steps, + # discount_factor=1, + discount_factor=0.99, + # cuda=True, + game_segment_length=game_segment_length, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=1e-4, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +memory_unizero_config = EasyDict(memory_unizero_config) +main_config = memory_unizero_config + +memory_unizero_create_config = dict( + env=dict( + type='memory_lightzero', + import_names=['zoo.memory.envs.memory_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +memory_unizero_create_config = EasyDict(memory_unizero_create_config) +create_config = memory_unizero_create_config + +if __name__ == "__main__": + # seeds = [0, 1, 2] # You can add more seed values here + seeds = [1] # You can add more seed values here + for seed in seeds: + main_config.exp_name = f'data_{env_id}_1202/{env_id}_memlen-{memory_length}_fixedcolormap_obs10value05_td{td_steps}_layer{num_layers}-head{num_heads}_unizero_edim{embed_dim}_H{num_unroll_steps}_bs{batch_size}_upc{update_per_collect}_seed{seed}' + # main_config.exp_name = f'data_{env_id}_1122/{env_id}_memlen-{memory_length}_randomcolormap/obs10value05_td{td_steps}_layer{num_layers}-head{num_heads}_unizero_edim{embed_dim}_H{num_unroll_steps}_bs{batch_size}_upc{update_per_collect}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/memory/config/memory_unizero_config_3.py b/zoo/memory/config/memory_unizero_config_3.py new file mode 100644 index 000000000..3d61f0ec6 --- /dev/null +++ b/zoo/memory/config/memory_unizero_config_3.py @@ -0,0 +1,159 @@ +from easydict import EasyDict +env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door' + +# memory_length = 2 # DEBUG +# memory_length = 500 +memory_length = 100 + +# max_env_step = int(1e6) # for visual_match [2, 60, 100] +max_env_step = int(3e6) # for visual_match [250,500] + +# embed_dim=256 +# num_layers=2 +# num_heads=2 + +# embed_dim=256 +# num_layers=8 +# num_heads=8 + +embed_dim=256 +num_layers=16 +num_heads=16 + +# memory_length = 500 +# max_env_step = int(3e6) # for visual_match [100,250,500] +# embed_dim=256 # for visual_match [100,250,500] +# num_layers=8 +# num_heads=8 +# ============================================================== +# begin of the most frequently changed config specified by the user, +# you should change the following configs to adapt to your own task +# ============================================================== +# for key_to_door +# num_unroll_steps = 30+memory_length +# game_segment_length = 30+memory_length # TODO: for "explore": 15 + +# for visual_match +num_unroll_steps = 16 + memory_length +game_segment_length = 16 + memory_length # TODO: for "explore": 1 +seed = 0 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 10 + +num_simulations = 50 +# update_per_collect = None +# update_per_collect = 10 +update_per_collect = 50 + +replay_ratio = 0.1 +# batch_size = 160 # 32*5 = 160 +batch_size = 64 # 32*5 = 160 +reanalyze_ratio = 0 +# td_steps = 10 +# td_steps = 5 +td_steps = game_segment_length + +# eps_greedy_exploration_in_collect = True + +# ========= only for debug =========== +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# num_simulations = 3 +# update_per_collect = None +# replay_ratio = 0.25 +# batch_size = 4 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== +memory_unizero_config = dict( + # exp_name=f'data_{env_id}_1025_clean/{env_id}_memlen-{memory_length}_unizero_H{num_unroll_steps}_bs{batch_size}_seed{seed}', + env=dict( + stop_value=int(1e6), + env_id=env_id, + flatten_observation=False, # Whether to flatten the observation + max_frames={ + # ================ Maximum frames per phase ============= + # "explore": 15, # TODO: for key_to_door + "explore": 1, # for visual_match + "distractor": memory_length, + "reward": 15 + }, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000,),),), # default is 10000 + sample_type='episode', # NOTE: very important for memory env + model=dict( + observation_shape=(3, 5, 5), + action_space_size=4, + world_model_cfg=dict( + # In order to preserve the observation data of the first frame in a memory environment, + # we must ensure that we do not exceed the episode_length during the MCTS of the last frame. + # Therefore, we set a longer context_length than during training to ensure that the observation data of the first frame is not lost. + max_blocks=num_unroll_steps + 5, + max_tokens=2 * (num_unroll_steps + 5), + context_length=2 * (num_unroll_steps + 5), + # device='cpu', + device='cuda', + action_space_size=4, + num_layers=num_layers, + num_heads=num_heads, + embed_dim=embed_dim, + env_num=max(collector_env_num, evaluator_env_num), + obs_type='image_memory', + policy_entropy_weight=5e-3, + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + td_steps=td_steps, + # discount_factor=1, + discount_factor=0.99, + # cuda=True, + game_segment_length=game_segment_length, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=1e-4, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +memory_unizero_config = EasyDict(memory_unizero_config) +main_config = memory_unizero_config + +memory_unizero_create_config = dict( + env=dict( + type='memory_lightzero', + import_names=['zoo.memory.envs.memory_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), +) +memory_unizero_create_config = EasyDict(memory_unizero_create_config) +create_config = memory_unizero_create_config + +if __name__ == "__main__": + seeds = [0, 1, 2] # You can add more seed values here + # seeds = [0] # You can add more seed values here + for seed in seeds: + main_config.exp_name = f'data_{env_id}_1126/{env_id}_memlen-{memory_length}_fixedcolormap_obs10value05_td{td_steps}_layer{num_layers}-head{num_heads}_unizero_edim{embed_dim}_H{num_unroll_steps}_bs{batch_size}_upc{update_per_collect}_seed{seed}' + # main_config.exp_name = f'data_{env_id}_1122/{env_id}_memlen-{memory_length}_randomcolormap/obs10value05_td{td_steps}_layer{num_layers}-head{num_heads}_unizero_edim{embed_dim}_H{num_unroll_steps}_bs{batch_size}_upc{update_per_collect}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) \ No newline at end of file