Skip to content

Commit

Permalink
Merge pull request #154 from MashiroChen/master
Browse files Browse the repository at this point in the history
910B + GE acc fix
  • Loading branch information
WilfChen authored Nov 24, 2023
2 parents d01d5d8 + f11d379 commit 4e11304
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mindspore_rl/algorithm/maddpg/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def __init__(self, params):
self.actor_net = params.get('actor_net')

# optimizer network
critic_optimizer = nn.Adam(self.critic_net.trainable_params(), learning_rate=params.get('learning_rate'))
actor_optimizer = nn.Adam(self.actor_net.trainable_params(), learning_rate=params.get('learning_rate'))
critic_optimizer = nn.Adam(self.critic_net.trainable_params(), learning_rate=params.get('learning_rate'), eps=1e-5)
actor_optimizer = nn.Adam(self.actor_net.trainable_params(), learning_rate=params.get('learning_rate'), eps=1e-5)

# loss network
self.target_actor_net = params.get('target_actor_net')
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/maddpg/maddpg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train(self, episodes, callbacks=None, ckpt_path=None):
and represent for `loss, rewards, steps, [optional]others.` in order"
)
episode_rewards.append(float(rewards.asnumpy()))
if i % 1000 == 0:
if (i % 1000 == 0) and (i != 0):
print("-----------------------------------------")
# pylint: disable=C0209
print(
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self, params):
trainable_parameter = (
self.critic_net.trainable_params() + self.actor_net.trainable_params()
)
optimizer_ppo = nn.Adam(trainable_parameter, learning_rate=params["lr"])
optimizer_ppo = nn.Adam(trainable_parameter, learning_rate=params["lr"], eps=1e-5)
ppo_loss_net = self.PPOLossCell(
self.actor_net,
self.critic_net,
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/qmix/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def __init__(self, params):
trainable_params = (
self.policy_net.trainable_params() + self.mixer_net.trainable_params()
)
optimizer = nn.Adam(trainable_params, learning_rate=params["lr"])
optimizer = nn.Adam(trainable_params, learning_rate=params["lr"], eps=1e-5)

qmix_loss_cell = self.QMIXLossCell(
params, self.policy_net, self.mixer_net, self.target_mixer_net
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/environment/petting_zoo_mpe_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, params, env_id=0):
observation_space = gym2ms_adapter(list(self._env.observation_spaces.values()))
env_action_space = self._env.action_spaces["agent_0"]
action_space = Space(
(env_action_space.n,), env_action_space.dtype.type, batch_shape=(self._num,)
(env_action_space.n,), np.float32, batch_shape=(self._num,)
)
reward_space = Space((1,), np.float32, batch_shape=(self._num,))
done_space = Space((1,), np.bool_, low=0, high=2, batch_shape=(self._num,))
Expand Down

0 comments on commit 4e11304

Please sign in to comment.