Skip to content

Commit

Permalink
Merge pull request #153 from WilfChen/bugfix
Browse files Browse the repository at this point in the history
adapte to 910B
  • Loading branch information
WilfChen authored Nov 23, 2023
2 parents 86b2e9d + 7b742c3 commit d01d5d8
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion example/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train(episode=options.episode):
if compute_type == mstype.float16 and options.device_target != 'Ascend':
raise ValueError("Fp16 mode is supported by Ascend backend.")

context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE, ascend_config={'precision_mode': 'allow_mix_precision'})
sac_session = SACSession(options.env_yaml, options.algo_yaml)
sac_session.run(class_type=SACTrainer, episode=episode)

Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/gail/gail_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
shapes=[obs_space.shape, action_space.shape],
dtypes=[obs_space.ms_dtype, action_space.ms_dtype])

policy_replay_buffer = UniformReplayBuffer(batch_size=policy_batch_size,
policy_replay_buffer = UniformReplayBuffer(sample_size=policy_batch_size,
capacity=policy_buffer_size,
shapes=[obs_space.shape, action_space.shape, obs_space.shape, (1,)],
types=[obs_space.ms_dtype, action_space.ms_dtype,
Expand Down
6 changes: 3 additions & 3 deletions mindspore_rl/algorithm/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,10 @@ def __init__(self, params):
critic_net1.trainable_params() + critic_net2.trainable_params()
)
critic_optim = nn.Adam(
critic_trainable_params, learning_rate=params["critic_lr"]
critic_trainable_params, learning_rate=params["critic_lr"], eps=1e-5
)
actor_optim = nn.Adam(
actor_net.trainable_params(), learning_rate=params["actor_lr"]
actor_net.trainable_params(), learning_rate=params["actor_lr"], eps=1e-5
)

self.critic_train = nn.TrainOneStepCell(critic_loss_net, critic_optim)
Expand All @@ -449,7 +449,7 @@ def __init__(self, params):
params["alpha_loss_weight"],
actor_net,
)
alpha_optim = nn.Adam([log_alpha], learning_rate=params["alpha_lr"])
alpha_optim = nn.Adam([log_alpha], learning_rate=params["alpha_lr"], eps=1e-5)
self.alpha_train = nn.TrainOneStepCell(alpha_loss_net, alpha_optim)

factor, interval = params["update_factor"], params["update_interval"]
Expand Down
5 changes: 2 additions & 3 deletions mindspore_rl/algorithm/sac/sac_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, msrl, params=None):
super(SACTrainer, self).__init__(msrl)
self.inited = Parameter(Tensor([False], mindspore.bool_), name="init_flag")
self.zero = Tensor([0], mindspore.float32)
self.fill_value = Tensor([10000], mindspore.float32)
self.false = Tensor([False], mindspore.bool_)
self.true = Tensor([True], mindspore.bool_)
self.less = P.Less()
Expand All @@ -47,8 +46,8 @@ def init_training(self):
"""Initialize training"""
state = self.msrl.collect_environment.reset()
done = self.false
i = self.zero
while self.less(i, self.fill_value):
i = Tensor([0], mindspore.int32)
while self.less(i, Tensor([10000], mindspore.int32)):
new_state, action, reward, done = self.msrl.agent_act(trainer.INIT, state)
self.msrl.replay_buffer_insert([state, action, reward, new_state, done])
state = new_state
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/td3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"type": TD3Actor,
"params": actor_params,
"policies": [],
"networks": ["actor_net"],
"networks": ["actor_net", "init_policy"],
},
"learner": {
"number": 1,
Expand Down
23 changes: 17 additions & 6 deletions mindspore_rl/algorithm/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""TD3"""
import numpy as np
import mindspore
import mindspore.nn.probability.distribution as msd
from mindspore import Parameter, Tensor, nn
Expand Down Expand Up @@ -40,11 +41,10 @@ class GaussianNoise(nn.Cell):

def __init__(self, mean, stddev, clip=None):
super().__init__()
self.abs = P.Abs()
self.clip = clip
if self.clip is not None:
self.high_clip = self.abs(Tensor(self.clip))
self.low_clip = -self.high_clip
self.high_clip = Tensor(np.abs(self.clip))
self.low_clip = Tensor(-np.abs(self.clip))
self.normal = msd.Normal(mean, stddev)

def construct(self, actions):
Expand Down Expand Up @@ -164,6 +164,15 @@ def construct(self, observation, action):

return q

class RandomPolicy(nn.Cell):
def __init__(self, action_space_dim):
super(TD3Policy.RandomPolicy, self).__init__()
self.uniform = P.UniformReal()
self.shape = (action_space_dim,)

def construct(self):
return self.uniform(self.shape) * 2 - 1

def __init__(self, params):
self.actor_net = self.TD3ActorNet(
params["state_space_dim"],
Expand Down Expand Up @@ -217,6 +226,7 @@ def __init__(self, params):
params["compute_type"],
name="target_critic_net_2.",
)
self.init_policy = self.RandomPolicy(params['action_space_dim'])


class TD3Actor(Actor):
Expand All @@ -225,6 +235,7 @@ class TD3Actor(Actor):
def __init__(self, params=None):
super().__init__()
self.actor_net = params["actor_net"]
self.init_policy = params["init_policy"]
self.env = params["collect_environment"]
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze()
Expand All @@ -242,7 +253,7 @@ def act(self, phase, params):

def get_action(self, phase, params):
if phase == 1:
actions = Tensor(self.env.action_space.sample(), mindspore.float32)
return self.init_policy()
else:
obs = self.expand_dims(params, 0)
actions = self.actor_net(obs)
Expand Down Expand Up @@ -353,10 +364,10 @@ def __init__(self, params):
# optimizer network
critic_optimizer = nn.Adam(
self.critic_net_1.trainable_params() + self.critic_net_2.trainable_params(),
learning_rate=params["critic_lr"],
learning_rate=params["critic_lr"], eps=1e-5
)
actor_optimizer = nn.Adam(
self.actor_net.trainable_params(), learning_rate=params["actor_lr"]
self.actor_net.trainable_params(), learning_rate=params["actor_lr"], eps=1e-5
)

# target networks and their initializations
Expand Down

0 comments on commit d01d5d8

Please sign in to comment.