From 5679b6b0c94350e5cffae5c41fb5b457791ca306 Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Fri, 30 Aug 2024 21:18:29 +0800 Subject: [PATCH] mindspore: policy gradient --- xuance/mindspore/agents/core/on_policy.py | 6 +- .../learners/policy_gradient/a2c_learner.py | 74 ++++----- .../learners/policy_gradient/pg_learner.py | 58 ++++--- .../learners/policy_gradient/ppg_learner.py | 156 ++++++++++-------- xuance/mindspore/policies/categorical.py | 10 +- xuance/mindspore/policies/core.py | 156 ++++++++---------- xuance/mindspore/policies/gaussian.py | 65 ++++++-- xuance/mindspore/utils/__init__.py | 4 +- xuance/mindspore/utils/operations.py | 18 +- .../learners/policy_gradient/ppg_learner.py | 4 +- 10 files changed, 286 insertions(+), 265 deletions(-) diff --git a/xuance/mindspore/agents/core/on_policy.py b/xuance/mindspore/agents/core/on_policy.py index 143e46a5..7de1540c 100644 --- a/xuance/mindspore/agents/core/on_policy.py +++ b/xuance/mindspore/agents/core/on_policy.py @@ -76,13 +76,13 @@ def action(self, observations: np.ndarray, """ _, policy_dists, values = self.policy(observations) actions = policy_dists.stochastic_sample() - log_pi = policy_dists.log_prob(actions).detach().cpu().numpy() if return_logpi else None + log_pi = policy_dists.log_prob(actions).asnumpy() if return_logpi else None dists = split_distributions(policy_dists) if return_dists else None - actions = actions.detach().cpu().numpy() + actions = actions.asnumpy() if values is None: values = 0 else: - values = values.detach().cpu().numpy() + values = values.asnumpy() return {"actions": actions, "values": values, "dists": dists, "log_pi": log_pi} def get_aux_info(self, policy_output: dict = None): diff --git a/xuance/mindspore/learners/policy_gradient/a2c_learner.py b/xuance/mindspore/learners/policy_gradient/a2c_learner.py index 98a4ba20..aeeed493 100644 --- a/xuance/mindspore/learners/policy_gradient/a2c_learner.py +++ b/xuance/mindspore/learners/policy_gradient/a2c_learner.py @@ -5,57 +5,55 @@ from xuance.mindspore import ms, Module, Tensor, optim from xuance.mindspore.learners import Learner from argparse import Namespace +from mindspore.nn import MSELoss class A2C_Learner(Learner): - class ACNetWithLossCell(Module): - def __init__(self, backbone, ent_coef, vf_coef): - super(A2C_Learner.ACNetWithLossCell, self).__init__() - self._backbone = backbone - self._mean = ms.ops.ReduceMean(keep_dims=True) - self._loss_c = nn.MSELoss() - self._ent_coef = ent_coef - self._vf_coef = vf_coef - - def construct(self, x, a, adv, r): - _, act_probs, v_pred = self._backbone(x) - log_prob = self._backbone.actor.log_prob(value=a, probs=act_probs) - loss_a = -self._mean(adv * log_prob) - loss_c = self._loss_c(logits=v_pred, labels=r) - loss_e = self._mean(self._backbone.actor.entropy(probs=act_probs)) - loss = loss_a - self._ent_coef * loss_e + self._vf_coef * loss_c - - return loss - def __init__(self, config: Namespace, policy: Module): super(A2C_Learner, self).__init__(config, policy) - self.vf_coef = vf_coef - self.ent_coef = ent_coef - self.clip_grad = clip_grad - # define mindspore trainer - self.loss_net = self.ACNetWithLossCell(policy, self.ent_coef, self.vf_coef) - # self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer) - self.policy_train = TrainOneStepCellWithGradClip(self.loss_net, optimizer, - clip_type=clip_type, clip_value=clip_grad) - self.policy_train.set_train() - - def update(self, obs_batch, act_batch, ret_batch, adv_batch): + self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5) + self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5, + total_iters=self.config.running_steps) + self.mse_loss = MSELoss() + self._mean = ms.ops.ReduceMean(keep_dims=True) + self.vf_coef = config.vf_coef + self.ent_coef = config.ent_coef + # Get gradient function + self.grad_fn = ms.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True) + self.policy.set_train() + + def forward_fn(self, x, a, adv, r): + _, a_dist, v_pred = self.policy(x) + log_prob = a_dist.log_prob(a) + loss_a = -self._mean(adv * log_prob) + loss_c = self.mse_loss(logits=v_pred, labels=r) + loss_e = self._mean(a_dist.entropy()) + loss = loss_a - self.ent_coef * loss_e + self.vf_coef * loss_c + + return loss, loss_a, loss_e, loss_c, v_pred + + def update(self, **samples): self.iterations += 1 - obs_batch = Tensor(obs_batch) - act_batch = Tensor(act_batch) - ret_batch = Tensor(ret_batch) - adv_batch = Tensor(adv_batch) + obs_batch = Tensor(samples['obs']) + act_batch = Tensor(samples['actions']) + ret_batch = Tensor(samples['returns']) + adv_batch = Tensor(samples['advantages']) - loss = self.policy_train(obs_batch, act_batch, adv_batch, ret_batch) + (loss, loss_a, loss_e, loss_c, v_pred), grads = self.grad_fn(obs_batch, act_batch, adv_batch, ret_batch) + self.optimizer(grads) - # Logger - lr = self.scheduler(self.iterations).asnumpy() + self.scheduler.step() + lr = self.scheduler.get_last_lr()[0] info = { "total-loss": loss.asnumpy(), - "learning_rate": lr + "actor-loss": loss_a.asnumpy(), + "critic-loss": loss_c.asnumpy(), + "entropy": loss_e.asnumpy(), + "learning_rate": lr.asnumpy(), + "predict_value": v_pred.mean().asnumpy(), } return info diff --git a/xuance/mindspore/learners/policy_gradient/pg_learner.py b/xuance/mindspore/learners/policy_gradient/pg_learner.py index c9e67b5d..83c07f67 100644 --- a/xuance/mindspore/learners/policy_gradient/pg_learner.py +++ b/xuance/mindspore/learners/policy_gradient/pg_learner.py @@ -9,46 +9,44 @@ class PG_Learner(Learner): - class PolicyNetWithLossCell(Module): - def __init__(self, backbone, ent_coef): - super(PG_Learner.PolicyNetWithLossCell, self).__init__(auto_prefix=False) - self._backbone = backbone - self._ent_coef = ent_coef - self._mean = ms.ops.ReduceMean(keep_dims=True) - - def construct(self, x, a, r): - _, act_probs = self._backbone(x) - log_prob = self._backbone.actor.log_prob(value=a, probs=act_probs) - loss_a = -self._mean(r * log_prob) - loss_e = self._mean(self._backbone.actor.entropy(probs=act_probs)) - loss = loss_a - self._ent_coef * loss_e - return loss - def __init__(self, config: Namespace, policy: Module): super(PG_Learner, self).__init__(config, policy) - self.optimizer = ms.nn.Adam(params=policy.trainable_params(), learning_rate=self.config.learning_rate, eps=1e-5) - # define mindspore trainer - self.loss_net = self.PolicyNetWithLossCell(policy, self.ent_coef) - # self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer) - self.policy_train = TrainOneStepCellWithGradClip(self.loss_net, optimizer, - clip_type=clip_type, clip_value=clip_grad) - self.policy_train.set_train() - - def update(self, obs_batch, act_batch, ret_batch): + self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5) + self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5, + total_iters=self.config.running_steps) + self.ent_coef = config.ent_coef + self._mean = ms.ops.ReduceMean(keep_dims=True) + # Get gradient function + self.grad_fn = ms.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True) + self.policy.set_train() + + def forward_fn(self, x, a, r): + _, a_dist, _ = self.policy(x) + log_prob = a_dist.log_prob(a) + loss_a = -self._mean(r * log_prob) + loss_e = self._mean(a_dist.entropy()) + loss = loss_a - self.ent_coef * loss_e + return loss, loss_a, loss_e + + def update(self, **samples): self.iterations += 1 - obs_batch = Tensor(obs_batch) - act_batch = Tensor(act_batch) - ret_batch = Tensor(ret_batch) + obs_batch = Tensor(samples['obs']) + act_batch = Tensor(samples['actions']) + ret_batch = Tensor(samples['returns']) - loss = self.policy_train(obs_batch, act_batch, ret_batch) + (loss, loss_a, loss_e), grads = self.grad_fn(obs_batch, act_batch, ret_batch) + self.optimizer(grads) - lr = self.scheduler(self.iterations).asnumpy() + self.scheduler.step() + lr = self.scheduler.get_last_lr()[0] info = { "total-loss": loss.asnumpy(), - "learning_rate": lr + "actor-loss": loss_a.asnumpy(), + "entropy": loss_e.asnumpy(), + "learning_rate": lr.asnumpy(), } return info diff --git a/xuance/mindspore/learners/policy_gradient/ppg_learner.py b/xuance/mindspore/learners/policy_gradient/ppg_learner.py index a29e1c6b..a2995d82 100644 --- a/xuance/mindspore/learners/policy_gradient/ppg_learner.py +++ b/xuance/mindspore/learners/policy_gradient/ppg_learner.py @@ -7,93 +7,107 @@ from xuance.mindspore.learners import Learner from argparse import Namespace from xuance.mindspore.utils.operations import merge_distributions -from mindspore.nn.probability.distribution import Categorical +from mindspore.nn import MSELoss class PPG_Learner(Learner): - class PolicyNetWithLossCell(Module): - def __init__(self, backbone, ent_coef, kl_beta, clip_range, loss_fn): - super(PPG_Learner.PolicyNetWithLossCell, self).__init__(auto_prefix=False) - self._backbone = backbone - self._ent_coef = ent_coef - self._kl_beta = kl_beta - self._clip_range = clip_range - self._loss_fn = loss_fn - self._mean = ms.ops.ReduceMean(keep_dims=True) - self._minimum = ms.ops.Minimum() - self._exp = ms.ops.Exp() - self._categorical = Categorical() - - def construct(self, x, a, r, adv, old_log, old_dist_logits, v, update_type): - loss = 0 - if update_type == 0: - _, a_dist, _, _ = self._backbone(x) - log_prob = self._categorical.log_prob(a, a_dist) - # ppo-clip core implementations - ratio = self._exp(log_prob - old_log) - surrogate1 = ms.ops.clip_by_value(ratio, 1.0 - self._clip_range, 1.0 + self._clip_range) * adv - surrogate2 = adv * ratio - a_loss = -self._minimum(surrogate1, surrogate2).mean() - entropy = self._categorical.entropy(a_dist) - e_loss = entropy.mean() - loss = a_loss - self._ent_coef * e_loss - elif update_type == 1: - _,_,v_pred,_ = self._backbone(x) - loss = self._loss_fn(v_pred, r) - elif update_type == 2: - _, a_dist, _, aux_v = self._backbone(x) - aux_loss = self._loss_fn(v, aux_v) - kl_loss = self._categorical.kl_loss('Categorical',a_dist, old_dist_logits).mean() - value_loss = self._loss_fn(v,r) - loss = aux_loss + self._kl_beta * kl_loss + value_loss - return loss - def __init__(self, config: Namespace, policy: Module): super(PPG_Learner, self).__init__(config, policy) - self.ent_coef = ent_coef - self.clip_range = clip_range - self.kl_beta = kl_beta + self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5) + self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5, + total_iters=self.config.running_steps) + self.mse_loss = MSELoss() + self.ent_coef = config.ent_coef + self.clip_range = config.clip_range + self.kl_beta = config.kl_beta self.policy_iterations = 0 self.value_iterations = 0 - loss_fn = nn.MSELoss() - # define mindspore trainer - self.loss_net = self.PolicyNetWithLossCell(policy, self.ent_coef, self.kl_beta, self.clip_range, loss_fn) - self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer) - self.policy_train.set_train() + # Get gradient function + self._mean = ms.ops.ReduceMean(keep_dims=True) + self._minimum = ms.ops.Minimum() + self._exp = ms.ops.Exp() + self.grad_fn_policy = ms.value_and_grad(self.forward_fn_policy, None, self.optimizer.parameters, has_aux=True) + self.grad_fn_critic = ms.value_and_grad(self.forward_fn_critic, None, self.optimizer.parameters, has_aux=True) + self.grad_fn_auxiliary = ms.value_and_grad(self.forward_fn_auxiliary, None, self.optimizer.parameters, + has_aux=True) + self.policy.set_train() - def update(self, obs_batch, act_batch, ret_batch, adv_batch, old_dists, update_type): + def forward_fn_policy(self, obs_batch, act_batch, adv_batch, old_logp_batch): + _, a_dist, _, _ = self.policy(obs_batch) + log_prob = a_dist.log_prob(act_batch) + # ppo-clip core implementations + ratio = self._exp(log_prob - old_logp_batch) + surrogate1 = ms.ops.clip_by_value(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range) * adv_batch + surrogate2 = adv_batch * ratio + a_loss = -self._minimum(surrogate1, surrogate2).mean() + e_loss = a_dist.entropy().mean() + loss = a_loss - self.ent_coef * e_loss + return loss, a_loss, e_loss, ratio + + def forward_fn_critic(self, obs_batch, ret_batch): + _, _, v_pred, _ = self.policy(obs_batch) + loss = self.mse_loss(v_pred, ret_batch) + return loss, v_pred + + def forward_fn_auxiliary(self, obs_batch, ret_batch, old_dist): + _, a_dist, v, aux_v = self.policy(obs_batch) + aux_loss = self.mse_loss(v, aux_v) + kl_loss = self._categorical.kl_loss('Categorical', a_dist, old_dist).mean() + value_loss = self.mse_loss(v, ret_batch) + loss = aux_loss + self.kl_beta * kl_loss + value_loss + return loss, v + + def update_policy(self, **samples): self.iterations += 1 - info = {} - obs_batch = Tensor(obs_batch) - act_batch = Tensor(act_batch) - ret_batch = Tensor(ret_batch) - adv_batch = Tensor(adv_batch) - old_dist = merge_distributions(old_dists) + obs_batch = Tensor(samples['obs']) + act_batch = Tensor(samples['actions']) + adv_batch = Tensor(samples['advantages']) + old_dist = merge_distributions(samples['aux_batch']['old_dist']) old_logp_batch = old_dist.log_prob(act_batch) - _, _, v, _ = self.policy(obs_batch) + (loss, a_loss, e_loss, ratio), grads = self.grad_fn_policy(obs_batch, act_batch, adv_batch, old_logp_batch) + self.optimizer(grads) + + self.scheduler.step() + lr = self.scheduler.get_last_lr()[0] + + info = { + "actor-loss": a_loss.asnumpy(), + "entropy": e_loss.asnumpy(), + "learning_rate": lr.asnumpy(), + "clip_ratio": ratio.asnumpy(), + } + self.policy_iterations += 1 - if update_type == 0: - loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type) + return info + + def update_critic(self, **samples): + obs_batch = Tensor(samples['obs']) + ret_batch = Tensor(samples['returns']) - lr = self.scheduler(self.iterations).asnumpy() - # self.writer.add_scalar("actor-loss", self.loss_net.loss_a.asnumpy(), self.iterations) - # self.writer.add_scalar("entropy", self.loss_net.loss_e.asnumpy(), self.iterations) - info["total-loss"] = loss.asnumpy() - info["learning_rate"] = lr - self.policy_iterations += 1 - - elif update_type == 1: - loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type) + (loss, v_pred), grads = self.grad_fn_critic(obs_batch, ret_batch) + self.optimizer(grads) - info["critic-loss"] = loss.asnumpy() - self.value_iterations += 1 - - elif update_type == 2: - loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type) + info = { + "critic-loss": loss.asnumpy() + } + self.value_iterations += 1 + return info - info["kl-loss"] = loss.asnumpy() + def update_auxiliary(self, **samples): + obs_batch = samples['obs'] + ret_batch = Tensor(samples['returns']) + old_dist = merge_distributions(samples['aux_batch']['old_dist']) + (loss, v), grads = self.grad_fn_auxiliary(obs_batch, ret_batch, old_dist) + self.optimizer(grads) + + info = { + "kl-loss": loss.asnumpy() + } return info + + def update(self, *args): + return \ No newline at end of file diff --git a/xuance/mindspore/policies/categorical.py b/xuance/mindspore/policies/categorical.py index 43e31614..d510f124 100644 --- a/xuance/mindspore/policies/categorical.py +++ b/xuance/mindspore/policies/categorical.py @@ -1,6 +1,4 @@ import mindspore as ms -import mindspore.nn as nn -import numpy as np from copy import deepcopy from gym.spaces import Discrete from xuance.common import Sequence, Optional, Callable, Union @@ -11,12 +9,6 @@ from .core import BasicQhead, CriticNet -def _init_layer(layer, gain=np.sqrt(2), bias=0.0): - nn.init.orthogonal_(layer.weight, gain=gain) - nn.init.constant_(layer.bias, bias) - return layer - - class ActorPolicy(Module): """ Actor for stochastic policy with categorical distributions. (Discrete action space) @@ -48,7 +40,7 @@ def __init__(self, def construct(self, observation: Tensor): outputs = self.representation(observation) a = self.actor(outputs['state']) - return outputs, a + return outputs, a, None class ActorCriticPolicy(Module): diff --git a/xuance/mindspore/policies/core.py b/xuance/mindspore/policies/core.py index 06eff5ce..53c3b9d9 100644 --- a/xuance/mindspore/policies/core.py +++ b/xuance/mindspore/policies/core.py @@ -3,10 +3,10 @@ from xuance.common import Sequence, Optional, Callable, Union from xuance.mindspore import Tensor, Module from xuance.mindspore.utils import ModuleType, mlp_block, gru_block, lstm_block -from mindspore.nn.probability.distribution import Categorical, Normal +from xuance.mindspore.utils import CategoricalDistribution, DiagGaussianDistribution, ActivatedDiagGaussianDistribution -class BasicQhead(nn.Cell): +class BasicQhead(Module): def __init__(self, state_dim: int, action_dim: int, @@ -27,7 +27,7 @@ def construct(self, x: ms.tensor): return self.model(x) -class DuelQhead(nn.Cell): +class DuelQhead(Module): def __init__(self, state_dim: int, action_dim: int, @@ -63,7 +63,7 @@ def construct(self, x: ms.tensor): return q -class C51Qhead(nn.Cell): +class C51Qhead(Module): def __init__(self, state_dim: int, action_dim: int, @@ -91,7 +91,7 @@ def construct(self, x: ms.tensor): return dist_probs -class QRDQNhead(nn.Cell): +class QRDQNhead(Module): def __init__(self, state_dim: int, action_dim: int, @@ -116,7 +116,7 @@ def construct(self, x: ms.tensor): return self.model(x).view(-1, self.action_dim, self.atom_num) -class BasicRecurrent(nn.Cell): +class BasicRecurrent(Module): def __init__(self, **kwargs): super(BasicRecurrent, self).__init__() self.lstm = False @@ -149,7 +149,7 @@ def construct(self, x: ms.tensor, h: ms.tensor, c: ms.tensor = None): return hn, self.model(output) -class ActorNet(nn.Cell): +class ActorNet(Module): def __init__(self, state_dim: int, action_dim: int, @@ -169,53 +169,52 @@ def construct(self, x: ms.tensor): return self.model(x) -class CategoricalActorNet(nn.Cell): - class Sample(nn.Cell): - def __init__(self): - super(ActorNet.Sample, self).__init__() - self._dist = Categorical(dtype=ms.float32) - - def construct(self, probs: ms.tensor): - return self._dist.sample(probs=probs).astype("int32") - - class LogProb(nn.Cell): - def __init__(self): - super(ActorNet.LogProb, self).__init__() - self._dist = Categorical(dtype=ms.float32) - - def construct(self, value, probs): - return self._dist._log_prob(value=value, probs=probs) - - class Entropy(nn.Cell): - def __init__(self): - super(ActorNet.Entropy, self).__init__() - self._dist = Categorical(dtype=ms.float32) - - def construct(self, probs): - return self._dist.entropy(probs=probs) +class CategoricalActorNet(Module): + """ + The actor network for categorical policy, which outputs a distribution over all discrete actions. + Args: + state_dim (int): The input state dimension. + action_dim (int): The dimension of continuous action space. + hidden_sizes (Sequence[int]): List of hidden units for fully connect layers. + normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. + initialize (Optional[Callable[..., Tensor]]): The parameters initializer. + activation (Optional[ModuleType]): The activation function for each layer. + """ def __init__(self, state_dim: int, action_dim: int, hidden_sizes: Sequence[int], normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., ms.Tensor]] = None, - activation: Optional[ModuleType] = None - ): - super(ActorNet, self).__init__() + activation: Optional[ModuleType] = None): + super(CategoricalActorNet, self).__init__() layers = [] input_shape = (state_dim,) for h in hidden_sizes: mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize) layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim, None, nn.Softmax, None)[0]) + layers.extend(mlp_block(input_shape[0], action_dim, None, None, None)[0]) self.model = nn.SequentialCell(*layers) - self.sample = self.Sample() - self.log_prob = self.LogProb() - self.entropy = self.Entropy() + self.softmax = nn.Softmax() + self.dist = CategoricalDistribution(action_dim) - def construct(self, x: ms.Tensor): - return self.model(x) + def construct(self, x: Tensor, avail_actions: Optional[Tensor] = None): + """ + Returns the stochastic distribution over all discrete actions. + Parameters: + x (Tensor): The input tensor. + avail_actions (Optional[Tensor]): The actions mask values when use actions mask, default is None. + + Returns: + self.dist: CategoricalDistribution(action_dim), a distribution over all discrete actions. + """ + logits = self.model(x) + if avail_actions is not None: + logits[avail_actions == 0] = -1e10 + probs = self.softmax(logits) + self.dist.set_param(probs=probs) + return self.dist class CategoricalActorNet_SAC(CategoricalActorNet): @@ -260,66 +259,53 @@ def forward(self, x: Tensor, avail_actions: Optional[Tensor] = None): return self.dist -class GaussianActorNet(nn.Cell): - class Sample(nn.Cell): - def __init__(self, log_std): - super(ActorNet.Sample, self).__init__() - self._dist = Normal(dtype=ms.float32) - self.logstd = log_std - self._exp = ms.ops.Exp() - - def construct(self, mean: ms.tensor): - return self._dist.sample(mean=mean, sd=self._exp(self.logstd)) - - class LogProb(nn.Cell): - def __init__(self, log_std): - super(ActorNet.LogProb, self).__init__() - self._dist = Normal(dtype=ms.float32) - self.logstd = log_std - self._exp = ms.ops.Exp() - self._sum = ms.ops.ReduceSum(keep_dims=False) - - def construct(self, value: ms.tensor, probs: ms.tensor): - return self._sum(self._dist.log_prob(value, probs, self._exp(self.logstd)), -1) - - class Entropy(nn.Cell): - def __init__(self, log_std): - super(ActorNet.Entropy, self).__init__() - self._dist = Normal(dtype=ms.float32) - self.logstd = log_std - self._exp = ms.ops.Exp() - self._sum = ms.ops.ReduceSum(keep_dims=False) - - def construct(self, probs: ms.tensor): - return self._sum(self._dist.entropy(probs, self._exp(self.logstd)), -1) +class GaussianActorNet(Module): + """ + The actor network for Gaussian policy, which outputs a distribution over the continuous action space. + Args: + state_dim (int): The input state dimension. + action_dim (int): The dimension of continuous action space. + hidden_sizes (Sequence[int]): List of hidden units for fully connect layers. + normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. + initialize (Optional[Callable[..., Tensor]]): The parameters initializer. + activation (Optional[ModuleType]): The activation function for each layer. + activation_action (Optional[ModuleType]): The activation of final layer to bound the actions. + """ def __init__(self, state_dim: int, action_dim: int, hidden_sizes: Sequence[int], normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., ms.Tensor]] = None, - activation: Optional[ModuleType] = None): - super(ActorNet, self).__init__() + activation: Optional[ModuleType] = None, + activation_action: Optional[ModuleType] = None): + super(GaussianActorNet, self).__init__() layers = [] input_shape = (state_dim,) for h in hidden_sizes: mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize) layers.extend(mlp) - layers.extend(mlp_block(input_shape[0], action_dim, None, None, initialize)[0]) + layers.extend(mlp_block(input_shape[0], action_dim, None, activation_action, initialize)[0]) self.mu = nn.SequentialCell(*layers) self._ones = ms.ops.Ones() self.logstd = ms.Parameter(-self._ones((action_dim,), ms.float32)) - # define the distribution methods - self.sample = self.Sample(self.logstd) - self.log_prob = self.LogProb(self.logstd) - self.entropy = self.Entropy(self.logstd) + self.dist = DiagGaussianDistribution(action_dim) def construct(self, x: ms.Tensor): - return self.mu(x) + """ + Returns the stochastic distribution over the continuous action space. + Parameters: + x (Tensor): The input tensor. + + Returns: + self.dist: A distribution over the continuous action space. + """ + self.dist.set_param(self.mu(x), self.logstd.exp()) + return self.dist -class CriticNet(nn.Cell): +class CriticNet(Module): def __init__(self, state_dim: int, hidden_sizes: Sequence[int], @@ -369,7 +355,7 @@ def construct(self, x: Tensor): return mu, std -class VDN_mixer(nn.Cell): +class VDN_mixer(Module): def __init__(self): super(VDN_mixer, self).__init__() self._sum = ms.ops.ReduceSum(keep_dims=False) @@ -378,7 +364,7 @@ def construct(self, values_n, states=None): return self._sum(values_n, 1) -class QMIX_mixer(nn.Cell): +class QMIX_mixer(Module): def __init__(self, dim_state, dim_hidden, dim_hypernet_hidden, n_agents): super(QMIX_mixer, self).__init__() self.dim_state = dim_state @@ -422,7 +408,7 @@ def construct(self, values_n, states): return q_tot -class QMIX_FF_mixer(nn.Cell): +class QMIX_FF_mixer(Module): def __init__(self, dim_state, dim_hidden, n_agents): super(QMIX_FF_mixer, self).__init__() self.dim_state = dim_state @@ -452,7 +438,7 @@ def construct(self, values_n, states): return q_tot -class QTRAN_base(nn.Cell): +class QTRAN_base(Module): def __init__(self, dim_state, dim_action, dim_hidden, n_agents, dim_utility_hidden): super(QTRAN_base, self).__init__() self.dim_state = dim_state diff --git a/xuance/mindspore/policies/gaussian.py b/xuance/mindspore/policies/gaussian.py index 44134f32..907aa939 100644 --- a/xuance/mindspore/policies/gaussian.py +++ b/xuance/mindspore/policies/gaussian.py @@ -1,30 +1,40 @@ -import mindspore as ms -import mindspore.nn as nn -import numpy as np -from xuance.common import Sequence, Optional, Callable, Union from copy import deepcopy from gym.spaces import Box -from xuance.torch import Module, Tensor -from xuance.torch.utils import ModuleType +from xuance.common import Sequence, Optional, Callable, Union +from xuance.mindspore import Module, Tensor +from xuance.mindspore.utils import ModuleType from .core import GaussianActorNet as ActorNet from .core import CriticNet, GaussianActorNet_SAC class ActorPolicy(Module): + """ + Actor for stochastic policy with Gaussian distributions. (Continuous action space) + + Args: + action_space (Box): The continuous action space. + representation (Module): The representation module. + actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. + normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. + initialize (Optional[Callable[..., Tensor]]): The parameters initializer. + activation (Optional[ModuleType]): The activation function for each layer. + activation_action (Optional[ModuleType]): The activation of final layer to bound the actions. + """ def __init__(self, action_space: Box, - representation: ModuleType, + representation: Module, actor_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, - activation: Optional[ModuleType] = None): - assert isinstance(action_space, Box) + activation: Optional[ModuleType] = None, + activation_action: Optional[ModuleType] = None, + fixed_std: bool = True): super(ActorPolicy, self).__init__() self.action_dim = action_space.shape[0] self.representation = representation self.representation_info_shape = self.representation.output_shapes self.actor = ActorNet(representation.output_shapes['state'][0], self.action_dim, actor_hidden_size, - normalize, initialize, activation) + normalize, initialize, activation, activation_action) def construct(self, observation: Tensor): outputs = self.representation(observation) @@ -33,26 +43,49 @@ def construct(self, observation: Tensor): class ActorCriticPolicy(Module): + """ + Actor-Critic for stochastic policy with Gaussian distributions. (Continuous action space) + + Args: + action_space (Box): The continuous action space. + representation (Module): The representation module. + actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network. + critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network. + normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs. + initialize (Optional[Callable[..., Tensor]]): The parameters initializer. + activation (Optional[ModuleType]): The activation function for each layer. + activation_action (Optional[ModuleType]): The activation of final layer to bound the actions. + """ def __init__(self, action_space: Box, - representation: ModuleType, + representation: Module, actor_hidden_size: Sequence[int] = None, critic_hidden_size: Sequence[int] = None, normalize: Optional[ModuleType] = None, initialize: Optional[Callable[..., Tensor]] = None, - activation: Optional[ModuleType] = None - ): - assert isinstance(action_space, Box) + activation: Optional[ModuleType] = None, + activation_action: Optional[ModuleType] = None): super(ActorCriticPolicy, self).__init__() self.action_dim = action_space.shape[0] self.representation = representation self.representation_info_shape = self.representation.output_shapes self.actor = ActorNet(representation.output_shapes['state'][0], self.action_dim, actor_hidden_size, - normalize, initialize, activation) + normalize, initialize, activation, activation_action) self.critic = CriticNet(representation.output_shapes['state'][0], critic_hidden_size, normalize, initialize, activation) def construct(self, observation: Tensor): + """ + Returns the hidden states, action distribution, and values. + + Parameters: + observation: The original observation of agent. + + Returns: + outputs: The outputs of representation. + a_dist: The distribution of actions output by actor. + value: The state values output by critic. + """ outputs = self.representation(observation) a = self.actor(outputs['state']) v = self.critic(outputs['state']) @@ -62,7 +95,7 @@ def construct(self, observation: Tensor): class SACPolicy(Module): def __init__(self, action_space: Box, - representation: ModuleType, + representation: Module, actor_hidden_size: Sequence[int], critic_hidden_size: Sequence[int], initialize: Optional[Callable[..., Tensor]] = None, diff --git a/xuance/mindspore/utils/__init__.py b/xuance/mindspore/utils/__init__.py index 631e3b61..4cdc7550 100644 --- a/xuance/mindspore/utils/__init__.py +++ b/xuance/mindspore/utils/__init__.py @@ -6,8 +6,8 @@ from .distributions import ( Distribution, CategoricalDistribution, - # DiagGaussianDistribution, - # ActivatedDiagGaussianDistribution + DiagGaussianDistribution, + ActivatedDiagGaussianDistribution ) from .operations import (update_linear_decay, set_seed, get_flat_grad, get_flat_params, assign_from_flat_grads, assign_from_flat_params, split_distributions, merge_distributions) diff --git a/xuance/mindspore/utils/operations.py b/xuance/mindspore/utils/operations.py index bef8e900..7d3d8cb0 100644 --- a/xuance/mindspore/utils/operations.py +++ b/xuance/mindspore/utils/operations.py @@ -52,11 +52,11 @@ def split_distributions(distribution): _unsqueeze = ExpandDims() return_list = [] if isinstance(distribution, CategoricalDistribution): - shape = distribution.logits.shape - logits = distribution.logits.view(-1,shape[-1]) - for logit in logits: - dist = CategoricalDistribution(logits.shape[-1]) - dist.set_param(_unsqueeze(logit, 0)) + shape = distribution.probs.shape + probs = distribution.probs.view(-1,shape[-1]) + for prob in probs: + dist = CategoricalDistribution(probs.shape[-1]) + dist.set_param(_unsqueeze(prob, 0)) return_list.append(dist) else: raise NotImplementedError @@ -65,10 +65,10 @@ def split_distributions(distribution): def merge_distributions(distribution_list): if isinstance(distribution_list[0], CategoricalDistribution): - logits = ms.ops.concat([dist.logits for dist in distribution_list], 0) - action_dim = logits.shape[-1] + probs = ms.ops.concat([dist.probs for dist in distribution_list], 0) + action_dim = probs.shape[-1] dist = CategoricalDistribution(action_dim) - dist.set_param(logits) + dist.set_param(probs) return dist else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/xuance/torch/learners/policy_gradient/ppg_learner.py b/xuance/torch/learners/policy_gradient/ppg_learner.py index 75689c0a..6ac539b1 100644 --- a/xuance/torch/learners/policy_gradient/ppg_learner.py +++ b/xuance/torch/learners/policy_gradient/ppg_learner.py @@ -99,5 +99,5 @@ def update_auxiliary(self, **samples): } return info - def update(self): - pass + def update(self, *args): + return