Skip to content

Commit

Permalink
gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Jul 28, 2024
1 parent 0786500 commit e1ac717
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 71 deletions.
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/multi_agent_rl/vdn_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def _build_policy(self):

# build policies
mixer = VDN_mixer()
target_mixer = VDN_mixer()
if self.config.policy == "Mixing_Q_network":
policy = REGISTRY_Policy["Mixing_Q_network"](
action_space=self.action_space, n_agents=self.n_agents, representation=representation,
mixer=mixer, hidden_size=self.config.q_hidden_size,
mixer=[mixer, target_mixer], hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
use_parameter_sharing=self.use_parameter_sharing, model_keys=self.model_keys,
use_rnn=self.use_rnn, rnn=self.config.rnn if self.use_rnn else None)
Expand Down
12 changes: 6 additions & 6 deletions xuance/tensorflow/learners/multi_agent_rl/iddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def update(self, sample):
bs = batch_size

# updata critic
with tf.GradientTape() as tape:
for key in self.model_keys:
_, q_eval = self.policy.Qpolicy(observation=obs, actions=actions, agent_ids=IDs)
_, next_actions = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
_, q_next = self.policy.Qtarget(next_observation=obs_next, next_actions=next_actions, agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
q_eval_a = tf.reshape(q_eval[key], [bs])
q_next_i = tf.reshape(q_next[key], [bs])
Expand All @@ -82,10 +82,10 @@ def update(self, sample):
f"{key}/predictQ": tf.math.reduce_mean(q_eval[key]).numpy()})

# update actor
with tf.GradientTape() as tape:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
_, q_policy = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=IDs)
for key in self.model_keys:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
_, q_policy = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
q_policy_i = tf.reshape(q_policy[key], [bs])
loss_a = -tf.reduce_sum(q_policy_i * mask_values) / tf.reduce_sum(mask_values)
Expand Down
22 changes: 11 additions & 11 deletions xuance/tensorflow/learners/multi_agent_rl/isac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def update(self, sample):
bs = batch_size

# Update critic
with tf.GradientTape() as tape:
_, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs)
_, _, action_q_1, action_q_2 = self.policy.Qaction(observation=obs, actions=actions, agent_ids=IDs)
_, _, next_q = self.policy.Qtarget(next_observation=obs_next, next_actions=actions_next, agent_ids=IDs)
_, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs)
_, _, action_q_1, action_q_2 = self.policy.Qaction(observation=obs, actions=actions, agent_ids=IDs)
_, _, next_q = self.policy.Qtarget(next_observation=obs_next, next_actions=actions_next, agent_ids=IDs)

for key in self.model_keys:
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
action_q_1_i, action_q_2_i = tf.reshape(action_q_1[key], [bs]), tf.reshape(action_q_2[key], [bs])
log_pi_next_eval = tf.reshape(log_pi_next[key], [bs])
Expand Down Expand Up @@ -99,10 +99,10 @@ def update(self, sample):
info.update({f"{key}/loss_critic": loss_c.numpy()})

# Update actor
with tf.GradientTape() as tape:
_, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs)
log_pi_eval_i = {}
for key in self.model_keys:
_, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs)
log_pi_eval_i = {}
for key in self.model_keys:
with tf.GradientTape() as tape:
_, _, policy_q_1, policy_q_2 = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=IDs,
agent_key=key)
log_pi_eval_i[key] = tf.reshape(log_pi_eval[key], [bs])
Expand All @@ -128,8 +128,8 @@ def update(self, sample):

# Automatic entropy tuning
if self.use_automatic_entropy_tuning:
with tf.GradientTape() as tape:
for key in self.model_keys:
for key in self.model_keys:
with tf.GradientTape() as tape:
alpha_loss = -tf.math.reduce_mean(self.alpha_layer[key].log_alpha.value() * (log_pi_eval_i[key] + self.target_entropy[key]))
gradients = tape.gradient(alpha_loss, self.alpha_layer[key].trainable_variables)
self.alpha_optimizer[key].apply_gradients([
Expand Down
32 changes: 16 additions & 16 deletions xuance/tensorflow/learners/multi_agent_rl/maddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ def update(self, sample):
actions_joint = np.concat(itemgetter(*self.agent_keys)(actions), axis=-1).reshape(batch_size, -1)

# update critic
with tf.GradientTape() as tape:
_, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
_, q_eval = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint, agent_ids=IDs)
_, q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)
for key in self.model_keys:
_, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
_, q_eval = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint, agent_ids=IDs)
_, q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
q_eval_a = tf.reshape(q_eval[key], [bs])
q_next_i = tf.reshape(q_next[key], [bs])
Expand All @@ -101,9 +101,9 @@ def update(self, sample):
f"{key}/predictQ": tf.math.reduce_mean(q_eval[key]).numpy()})

# Update actor
with tf.GradientTape() as tape:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
for key in self.model_keys:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
if self.use_parameter_sharing:
act_eval = tf.reshape(tf.reshape(actions_eval[key], [batch_size, self.n_agents, -1]), [batch_size, -1])
Expand Down
40 changes: 20 additions & 20 deletions xuance/tensorflow/learners/multi_agent_rl/masac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ def update(self, sample):
actions_joint = np.stack(itemgetter(*self.agent_keys)(actions), axis=-1).reshape(batch_size, -1)

# Update critic
with tf.GradientTape() as tape:
_, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
_, _, action_q_1, action_q_2 = self.policy.Qaction(joint_observation=obs_joint, joint_actions=actions_joint,
agent_ids=IDs)
_, _, target_q = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)
for key in self.model_keys:
_, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
_, _, action_q_1, action_q_2 = self.policy.Qaction(joint_observation=obs_joint, joint_actions=actions_joint,
agent_ids=IDs)
_, _, target_q = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
action_q_1_i = tf.reshape(action_q_1[key], [bs])
action_q_2_i = tf.reshape(action_q_2[key], [bs])
Expand Down Expand Up @@ -91,10 +91,10 @@ def update(self, sample):
info.update({f"{key}/loss_critic": loss_c.numpy()})

# Update actor
with tf.GradientTape() as tape:
_, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs)
log_pi_eval_i = {}
for key in self.model_keys:
_, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=IDs)
log_pi_eval_i = {}
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
if self.use_parameter_sharing:
actions_eval_joint = tf.reshape(tf.reshape(actions_eval[key], [batch_size, self.n_agents, -1]),
Expand Down Expand Up @@ -130,8 +130,8 @@ def update(self, sample):

# Automatically entropy tuning
if self.use_automatic_entropy_tuning:
with tf.GradientTape() as tape:
for key in self.model_keys:
for key in self.model_keys:
with tf.GradientTape() as tape:
alpha_loss = -tf.math.reduce_mean(
self.alpha_layer[key].log_alpha.value() * (log_pi_eval_i[key] + self.target_entropy[key]))
gradients = tape.gradient(alpha_loss, self.alpha_layer[key].trainable_variables)
Expand Down
34 changes: 17 additions & 17 deletions xuance/tensorflow/learners/multi_agent_rl/matd3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,21 @@ def update(self, sample):
actions_joint = np.concat(itemgetter(*self.agent_keys)(actions), axis=-1).reshape(batch_size, -1)

# Update critic
with tf.GradientTape() as tape:
_, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
q_eval_A, q_eval_B, _ = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint,
agent_ids=IDs)
q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)
_, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = tf.reshape(tf.reshape(actions_next[key], [batch_size, self.n_agents, -1]),
[batch_size, -1])
else:
actions_next_joint = tf.reshape(tf.concat(itemgetter(*self.model_keys)(actions_next), -1),
[batch_size, -1])
q_eval_A, q_eval_B, _ = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint,
agent_ids=IDs)
q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)

for key in self.model_keys:
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
q_eval_A_i, q_eval_B_i = tf.reshape(q_eval_A[key], [bs]), tf.reshape(q_eval_B[key], [bs])
q_next_i = tf.reshape(q_next[key], [bs])
Expand Down Expand Up @@ -104,9 +104,9 @@ def update(self, sample):

# Update actor
if self.iterations % self.actor_update_delay == 0:
with tf.GradientTape() as tape:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
for key in self.model_keys:
_, actions_eval = self.policy(observation=obs, agent_ids=IDs)
for key in self.model_keys:
with tf.GradientTape() as tape:
mask_values = agent_mask[key]
if self.use_parameter_sharing:
act_eval = tf.reshape(tf.reshape(actions_eval[key], [batch_size, self.n_agents, -1]),
Expand Down

0 comments on commit e1ac717

Please sign in to comment.