Skip to content

Commit

Permalink
mindspore: policy gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Aug 30, 2024
1 parent 5d29d92 commit 5679b6b
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 265 deletions.
6 changes: 3 additions & 3 deletions xuance/mindspore/agents/core/on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def action(self, observations: np.ndarray,
"""
_, policy_dists, values = self.policy(observations)
actions = policy_dists.stochastic_sample()
log_pi = policy_dists.log_prob(actions).detach().cpu().numpy() if return_logpi else None
log_pi = policy_dists.log_prob(actions).asnumpy() if return_logpi else None
dists = split_distributions(policy_dists) if return_dists else None
actions = actions.detach().cpu().numpy()
actions = actions.asnumpy()
if values is None:
values = 0
else:
values = values.detach().cpu().numpy()
values = values.asnumpy()
return {"actions": actions, "values": values, "dists": dists, "log_pi": log_pi}

def get_aux_info(self, policy_output: dict = None):
Expand Down
74 changes: 36 additions & 38 deletions xuance/mindspore/learners/policy_gradient/a2c_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,55 @@
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.nn import MSELoss


class A2C_Learner(Learner):
class ACNetWithLossCell(Module):
def __init__(self, backbone, ent_coef, vf_coef):
super(A2C_Learner.ACNetWithLossCell, self).__init__()
self._backbone = backbone
self._mean = ms.ops.ReduceMean(keep_dims=True)
self._loss_c = nn.MSELoss()
self._ent_coef = ent_coef
self._vf_coef = vf_coef

def construct(self, x, a, adv, r):
_, act_probs, v_pred = self._backbone(x)
log_prob = self._backbone.actor.log_prob(value=a, probs=act_probs)
loss_a = -self._mean(adv * log_prob)
loss_c = self._loss_c(logits=v_pred, labels=r)
loss_e = self._mean(self._backbone.actor.entropy(probs=act_probs))
loss = loss_a - self._ent_coef * loss_e + self._vf_coef * loss_c

return loss

def __init__(self,
config: Namespace,
policy: Module):
super(A2C_Learner, self).__init__(config, policy)
self.vf_coef = vf_coef
self.ent_coef = ent_coef
self.clip_grad = clip_grad
# define mindspore trainer
self.loss_net = self.ACNetWithLossCell(policy, self.ent_coef, self.vf_coef)
# self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer)
self.policy_train = TrainOneStepCellWithGradClip(self.loss_net, optimizer,
clip_type=clip_type, clip_value=clip_grad)
self.policy_train.set_train()

def update(self, obs_batch, act_batch, ret_batch, adv_batch):
self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5)
self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5,
total_iters=self.config.running_steps)
self.mse_loss = MSELoss()
self._mean = ms.ops.ReduceMean(keep_dims=True)
self.vf_coef = config.vf_coef
self.ent_coef = config.ent_coef
# Get gradient function
self.grad_fn = ms.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True)
self.policy.set_train()

def forward_fn(self, x, a, adv, r):
_, a_dist, v_pred = self.policy(x)
log_prob = a_dist.log_prob(a)
loss_a = -self._mean(adv * log_prob)
loss_c = self.mse_loss(logits=v_pred, labels=r)
loss_e = self._mean(a_dist.entropy())
loss = loss_a - self.ent_coef * loss_e + self.vf_coef * loss_c

return loss, loss_a, loss_e, loss_c, v_pred

def update(self, **samples):
self.iterations += 1
obs_batch = Tensor(obs_batch)
act_batch = Tensor(act_batch)
ret_batch = Tensor(ret_batch)
adv_batch = Tensor(adv_batch)
obs_batch = Tensor(samples['obs'])
act_batch = Tensor(samples['actions'])
ret_batch = Tensor(samples['returns'])
adv_batch = Tensor(samples['advantages'])

loss = self.policy_train(obs_batch, act_batch, adv_batch, ret_batch)
(loss, loss_a, loss_e, loss_c, v_pred), grads = self.grad_fn(obs_batch, act_batch, adv_batch, ret_batch)
self.optimizer(grads)

# Logger
lr = self.scheduler(self.iterations).asnumpy()
self.scheduler.step()
lr = self.scheduler.get_last_lr()[0]

info = {
"total-loss": loss.asnumpy(),
"learning_rate": lr
"actor-loss": loss_a.asnumpy(),
"critic-loss": loss_c.asnumpy(),
"entropy": loss_e.asnumpy(),
"learning_rate": lr.asnumpy(),
"predict_value": v_pred.mean().asnumpy(),
}

return info
58 changes: 28 additions & 30 deletions xuance/mindspore/learners/policy_gradient/pg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,44 @@


class PG_Learner(Learner):
class PolicyNetWithLossCell(Module):
def __init__(self, backbone, ent_coef):
super(PG_Learner.PolicyNetWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._ent_coef = ent_coef
self._mean = ms.ops.ReduceMean(keep_dims=True)

def construct(self, x, a, r):
_, act_probs = self._backbone(x)
log_prob = self._backbone.actor.log_prob(value=a, probs=act_probs)
loss_a = -self._mean(r * log_prob)
loss_e = self._mean(self._backbone.actor.entropy(probs=act_probs))
loss = loss_a - self._ent_coef * loss_e
return loss

def __init__(self,
config: Namespace,
policy: Module):
super(PG_Learner, self).__init__(config, policy)
self.optimizer = ms.nn.Adam(params=policy.trainable_params(), learning_rate=self.config.learning_rate, eps=1e-5)
# define mindspore trainer
self.loss_net = self.PolicyNetWithLossCell(policy, self.ent_coef)
# self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer)
self.policy_train = TrainOneStepCellWithGradClip(self.loss_net, optimizer,
clip_type=clip_type, clip_value=clip_grad)
self.policy_train.set_train()

def update(self, obs_batch, act_batch, ret_batch):
self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5)
self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5,
total_iters=self.config.running_steps)
self.ent_coef = config.ent_coef
self._mean = ms.ops.ReduceMean(keep_dims=True)
# Get gradient function
self.grad_fn = ms.value_and_grad(self.forward_fn, None, self.optimizer.parameters, has_aux=True)
self.policy.set_train()

def forward_fn(self, x, a, r):
_, a_dist, _ = self.policy(x)
log_prob = a_dist.log_prob(a)
loss_a = -self._mean(r * log_prob)
loss_e = self._mean(a_dist.entropy())
loss = loss_a - self.ent_coef * loss_e
return loss, loss_a, loss_e

def update(self, **samples):
self.iterations += 1
obs_batch = Tensor(obs_batch)
act_batch = Tensor(act_batch)
ret_batch = Tensor(ret_batch)
obs_batch = Tensor(samples['obs'])
act_batch = Tensor(samples['actions'])
ret_batch = Tensor(samples['returns'])

loss = self.policy_train(obs_batch, act_batch, ret_batch)
(loss, loss_a, loss_e), grads = self.grad_fn(obs_batch, act_batch, ret_batch)
self.optimizer(grads)

lr = self.scheduler(self.iterations).asnumpy()
self.scheduler.step()
lr = self.scheduler.get_last_lr()[0]

info = {
"total-loss": loss.asnumpy(),
"learning_rate": lr
"actor-loss": loss_a.asnumpy(),
"entropy": loss_e.asnumpy(),
"learning_rate": lr.asnumpy(),
}

return info
156 changes: 85 additions & 71 deletions xuance/mindspore/learners/policy_gradient/ppg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,93 +7,107 @@
from xuance.mindspore.learners import Learner
from argparse import Namespace
from xuance.mindspore.utils.operations import merge_distributions
from mindspore.nn.probability.distribution import Categorical
from mindspore.nn import MSELoss


class PPG_Learner(Learner):
class PolicyNetWithLossCell(Module):
def __init__(self, backbone, ent_coef, kl_beta, clip_range, loss_fn):
super(PPG_Learner.PolicyNetWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._ent_coef = ent_coef
self._kl_beta = kl_beta
self._clip_range = clip_range
self._loss_fn = loss_fn
self._mean = ms.ops.ReduceMean(keep_dims=True)
self._minimum = ms.ops.Minimum()
self._exp = ms.ops.Exp()
self._categorical = Categorical()

def construct(self, x, a, r, adv, old_log, old_dist_logits, v, update_type):
loss = 0
if update_type == 0:
_, a_dist, _, _ = self._backbone(x)
log_prob = self._categorical.log_prob(a, a_dist)
# ppo-clip core implementations
ratio = self._exp(log_prob - old_log)
surrogate1 = ms.ops.clip_by_value(ratio, 1.0 - self._clip_range, 1.0 + self._clip_range) * adv
surrogate2 = adv * ratio
a_loss = -self._minimum(surrogate1, surrogate2).mean()
entropy = self._categorical.entropy(a_dist)
e_loss = entropy.mean()
loss = a_loss - self._ent_coef * e_loss
elif update_type == 1:
_,_,v_pred,_ = self._backbone(x)
loss = self._loss_fn(v_pred, r)
elif update_type == 2:
_, a_dist, _, aux_v = self._backbone(x)
aux_loss = self._loss_fn(v, aux_v)
kl_loss = self._categorical.kl_loss('Categorical',a_dist, old_dist_logits).mean()
value_loss = self._loss_fn(v,r)
loss = aux_loss + self._kl_beta * kl_loss + value_loss
return loss

def __init__(self,
config: Namespace,
policy: Module):
super(PPG_Learner, self).__init__(config, policy)
self.ent_coef = ent_coef
self.clip_range = clip_range
self.kl_beta = kl_beta
self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5)
self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.5,
total_iters=self.config.running_steps)
self.mse_loss = MSELoss()
self.ent_coef = config.ent_coef
self.clip_range = config.clip_range
self.kl_beta = config.kl_beta
self.policy_iterations = 0
self.value_iterations = 0
loss_fn = nn.MSELoss()
# define mindspore trainer
self.loss_net = self.PolicyNetWithLossCell(policy, self.ent_coef, self.kl_beta, self.clip_range, loss_fn)
self.policy_train = nn.TrainOneStepCell(self.loss_net, optimizer)
self.policy_train.set_train()
# Get gradient function
self._mean = ms.ops.ReduceMean(keep_dims=True)
self._minimum = ms.ops.Minimum()
self._exp = ms.ops.Exp()
self.grad_fn_policy = ms.value_and_grad(self.forward_fn_policy, None, self.optimizer.parameters, has_aux=True)
self.grad_fn_critic = ms.value_and_grad(self.forward_fn_critic, None, self.optimizer.parameters, has_aux=True)
self.grad_fn_auxiliary = ms.value_and_grad(self.forward_fn_auxiliary, None, self.optimizer.parameters,
has_aux=True)
self.policy.set_train()

def update(self, obs_batch, act_batch, ret_batch, adv_batch, old_dists, update_type):
def forward_fn_policy(self, obs_batch, act_batch, adv_batch, old_logp_batch):
_, a_dist, _, _ = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
# ppo-clip core implementations
ratio = self._exp(log_prob - old_logp_batch)
surrogate1 = ms.ops.clip_by_value(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range) * adv_batch
surrogate2 = adv_batch * ratio
a_loss = -self._minimum(surrogate1, surrogate2).mean()
e_loss = a_dist.entropy().mean()
loss = a_loss - self.ent_coef * e_loss
return loss, a_loss, e_loss, ratio

def forward_fn_critic(self, obs_batch, ret_batch):
_, _, v_pred, _ = self.policy(obs_batch)
loss = self.mse_loss(v_pred, ret_batch)
return loss, v_pred

def forward_fn_auxiliary(self, obs_batch, ret_batch, old_dist):
_, a_dist, v, aux_v = self.policy(obs_batch)
aux_loss = self.mse_loss(v, aux_v)
kl_loss = self._categorical.kl_loss('Categorical', a_dist, old_dist).mean()
value_loss = self.mse_loss(v, ret_batch)
loss = aux_loss + self.kl_beta * kl_loss + value_loss
return loss, v

def update_policy(self, **samples):
self.iterations += 1
info = {}
obs_batch = Tensor(obs_batch)
act_batch = Tensor(act_batch)
ret_batch = Tensor(ret_batch)
adv_batch = Tensor(adv_batch)
old_dist = merge_distributions(old_dists)
obs_batch = Tensor(samples['obs'])
act_batch = Tensor(samples['actions'])
adv_batch = Tensor(samples['advantages'])
old_dist = merge_distributions(samples['aux_batch']['old_dist'])
old_logp_batch = old_dist.log_prob(act_batch)

_, _, v, _ = self.policy(obs_batch)
(loss, a_loss, e_loss, ratio), grads = self.grad_fn_policy(obs_batch, act_batch, adv_batch, old_logp_batch)
self.optimizer(grads)

self.scheduler.step()
lr = self.scheduler.get_last_lr()[0]

info = {
"actor-loss": a_loss.asnumpy(),
"entropy": e_loss.asnumpy(),
"learning_rate": lr.asnumpy(),
"clip_ratio": ratio.asnumpy(),
}
self.policy_iterations += 1

if update_type == 0:
loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type)
return info

def update_critic(self, **samples):
obs_batch = Tensor(samples['obs'])
ret_batch = Tensor(samples['returns'])

lr = self.scheduler(self.iterations).asnumpy()
# self.writer.add_scalar("actor-loss", self.loss_net.loss_a.asnumpy(), self.iterations)
# self.writer.add_scalar("entropy", self.loss_net.loss_e.asnumpy(), self.iterations)
info["total-loss"] = loss.asnumpy()
info["learning_rate"] = lr
self.policy_iterations += 1

elif update_type == 1:
loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type)
(loss, v_pred), grads = self.grad_fn_critic(obs_batch, ret_batch)
self.optimizer(grads)

info["critic-loss"] = loss.asnumpy()
self.value_iterations += 1

elif update_type == 2:
loss = self.policy_train(obs_batch, act_batch, ret_batch, adv_batch, old_logp_batch, old_dist.logits, v, update_type)
info = {
"critic-loss": loss.asnumpy()
}
self.value_iterations += 1
return info

info["kl-loss"] = loss.asnumpy()
def update_auxiliary(self, **samples):
obs_batch = samples['obs']
ret_batch = Tensor(samples['returns'])
old_dist = merge_distributions(samples['aux_batch']['old_dist'])

(loss, v), grads = self.grad_fn_auxiliary(obs_batch, ret_batch, old_dist)
self.optimizer(grads)

info = {
"kl-loss": loss.asnumpy()
}
return info

def update(self, *args):
return
10 changes: 1 addition & 9 deletions xuance/mindspore/policies/categorical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import mindspore as ms
import mindspore.nn as nn
import numpy as np
from copy import deepcopy
from gym.spaces import Discrete
from xuance.common import Sequence, Optional, Callable, Union
Expand All @@ -11,12 +9,6 @@
from .core import BasicQhead, CriticNet


def _init_layer(layer, gain=np.sqrt(2), bias=0.0):
nn.init.orthogonal_(layer.weight, gain=gain)
nn.init.constant_(layer.bias, bias)
return layer


class ActorPolicy(Module):
"""
Actor for stochastic policy with categorical distributions. (Discrete action space)
Expand Down Expand Up @@ -48,7 +40,7 @@ def __init__(self,
def construct(self, observation: Tensor):
outputs = self.representation(observation)
a = self.actor(outputs['state'])
return outputs, a
return outputs, a, None


class ActorCriticPolicy(Module):
Expand Down
Loading

0 comments on commit 5679b6b

Please sign in to comment.