From 7a3419d674869468743b2338af35d7024511f3a9 Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Wed, 2 Oct 2024 22:44:20 +0800 Subject: [PATCH] distributed training for tf, off policy --- .../agents/qlearning_family/c51_agent.py | 3 +- .../agents/qlearning_family/drqn_agent.py | 3 +- .../agents/qlearning_family/dueldqn_agent.py | 3 +- .../agents/qlearning_family/noisydqn_agent.py | 3 +- .../agents/qlearning_family/qrdqn_agent.py | 3 +- .../learners/qlearning_family/c51_learner.py | 23 ++- .../learners/qlearning_family/ddqn_learner.py | 24 ++- .../learners/qlearning_family/dqn_learner.py | 3 +- .../learners/qlearning_family/drqn_learner.py | 23 ++- .../qlearning_family/dueldqn_learner.py | 24 ++- .../qlearning_family/perdqn_learner.py | 25 ++- .../qlearning_family/qrdqn_learner.py | 23 ++- xuance/tensorflow/policies/deterministic.py | 150 ++++++++++++++---- 13 files changed, 251 insertions(+), 59 deletions(-) diff --git a/xuance/tensorflow/agents/qlearning_family/c51_agent.py b/xuance/tensorflow/agents/qlearning_family/c51_agent.py index d2b8a876..309cb9bb 100644 --- a/xuance/tensorflow/agents/qlearning_family/c51_agent.py +++ b/xuance/tensorflow/agents/qlearning_family/c51_agent.py @@ -32,7 +32,8 @@ def _build_policy(self) -> Module: action_space=self.action_space, atom_num=self.config.atom_num, v_min=self.config.v_min, v_max=self.config.v_max, representation=representation, hidden_size=self.config.q_hidden_size, - normalize=normalize_fn, initialize=initializer, activation=activation) + normalize=normalize_fn, initialize=initializer, activation=activation, + use_distributed_training=self.distributed_training) else: raise AttributeError(f"C51 currently does not support the policy named {self.config.policy}.") diff --git a/xuance/tensorflow/agents/qlearning_family/drqn_agent.py b/xuance/tensorflow/agents/qlearning_family/drqn_agent.py index ba74b563..4eb91ae2 100644 --- a/xuance/tensorflow/agents/qlearning_family/drqn_agent.py +++ b/xuance/tensorflow/agents/qlearning_family/drqn_agent.py @@ -59,7 +59,8 @@ def _build_policy(self) -> Module: action_space=self.action_space, representation=representation, rnn=self.config.rnn, recurrent_hidden_size=self.config.recurrent_hidden_size, recurrent_layer_N=self.config.recurrent_layer_N, dropout=self.config.dropout, - normalize=normalize_fn, initialize=initializer, activation=activation) + normalize=normalize_fn, initialize=initializer, activation=activation, + use_distributed_training=self.distributed_training) else: raise AttributeError( f"{self.config.agent} currently does not support the policy named {self.config.policy}.") diff --git a/xuance/tensorflow/agents/qlearning_family/dueldqn_agent.py b/xuance/tensorflow/agents/qlearning_family/dueldqn_agent.py index 1bd43177..3862f4f6 100644 --- a/xuance/tensorflow/agents/qlearning_family/dueldqn_agent.py +++ b/xuance/tensorflow/agents/qlearning_family/dueldqn_agent.py @@ -30,7 +30,8 @@ def _build_policy(self) -> Module: if self.config.policy == "Duel_Q_network": policy = REGISTRY_Policy["Duel_Q_network"]( action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size, - normalize=normalize_fn, initialize=initializer, activation=activation) + normalize=normalize_fn, initialize=initializer, activation=activation, + use_distributed_training=self.distributed_training) else: raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.") diff --git a/xuance/tensorflow/agents/qlearning_family/noisydqn_agent.py b/xuance/tensorflow/agents/qlearning_family/noisydqn_agent.py index fab4cda3..e5070fe7 100644 --- a/xuance/tensorflow/agents/qlearning_family/noisydqn_agent.py +++ b/xuance/tensorflow/agents/qlearning_family/noisydqn_agent.py @@ -54,7 +54,8 @@ def _build_policy(self) -> Module: if self.config.policy == "Noisy_Q_network": policy = REGISTRY_Policy["Noisy_Q_network"]( action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size, - normalize=normalize_fn, initialize=initializer, activation=activation) + normalize=normalize_fn, initialize=initializer, activation=activation, + use_distributed_training=self.distributed_training) else: raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.") diff --git a/xuance/tensorflow/agents/qlearning_family/qrdqn_agent.py b/xuance/tensorflow/agents/qlearning_family/qrdqn_agent.py index d092bf9f..9cbdf92a 100644 --- a/xuance/tensorflow/agents/qlearning_family/qrdqn_agent.py +++ b/xuance/tensorflow/agents/qlearning_family/qrdqn_agent.py @@ -31,7 +31,8 @@ def _build_policy(self) -> Module: policy = REGISTRY_Policy["QR_Q_network"]( action_space=self.action_space, quantile_num=self.config.quantile_num, representation=representation, hidden_size=self.config.q_hidden_size, - normalize=normalize_fn, initialize=initializer, activation=activation) + normalize=normalize_fn, initialize=initializer, activation=activation, + use_distributed_training=self.distributed_training) else: raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.") diff --git a/xuance/tensorflow/learners/qlearning_family/c51_learner.py b/xuance/tensorflow/learners/qlearning_family/c51_learner.py index 4becdd15..25e9babf 100644 --- a/xuance/tensorflow/learners/qlearning_family/c51_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/c51_learner.py @@ -15,14 +15,22 @@ def __init__(self, policy: Module): super(C51_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency @tf.function - def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): + def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: _, _, evalZ = self.policy(obs_batch) _, targetA, targetZ = self.policy.target(next_batch) @@ -59,6 +67,15 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): return loss + @tf.function + def learn(self, *inputs): + if self.distributed_training: + loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) + else: + loss = self.forward_fn(*inputs) + return loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/learners/qlearning_family/ddqn_learner.py b/xuance/tensorflow/learners/qlearning_family/ddqn_learner.py index 6ff710f0..336b09ec 100644 --- a/xuance/tensorflow/learners/qlearning_family/ddqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/ddqn_learner.py @@ -15,14 +15,22 @@ def __init__(self, policy: Module): super(DDQN_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency @tf.function - def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): + def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: _, _, evalQ = self.policy(obs_batch) _, targetA, targetQ = self.policy.target(next_batch) @@ -49,6 +57,16 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): ]) return predictQ, loss + @tf.function + def learn(self, *inputs): + if self.distributed_training: + predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)) + else: + predictQ, loss = self.forward_fn(*inputs) + return predictQ, loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/learners/qlearning_family/dqn_learner.py b/xuance/tensorflow/learners/qlearning_family/dqn_learner.py index 2166ef3b..d19e53a7 100644 --- a/xuance/tensorflow/learners/qlearning_family/dqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/dqn_learner.py @@ -61,7 +61,8 @@ def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): def learn(self, *inputs): if self.distributed_training: predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) - return predictQ, self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) + return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)) else: predictQ, loss = self.forward_fn(*inputs) return predictQ, loss diff --git a/xuance/tensorflow/learners/qlearning_family/drqn_learner.py b/xuance/tensorflow/learners/qlearning_family/drqn_learner.py index 968ebb4f..3ca2afe6 100644 --- a/xuance/tensorflow/learners/qlearning_family/drqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/drqn_learner.py @@ -15,15 +15,23 @@ def __init__(self, policy: Module): super(DRQN_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency self.n_actions = self.policy.action_dim @tf.function - def learn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch): + def forward_fn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: rnn_hidden = self.policy.init_hidden(batch_size) _, _, evalQ, _ = self.policy(obs_batch[:, 0:-1], *rnn_hidden) @@ -56,6 +64,15 @@ def learn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch): return predictQ, loss + @tf.function + def learn(self, *inputs): + if self.distributed_training: + predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return predictQ, self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None) + else: + predictQ, loss = self.forward_fn(*inputs) + return predictQ, loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/learners/qlearning_family/dueldqn_learner.py b/xuance/tensorflow/learners/qlearning_family/dueldqn_learner.py index 4a894aca..e9715c04 100644 --- a/xuance/tensorflow/learners/qlearning_family/dueldqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/dueldqn_learner.py @@ -15,14 +15,22 @@ def __init__(self, policy: Module): super(DuelDQN_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency @tf.function - def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): + def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: _, _, evalQ = self.policy(obs_batch) _, _, targetQ = self.policy.target(next_batch) @@ -47,6 +55,16 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): ]) return predictQ, loss + @tf.function + def learn(self, *inputs): + if self.distributed_training: + predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)) + else: + predictQ, loss = self.forward_fn(*inputs) + return predictQ, loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/learners/qlearning_family/perdqn_learner.py b/xuance/tensorflow/learners/qlearning_family/perdqn_learner.py index 953d1277..9acdf22a 100644 --- a/xuance/tensorflow/learners/qlearning_family/perdqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/perdqn_learner.py @@ -15,14 +15,22 @@ def __init__(self, policy: Module): super(PerDQN_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency @tf.function - def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): + def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: _, _, evalQ = self.policy(obs_batch) _, _, targetQ = self.policy.target(next_batch) @@ -48,6 +56,17 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): ]) return td_error, predictQ, loss + @tf.function + def learn(self, *inputs): + if self.distributed_training: + td_error, predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, td_error, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)) + else: + td_error, predictQ, loss = self.forward_fn(*inputs) + return td_error, predictQ, loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py b/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py index 1f3a9cce..02a3ef8b 100644 --- a/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py +++ b/xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py @@ -15,14 +15,22 @@ def __init__(self, policy: Module): super(QRDQN_Learner, self).__init__(config, policy) if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips. - self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate) else: - self.optimizer = tk.optimizers.Adam(config.learning_rate) + if self.distributed_training: + with self.policy.mirrored_strategy.scope(): + self.optimizer = tk.optimizers.Adam(config.learning_rate) + else: + self.optimizer = tk.optimizers.Adam(config.learning_rate) self.gamma = config.gamma self.sync_frequency = config.sync_frequency @tf.function - def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): + def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): with tf.GradientTape() as tape: _, _, evalZ = self.policy(obs_batch) _, targetA, targetZ = self.policy.target(next_batch) @@ -50,6 +58,15 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch): ]) return current_quantile, loss + def learn(self, *inputs): + if self.distributed_training: + predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs) + return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None), + self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)) + else: + predictQ, loss = self.forward_fn(*inputs) + return predictQ, loss + def update(self, **samples): self.iterations += 1 obs_batch = samples['obs'] diff --git a/xuance/tensorflow/policies/deterministic.py b/xuance/tensorflow/policies/deterministic.py index 84cb5d0a..1b4605d2 100644 --- a/xuance/tensorflow/policies/deterministic.py +++ b/xuance/tensorflow/policies/deterministic.py @@ -37,6 +37,7 @@ def __init__(self, with self.mirrored_strategy.scope(): self.representation = representation self.target_representation = deepcopy(representation) + self.representation.build((None,) + self.representation.input_shapes) self.representation_info_shape = self.representation.output_shapes self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, normalize, initialize, activation) @@ -48,12 +49,16 @@ def __init__(self, self.representation = representation self.target_representation = deepcopy(representation) self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) - self.target_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) + self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) self.target_Qhead.set_weights(self.eval_Qhead.get_weights()) + @property + def trainable_variables(self): + return self.representation.trainable_variables + self.eval_Qhead.trainable_variables + @tf.function def call(self, observation: Union[Tensor, np.ndarray]): """ @@ -119,14 +124,33 @@ def __init__(self, use_distributed_training: bool = False): super(DuelQnetwork, self).__init__() self.action_dim = action_space.n - self.representation = representation - self.target_representation = deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) - self.target_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) - self.target_Qhead.set_weights(self.eval_Qhead.get_weights()) + + self.use_distributed_training = use_distributed_training + if self.use_distributed_training: + self.mirrored_strategy = tf.distribute.MirroredStrategy() + with self.mirrored_strategy.scope(): + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation.build((None,) + self.representation.input_shapes) + self.representation_info_shape = self.representation.output_shapes + self.eval_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead.set_weights(self.eval_Qhead.get_weights()) + else: + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation_info_shape = self.representation.output_shapes + self.eval_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead = DuelQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead.set_weights(self.eval_Qhead.get_weights()) + + @property + def trainable_variables(self): + return self.representation.trainable_variables + self.eval_Qhead.trainable_variables @tf.function def call(self, observation: Union[np.ndarray, dict], **kwargs): @@ -193,18 +217,36 @@ def __init__(self, use_distributed_training: bool = False): super(NoisyQnetwork, self).__init__() self.action_dim = action_space.n - self.representation = representation - self.target_representation = deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) - self.target_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, hidden_size, - normalize, initialize, activation) + + self.use_distributed_training = use_distributed_training + if self.use_distributed_training: + self.mirrored_strategy = tf.distribute.MirroredStrategy() + with self.mirrored_strategy.scope(): + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation.build((None,) + self.representation.input_shapes) + self.representation_info_shape = self.representation.output_shapes + self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + else: + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation_info_shape = self.representation.output_shapes + self.eval_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) + self.target_Qhead = BasicQhead(self.representation.output_shapes['state'][0], self.action_dim, + hidden_size, normalize, initialize, activation) self.target_Qhead.set_weights(self.eval_Qhead.get_weights()) self.noise_scale = 0.0 self.eval_noise_parameter = [] self.target_noise_parameter = [] + @property + def trainable_variables(self): + return self.representation.trainable_variables + self.eval_Qhead.trainable_variables + def update_noise(self, noisy_bound: float = 0.0): """Updates the noises for network parameters.""" self.eval_noise_parameter = [] @@ -295,17 +337,37 @@ def __init__(self, self.atom_num = atom_num self.v_min = v_min self.v_max = v_max - self.representation = representation - self.target_representation = deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, self.atom_num, - hidden_size, normalize, initialize, activation) - self.target_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, self.atom_num, - hidden_size, normalize, initialize, activation) - self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + + self.use_distributed_training = use_distributed_training + if self.use_distributed_training: + self.mirrored_strategy = tf.distribute.MirroredStrategy() + with self.mirrored_strategy.scope(): + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation.build((None,) + self.representation.input_shapes) + self.representation_info_shape = self.representation.output_shapes + self.eval_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, + self.atom_num, hidden_size, normalize, initialize, activation) + self.target_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, + self.atom_num, hidden_size, normalize, initialize, activation) + self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + else: + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation_info_shape = self.representation.output_shapes + self.eval_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, + self.atom_num, hidden_size, normalize, initialize, activation) + self.target_Zhead = C51Qhead(self.representation.output_shapes['state'][0], self.action_dim, + self.atom_num, hidden_size, normalize, initialize, activation) + self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + self.supports = tf.cast(tf.linspace(self.v_min, self.v_max, self.atom_num), dtype=tf.float32) self.deltaz = (v_max - v_min) / (atom_num - 1) + @property + def trainable_variables(self): + return self.representation.trainable_variables + self.eval_Zhead.trainable_variables + @tf.function def call(self, observation: Union[np.ndarray, dict], **kwargs): """ @@ -376,14 +438,32 @@ def __init__(self, super(QRDQN_Network, self).__init__() self.action_dim = action_space.n self.quantile_num = quantile_num - self.representation = representation - self.target_representation = deepcopy(representation) - self.representation_info_shape = self.representation.output_shapes - self.eval_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, self.quantile_num, - hidden_size, normalize, initialize, activation) - self.target_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, self.quantile_num, - hidden_size, normalize, initialize, activation) - self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + + self.use_distributed_training = use_distributed_training + if self.use_distributed_training: + self.mirrored_strategy = tf.distribute.MirroredStrategy() + with self.mirrored_strategy.scope(): + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation_info_shape = self.representation.output_shapes + self.eval_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, + self.quantile_num, hidden_size, normalize, initialize, activation) + self.target_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, + self.quantile_num, hidden_size, normalize, initialize, activation) + self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + else: + self.representation = representation + self.target_representation = deepcopy(representation) + self.representation_info_shape = self.representation.output_shapes + self.eval_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, + self.quantile_num, hidden_size, normalize, initialize, activation) + self.target_Zhead = QRDQNhead(self.representation.output_shapes['state'][0], self.action_dim, + self.quantile_num, hidden_size, normalize, initialize, activation) + self.target_Zhead.set_weights(self.eval_Zhead.get_weights()) + + @property + def trainable_variables(self): + return self.representation.trainable_variables + self.eval_Zhead.trainable_variables @tf.function def call(self, observation: Union[np.ndarray, dict], **kwargs):