Skip to content

Commit

Permalink
mindspore
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Aug 28, 2024
1 parent 9c11535 commit 0af39b5
Show file tree
Hide file tree
Showing 40 changed files with 52 additions and 46 deletions.
4 changes: 2 additions & 2 deletions xuance/common/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def get_runner(method,
device = "CPU"
context.set_context(device_target=device)
# context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE) # Graph mode (静态图, 断点无法进入)
# context.set_context(mode=context.PYNATIVE_MODE) # Pynative mode (动态图, 便于调试)
# context.set_context(mode=context.GRAPH_MODE) # Graph mode (静态图, 断点无法进入)
context.set_context(mode=context.PYNATIVE_MODE) # Pynative mode (动态图, 便于调试)
elif dl_toolbox == "tensorflow":
from xuance.tensorflow.runners import REGISTRY_Runner
print("Deep learning toolbox: TensorFlow.")
Expand Down
5 changes: 4 additions & 1 deletion xuance/mindspore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import mindspore as ms
from mindspore import Tensor
from mindspore import Tensor, ops
from mindspore.nn import Cell as Module
from mindspore.nn import CellDict as ModuleDict
from mindspore.experimental import optim
from xuance.mindspore.representations import REGISTRY_Representation
from xuance.mindspore.policies import REGISTRY_Policy
from xuance.mindspore.learners import REGISTRY_Learners
Expand All @@ -12,5 +13,7 @@
"Tensor",
"Module",
"ModuleDict",
"ops",
"optim",
"REGISTRY_Representation", "REGISTRY_Policy", "REGISTRY_Learners", "REGISTRY_Agents"
]
4 changes: 2 additions & 2 deletions xuance/mindspore/learners/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xuance.common import Optional, List, Union
from argparse import Namespace
from operator import itemgetter
from xuance.mindspore import Tensor, Module
from xuance.mindspore import Tensor, Module, optim


class Learner(ABC):
Expand All @@ -20,7 +20,7 @@ def __init__(self,
self.use_actions_mask = config.use_actions_mask if hasattr(config, 'use_actions_mask') else False
self.policy = policy
self.optimizer: Union[dict, list, Optional[ms.nn.Optimizer]] = None
self.scheduler: Union[dict, list, Optional[ms.experimental.optim.lr_scheduler.LRScheduler]] = None
self.scheduler: Union[dict, list, Optional[optim.lr_scheduler.LRScheduler]] = None

self.use_grad_clip = config.use_grad_clip
self.grad_clip_norm = config.grad_clip_norm
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/coma_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/11794
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/dcg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v119/boehmer20a/boehmer20a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Independent Deep Deterministic Policy Gradient (IDDPG)
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/ippo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://arxiv.org/pdf/2103.01955.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/iql_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Independent Q-learning (IQL)
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/isac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Implementation: Pytorch
Creator: Kun Jiang ([email protected])
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Implementation: MindSpore
Trick: Parameter sharing for all agents, with agents' one-hot IDs as actor-critic's inputs.
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/mappo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://arxiv.org/pdf/2103.01955.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/masac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Implementation: Pytorch
Creator: Kun Jiang ([email protected])
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/matd3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Multi-Agent TD3
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/mfac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/mfq_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/qmix_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
http://proceedings.mlr.press/v80/rashid18a/rashid18a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/qtran_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
http://proceedings.mlr.press/v97/son19a/son19a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/vdac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17353
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/vdn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://arxiv.org/pdf/1706.05296.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/multi_agent_rl/wqmix_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
https://proceedings.neurips.cc/paper/2020/file/73a427badebe0e32caa2e1fc7530b7f3-Paper.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import LearnerMAS
from xuance.common import List
from argparse import Namespace
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/a2c_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Advantage Actor-Critic (A2C)
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/ddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1509.02971.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/mpdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1905.04388.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/pdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1810.06394.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/pg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/ppg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v139/cobbe21a/cobbe21a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from xuance.mindspore.utils.operations import merge_distributions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1707.06347.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/ppokl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1707.06347.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/sac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v80/haarnoja18b/haarnoja18b.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.nn.probability.distribution import Normal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1910.07207.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/spdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1810.06394.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/policy_gradient/td3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v80/fujimoto18a/fujimoto18a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace

Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/qlearning_family/c51_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v70/bellemare17a/bellemare17a.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot, Log, BatchMatMul, ExpandDims, Squeeze, ReduceSum, Abs, ReduceMean, clip_by_value
Expand Down
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/qlearning_family/ddqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/10295
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
13 changes: 8 additions & 5 deletions xuance/mindspore/learners/qlearning_family/dqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
Paper link: https://www.nature.com/articles/nature14236
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
from mindspore.nn import MSELoss, Adam
from mindspore.nn import MSELoss


class DQN_Learner(Learner):
def __init__(self,
config: Namespace,
policy: Module):
super(DQN_Learner, self).__init__(config, policy)
self.optimizer = Adam(params=self.policy.trainable_params(), learning_rate=self.config.learning_rate, eps=1e-5)
self.optimizer = optim.Adam(params=self.policy.trainable_params(), lr=self.config.learning_rate, eps=1e-5)
self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.9,
total_iters=self.config.running_steps)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.mse_loss = MSELoss()
Expand Down Expand Up @@ -50,12 +52,13 @@ def update(self, **samples):
if self.iterations % self.sync_frequency == 0:
self.policy.copy_target()

lr = self.scheduler(self.iterations).asnumpy()
self.scheduler.step()
lr = self.scheduler.get_last_lr()[0]

info = {
"Qloss": loss.asnumpy(),
"predictQ": predictQ.mean().asnumpy(),
"learning_rate": lr
"learning_rate": lr.asnumpy(),
}

return info
2 changes: 1 addition & 1 deletion xuance/mindspore/learners/qlearning_family/drqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://cdn.aaai.org/ocs/11673/11673-51288-1-PB.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: http://proceedings.mlr.press/v48/wangf16.pdf
Implementation: MindSpore
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Paper link: https://arxiv.org/pdf/1511.05952.pdf
Implementation: Pytorch
"""
from xuance.mindspore import ms, Module, Tensor
from xuance.mindspore import ms, Module, Tensor, optim
from xuance.mindspore.learners import Learner
from argparse import Namespace
from mindspore.ops import OneHot
Expand Down
Loading

0 comments on commit 0af39b5

Please sign in to comment.