diff --git a/README.md b/README.md index 1750d023f..ef005f4df 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -# Coach +# Warrning: +## This branch of Coach is WIP for migration to tf2 and should not be checked out [![CI](https://img.shields.io/circleci/project/github/NervanaSystems/coach/master.svg)](https://circleci.com/gh/NervanaSystems/workflows/coach/tree/master) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/NervanaSystems/coach/blob/master/LICENSE) diff --git a/benchmarks/clipped_ppo/ant_clipped_ppo_tf2.png b/benchmarks/clipped_ppo/ant_clipped_ppo_tf2.png new file mode 100644 index 000000000..8a6a3fe57 Binary files /dev/null and b/benchmarks/clipped_ppo/ant_clipped_ppo_tf2.png differ diff --git a/benchmarks/clipped_ppo/ant_clipped_ppo_tf2_3.png b/benchmarks/clipped_ppo/ant_clipped_ppo_tf2_3.png new file mode 100644 index 000000000..2b4edcb45 Binary files /dev/null and b/benchmarks/clipped_ppo/ant_clipped_ppo_tf2_3.png differ diff --git a/benchmarks/clipped_ppo/half_cheetah_clipped_ppo_tf2.png b/benchmarks/clipped_ppo/half_cheetah_clipped_ppo_tf2.png new file mode 100644 index 000000000..94acb2f3f Binary files /dev/null and b/benchmarks/clipped_ppo/half_cheetah_clipped_ppo_tf2.png differ diff --git a/benchmarks/clipped_ppo/humanoid_clipped_ppo_tf2.png b/benchmarks/clipped_ppo/humanoid_clipped_ppo_tf2.png new file mode 100644 index 000000000..45f0c47af Binary files /dev/null and b/benchmarks/clipped_ppo/humanoid_clipped_ppo_tf2.png differ diff --git a/benchmarks/clipped_ppo/inverted_double_pendulum_clipped_ppo_tf2.png b/benchmarks/clipped_ppo/inverted_double_pendulum_clipped_ppo_tf2.png new file mode 100644 index 000000000..309bb2056 Binary files /dev/null and b/benchmarks/clipped_ppo/inverted_double_pendulum_clipped_ppo_tf2.png differ diff --git a/benchmarks/clipped_ppo/inverted_pendulum_clipped_ppo_tf2.png b/benchmarks/clipped_ppo/inverted_pendulum_clipped_ppo_tf2.png new file mode 100644 index 000000000..2103140cf Binary files /dev/null and b/benchmarks/clipped_ppo/inverted_pendulum_clipped_ppo_tf2.png differ diff --git a/benchmarks/dqn/breakout_dqn_tf2.png b/benchmarks/dqn/breakout_dqn_tf2.png new file mode 100644 index 000000000..3e0ba57c3 Binary files /dev/null and b/benchmarks/dqn/breakout_dqn_tf2.png differ diff --git a/benchmarks/dqn/pong_dqn_tf2.png b/benchmarks/dqn/pong_dqn_tf2.png new file mode 100644 index 000000000..38b112f00 Binary files /dev/null and b/benchmarks/dqn/pong_dqn_tf2.png differ diff --git a/benchmarks/dqn/space_invaders_dqn_tf2.png b/benchmarks/dqn/space_invaders_dqn_tf2.png new file mode 100644 index 000000000..966ce28f4 Binary files /dev/null and b/benchmarks/dqn/space_invaders_dqn_tf2.png differ diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index cc29f3339..aa6b4dc4a 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -202,10 +202,15 @@ def train_network(self, batch, epochs): 'entropy': [] } - fetches = [self.networks['main'].online_network.output_heads[1].kl_divergence, - self.networks['main'].online_network.output_heads[1].entropy, - self.networks['main'].online_network.output_heads[1].likelihood_ratio, - self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio] + # fetches = [self.networks['main'].online_network.output_heads[1].kl_divergence, + # self.networks['main'].online_network.output_heads[1].entropy, + # self.networks['main'].online_network.output_heads[1].likelihood_ratio, + # self.networks['main'].online_network.output_heads[1].clipped_likelihood_ratio] + + fetches = [(1, 'kl_divergence'), + (1, 'entropy'), + (1, 'likelihood_ratio'), + (1, 'clipped_likelihood_ratio')] # TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on # some of the data diff --git a/rl_coach/architectures/architecture.py b/rl_coach/architectures/architecture.py index 90dbd6eed..2950501c5 100644 --- a/rl_coach/architectures/architecture.py +++ b/rl_coach/architectures/architecture.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,15 +14,14 @@ # limitations under the License. # -from typing import Any, Dict, List, Tuple - import numpy as np - +from typing import Any, Dict, List, Tuple from rl_coach.base_parameters import AgentParameters from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition + class Architecture(object): @staticmethod def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'Architecture': diff --git a/rl_coach/architectures/embedder_parameters.py b/rl_coach/architectures/embedder_parameters.py index 45269ef20..9a17490f2 100644 --- a/rl_coach/architectures/embedder_parameters.py +++ b/rl_coach/architectures/embedder_parameters.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rl_coach/architectures/head_parameters.py b/rl_coach/architectures/head_parameters.py index 1c64b63af..3f8c704d2 100644 --- a/rl_coach/architectures/head_parameters.py +++ b/rl_coach/architectures/head_parameters.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rl_coach/architectures/layers.py b/rl_coach/architectures/layers.py index e295199d8..93fb87126 100644 --- a/rl_coach/architectures/layers.py +++ b/rl_coach/architectures/layers.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rl_coach/architectures/tensorflow_components/__init__.py b/rl_coach/architectures/legacy_tf_components/__init__.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/__init__.py rename to rl_coach/architectures/legacy_tf_components/__init__.py diff --git a/rl_coach/architectures/legacy_tf_components/architecture.py b/rl_coach/architectures/legacy_tf_components/architecture.py new file mode 100644 index 000000000..68420febb --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/architecture.py @@ -0,0 +1,693 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +from typing import Any, List, Tuple, Dict + +import numpy as np +import tensorflow as tf + +from rl_coach.architectures.architecture import Architecture +from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver +from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +from rl_coach.core_types import GradientClippingMethod +from rl_coach.saver import SaverCollection +from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait + + +def variable_summaries(var): + """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" + with tf.name_scope('summaries'): + layer_weight_name = '_'.join(var.name.split('/')[-3:])[:-2] + + with tf.name_scope(layer_weight_name): + mean = tf.reduce_mean(var) + tf.summary.scalar('mean', mean) + with tf.name_scope('stddev'): + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) + tf.summary.scalar('stddev', stddev) + tf.summary.scalar('max', tf.reduce_max(var)) + tf.summary.scalar('min', tf.reduce_min(var)) + tf.summary.histogram('histogram', var) + + +def local_getter(getter, name, *args, **kwargs): + """ + This is a wrapper around the tf.get_variable function which puts the variables in the local variables collection + instead of the global variables collection. The local variables collection will hold variables which are not shared + between workers. these variables are also assumed to be non-trainable (the optimizer does not apply gradients to + these variables), but we can calculate the gradients wrt these variables, and we can update their content. + """ + kwargs['collections'] = [tf.GraphKeys.LOCAL_VARIABLES] + return getter(name, *args, **kwargs) + + +class TensorFlowArchitecture(Architecture): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "", + global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): + """ + :param agent_parameters: the agent parameters + :param spaces: the spaces definition of the agent + :param name: the name of the network + :param global_network: the global network replica that is shared between all the workers + :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) + :param network_is_trainable: is the network trainable (we can apply gradients on it) + """ + super().__init__(agent_parameters, spaces, name) + self.middleware = None + self.network_is_local = network_is_local + self.global_network = global_network + if not self.network_parameters.tensorflow_support: + raise ValueError('TensorFlow is not supported for this agent') + self.sess = None + self.inputs = {} + self.outputs = [] + self.targets = [] + self.importance_weights = [] + self.losses = [] + self.total_loss = None + self.trainable_weights = [] + self.weights_placeholders = [] + self.shared_accumulated_gradients = [] + self.curr_rnn_c_in = None + self.curr_rnn_h_in = None + self.gradients_wrt_inputs = [] + self.train_writer = None + self.accumulated_gradients = None + self.network_is_trainable = network_is_trainable + + self.is_chief = self.ap.task_parameters.task_index == 0 + self.network_is_global = not self.network_is_local and global_network is None + self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None + + self.optimizer_type = self.network_parameters.optimizer_type + if self.ap.task_parameters.seed is not None: + tf.set_random_seed(self.ap.task_parameters.seed) + with tf.variable_scope("/".join(self.name.split("/")[1:]), initializer=tf.contrib.layers.xavier_initializer(), + custom_getter=local_getter if network_is_local and global_network else None): + self.global_step = tf.train.get_or_create_global_step() + + # build the network + self.weights = self.get_model() + + # create the placeholder for the assigning gradients and some tensorboard summaries for the weights + for idx, var in enumerate(self.weights): + placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder') + self.weights_placeholders.append(placeholder) + if self.ap.visualization.tensorboard: + variable_summaries(var) + + # create op for assigning a list of weights to the network weights + self.update_weights_from_list = [weights.assign(holder) for holder, weights in + zip(self.weights_placeholders, self.weights)] + + # locks for synchronous training + if self.network_is_global: + self._create_locks_for_synchronous_training() + + # gradients ops + self._create_gradient_ops() + + self.inc_step = self.global_step.assign_add(1) + + # reset LSTM hidden cells + self.reset_internal_memory() + + if self.ap.visualization.tensorboard: + current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, + scope=tf.contrib.framework.get_name_scope()) + self.merged = tf.summary.merge(current_scope_summaries) + + # initialize or restore model + self.init_op = tf.group( + tf.global_variables_initializer(), + tf.local_variables_initializer() + ) + + # set the fetches for training + self._set_initial_fetch_list() + + def get_model(self) -> List: + """ + Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`, + `output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`, + `current_learning_rate`, and `optimizer`. + + :return: A list of the model's weights + """ + raise NotImplementedError + + def _set_initial_fetch_list(self): + """ + Create an initial list of tensors to fetch in each training iteration + :return: None + """ + self.train_fetches = [self.gradients_norm] + if self.network_parameters.clip_gradients: + self.train_fetches.append(self.clipped_grads) + else: + self.train_fetches.append(self.tensor_gradients) + self.train_fetches += [self.total_loss, self.losses] + if self.middleware.__class__.__name__ == 'LSTMMiddleware': + self.train_fetches.append(self.middleware.state_out) + self.additional_fetches_start_idx = len(self.train_fetches) + + def _create_locks_for_synchronous_training(self): + """ + Create locks for synchronizing the different workers during training + :return: None + """ + self.lock_counter = tf.get_variable("lock_counter", [], tf.int32, + initializer=tf.constant_initializer(0, dtype=tf.int32), + trainable=False) + self.lock = self.lock_counter.assign_add(1, use_locking=True) + self.lock_init = self.lock_counter.assign(0) + + self.release_counter = tf.get_variable("release_counter", [], tf.int32, + initializer=tf.constant_initializer(0, dtype=tf.int32), + trainable=False) + self.release = self.release_counter.assign_add(1, use_locking=True) + self.release_decrement = self.release_counter.assign_add(-1, use_locking=True) + self.release_init = self.release_counter.assign(0) + + def _create_gradient_ops(self): + """ + Create all the tensorflow operations for calculating gradients, processing the gradients and applying them + :return: None + """ + + self.tensor_gradients = tf.gradients(self.total_loss, self.weights) + self.gradients_norm = tf.global_norm(self.tensor_gradients) + + # gradient clipping + if self.network_parameters.clip_gradients is not None and self.network_parameters.clip_gradients != 0: + self._create_gradient_clipping_ops() + + # when using a shared optimizer, we create accumulators to store gradients from all the workers before + # applying them + if self.distributed_training: + self._create_gradient_accumulators() + + # gradients of the outputs w.r.t. the inputs + self.gradients_wrt_inputs = [{name: tf.gradients(output, input_ph) for name, input_ph in + self.inputs.items()} for output in self.outputs] + self.gradients_weights_ph = [tf.placeholder('float32', self.outputs[i].shape, 'output_gradient_weights') + for i in range(len(self.outputs))] + self.weighted_gradients = [] + for i in range(len(self.outputs)): + unnormalized_gradients = tf.gradients(self.outputs[i], self.weights, self.gradients_weights_ph[i]) + # unnormalized gradients seems to be better at the time. TODO: validate this accross more environments + # self.weighted_gradients.append(list(map(lambda x: tf.div(x, self.network_parameters.batch_size), + # unnormalized_gradients))) + self.weighted_gradients.append(unnormalized_gradients) + + # defining the optimization process (for LBFGS we have less control over the optimizer) + if self.optimizer_type != 'LBFGS' and self.network_is_trainable: + self._create_gradient_applying_ops() + + def _create_gradient_accumulators(self): + if self.network_is_global: + self.shared_accumulated_gradients = [tf.Variable(initial_value=tf.zeros_like(var)) for var in self.weights] + self.accumulate_shared_gradients = [var.assign_add(holder, use_locking=True) for holder, var in + zip(self.weights_placeholders, self.shared_accumulated_gradients)] + self.init_shared_accumulated_gradients = [var.assign(tf.zeros_like(var)) for var in + self.shared_accumulated_gradients] + elif self.network_is_local: + self.accumulate_shared_gradients = self.global_network.accumulate_shared_gradients + self.init_shared_accumulated_gradients = self.global_network.init_shared_accumulated_gradients + + def _create_gradient_clipping_ops(self): + """ + Create tensorflow ops for clipping the gradients according to the given GradientClippingMethod + :return: None + """ + if self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByGlobalNorm: + self.clipped_grads, self.grad_norms = tf.clip_by_global_norm(self.tensor_gradients, + self.network_parameters.clip_gradients) + elif self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByValue: + self.clipped_grads = [tf.clip_by_value(grad, + -self.network_parameters.clip_gradients, + self.network_parameters.clip_gradients) + for grad in self.tensor_gradients] + elif self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByNorm: + self.clipped_grads = [tf.clip_by_norm(grad, self.network_parameters.clip_gradients) + for grad in self.tensor_gradients] + + def _create_gradient_applying_ops(self): + """ + Create tensorflow ops for applying the gradients to the network weights according to the training scheme + (distributed training - local or global network, shared optimizer, etc.) + :return: None + """ + if self.network_is_global and self.network_parameters.shared_optimizer and \ + not self.network_parameters.async_training: + # synchronous training with shared optimizer? -> create an operation for applying the gradients + # accumulated in the shared gradients accumulator + self.update_weights_from_shared_gradients = self.optimizer.apply_gradients( + zip(self.shared_accumulated_gradients, self.weights), + global_step=self.global_step) + + elif self.distributed_training and self.network_is_local: + # distributed training but independent optimizer? -> create an operation for applying the gradients + # to the global weights + self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( + zip(self.weights_placeholders, self.global_network.weights), global_step=self.global_step) + + elif self.network_is_trainable: + # not any of the above but is trainable? -> create an operation for applying the gradients to + # this network weights + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name) + + with tf.control_dependencies(update_ops): + self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( + zip(self.weights_placeholders, self.weights), global_step=self.global_step) + + def set_session(self, sess): + self.sess = sess + + task_is_distributed = isinstance(self.ap.task_parameters, DistributedTaskParameters) + # initialize the session parameters in single threaded runs. Otherwise, this is done through the + # MonitoredSession object in the graph manager + if not task_is_distributed: + self.sess.run(self.init_op) + + if self.ap.visualization.tensorboard: + # Write the merged summaries to the current experiment directory + if not task_is_distributed: + self.train_writer = tf.summary.FileWriter(self.ap.task_parameters.experiment_path + '/tensorboard') + self.train_writer.add_graph(self.sess.graph) + elif self.network_is_local: + self.train_writer = tf.summary.FileWriter(self.ap.task_parameters.experiment_path + + '/tensorboard/worker{}'.format(self.ap.task_parameters.task_index)) + self.train_writer.add_graph(self.sess.graph) + + # wait for all the workers to set their session + if not self.network_is_local: + self.wait_for_all_workers_barrier() + + def reset_accumulated_gradients(self): + """ + Reset the gradients accumulation placeholder + """ + if self.accumulated_gradients is None: + self.accumulated_gradients = self.sess.run(self.weights) + + for ix, grad in enumerate(self.accumulated_gradients): + self.accumulated_gradients[ix] = grad * 0 + + def accumulate_gradients(self, inputs, targets, additional_fetches=None, importance_weights=None, + no_accumulation=False): + """ + Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation + placeholders + :param additional_fetches: Optional tensors to fetch during gradients calculation + :param inputs: The input batch for the network + :param targets: The targets corresponding to the input batch + :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss + error of this sample. If it is not given, the samples losses won't be scaled + :param no_accumulation: If is set to True, the gradients in the accumulated gradients placeholder will be + replaced by the newely calculated gradients instead of accumulating the new gradients. + This can speed up the function runtime by around 10%. + :return: A list containing the total loss and the individual network heads losses + """ + + if self.accumulated_gradients is None: + self.reset_accumulated_gradients() + + # feed inputs + if additional_fetches is None: + additional_fetches = [] + feed_dict = self.create_feed_dict(inputs) + + # feed targets + targets = force_list(targets) + for placeholder_idx, target in enumerate(targets): + feed_dict[self.targets[placeholder_idx]] = target + + # feed importance weights + importance_weights = force_list(importance_weights) + for placeholder_idx, target_ph in enumerate(targets): + if len(importance_weights) <= placeholder_idx or importance_weights[placeholder_idx] is None: + importance_weight = np.ones(target_ph.shape[0]) + else: + importance_weight = importance_weights[placeholder_idx] + importance_weight = np.reshape(importance_weight, (-1,) + (1,) * (len(target_ph.shape) - 1)) + + feed_dict[self.importance_weights[placeholder_idx]] = importance_weight + + if self.optimizer_type != 'LBFGS': + + # feed the lstm state if necessary + if self.middleware.__class__.__name__ == 'LSTMMiddleware': + # we can't always assume that we are starting from scratch here can we? + feed_dict[self.middleware.c_in] = self.middleware.c_init + feed_dict[self.middleware.h_in] = self.middleware.h_init + + fetches = self.train_fetches + additional_fetches + if self.ap.visualization.tensorboard: + fetches += [self.merged] + + # get grads + result = self.sess.run(fetches, feed_dict=feed_dict) + if hasattr(self, 'train_writer') and self.train_writer is not None: + self.train_writer.add_summary(result[-1], self.sess.run(self.global_step)) + + # extract the fetches + norm_unclipped_grads, grads, total_loss, losses = result[:4] + if self.middleware.__class__.__name__ == 'LSTMMiddleware': + (self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4] + fetched_tensors = [] + if len(additional_fetches) > 0: + fetched_tensors = result[self.additional_fetches_start_idx:self.additional_fetches_start_idx + + len(additional_fetches)] + + # accumulate the gradients + for idx, grad in enumerate(grads): + if no_accumulation: + self.accumulated_gradients[idx] = grad + else: + self.accumulated_gradients[idx] += grad + + return total_loss, losses, norm_unclipped_grads, fetched_tensors + + else: + self.optimizer.minimize(session=self.sess, feed_dict=feed_dict) + + return [0] + + def create_feed_dict(self, inputs): + feed_dict = {} + for input_name, input_value in inputs.items(): + if isinstance(input_name, str): + if input_name not in self.inputs: + raise ValueError(( + 'input name {input_name} was provided to create a feed ' + 'dictionary, but there is no placeholder with that name. ' + 'placeholder names available include: {placeholder_names}' + ).format( + input_name=input_name, + placeholder_names=', '.join(self.inputs.keys()) + )) + + feed_dict[self.inputs[input_name]] = input_value + elif isinstance(input_name, tf.Tensor) and input_name.op.type == 'Placeholder': + feed_dict[input_name] = input_value + else: + raise ValueError(( + 'input dictionary expects strings or placeholders as keys, ' + 'but found key {key} of type {type}' + ).format( + key=input_name, + type=type(input_name), + )) + + return feed_dict + + def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None): + """ + Applies the given gradients to the network weights and resets the accumulation placeholder + :param gradients: The gradients to use for the update + :param scaler: A scaling factor that allows rescaling the gradients before applying them + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) + + """ + self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs) + self.reset_accumulated_gradients() + + def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False): + """ + Waits for all the workers to lock a certain lock and then continues + :param lock: the name of the lock to use + :param include_only_training_workers: wait only for training workers or for all the workers? + :return: None + """ + if include_only_training_workers: + num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks + else: + num_workers_to_wait_for = self.ap.task_parameters.num_tasks + + # lock + if hasattr(self, '{}_counter'.format(lock)): + self.sess.run(getattr(self, lock)) + while self.sess.run(getattr(self, '{}_counter'.format(lock))) % num_workers_to_wait_for != 0: + time.sleep(0.00001) + # self.sess.run(getattr(self, '{}_init'.format(lock))) + else: + raise ValueError("no counter was defined for the lock {}".format(lock)) + + def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False): + """ + A barrier that allows waiting for all the workers to finish a certain block of commands + :param include_only_training_workers: wait only for training workers or for all the workers? + :return: None + """ + self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers) + self.sess.run(self.lock_init) + + # we need to lock again (on a different lock) in order to prevent a situation where one of the workers continue + # and then was able to first increase the lock again by one, only to have a late worker to reset it again. + # so we want to make sure that all workers are done resetting the lock before continuting to reuse that lock. + + self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) + self.sess.run(self.release_init) + + def apply_gradients(self, gradients, scaler=1., additional_inputs=None): + """ + Applies the given gradients to the network weights + :param gradients: The gradients to use for the update + :param scaler: A scaling factor that allows rescaling the gradients before applying them. + The gradients will be MULTIPLIED by this factor + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) + """ + + if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters): + if hasattr(self, 'global_step') and not self.network_is_local: + self.sess.run(self.inc_step) + + if self.optimizer_type != 'LBFGS': + + if self.distributed_training and not self.network_parameters.async_training: + # rescale the gradients so that they average out with the gradients from the other workers + if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training: + scaler /= float(self.ap.task_parameters.num_training_tasks) + + # rescale the gradients + if scaler != 1.: + for gradient in gradients: + gradient *= scaler + + # apply the gradients + feed_dict = dict(zip(self.weights_placeholders, gradients)) + if self.distributed_training and self.network_parameters.shared_optimizer \ + and not self.network_parameters.async_training: + # synchronous distributed training with shared optimizer: + # - each worker adds its gradients to the shared gradients accumulators + # - we wait for all the workers to add their gradients + # - the chief worker (worker with task index = 0) applies the gradients once and resets the accumulators + + self.sess.run(self.accumulate_shared_gradients, feed_dict=feed_dict) + + self.wait_for_all_workers_barrier(include_only_training_workers=True) + + if self.is_chief: + self.sess.run(self.update_weights_from_shared_gradients) + self.sess.run(self.init_shared_accumulated_gradients) + else: + # async distributed training / distributed training with independent optimizer + # / non-distributed training - just apply the gradients + feed_dict = dict(zip(self.weights_placeholders, gradients)) + if additional_inputs is not None: + feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)} + self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict) + + # release barrier + if self.distributed_training and not self.network_parameters.async_training: + self.wait_for_all_workers_barrier(include_only_training_workers=True) + + def predict(self, inputs, outputs=None, squeeze_output=True, initial_feed_dict=None): + """ + Run a forward pass of the network using the given input + :param inputs: The input for the network + :param outputs: The output for the network, defaults to self.outputs + :param squeeze_output: call squeeze_list on output + :param initial_feed_dict: a dictionary to use as the initial feed_dict. other inputs will be added to this dict + :return: The network output + + WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. + """ + feed_dict = self.create_feed_dict(inputs) + if initial_feed_dict: + feed_dict.update(initial_feed_dict) + if outputs is None: + outputs = self.outputs + + if self.middleware.__class__.__name__ == 'LSTMMiddleware': + feed_dict[self.middleware.c_in] = self.curr_rnn_c_in + feed_dict[self.middleware.h_in] = self.curr_rnn_h_in + + output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.sess.run([outputs, self.middleware.state_out], + feed_dict=feed_dict) + else: + output = self.sess.run(outputs, feed_dict) + + if squeeze_output: + output = squeeze_list(output) + return output + + @staticmethod + def parallel_predict(sess: Any, + network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\ + List[np.ndarray]: + """ + :param sess: active session to use for prediction + :param network_input_tuples: tuple of network and corresponding input + :return: list of outputs from all networks + """ + feed_dict = {} + fetches = [] + + for network, input in network_input_tuples: + feed_dict.update(network.create_feed_dict(input)) + fetches += network.outputs + + outputs = sess.run(fetches, feed_dict) + + return outputs + + def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None): + """ + Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients + :param additional_fetches: Optional tensors to fetch during the training process + :param inputs: The input for the network + :param targets: The targets corresponding to the input batch + :param scaler: A scaling factor that allows rescaling the gradients before applying them + :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss + error of this sample. If it is not given, the samples losses won't be scaled + :return: The loss of the network + """ + if additional_fetches is None: + additional_fetches = [] + force_list(additional_fetches) + loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, + importance_weights=importance_weights) + self.apply_and_reset_gradients(self.accumulated_gradients, scaler) + return loss + + def get_weights(self): + """ + :return: a list of tensors containing the network weights for each layer + """ + return self.weights + + def set_weights(self, weights, new_rate=1.0): + """ + Sets the network weights from the given list of weights tensors + """ + feed_dict = {} + old_weights, new_weights = self.sess.run([self.get_weights(), weights]) + for placeholder_idx, new_weight in enumerate(new_weights): + feed_dict[self.weights_placeholders[placeholder_idx]]\ + = new_rate * new_weight + (1 - new_rate) * old_weights[placeholder_idx] + self.sess.run(self.update_weights_from_list, feed_dict) + + def get_variable_value(self, variable): + """ + Get the value of a variable from the graph + :param variable: the variable + :return: the value of the variable + """ + return self.sess.run(variable) + + def set_variable_value(self, assign_op, value, placeholder=None): + """ + Updates the value of a variable. + This requires having an assign operation for the variable, and a placeholder which will provide the value + :param assign_op: an assign operation for the variable + :param value: a value to set the variable to + :param placeholder: a placeholder to hold the given value for injecting it into the variable + """ + self.sess.run(assign_op, feed_dict={placeholder: value}) + + def set_is_training(self, state: bool): + """ + Set the phase of the network between training and testing + :param state: The current state (True = Training, False = Testing) + :return: None + """ + self.set_variable_value(self.assign_is_training, state, self.is_training_placeholder) + + def reset_internal_memory(self): + """ + Reset any internal memory used by the network. For example, an LSTM internal state + :return: None + """ + # initialize LSTM hidden states + if self.middleware.__class__.__name__ == 'LSTMMiddleware': + self.curr_rnn_c_in = self.middleware.c_init + self.curr_rnn_h_in = self.middleware.h_init + + def collect_savers(self, parent_path_suffix: str) -> SaverCollection: + """ + Collection of all checkpoints for the network (typically only one checkpoint) + :param parent_path_suffix: path suffix of the parent of the network + (e.g. could be name of level manager plus name of agent) + :return: checkpoint collection for the network + """ + savers = SaverCollection() + if not self.distributed_training: + savers.add(GlobalVariableSaver(self.name)) + return savers + + +def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: + """ + Given the input nodes and output nodes of the TF graph, save it as an onnx graph + This requires the TF graph and the weights checkpoint to be stored in the experiment directory. + It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. + + :param input_nodes: A list of input nodes for the TF graph + :param output_nodes: A list of output nodes for the TF graph + :param checkpoint_save_dir: The directory to save the ONNX graph to + :return: None + """ + import tf2onnx # just to verify that tf2onnx is installed + + # freeze graph + frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb") + freeze_graph_command = [ + "python -m tensorflow.python.tools.freeze_graph", + "--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")), + "--input_binary=true", + "--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])), + "--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)), + "--output_graph={}".format(frozen_graph_path) + ] + start_shell_command_and_wait(" ".join(freeze_graph_command)) + + # convert graph to onnx + onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx") + convert_to_onnx_command = [ + "python -m tf2onnx.convert", + "--input {}".format(frozen_graph_path), + "--inputs '{}'".format(','.join(input_nodes)), + "--outputs '{}'".format(','.join(output_nodes)), + "--output {}".format(onnx_graph_path), + "--verbose" + ] + start_shell_command_and_wait(" ".join(convert_to_onnx_command)) diff --git a/rl_coach/architectures/legacy_tf_components/distributed_tf_utils.py b/rl_coach/architectures/legacy_tf_components/distributed_tf_utils.py new file mode 100644 index 000000000..bbbbc0f23 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/distributed_tf_utils.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Tuple + +import tensorflow as tf + + +def create_cluster_spec(parameters_server: str, workers: str) -> tf.train.ClusterSpec: + """ + Creates a ClusterSpec object representing the cluster. + :param parameters_server: comma-separated list of hostname:port pairs to which the parameter servers are assigned + :param workers: comma-separated list of hostname:port pairs to which the workers are assigned + :return: a ClusterSpec object representing the cluster + """ + # extract the parameter servers and workers from the given strings + ps_hosts = parameters_server.split(",") + worker_hosts = workers.split(",") + + # Create a cluster spec from the parameter server and worker hosts + cluster_spec = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) + + return cluster_spec + + +def create_and_start_parameters_server(cluster_spec: tf.train.ClusterSpec, config: tf.ConfigProto=None) -> None: + """ + Create and start a parameter server + :param cluster_spec: the ClusterSpec object representing the cluster + :param config: the tensorflow config to use + :return: None + """ + # create a server object for the parameter server + server = tf.train.Server(cluster_spec, job_name="ps", task_index=0, config=config) + + # wait for the server to finish + server.join() + + +def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_index: int, + use_cpu: bool=True, config: tf.ConfigProto=None) -> Tuple[str, tf.device]: + """ + Creates a worker server and a device setter used to assign the workers operations to + :param cluster_spec: a ClusterSpec object representing the cluster + :param task_index: the index of the worker task + :param use_cpu: if use_cpu=True, all the agent operations will be assigned to a CPU instead of a GPU + :param config: the tensorflow config to use + :return: the target string for the tf.Session and the worker device setter object + """ + # Create and start a worker + server = tf.train.Server(cluster_spec, job_name="worker", task_index=task_index, config=config) + + # Assign ops to the local worker + worker_device = "/job:worker/task:{}".format(task_index) + if use_cpu: + worker_device += "/cpu:0" + else: + worker_device += "/device:GPU:0" + device = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster_spec) + + return server.target, device + + +def create_monitored_session(target: tf.train.Server, task_index: int, + checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session: + """ + Create a monitored session for the worker + :param target: the target string for the tf.Session + :param task_index: the task index of the worker + :param checkpoint_dir: a directory path where the checkpoints will be stored + :param checkpoint_save_secs: number of seconds between checkpoints storing + :param config: the tensorflow configuration (optional) + :return: the session to use for the run + """ + # we chose the first task to be the chief + is_chief = task_index == 0 + + # Create the monitored session + sess = tf.train.MonitoredTrainingSession( + master=target, + is_chief=is_chief, + hooks=[], + checkpoint_dir=checkpoint_dir, + save_checkpoint_secs=checkpoint_save_secs, + config=config, + log_step_count_steps=0 # disable logging of steps to avoid TF warning during inference + ) + + return sess + diff --git a/rl_coach/architectures/legacy_tf_components/embedders/__init__.py b/rl_coach/architectures/legacy_tf_components/embedders/__init__.py new file mode 100644 index 000000000..5091f35c1 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/embedders/__init__.py @@ -0,0 +1,5 @@ +from .image_embedder import ImageEmbedder +from .vector_embedder import VectorEmbedder +from .tensor_embedder import TensorEmbedder + +__all__ = ['ImageEmbedder', 'VectorEmbedder', 'TensorEmbedder'] diff --git a/rl_coach/architectures/legacy_tf_components/embedders/embedder.py b/rl_coach/architectures/legacy_tf_components/embedders/embedder.py new file mode 100644 index 000000000..13544c9ac --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/embedders/embedder.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Union, Tuple +import copy + +import numpy as np +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense +from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters + +from rl_coach.core_types import InputEmbedding +from rl_coach.utils import force_list + + +class InputEmbedder(object): + """ + An input embedder is the first part of the network, which takes the input from the state and produces a vector + embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there + can be multiple embedders in a single network + """ + def __init__(self, input_size: List[int], activation_function=tf.nn.relu, + scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, + name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense, + is_training=False): + self.name = name + self.input_size = input_size + self.activation_function = activation_function + self.batchnorm = batchnorm + self.dropout_rate = dropout_rate + self.input = None + self.output = None + self.scheme = scheme + self.return_type = InputEmbedding + self.layers_params = [] + self.layers = [] + self.input_rescaling = input_rescaling + self.input_offset = input_offset + self.input_clipping = input_clipping + self.dense_layer = dense_layer + if self.dense_layer is None: + self.dense_layer = Dense + self.is_training = is_training + + # layers order is conv -> batchnorm -> activation -> dropout + if isinstance(self.scheme, EmbedderScheme): + self.layers_params = copy.copy(self.schemes[self.scheme]) + self.layers_params = [convert_layer(l) for l in self.layers_params] + else: + # if scheme is specified directly, convert to TF layer if it's not a callable object + # NOTE: if layer object is callable, it must return a TF tensor when invoked + self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] + + # we allow adding batchnorm, dropout or activation functions after each layer. + # The motivation is to simplify the transition between a network with batchnorm and a network without + # batchnorm to a single flag (the same applies to activation function and dropout) + if self.batchnorm or self.activation_function or self.dropout_rate > 0: + for layer_idx in reversed(range(len(self.layers_params))): + self.layers_params.insert(layer_idx+1, + BatchnormActivationDropout(batchnorm=self.batchnorm, + activation_function=self.activation_function, + dropout_rate=self.dropout_rate)) + + def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Wrapper for building the module graph including scoping and loss creation + :param prev_input_placeholder: the input to the graph + :return: the input placeholder and the output of the last layer + """ + with tf.variable_scope(self.get_name()): + if prev_input_placeholder is None: + self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name()) + else: + self.input = prev_input_placeholder + self._build_module() + + return self.input, self.output + + def _build_module(self) -> None: + """ + Builds the graph of the module + This method is called early on from __call__. It is expected to store the graph + in self.output. + :return: None + """ + # NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not + # to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the + # input to the network to be float, which is 4x more expensive in memory. + # thus causing each saved transition in the memory to also be 4x more pricier. + + input_layer = self.input / self.input_rescaling + input_layer -= self.input_offset + # clip input using te given range + if self.input_clipping is not None: + input_layer = tf.clip_by_value(input_layer, self.input_clipping[0], self.input_clipping[1]) + + self.layers.append(input_layer) + + for idx, layer_params in enumerate(self.layers_params): + self.layers.extend(force_list( + layer_params(input_layer=self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx), + is_training=self.is_training) + )) + + self.output = tf.contrib.layers.flatten(self.layers[-1]) + + @property + def input_size(self) -> List[int]: + return self._input_size + + @input_size.setter + def input_size(self, value: Union[int, List[int]]): + if isinstance(value, np.ndarray) or isinstance(value, tuple): + value = list(value) + elif isinstance(value, int): + value = [value] + if not isinstance(value, list): + raise ValueError(( + 'input_size expected to be a list, found {value} which has type {type}' + ).format(value=value, type=type(value))) + self._input_size = value + + @property + def schemes(self): + raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default " + "configurations.") + + def get_name(self) -> str: + """ + Get a formatted name for the module + :return: the formatted name + """ + return self.name + + def __str__(self): + result = ['Input size = {}'.format(self._input_size)] + if self.input_rescaling != 1.0 or self.input_offset != 0.0: + result.append('Input Normalization (scale = {}, offset = {})'.format(self.input_rescaling, self.input_offset)) + result.extend([str(l) for l in self.layers_params]) + if not self.layers_params: + result.append('No layers') + + return '\n'.join(result) diff --git a/rl_coach/architectures/legacy_tf_components/embedders/image_embedder.py b/rl_coach/architectures/legacy_tf_components/embedders/image_embedder.py new file mode 100644 index 000000000..b05ec8e03 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/embedders/image_embedder.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder +from rl_coach.base_parameters import EmbedderScheme +from rl_coach.core_types import InputImageEmbedding + + +class ImageEmbedder(InputEmbedder): + """ + An input embedder that performs convolutions on the input and then flattens the result. + The embedder is intended for image like inputs, where the channels are expected to be the last axis. + The embedder also allows custom rescaling of the input prior to the neural network. + """ + + def __init__(self, input_size: List[int], activation_function=tf.nn.relu, + scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, + name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None, + dense_layer=Dense, is_training=False): + super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, + input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) + self.return_type = InputImageEmbedding + if len(input_size) != 3 and scheme != EmbedderScheme.Empty: + raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}" + .format(input_size)) + + @property + def schemes(self): + return { + EmbedderScheme.Empty: + [], + + EmbedderScheme.Shallow: + [ + Conv2d(32, 3, 1) + ], + + # atari dqn + EmbedderScheme.Medium: + [ + Conv2d(32, 8, 4), + Conv2d(64, 4, 2), + Conv2d(64, 3, 1) + ], + + # carla + EmbedderScheme.Deep: \ + [ + Conv2d(32, 5, 2), + Conv2d(32, 3, 1), + Conv2d(64, 3, 2), + Conv2d(64, 3, 1), + Conv2d(128, 3, 2), + Conv2d(128, 3, 1), + Conv2d(256, 3, 2), + Conv2d(256, 3, 1) + ] + } + + diff --git a/rl_coach/architectures/legacy_tf_components/embedders/tensor_embedder.py b/rl_coach/architectures/legacy_tf_components/embedders/tensor_embedder.py new file mode 100644 index 000000000..286442c4e --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/embedders/tensor_embedder.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder +from rl_coach.base_parameters import EmbedderScheme +from rl_coach.core_types import InputTensorEmbedding + + +class TensorEmbedder(InputEmbedder): + """ + A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector + embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors) + or other type of data that is more than 1 dimension (i.e. not vector) but is not an image. + + NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing + a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable + object must accept a single input, the normalized observation, and return a Tensorflow symbol which + will calculate an embedding vector for each sample in the batch. + Keep in mind that the scheme is a list of Tensorflow symbols, which are stacked by optional batchnorm, + activation, and dropout in between as specified in InputEmbedderParameters. + """ + + def __init__(self, input_size: List[int], activation_function=tf.nn.relu, + scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, + name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, + dense_layer=Dense, is_training=False): + super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, + input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) + self.return_type = InputTensorEmbedding + assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder" + + @property + def schemes(self): + return {} diff --git a/rl_coach/architectures/legacy_tf_components/embedders/vector_embedder.py b/rl_coach/architectures/legacy_tf_components/embedders/vector_embedder.py new file mode 100644 index 000000000..60b728dbd --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/embedders/vector_embedder.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder +from rl_coach.base_parameters import EmbedderScheme +from rl_coach.core_types import InputVectorEmbedding + + +class VectorEmbedder(InputEmbedder): + """ + An input embedder that is intended for inputs that can be represented as vectors. + The embedder flattens the input, applies several dense layers to it and returns the output. + """ + + def __init__(self, input_size: List[int], activation_function=tf.nn.relu, + scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, + name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, + dense_layer=Dense, is_training=False): + super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, + input_rescaling, input_offset, input_clipping, dense_layer=dense_layer, + is_training=is_training) + + self.return_type = InputVectorEmbedding + if len(self.input_size) != 1 and scheme != EmbedderScheme.Empty: + raise ValueError("The input size of a vector embedder must contain only a single dimension") + + @property + def schemes(self): + return { + EmbedderScheme.Empty: + [], + + EmbedderScheme.Shallow: + [ + self.dense_layer(128) + ], + + # dqn + EmbedderScheme.Medium: + [ + self.dense_layer(256) + ], + + # carla + EmbedderScheme.Deep: \ + [ + self.dense_layer(128), + self.dense_layer(128), + self.dense_layer(128) + ] + } diff --git a/rl_coach/architectures/legacy_tf_components/general_network.py b/rl_coach/architectures/legacy_tf_components/general_network.py new file mode 100644 index 000000000..8821ac6cc --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/general_network.py @@ -0,0 +1,449 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from types import MethodType +from typing import Dict, List, Union + +import numpy as np +import tensorflow as tf + +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.head_parameters import HeadParameters +from rl_coach.architectures.middleware_parameters import MiddlewareParameters +from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture +from rl_coach.architectures.tensorflow_components import utils +from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType +from rl_coach.core_types import PredictionType +from rl_coach.logger import screen +from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace +from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string + + +class GeneralTensorFlowNetwork(TensorFlowArchitecture): + """ + A generalized version of all possible networks implemented using tensorflow. + """ + # dictionary of variable-scope name to variable-scope object to prevent tensorflow from + # creating a new auxiliary variable scope even when name is properly specified + variable_scopes_dict = dict() + + @staticmethod + def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork': + """ + Construct a network class using the provided variable scope and on requested devices + :param variable_scope: string specifying variable scope under which to create network variables + :param devices: list of devices (can be list of Device objects, or string for TF distributed) + :param args: all other arguments for class initializer + :param kwargs: all other keyword arguments for class initializer + :return: a GeneralTensorFlowNetwork object + """ + if len(devices) > 1: + screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0])) + + def construct_on_device(): + with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])): + return GeneralTensorFlowNetwork(*args, **kwargs) + + # If variable_scope is in our dictionary, then this is not the first time that this variable_scope + # is being used with construct(). So to avoid TF adding an incrementing number to the end of the + # variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new + # variable_scope() call and also recover the name space using name_scope + if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict: + variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] + with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs: + with tf.name_scope(vs.original_name_scope): + return construct_on_device() + else: + with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs: + # Add variable_scope object to dictionary for next call to construct + GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs + return construct_on_device() + + @staticmethod + def _tf_device(device: Union[str, MethodType, Device]) -> str: + """ + Convert device to tensorflow-specific device representation + :param device: either a specific string or method (used in distributed mode) which is returned without + any change or a Device type, which will be converted to a string + :return: tensorflow-specific string for device + """ + if isinstance(device, str) or isinstance(device, MethodType): + return device + elif isinstance(device, Device): + if device.device_type == DeviceType.CPU: + return "/cpu:0" + elif device.device_type == DeviceType.GPU: + return "/device:GPU:{}".format(device.index) + else: + raise ValueError("Invalid device_type: {}".format(device.device_type)) + else: + raise ValueError("Invalid device instance type: {}".format(type(device))) + + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str, + global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): + """ + :param agent_parameters: the agent parameters + :param spaces: the spaces definition of the agent + :param name: the name of the network + :param global_network: the global network replica that is shared between all the workers + :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) + :param network_is_trainable: is the network trainable (we can apply gradients on it) + """ + self.global_network = global_network + self.network_is_local = network_is_local + self.network_wrapper_name = name.split('/')[0] + self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name] + self.num_heads_per_network = 1 if self.network_parameters.use_separate_networks_per_head else \ + len(self.network_parameters.heads_parameters) + self.num_networks = 1 if not self.network_parameters.use_separate_networks_per_head else \ + len(self.network_parameters.heads_parameters) + + self.gradients_from_head_rescalers = [] + self.gradients_from_head_rescalers_placeholders = [] + self.update_head_rescaler_value_ops = [] + + self.adaptive_learning_rate_scheme = None + self.current_learning_rate = None + + # init network modules containers + self.input_embedders = [] + self.output_heads = [] + super().__init__(agent_parameters, spaces, name, global_network, + network_is_local, network_is_trainable) + + self.available_return_types = self._available_return_types() + self.is_training = None + + def _available_return_types(self): + ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)} + + components = self.input_embedders + [self.middleware] + self.output_heads + for component in components: + if not hasattr(component, 'return_type'): + raise ValueError(( + "{} has no return_type attribute. Without this, it is " + "unclear how this component should be used." + ).format(component)) + + if component.return_type is not None: + ret_dict[component.return_type].append(component) + + return ret_dict + + def predict_with_prediction_type(self, states: Dict[str, np.ndarray], + prediction_type: PredictionType) -> Dict[str, np.ndarray]: + """ + Search for a component[s] which has a return_type set to the to the requested PredictionType, and get + predictions for it. + + :param states: The input states to the network. + :param prediction_type: The requested PredictionType to look for in the network components + :return: A dictionary with predictions for all components matching the requested prediction type + """ + + ret_dict = {} + for component in self.available_return_types[prediction_type]: + ret_dict[component] = self.predict(inputs=states, outputs=component.output) + + return ret_dict + + def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderParameters): + """ + Given an input embedder parameters class, creates the input embedder and returns it + :param input_name: the name of the input to the embedder (used for retrieving the shape). The input should + be a value within the state or the action. + :param embedder_params: the parameters of the class of the embedder + :return: the embedder instance + """ + allowed_inputs = copy.copy(self.spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(self.spaces.action) + allowed_inputs["goal"] = copy.copy(self.spaces.goal) + + if input_name not in allowed_inputs.keys(): + raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" + .format(input_name, allowed_inputs.keys())) + + emb_type = "vector" + if isinstance(allowed_inputs[input_name], TensorObservationSpace): + emb_type = "tensor" + elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): + emb_type = "image" + + embedder_path = embedder_params.path(emb_type) + embedder_params_copy = copy.copy(embedder_params) + embedder_params_copy.is_training = self.is_training + embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) + embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] + embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] + embedder_params_copy.name = input_name + module = dynamic_import_and_instantiate_module_from_params(embedder_params_copy, + path=embedder_path, + positional_args=[allowed_inputs[input_name].shape]) + return module + + def get_middleware(self, middleware_params: MiddlewareParameters): + """ + Given a middleware type, creates the middleware and returns it + :param middleware_params: the paramaeters of the middleware class + :return: the middleware instance + """ + mod_name = middleware_params.parameterized_class_name + middleware_path = middleware_params.path + middleware_params_copy = copy.copy(middleware_params) + middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) + middleware_params_copy.is_training = self.is_training + module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) + return module + + def get_output_head(self, head_params: HeadParameters, head_idx: int): + """ + Given a head type, creates the head and returns it + :param head_params: the parameters of the head to create + :param head_idx: the head index + :return: the head + """ + mod_name = head_params.parameterized_class_name + head_path = head_params.path + head_params_copy = copy.copy(head_params) + head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) + head_params_copy.is_training = self.is_training + return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ + 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, + 'head_idx': head_idx, 'is_local': self.network_is_local}) + + def get_model(self) -> List: + # validate the configuration + if len(self.network_parameters.input_embedders_parameters) == 0: + raise ValueError("At least one input type should be defined") + + if len(self.network_parameters.heads_parameters) == 0: + raise ValueError("At least one output type should be defined") + + if self.network_parameters.middleware_parameters is None: + raise ValueError("Exactly one middleware type should be defined") + + # ops for defining the training / testing phase + self.is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + self.is_training_placeholder = tf.placeholder("bool") + self.assign_is_training = tf.assign(self.is_training, self.is_training_placeholder) + + for network_idx in range(self.num_networks): + with tf.variable_scope('network_{}'.format(network_idx)): + + #################### + # Input Embeddings # + #################### + + state_embedding = [] + for input_name in sorted(self.network_parameters.input_embedders_parameters): + input_type = self.network_parameters.input_embedders_parameters[input_name] + # get the class of the input embedder + input_embedder = self.get_input_embedder(input_name, input_type) + self.input_embedders.append(input_embedder) + + # input placeholders are reused between networks. on the first network, store the placeholders + # generated by the input_embedders in self.inputs. on the rest of the networks, pass + # the existing input_placeholders into the input_embedders. + if network_idx == 0: + input_placeholder, embedding = input_embedder() + self.inputs[input_name] = input_placeholder + else: + input_placeholder, embedding = input_embedder(self.inputs[input_name]) + + state_embedding.append(embedding) + + ########## + # Merger # + ########## + + if len(state_embedding) == 1: + state_embedding = state_embedding[0] + else: + if self.network_parameters.embedding_merger_type == EmbeddingMergerType.Concat: + state_embedding = tf.concat(state_embedding, axis=-1, name="merger") + elif self.network_parameters.embedding_merger_type == EmbeddingMergerType.Sum: + state_embedding = tf.add_n(state_embedding, name="merger") + + ############## + # Middleware # + ############## + + self.middleware = self.get_middleware(self.network_parameters.middleware_parameters) + _, self.state_embedding = self.middleware(state_embedding) + + ################ + # Output Heads # + ################ + + head_count = 0 + for head_idx in range(self.num_heads_per_network): + + if self.network_parameters.use_separate_networks_per_head: + # if we use separate networks per head, then the head type corresponds to the network idx + head_type_idx = network_idx + head_count = network_idx + else: + # if we use a single network with multiple embedders, then the head type is the current head idx + head_type_idx = head_idx + head_params = self.network_parameters.heads_parameters[head_type_idx] + + for head_copy_idx in range(head_params.num_output_head_copies): + # create output head and add it to the output heads list + self.output_heads.append( + self.get_output_head(head_params, + head_idx*head_params.num_output_head_copies + head_copy_idx) + ) + + # rescale the gradients from the head + self.gradients_from_head_rescalers.append( + tf.get_variable('gradients_from_head_{}-{}_rescalers'.format(head_idx, head_copy_idx), + initializer=float(head_params.rescale_gradient_from_head_by_factor), + dtype=tf.float32)) + + self.gradients_from_head_rescalers_placeholders.append( + tf.placeholder('float', + name='gradients_from_head_{}-{}_rescalers'.format(head_type_idx, head_copy_idx))) + + self.update_head_rescaler_value_ops.append(self.gradients_from_head_rescalers[head_count].assign( + self.gradients_from_head_rescalers_placeholders[head_count])) + + head_input = (1-self.gradients_from_head_rescalers[head_count]) * tf.stop_gradient(self.state_embedding) + \ + self.gradients_from_head_rescalers[head_count] * self.state_embedding + + # build the head + if self.network_is_local: + output, target_placeholder, input_placeholders, importance_weight_ph = \ + self.output_heads[-1](head_input) + + self.targets.extend(target_placeholder) + self.importance_weights.extend(importance_weight_ph) + else: + output, input_placeholders = self.output_heads[-1](head_input) + + self.outputs.extend(output) + # TODO: use head names as well + for placeholder_index, input_placeholder in enumerate(input_placeholders): + self.inputs['output_{}_{}'.format(head_type_idx, placeholder_index)] = input_placeholder + + head_count += 1 + + # model weights + if not self.distributed_training or self.network_is_global: + self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if + 'global_step' not in var.name] + else: + self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)] + + # Losses + self.losses = tf.losses.get_losses(self.full_name) + + # L2 regularization + if self.network_parameters.l2_regularization != 0: + self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \ + * self.network_parameters.l2_regularization + self.losses += self.l2_regularization + + self.total_loss = tf.reduce_sum(self.losses) + # tf.summary.scalar('total_loss', self.total_loss) + + # Learning rate + if self.network_parameters.learning_rate_decay_rate != 0: + self.adaptive_learning_rate_scheme = \ + tf.train.exponential_decay( + self.network_parameters.learning_rate, + self.global_step, + decay_steps=self.network_parameters.learning_rate_decay_steps, + decay_rate=self.network_parameters.learning_rate_decay_rate, + staircase=True) + + self.current_learning_rate = self.adaptive_learning_rate_scheme + else: + self.current_learning_rate = self.network_parameters.learning_rate + + # Optimizer + if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer: + # distributed training + is a local network + optimizer shared -> take the global optimizer + self.optimizer = self.global_network.optimizer + elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer) \ + or self.network_parameters.shared_optimizer or not self.distributed_training: + # distributed training + is a global network + optimizer shared + # OR + # distributed training + is a local network + optimizer not shared + # OR + # non-distributed training + # -> create an optimizer + + if self.network_parameters.optimizer_type == 'Adam': + self.optimizer = tf.train.AdamOptimizer(learning_rate=self.current_learning_rate, + beta1=self.network_parameters.adam_optimizer_beta1, + beta2=self.network_parameters.adam_optimizer_beta2, + epsilon=self.network_parameters.optimizer_epsilon) + elif self.network_parameters.optimizer_type == 'RMSProp': + self.optimizer = tf.train.RMSPropOptimizer(self.current_learning_rate, + decay=self.network_parameters.rms_prop_optimizer_decay, + epsilon=self.network_parameters.optimizer_epsilon) + elif self.network_parameters.optimizer_type == 'LBFGS': + self.optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.total_loss, method='L-BFGS-B', + options={'maxiter': 25}) + else: + raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type)) + + return self.weights + + def __str__(self): + result = [] + + for network in range(self.num_networks): + network_structure = [] + + # embedder + for embedder in self.input_embedders: + network_structure.append("Input Embedder: {}".format(embedder.name)) + network_structure.append(indent_string(str(embedder))) + + if len(self.input_embedders) > 1: + network_structure.append("{} ({})".format(self.network_parameters.embedding_merger_type.name, + ", ".join(["{} embedding".format(e.name) for e in self.input_embedders]))) + + # middleware + network_structure.append("Middleware:") + network_structure.append(indent_string(str(self.middleware))) + + # head + if self.network_parameters.use_separate_networks_per_head: + heads = range(network, network+1) + else: + heads = range(0, len(self.output_heads)) + + for head_idx in heads: + head = self.output_heads[head_idx] + head_params = self.network_parameters.heads_parameters[head_idx] + if head_params.num_output_head_copies > 1: + network_structure.append("Output Head: {} (num copies = {})".format(head.name, head_params.num_output_head_copies)) + else: + network_structure.append("Output Head: {}".format(head.name)) + network_structure.append(indent_string(str(head))) + + # finalize network + if self.num_networks > 1: + result.append("Sub-network for head: {}".format(self.output_heads[network].name)) + result.append(indent_string('\n'.join(network_structure))) + else: + result.append('\n'.join(network_structure)) + + result = '\n'.join(result) + return result diff --git a/rl_coach/architectures/legacy_tf_components/heads/__init__.py b/rl_coach/architectures/legacy_tf_components/heads/__init__.py new file mode 100644 index 000000000..03c237a84 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/heads/__init__.py @@ -0,0 +1,43 @@ +from .q_head import QHead +from .categorical_q_head import CategoricalQHead +from .ddpg_actor_head import DDPGActor +from .dnd_q_head import DNDQHead +from .dueling_q_head import DuelingQHead +from .measurements_prediction_head import MeasurementsPredictionHead +from .naf_head import NAFHead +from .policy_head import PolicyHead +from .ppo_head import PPOHead +from .ppo_v_head import PPOVHead +from .quantile_regression_q_head import QuantileRegressionQHead +from .rainbow_q_head import RainbowQHead +from .v_head import VHead +from .acer_policy_head import ACERPolicyHead +from .sac_head import SACPolicyHead +from .sac_q_head import SACQHead +from .classification_head import ClassificationHead +from .cil_head import RegressionHead +from .td3_v_head import TD3VHead +from .ddpg_v_head import DDPGVHead + +__all__ = [ + 'CategoricalQHead', + 'DDPGActor', + 'DNDQHead', + 'DuelingQHead', + 'MeasurementsPredictionHead', + 'NAFHead', + 'PolicyHead', + 'PPOHead', + 'PPOVHead', + 'QHead', + 'QuantileRegressionQHead', + 'RainbowQHead', + 'VHead', + 'ACERPolicyHead', + 'SACPolicyHead', + 'SACQHead', + 'ClassificationHead', + 'RegressionHead', + 'TD3VHead' + 'DDPGVHead' +] diff --git a/rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py b/rl_coach/architectures/legacy_tf_components/heads/acer_policy_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/acer_policy_head.py rename to rl_coach/architectures/legacy_tf_components/heads/acer_policy_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/categorical_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/categorical_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/categorical_q_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/cil_head.py b/rl_coach/architectures/legacy_tf_components/heads/cil_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/cil_head.py rename to rl_coach/architectures/legacy_tf_components/heads/cil_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/classification_head.py b/rl_coach/architectures/legacy_tf_components/heads/classification_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/classification_head.py rename to rl_coach/architectures/legacy_tf_components/heads/classification_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py b/rl_coach/architectures/legacy_tf_components/heads/ddpg_actor_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py rename to rl_coach/architectures/legacy_tf_components/heads/ddpg_actor_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/ddpg_v_head.py b/rl_coach/architectures/legacy_tf_components/heads/ddpg_v_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/ddpg_v_head.py rename to rl_coach/architectures/legacy_tf_components/heads/ddpg_v_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/dnd_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/dnd_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/dnd_q_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/dueling_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/dueling_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/dueling_q_head.py diff --git a/rl_coach/architectures/legacy_tf_components/heads/head.py b/rl_coach/architectures/legacy_tf_components/heads/head.py new file mode 100644 index 000000000..e971889e9 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/heads/head.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import tensorflow as tf +from tensorflow.python.ops.losses.losses_impl import Reduction +from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer_class +from rl_coach.base_parameters import AgentParameters +from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import force_list +from rl_coach.architectures.tensorflow_components.utils import squeeze_tensor + +# Used to initialize weights for policy and value output layers +def normalized_columns_initializer(std=1.0): + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer + + +class Head(object): + """ + A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through + a neural network to produce the output of the network. There can be multiple heads in a network, and each one has + an assigned loss function. The heads are algorithm dependent. + """ + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu', + dense_layer=Dense, is_training=False): + self.head_idx = head_idx + self.network_name = network_name + self.network_parameters = agent_parameters.network_wrappers[self.network_name] + self.name = "head" + self.output = [] + self.loss = [] + self.loss_type = [] + self.regularizations = [] + self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)], + trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + self.target = [] + self.importance_weight = [] + self.input = [] + self.is_local = is_local + self.ap = agent_parameters + self.spaces = spaces + self.return_type = None + self.activation_function = activation_function + self.dense_layer = dense_layer + if self.dense_layer is None: + self.dense_layer = Dense + else: + self.dense_layer = convert_layer_class(self.dense_layer) + self.is_training = is_training + + def __call__(self, input_layer): + """ + Wrapper for building the module graph including scoping and loss creation + :param input_layer: the input to the graph + :return: the output of the last layer and the target placeholder + """ + + with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()): + self._build_module(squeeze_tensor(input_layer)) + + self.output = force_list(self.output) + self.target = force_list(self.target) + self.input = force_list(self.input) + self.loss_type = force_list(self.loss_type) + self.loss = force_list(self.loss) + self.regularizations = force_list(self.regularizations) + if self.is_local: + self.set_loss() + self._post_build() + + if self.is_local: + return self.output, self.target, self.input, self.importance_weight + else: + return self.output, self.input + + def _build_module(self, input_layer): + """ + Builds the graph of the module + This method is called early on from __call__. It is expected to store the graph + in self.output. + :param input_layer: the input to the graph + :return: None + """ + pass + + def _post_build(self): + """ + Optional function that allows adding any extra definitions after the head has been fully defined + For example, this allows doing additional calculations that are based on the loss + :return: None + """ + pass + + def get_name(self): + """ + Get a formatted name for the module + :return: the formatted name + """ + return '{}_{}'.format(self.name, self.head_idx) + + def set_loss(self): + """ + Creates a target placeholder and loss function for each loss_type and regularization + :param loss_type: a tensorflow loss function + :param scope: the name scope to include the tensors in + :return: None + """ + + # there are heads that define the loss internally, but we need to create additional placeholders for them + for idx in range(len(self.loss)): + importance_weight = tf.placeholder('float', + [None] + [1] * (len(self.target[idx].shape) - 1), + '{}_importance_weight'.format(self.get_name())) + self.importance_weight.append(importance_weight) + + # add losses and target placeholder + for idx in range(len(self.loss_type)): + # create target placeholder + target = tf.placeholder('float', self.output[idx].shape, '{}_target'.format(self.get_name())) + self.target.append(target) + + # create importance sampling weights placeholder + num_target_dims = len(self.target[idx].shape) + importance_weight = tf.placeholder('float', [None] + [1] * (num_target_dims - 1), + '{}_importance_weight'.format(self.get_name())) + self.importance_weight.append(importance_weight) + + # compute the weighted loss. importance_weight weights over the samples in the batch, while self.loss_weight + # weights the specific loss of this head against other losses in this head or in other heads + loss_weight = self.loss_weight[idx]*importance_weight + loss = self.loss_type[idx](self.target[-1], self.output[idx], + scope=self.get_name(), reduction=Reduction.NONE, loss_collection=None) + + # the loss is first summed over each sample in the batch and then the mean over the batch is taken + loss = tf.reduce_mean(loss_weight*tf.reduce_sum(loss, axis=list(range(1, num_target_dims)))) + + # we add the loss to the losses collection and later we will extract it in general_network + tf.losses.add_loss(loss) + self.loss.append(loss) + + # add regularizations + for regularization in self.regularizations: + self.loss.append(regularization) + tf.losses.add_loss(regularization) + + @classmethod + def path(cls): + return cls.__class__.__name__ diff --git a/rl_coach/architectures/tensorflow_components/heads/measurements_prediction_head.py b/rl_coach/architectures/legacy_tf_components/heads/measurements_prediction_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/measurements_prediction_head.py rename to rl_coach/architectures/legacy_tf_components/heads/measurements_prediction_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/naf_head.py b/rl_coach/architectures/legacy_tf_components/heads/naf_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/naf_head.py rename to rl_coach/architectures/legacy_tf_components/heads/naf_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/policy_head.py b/rl_coach/architectures/legacy_tf_components/heads/policy_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/policy_head.py rename to rl_coach/architectures/legacy_tf_components/heads/policy_head.py diff --git a/rl_coach/architectures/legacy_tf_components/heads/ppo_head.py b/rl_coach/architectures/legacy_tf_components/heads/ppo_head.py new file mode 100644 index 000000000..63f95a3ba --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/heads/ppo_head.py @@ -0,0 +1,156 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer +from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +from rl_coach.core_types import ActionProbabilities +from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace +from rl_coach.spaces import SpacesDefinition +from rl_coach.utils import eps + + +class PPOHead(Head): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh', + dense_layer=Dense): + super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.name = 'ppo_head' + self.return_type = ActionProbabilities + + # used in regular PPO + self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization + if self.use_kl_regularization: + # kl coefficient and its corresponding assignment operation and placeholder + self.kl_coefficient = tf.Variable(agent_parameters.algorithm.initial_kl_coefficient, + trainable=False, name='kl_coefficient') + self.kl_coefficient_ph = tf.placeholder('float', name='kl_coefficient_ph') + self.assign_kl_coefficient = tf.assign(self.kl_coefficient, self.kl_coefficient_ph) + self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence + self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient + + self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon + self.beta = agent_parameters.algorithm.beta_entropy + + def _build_module(self, input_layer): + if isinstance(self.spaces.action, DiscreteActionSpace): + self._build_discrete_net(input_layer, self.spaces.action) + elif isinstance(self.spaces.action, BoxActionSpace): + self._build_continuous_net(input_layer, self.spaces.action) + else: + raise ValueError("only discrete or continuous action spaces are supported for PPO") + + self.action_probs_wrt_policy = self.policy_distribution.log_prob(self.actions) + self.action_probs_wrt_old_policy = self.old_policy_distribution.log_prob(self.actions) + self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) + + # Used by regular PPO only + # add kl divergence regularization + self.kl_divergence = tf.reduce_mean(tf.distributions.kl_divergence(self.old_policy_distribution, self.policy_distribution)) + + if self.use_kl_regularization: + # no clipping => use kl regularization + self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence) + self.regularizations += [self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \ + tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))] + + # calculate surrogate loss + self.advantages = tf.placeholder(tf.float32, [None], name="advantages") + self.target = self.advantages + # action_probs_wrt_old_policy != 0 because it is e^... + self.likelihood_ratio = tf.exp(self.action_probs_wrt_policy - self.action_probs_wrt_old_policy) + if self.clip_likelihood_ratio_using_epsilon is not None: + self.clip_param_rescaler = tf.placeholder(tf.float32, ()) + self.input.append(self.clip_param_rescaler) + max_value = 1 + self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler + min_value = 1 - self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler + self.clipped_likelihood_ratio = tf.clip_by_value(self.likelihood_ratio, min_value, max_value) + self.scaled_advantages = tf.minimum(self.likelihood_ratio * self.advantages, + self.clipped_likelihood_ratio * self.advantages) + else: + self.scaled_advantages = self.likelihood_ratio * self.advantages + # minus sign is in order to set an objective to minimize (we actually strive for maximizing the surrogate loss) + self.surrogate_loss = -tf.reduce_mean(self.scaled_advantages) + if self.is_local: + # add entropy regularization + if self.beta: + self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) + self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')] + + self.loss = self.surrogate_loss + tf.losses.add_loss(self.loss) + + def _build_discrete_net(self, input_layer, action_space): + num_actions = len(action_space.actions) + self.actions = tf.placeholder(tf.int32, [None], name="actions") + + self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean") + self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std") + + # Policy Head + self.input = [self.actions, self.old_policy_mean] + policy_values = self.dense_layer(num_actions)(input_layer, name='policy_fc') + self.policy_mean = tf.nn.softmax(policy_values, name="policy") + + # define the distributions for the policy and the old policy + self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean) + self.old_policy_distribution = tf.contrib.distributions.Categorical(probs=self.old_policy_mean) + + self.output = self.policy_mean + + def _build_continuous_net(self, input_layer, action_space): + num_actions = action_space.shape[0] + self.actions = tf.placeholder(tf.float32, [None, num_actions], name="actions") + + self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean") + self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std") + + self.input = [self.actions, self.old_policy_mean, self.old_policy_std] + self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean', + kernel_initializer=normalized_columns_initializer(0.01)) + + # for local networks in distributed settings, we need to move variables we create manually to the + # tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in + # Architecture does not apply to them + if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters): + self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', + collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std") + else: + self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std") + + self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std') + + # define the distributions for the policy and the old policy + self.policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean, self.policy_std + eps) + self.old_policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.old_policy_mean, self.old_policy_std + eps) + + self.output = [self.policy_mean, self.policy_std] + + def __str__(self): + action_head_mean_result = [] + if isinstance(self.spaces.action, DiscreteActionSpace): + # create a discrete action network (softmax probabilities output) + action_head_mean_result.append("Dense (num outputs = {})".format(len(self.spaces.action.actions))) + action_head_mean_result.append("Softmax") + elif isinstance(self.spaces.action, BoxActionSpace): + # create a continuous action network (bounded mean and stdev outputs) + action_head_mean_result.append("Dense (num outputs = {})".format(self.spaces.action.shape)) + + return '\n'.join(action_head_mean_result) diff --git a/rl_coach/architectures/tensorflow_components/heads/ppo_v_head.py b/rl_coach/architectures/legacy_tf_components/heads/ppo_v_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/ppo_v_head.py rename to rl_coach/architectures/legacy_tf_components/heads/ppo_v_head.py diff --git a/rl_coach/architectures/legacy_tf_components/heads/q_head.py b/rl_coach/architectures/legacy_tf_components/heads/q_head.py new file mode 100644 index 000000000..ecc1461a0 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/heads/q_head.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.tensorflow_components.heads.head import Head +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import QActionStateValue +from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace + + +class QHead(Head): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', + dense_layer=Dense): + super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.name = 'q_values_head' + if isinstance(self.spaces.action, BoxActionSpace): + self.num_actions = 1 + elif isinstance(self.spaces.action, DiscreteActionSpace): + self.num_actions = len(self.spaces.action.actions) + else: + raise ValueError( + 'QHead does not support action spaces of type: {class_name}'.format( + class_name=self.spaces.action.__class__.__name__, + ) + ) + self.return_type = QActionStateValue + if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss: + self.loss_type = tf.losses.huber_loss + else: + self.loss_type = tf.losses.mean_squared_error + + def _build_module(self, input_layer): + # Standard Q Network + self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output') + + # used in batch-rl to estimate a probablity distribution over actions + self.softmax = self.add_softmax_with_temperature() + + def __str__(self): + result = [ + "Dense (num outputs = {})".format(self.num_actions) + ] + return '\n'.join(result) + + def add_softmax_with_temperature(self): + temperature = self.ap.network_wrappers[self.network_name].softmax_temperature + temperature_scaled_outputs = self.q_values / temperature + return tf.nn.softmax(temperature_scaled_outputs, name="softmax") + diff --git a/rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/quantile_regression_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/quantile_regression_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/quantile_regression_q_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/rainbow_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/rainbow_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/rainbow_q_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/sac_head.py b/rl_coach/architectures/legacy_tf_components/heads/sac_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/sac_head.py rename to rl_coach/architectures/legacy_tf_components/heads/sac_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/sac_q_head.py b/rl_coach/architectures/legacy_tf_components/heads/sac_q_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/sac_q_head.py rename to rl_coach/architectures/legacy_tf_components/heads/sac_q_head.py diff --git a/rl_coach/architectures/tensorflow_components/heads/td3_v_head.py b/rl_coach/architectures/legacy_tf_components/heads/td3_v_head.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/heads/td3_v_head.py rename to rl_coach/architectures/legacy_tf_components/heads/td3_v_head.py diff --git a/rl_coach/architectures/legacy_tf_components/heads/v_head.py b/rl_coach/architectures/legacy_tf_components/heads/v_head.py new file mode 100644 index 000000000..62bfba03b --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/heads/v_head.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer +from rl_coach.base_parameters import AgentParameters +from rl_coach.core_types import VStateValue +from rl_coach.spaces import SpacesDefinition + + +class VHead(Head): + def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, + head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', + dense_layer=Dense, initializer='normalized_columns'): + super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, + dense_layer=dense_layer) + self.name = 'v_values_head' + self.return_type = VStateValue + + if agent_parameters.network_wrappers[self.network_name.split('/')[0]].replace_mse_with_huber_loss: + self.loss_type = tf.losses.huber_loss + else: + self.loss_type = tf.losses.mean_squared_error + + self.initializer = initializer + + def _build_module(self, input_layer): + # Standard V Network + if self.initializer == 'normalized_columns': + self.output = self.dense_layer(1)(input_layer, name='output', + kernel_initializer=normalized_columns_initializer(1.0)) + elif self.initializer == 'xavier' or self.initializer is None: + self.output = self.dense_layer(1)(input_layer, name='output') + + def __str__(self): + result = [ + "Dense (num outputs = 1)" + ] + return '\n'.join(result) diff --git a/rl_coach/architectures/legacy_tf_components/layers.py b/rl_coach/architectures/legacy_tf_components/layers.py new file mode 100644 index 000000000..eb6326234 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/layers.py @@ -0,0 +1,260 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import math +from types import FunctionType +import tensorflow as tf + +from rl_coach.architectures import layers +from rl_coach.architectures.tensorflow_components import utils + + +def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): + layers = [input_layer] + + # Rationale: passing a bool here will mean that batchnorm and or activation will never activate + assert not isinstance(is_training, bool) + + # batchnorm + if batchnorm: + layers.append( + tf.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training) + ) + + # activation + if activation_function: + if isinstance(activation_function, str): + activation_function = utils.get_activation_function(activation_function) + layers.append( + activation_function(layers[-1], name="{}_activation".format(name)) + ) + + # dropout + if dropout_rate > 0: + layers.append( + tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training) + ) + + # remove the input layer from the layers list + del layers[0] + + return layers + + +# define global dictionary for storing layer type to layer implementation mapping +tf_layer_dict = dict() +tf_layer_class_dict = dict() + + +def reg_to_tf_instance(layer_type) -> FunctionType: + """ function decorator that registers layer implementation + :return: decorated function + """ + def reg_impl_decorator(func): + assert layer_type not in tf_layer_dict + tf_layer_dict[layer_type] = func + return func + return reg_impl_decorator + + +def reg_to_tf_class(layer_type) -> FunctionType: + """ function decorator that registers layer type + :return: decorated function + """ + def reg_impl_decorator(func): + assert layer_type not in tf_layer_class_dict + tf_layer_class_dict[layer_type] = func + return func + return reg_impl_decorator + + +def convert_layer(layer): + """ + If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type + :param layer: layer to be converted + :return: converted layer if not callable, otherwise layer itself + """ + if callable(layer): + return layer + return tf_layer_dict[type(layer)](layer) + + +def convert_layer_class(layer_class): + """ + If layer instance is callable, return layer, otherwise convert to TF type + :param layer: layer to be converted + :return: converted layer if not callable, otherwise layer itself + """ + if hasattr(layer_class, 'to_tf_instance'): + return layer_class + else: + return tf_layer_class_dict[layer_class]() + + +class Conv2d(layers.Conv2d): + def __init__(self, num_filters: int, kernel_size: int, strides: int): + super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides) + + def __call__(self, input_layer, name: str=None, is_training=None): + """ + returns a tensorflow conv2d layer + :param input_layer: previous layer + :param name: layer name + :return: conv2d layer + """ + return tf.layers.conv2d(input_layer, filters=self.num_filters, kernel_size=self.kernel_size, + strides=self.strides, data_format='channels_last', name=name) + + @staticmethod + @reg_to_tf_instance(layers.Conv2d) + def to_tf_instance(base: layers.Conv2d): + return Conv2d( + num_filters=base.num_filters, + kernel_size=base.kernel_size, + strides=base.strides) + + @staticmethod + @reg_to_tf_class(layers.Conv2d) + def to_tf_class(): + return Conv2d + + +class BatchnormActivationDropout(layers.BatchnormActivationDropout): + def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0): + super(BatchnormActivationDropout, self).__init__( + batchnorm=batchnorm, activation_function=activation_function, dropout_rate=dropout_rate) + + def __call__(self, input_layer, name: str=None, is_training=None): + """ + returns a list of tensorflow batchnorm, activation and dropout layers + :param input_layer: previous layer + :param name: layer name + :return: batchnorm, activation and dropout layers + """ + return batchnorm_activation_dropout(input_layer, batchnorm=self.batchnorm, + activation_function=self.activation_function, + dropout_rate=self.dropout_rate, + is_training=is_training, name=name) + + @staticmethod + @reg_to_tf_instance(layers.BatchnormActivationDropout) + def to_tf_instance(base: layers.BatchnormActivationDropout): + return BatchnormActivationDropout, BatchnormActivationDropout( + batchnorm=base.batchnorm, + activation_function=base.activation_function, + dropout_rate=base.dropout_rate) + + @staticmethod + @reg_to_tf_class(layers.BatchnormActivationDropout) + def to_tf_class(): + return BatchnormActivationDropout + + +class Dense(layers.Dense): + def __init__(self, units: int): + super(Dense, self).__init__(units=units) + + def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None, is_training=None): + """ + returns a tensorflow dense layer + :param input_layer: previous layer + :param name: layer name + :return: dense layer + """ + return tf.layers.dense(input_layer, self.units, name=name, kernel_initializer=kernel_initializer, + activation=activation) + + @staticmethod + @reg_to_tf_instance(layers.Dense) + def to_tf_instance(base: layers.Dense): + return Dense(units=base.units) + + @staticmethod + @reg_to_tf_class(layers.Dense) + def to_tf_class(): + return Dense + + +class NoisyNetDense(layers.NoisyNetDense): + """ + A factorized Noisy Net layer + + https://arxiv.org/abs/1706.10295. + """ + + def __init__(self, units: int): + super(NoisyNetDense, self).__init__(units=units) + + def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None): + """ + returns a NoisyNet dense layer + :param input_layer: previous layer + :param name: layer name + :param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev. + :param activation: the activation function + :return: dense layer + """ + #TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every + # forward (either act or train, both for online and target networks). + # A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps) + + def _f(values): + return tf.sqrt(tf.abs(values)) * tf.sign(values) + + def _factorized_noise(inputs, outputs): + # TODO: use factorized noise only for compute intensive algos (e.g. DQN). + # lighter algos (e.g. DQN) should not use it + noise1 = _f(tf.random_normal((inputs, 1))) + noise2 = _f(tf.random_normal((1, outputs))) + return tf.matmul(noise1, noise2) + + num_inputs = input_layer.get_shape()[-1].value + num_outputs = self.units + + stddev = 1 / math.sqrt(num_inputs) + activation = activation if activation is not None else (lambda x: x) + + if kernel_initializer is None: + kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev) + kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0) + else: + kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer + with tf.variable_scope(None, default_name=name): + weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs), + initializer=kernel_mean_initializer) + bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer()) + + weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs), + initializer=kernel_stddev_initializer) + bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,), + initializer=kernel_stddev_initializer) + bias_noise = _f(tf.random_normal((num_outputs,))) + weight_noise = _factorized_noise(num_inputs, num_outputs) + + bias = bias_mean + bias_stddev * bias_noise + weight = weight_mean + weight_stddev * weight_noise + return activation(tf.matmul(input_layer, weight) + bias) + + @staticmethod + @reg_to_tf_instance(layers.NoisyNetDense) + def to_tf_instance(base: layers.NoisyNetDense): + return NoisyNetDense(units=base.units) + + @staticmethod + @reg_to_tf_class(layers.NoisyNetDense) + def to_tf_class(): + return NoisyNetDense diff --git a/rl_coach/architectures/legacy_tf_components/middlewares/__init__.py b/rl_coach/architectures/legacy_tf_components/middlewares/__init__.py new file mode 100644 index 000000000..481eab0bf --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/middlewares/__init__.py @@ -0,0 +1,4 @@ +from .fc_middleware import FCMiddleware +from .lstm_middleware import LSTMMiddleware + +__all__ = ["FCMiddleware", "LSTMMiddleware"] diff --git a/rl_coach/architectures/legacy_tf_components/middlewares/fc_middleware.py b/rl_coach/architectures/legacy_tf_components/middlewares/fc_middleware.py new file mode 100644 index 000000000..4361e171d --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/middlewares/fc_middleware.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Union, List + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware +from rl_coach.base_parameters import MiddlewareScheme +from rl_coach.core_types import Middleware_FC_Embedding +from rl_coach.utils import force_list + + +class FCMiddleware(Middleware): + def __init__(self, activation_function=tf.nn.relu, + scheme: MiddlewareScheme = MiddlewareScheme.Medium, + batchnorm: bool = False, dropout_rate: float = 0.0, + name="middleware_fc_embedder", dense_layer=Dense, is_training=False, num_streams: int = 1): + super().__init__(activation_function=activation_function, batchnorm=batchnorm, + dropout_rate=dropout_rate, scheme=scheme, name=name, dense_layer=dense_layer, + is_training=is_training) + self.return_type = Middleware_FC_Embedding + + assert(isinstance(num_streams, int) and num_streams >= 1) + self.num_streams = num_streams + + def _build_module(self): + self.output = [] + + for stream_idx in range(self.num_streams): + layers = [self.input] + + for idx, layer_params in enumerate(self.layers_params): + layers.extend(force_list( + layer_params(layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, + idx + stream_idx * len(self.layers_params)), + is_training=self.is_training) + )) + self.output.append((layers[-1])) + + @property + def schemes(self): + return { + MiddlewareScheme.Empty: + [], + + # ppo + MiddlewareScheme.Shallow: + [ + self.dense_layer(64) + ], + + # dqn + MiddlewareScheme.Medium: + [ + self.dense_layer(512) + ], + + MiddlewareScheme.Deep: \ + [ + self.dense_layer(128), + self.dense_layer(128), + self.dense_layer(128) + ] + } + + def __str__(self): + stream = [str(l) for l in self.layers_params] + if self.layers_params: + if self.num_streams > 1: + stream = [''] + ['\t' + l for l in stream] + result = stream * self.num_streams + result[0::len(stream)] = ['Stream {}'.format(i) for i in range(self.num_streams)] + else: + result = stream + return '\n'.join(result) + else: + return 'No layers' diff --git a/rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py b/rl_coach/architectures/legacy_tf_components/middlewares/lstm_middleware.py similarity index 100% rename from rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py rename to rl_coach/architectures/legacy_tf_components/middlewares/lstm_middleware.py diff --git a/rl_coach/architectures/legacy_tf_components/middlewares/middleware.py b/rl_coach/architectures/legacy_tf_components/middlewares/middleware.py new file mode 100644 index 000000000..64c578fc1 --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/middlewares/middleware.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy +from typing import Union, Tuple + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense +from rl_coach.base_parameters import MiddlewareScheme, NetworkComponentParameters +from rl_coach.core_types import MiddlewareEmbedding + + +class Middleware(object): + """ + A middleware embedder is the middle part of the network. It takes the embeddings from the input embedders, + after they were aggregated in some method (for example, concatenation) and passes it through a neural network + which can be customizable but shared between the heads of the network + """ + def __init__(self, activation_function=tf.nn.relu, + scheme: MiddlewareScheme = MiddlewareScheme.Medium, + batchnorm: bool = False, dropout_rate: float = 0.0, name="middleware_embedder", dense_layer=Dense, + is_training=False): + self.name = name + self.input = None + self.output = None + self.activation_function = activation_function + self.batchnorm = batchnorm + self.dropout_rate = dropout_rate + self.scheme = scheme + self.return_type = MiddlewareEmbedding + self.dense_layer = dense_layer + if self.dense_layer is None: + self.dense_layer = Dense + self.is_training = is_training + + # layers order is conv -> batchnorm -> activation -> dropout + if isinstance(self.scheme, MiddlewareScheme): + self.layers_params = copy.copy(self.schemes[self.scheme]) + self.layers_params = [convert_layer(l) for l in self.layers_params] + else: + # if scheme is specified directly, convert to TF layer if it's not a callable object + # NOTE: if layer object is callable, it must return a TF tensor when invoked + self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] + + # we allow adding batchnorm, dropout or activation functions after each layer. + # The motivation is to simplify the transition between a network with batchnorm and a network without + # batchnorm to a single flag (the same applies to activation function and dropout) + if self.batchnorm or self.activation_function or self.dropout_rate > 0: + for layer_idx in reversed(range(len(self.layers_params))): + self.layers_params.insert(layer_idx+1, + BatchnormActivationDropout(batchnorm=self.batchnorm, + activation_function=self.activation_function, + dropout_rate=self.dropout_rate)) + + def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Wrapper for building the module graph including scoping and loss creation + :param input_layer: the input to the graph + :return: the input placeholder and the output of the last layer + """ + with tf.variable_scope(self.get_name()): + self.input = input_layer + self._build_module() + + return self.input, self.output + + def _build_module(self) -> None: + """ + Builds the graph of the module + This method is called early on from __call__. It is expected to store the graph + in self.output. + :param input_layer: the input to the graph + :return: None + """ + pass + + def get_name(self) -> str: + """ + Get a formatted name for the module + :return: the formatted name + """ + return self.name + + @property + def schemes(self): + raise NotImplementedError("Inheriting middleware must define schemes matching its allowed default " + "configurations.") + + def __str__(self): + result = [str(l) for l in self.layers_params] + if self.layers_params: + return '\n'.join(result) + else: + return 'No layers' diff --git a/rl_coach/architectures/legacy_tf_components/savers.py b/rl_coach/architectures/legacy_tf_components/savers.py new file mode 100644 index 000000000..531c5236a --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/savers.py @@ -0,0 +1,141 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pickle +from typing import Any, List, Dict + +import tensorflow as tf +import numpy as np + +from rl_coach.saver import Saver + + +class GlobalVariableSaver(Saver): + def __init__(self, name): + self._names = [name] + # if graph is finalized, savers must have already already been added. This happens + # in the case of a MonitoredSession + self._variables = tf.global_variables() + + # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list + # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. + self._variables = [v for v in self._variables if "/target" not in v.name] + + # Using a placeholder to update the variable during restore to avoid memory leak. + # Ref: https://github.com/tensorflow/tensorflow/issues/4151 + self._variable_placeholders = [] + self._variable_update_ops = [] + for v in self._variables: + variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape()) + self._variable_placeholders.append(variable_placeholder) + self._variable_update_ops.append(v.assign(variable_placeholder)) + + self._saver = tf.train.Saver(self._variables, max_to_keep=None) + + @property + def path(self): + """ + Relative path for save/load. If two checkpoint objects return the same path, they must be merge-able. + """ + return "" # use empty string for global file + + def save(self, sess: None, save_path: str) -> List[str]: + """ + Save to save_path + :param sess: active session + :param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path) + :return: list of all saved paths + """ + save_path = self._saver.save(sess, save_path) + return [save_path] + + def to_arrays(self, session: Any) -> Dict[str, np.ndarray]: + """ + Save to dictionary of arrays + :param sess: active session + :return: dictionary of arrays + """ + return { + k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables)) + } + + def from_arrays(self, session: Any, tensors: Any): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param tensors: {name: array} + """ + # if variable was saved using global network, re-map it to online + # network + # TODO: Can this be more generic so that `global/` and `online/` are not + # hardcoded here? + if isinstance(tensors, dict): + tensors = tensors.items() + + variables = {k.replace("global/", "online/"): v for k, v in tensors} + + # Assign all variables using placeholder + placeholder_dict = { + ph: variables[v.name.split(":")[0]] + for ph, v in zip(self._variable_placeholders, self._variables) + } + session.run(self._variable_update_ops, placeholder_dict) + + def to_string(self, session: Any) -> str: + """ + Save to byte string + :param session: active session + :return: serialized byte string + """ + return pickle.dumps(self.to_arrays(session), protocol=-1) + + def from_string(self, session: Any, string: str): + """ + Restore from byte string + :param session: active session + :param string: byte string to restore from + """ + self.from_arrays(session, pickle.loads(string)) + + def _read_tensors(self, restore_path: str): + """ + Load tensors from a checkpoint + :param restore_path: full path to load checkpoint from. + """ + # We don't use saver.restore() because checkpoint is loaded to online + # network, but if the checkpoint is from the global network, a namespace + # mismatch exists and variable name must be modified before loading. + reader = tf.contrib.framework.load_checkpoint(restore_path) + for var_name, _ in reader.get_variable_to_shape_map().items(): + yield var_name, reader.get_tensor(var_name) + + def restore(self, sess: Any, restore_path: str): + """ + Restore from restore_path + :param sess: active session for session-based frameworks (e.g. TF) + :param restore_path: full path to load checkpoint from. + """ + self.from_arrays(sess, self._read_tensors(restore_path)) + + def merge(self, other: "Saver"): + """ + Merge other saver into this saver + :param other: saver to be merged into self + """ + assert isinstance(other, GlobalVariableSaver) + self._names.extend(other._names) + # There is nothing else to do because variables must already be part of + # the global collection. diff --git a/rl_coach/architectures/legacy_tf_components/shared_variables.py b/rl_coach/architectures/legacy_tf_components/shared_variables.py new file mode 100644 index 000000000..fe805afed --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/shared_variables.py @@ -0,0 +1,155 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pickle + +import numpy as np +import tensorflow as tf + +from rl_coach.utilities.shared_running_stats import SharedRunningStats + + +class TFSharedRunningStats(SharedRunningStats): + def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None): + super().__init__(name=name, pubsub_params=pubsub_params) + self.sess = None + self.replicated_device = replicated_device + self.epsilon = epsilon + self.ops_were_created = False + if create_ops: + with tf.device(replicated_device): + self.set_params() + + def set_params(self, shape=[1], clip_values=None): + """ + set params and create ops + + :param shape: shape of the stats to track + :param clip_values: if not None, sets clip min/max thresholds + """ + + self.clip_values = clip_values + with tf.variable_scope(self.name): + self._sum = tf.get_variable( + dtype=tf.float64, + initializer=tf.constant_initializer(0.0), + name="running_sum", trainable=False, shape=shape, validate_shape=False, + collections=[tf.GraphKeys.GLOBAL_VARIABLES]) + self._sum_squares = tf.get_variable( + dtype=tf.float64, + initializer=tf.constant_initializer(self.epsilon), + name="running_sum_squares", trainable=False, shape=shape, validate_shape=False, + collections=[tf.GraphKeys.GLOBAL_VARIABLES]) + self._count = tf.get_variable( + dtype=tf.float64, + shape=(), + initializer=tf.constant_initializer(self.epsilon), + name="count", trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) + + self._shape = None + self._mean = tf.div(self._sum, self._count, name="mean") + self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean)) + / tf.maximum(self._count-1, 1), self.epsilon), name="stdev") + self.tf_mean = tf.cast(self._mean, 'float32') + self.tf_std = tf.cast(self._std, 'float32') + + self.new_sum = tf.placeholder(dtype=tf.float64, name='sum') + self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var') + self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') + + self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True) + self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True) + self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True) + + self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs') + self.normalized_obs = (self.raw_obs - self._mean) / self._std + if self.clip_values is not None: + self.clipped_obs = tf.clip_by_value(self.normalized_obs, self.clip_values[0], self.clip_values[1]) + + self.ops_were_created = True + + def set_session(self, sess): + self.sess = sess + + def push_val(self, x): + x = x.astype('float64') + self.sess.run([self._inc_sum, self._inc_sum_squares, self._inc_count], + feed_dict={ + self.new_sum: x.sum(axis=0).ravel(), + self.new_sum_squares: np.square(x).sum(axis=0).ravel(), + self.newcount: np.array(len(x), dtype='float64') + }) + if self._shape is None: + self._shape = x.shape + + @property + def n(self): + return self.sess.run(self._count) + + @property + def mean(self): + return self.sess.run(self._mean) + + @property + def var(self): + return self.std ** 2 + + @property + def std(self): + return self.sess.run(self._std) + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, val): + self._shape = val + self.new_sum.set_shape(val) + self.new_sum_squares.set_shape(val) + self.tf_mean.set_shape(val) + self.tf_std.set_shape(val) + self._sum.set_shape(val) + self._sum_squares.set_shape(val) + + def normalize(self, batch): + if self.clip_values is not None: + return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) + else: + return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) + + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Nevertheless, if we'll want to restore a checkpoint back to either a * single-worker, or a + # multi-node-multi-worker * run, we have to save the internal state, so that it can be restored to the + # NumpySharedRunningStats implementation. + + dict_to_save = {'_mean': self.mean, + '_std': self.std, + '_count': self.n, + '_sum': self.sess.run(self._sum), + '_sum_squares': self.sess.run(self._sum_squares)} + + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) + + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Restoring from either a * single-worker, or a multi-node-multi-worker * run, to a single-node-multi-thread run + # is not supported. + pass diff --git a/rl_coach/architectures/legacy_tf_components/utils.py b/rl_coach/architectures/legacy_tf_components/utils.py new file mode 100644 index 000000000..45f6d01ae --- /dev/null +++ b/rl_coach/architectures/legacy_tf_components/utils.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Module containing utility functions +""" +import tensorflow as tf + + +def get_activation_function(activation_function_string: str): + """ + Map the activation function from a string to the tensorflow framework equivalent + :param activation_function_string: the type of the activation function + :return: the tensorflow activation function + """ + activation_functions = { + 'relu': tf.nn.relu, + 'tanh': tf.nn.tanh, + 'sigmoid': tf.nn.sigmoid, + 'elu': tf.nn.elu, + 'selu': tf.nn.selu, + 'leaky_relu': tf.nn.leaky_relu, + 'none': None + } + assert activation_function_string in activation_functions.keys(), \ + "Activation function must be one of the following {}. instead it was: {}" \ + .format(activation_functions.keys(), activation_function_string) + return activation_functions[activation_function_string] + + +def squeeze_tensor(tensor): + if tensor.shape[0] == 1: + return tensor[0] + else: + return tensor \ No newline at end of file diff --git a/rl_coach/architectures/loss_parameters.py b/rl_coach/architectures/loss_parameters.py new file mode 100644 index 000000000..d6cb83c0e --- /dev/null +++ b/rl_coach/architectures/loss_parameters.py @@ -0,0 +1,48 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Type + +from rl_coach.base_parameters import Parameters + + +class LossParameters(Parameters): + def __init__(self, parameterized_class_name: str, + name: str= 'loss', + num_output_head_copies: int=1, + loss_weight: float=1.0, + is_training=False): + super().__init__() + self.name = name + self.num_output_head_copies = num_output_head_copies + self.loss_weight = loss_weight + self.parameterized_class_name = parameterized_class_name + self.is_training = is_training + + @property + def path(self): + return 'rl_coach.architectures.tensorflow_components.heads:' + self.parameterized_class_name + + + +class QLossParameters(LossParameters): + def __init__(self, name: str='q_loss_params', + num_output_head_copies: int = 1, + loss_weight: float = 1.0): + super().__init__(parameterized_class_name="QLoss", + name=name, + num_output_head_copies=num_output_head_copies, + loss_weight=loss_weight) \ No newline at end of file diff --git a/rl_coach/architectures/middleware_parameters.py b/rl_coach/architectures/middleware_parameters.py index 0bdbe09f4..e06c91976 100644 --- a/rl_coach/architectures/middleware_parameters.py +++ b/rl_coach/architectures/middleware_parameters.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/rl_coach/architectures/mxnet_components/layers.py b/rl_coach/architectures/mxnet_components/layers.py index 60a982a98..12eea2dc9 100644 --- a/rl_coach/architectures/mxnet_components/layers.py +++ b/rl_coach/architectures/mxnet_components/layers.py @@ -70,6 +70,7 @@ def to_mx(base: layers.Conv2d): return Conv2d(num_filters=base.num_filters, kernel_size=base.kernel_size, strides=base.strides) + class BatchnormActivationDropout(layers.BatchnormActivationDropout): def __init__(self, batchnorm: bool=False, activation_function=None, dropout_rate: float=0): super(BatchnormActivationDropout, self).__init__( diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index dfefc4122..188a56dba 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ # from typing import List, Tuple +import numpy as np +import time from rl_coach.base_parameters import Frameworks, AgentParameters from rl_coach.logger import failed_imports @@ -76,6 +78,7 @@ def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_glob # Online network - local copy of the main network used for playing self.online_network = None + #time.sleep(1) self.online_network = general_network(variable_scope=variable_scope, devices=force_list(worker_device), agent_parameters=agent_parameters, @@ -88,6 +91,7 @@ def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_glob # Target network - a local, slow updating network used for stabilizing the learning self.target_network = None if self.has_target: + #time.sleep(1) self.target_network = general_network(variable_scope=variable_scope, devices=force_list(worker_device), agent_parameters=agent_parameters, @@ -199,8 +203,7 @@ def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inp self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients, additional_inputs=additional_inputs) else: - self.online_network.apply_gradients(self.online_network.accumulated_gradients, - additional_inputs=additional_inputs) + self.online_network.apply_gradients(self.online_network.accumulated_gradients) def parallel_prediction(self, network_input_tuples: List[Tuple]): """ diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 68420febb..caa611147 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os -import time -from typing import Any, List, Tuple, Dict + import numpy as np import tensorflow as tf - +from typing import Any, Dict, List, Tuple +from tensorflow_probability.python.distributions import Distribution +from tensorflow.keras.losses import Loss from rl_coach.architectures.architecture import Architecture -from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver -from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters -from rl_coach.core_types import GradientClippingMethod +from rl_coach.base_parameters import AgentParameters from rl_coach.saver import SaverCollection from rl_coach.spaces import SpacesDefinition -from rl_coach.utils import force_list, squeeze_list, start_shell_command_and_wait +from rl_coach.utils import force_list +from rl_coach.architectures.tensorflow_components import utils +from rl_coach.core_types import GradientClippingMethod +from rl_coach.architectures.tensorflow_components.savers import TfSaver +from rl_coach.architectures.tensorflow_components.losses.head_loss import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION def variable_summaries(var): @@ -45,43 +47,30 @@ def variable_summaries(var): tf.summary.histogram('histogram', var) -def local_getter(getter, name, *args, **kwargs): + +class TensorFlowArchitecture(Architecture): """ - This is a wrapper around the tf.get_variable function which puts the variables in the local variables collection - instead of the global variables collection. The local variables collection will hold variables which are not shared - between workers. these variables are also assumed to be non-trainable (the optimizer does not apply gradients to - these variables), but we can calculate the gradients wrt these variables, and we can update their content. + :param agent_parameters: the agent parameters + :param spaces: the spaces definition of the agent + :param name: the name of the network + :param global_network: the global network replica that is shared between all the workers + :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) + :param network_is_trainable: is the network trainable (we can apply gradients on it) """ - kwargs['collections'] = [tf.GraphKeys.LOCAL_VARIABLES] - return getter(name, *args, **kwargs) + def __init__(self, agent_parameters: AgentParameters, + spaces: SpacesDefinition, + name: str = "", + global_network=None, + network_is_local: bool=True, + network_is_trainable: bool=False): - -class TensorFlowArchitecture(Architecture): - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str= "", - global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): - """ - :param agent_parameters: the agent parameters - :param spaces: the spaces definition of the agent - :param name: the name of the network - :param global_network: the global network replica that is shared between all the workers - :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) - :param network_is_trainable: is the network trainable (we can apply gradients on it) - """ super().__init__(agent_parameters, spaces, name) self.middleware = None self.network_is_local = network_is_local self.global_network = global_network if not self.network_parameters.tensorflow_support: raise ValueError('TensorFlow is not supported for this agent') - self.sess = None - self.inputs = {} - self.outputs = [] - self.targets = [] - self.importance_weights = [] - self.losses = [] - self.total_loss = None - self.trainable_weights = [] - self.weights_placeholders = [] + self.losses = [] # type: List[Loss] self.shared_accumulated_gradients = [] self.curr_rnn_c_in = None self.curr_rnn_h_in = None @@ -89,335 +78,146 @@ def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, self.train_writer = None self.accumulated_gradients = None self.network_is_trainable = network_is_trainable + self.is_training = False + self.model = None # type: GeneralModel self.is_chief = self.ap.task_parameters.task_index == 0 self.network_is_global = not self.network_is_local and global_network is None self.distributed_training = self.network_is_global or self.network_is_local and global_network is not None self.optimizer_type = self.network_parameters.optimizer_type + self.emmbeding_types = list(self.network_parameters.input_embedders_parameters.keys()) if self.ap.task_parameters.seed is not None: - tf.set_random_seed(self.ap.task_parameters.seed) - with tf.variable_scope("/".join(self.name.split("/")[1:]), initializer=tf.contrib.layers.xavier_initializer(), - custom_getter=local_getter if network_is_local and global_network else None): - self.global_step = tf.train.get_or_create_global_step() - - # build the network - self.weights = self.get_model() - - # create the placeholder for the assigning gradients and some tensorboard summaries for the weights - for idx, var in enumerate(self.weights): - placeholder = tf.placeholder(tf.float32, shape=var.get_shape(), name=str(idx) + '_holder') - self.weights_placeholders.append(placeholder) - if self.ap.visualization.tensorboard: - variable_summaries(var) + tf.random.set_seed(self.ap.task_parameters.seed) + # TODO - tf2: convert to tf2 syntax + # tf.compat.v1.set_random_seed(self.ap.task_parameters.seed) - # create op for assigning a list of weights to the network weights - self.update_weights_from_list = [weights.assign(holder) for holder, weights in - zip(self.weights_placeholders, self.weights)] + # Call to child class to create the model + self.construct_model() + self.trainer = None - # locks for synchronous training - if self.network_is_global: - self._create_locks_for_synchronous_training() - - # gradients ops - self._create_gradient_ops() - - self.inc_step = self.global_step.assign_add(1) - - # reset LSTM hidden cells - self.reset_internal_memory() - - if self.ap.visualization.tensorboard: - current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, - scope=tf.contrib.framework.get_name_scope()) - self.merged = tf.summary.merge(current_scope_summaries) + # TODO: Need to fix this, when do we want to save the weights? + if self.ap.visualization.tensorboard: + weights = self.get_weights() + for idx, var in enumerate(weights): + variable_summaries(var) - # initialize or restore model - self.init_op = tf.group( - tf.global_variables_initializer(), - tf.local_variables_initializer() - ) + # global step is used for multi-threaded agents + # self.global_step = tf.train.get_or_create_global_step() + # locks for synchronous training- not supported in tf2 + # if self.network_is_global: + # self._create_locks_for_synchronous_training() + # self.inc_step = self.global_step.assign_add(1) - # set the fetches for training - self._set_initial_fetch_list() + def __str__(self): + return self.model.summary(self._dummy_model_inputs()) - def get_model(self) -> List: + def construct_model(self) -> None: """ - Constructs the model using `network_parameters` and sets `input_embedders`, `middleware`, - `output_heads`, `outputs`, `losses`, `total_loss`, `adaptive_learning_rate_scheme`, - `current_learning_rate`, and `optimizer`. - - :return: A list of the model's weights + Construct network model. Implemented by child class. """ - raise NotImplementedError + print('Construct is empty for now and is called from class constructor') - def _set_initial_fetch_list(self): + def set_session(self, sess) -> None: """ - Create an initial list of tensors to fetch in each training iteration - :return: None + Initializes the model parameters and creates the model trainer. + NOTEL Session for TF2 backend must be None. + :param sess: must be None """ - self.train_fetches = [self.gradients_norm] - if self.network_parameters.clip_gradients: - self.train_fetches.append(self.clipped_grads) - else: - self.train_fetches.append(self.tensor_gradients) - self.train_fetches += [self.total_loss, self.losses] - if self.middleware.__class__.__name__ == 'LSTMMiddleware': - self.train_fetches.append(self.middleware.state_out) - self.additional_fetches_start_idx = len(self.train_fetches) + assert sess is None - def _create_locks_for_synchronous_training(self): + def reset_accumulated_gradients(self) -> None: """ - Create locks for synchronizing the different workers during training - :return: None - """ - self.lock_counter = tf.get_variable("lock_counter", [], tf.int32, - initializer=tf.constant_initializer(0, dtype=tf.int32), - trainable=False) - self.lock = self.lock_counter.assign_add(1, use_locking=True) - self.lock_init = self.lock_counter.assign(0) - - self.release_counter = tf.get_variable("release_counter", [], tf.int32, - initializer=tf.constant_initializer(0, dtype=tf.int32), - trainable=False) - self.release = self.release_counter.assign_add(1, use_locking=True) - self.release_decrement = self.release_counter.assign_add(-1, use_locking=True) - self.release_init = self.release_counter.assign(0) - - def _create_gradient_ops(self): - """ - Create all the tensorflow operations for calculating gradients, processing the gradients and applying them - :return: None + Reset model gradients as well as accumulated gradients to zero. If accumulated gradients + have not been created yet, it constructs them. """ + if self.accumulated_gradients is None: + self.accumulated_gradients = self.model.get_weights().copy() + + self.accumulated_gradients = list(map(lambda grad: grad * 0, self.accumulated_gradients)) + + def accumulate_gradients(self, + inputs: Dict[str, np.ndarray], + targets: List[np.ndarray], + additional_fetches: List[Tuple[int, str]] = None, + importance_weights: np.ndarray = None, + no_accumulation: bool = False) -> Tuple[float, List[float], float, list]: + """ + Runs a forward & backward pass, clips gradients if needed and accumulates them into the accumulation + :param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray + is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size) + :param targets: targets required by loss (e.g. sum of discounted rewards) + :param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str) + tuple of head-type-index and fetch-name. The tuple is obtained from each head. + :param importance_weights: ndarray of shape (batch_size,) to multiply with batch loss. + :param no_accumulation: if True, set gradient values to the new gradients, otherwise sum with previously + calculated gradients + :return: tuple of total_loss, losses, norm_unclipped_grads, fetched_tensors + total_loss (float): sum of all head losses + losses (list of float): list of all losses. The order is list of target losses followed by list of + regularization losses. The specifics of losses is dependant on the network parameters + (number of heads, etc.) + norm_unclippsed_grads (float): global norm of all gradients before any gradient clipping is applied + fetched_tensors: all values for additional_fetches + """ + assert self.middleware.__class__.__name__ != 'LSTMMiddleware', "LSTM middleware not supported" + if self.accumulated_gradients is None: + self.reset_accumulated_gradients() - self.tensor_gradients = tf.gradients(self.total_loss, self.weights) - self.gradients_norm = tf.global_norm(self.tensor_gradients) + model_inputs = tuple(inputs[emb_type] for emb_type in self.emmbeding_types) + targets = force_list(targets) + targets = list(map(lambda x: tf.cast(x, tf.float32), targets)) + losses = list() + regularisations = list() - # gradient clipping - if self.network_parameters.clip_gradients is not None and self.network_parameters.clip_gradients != 0: - self._create_gradient_clipping_ops() - - # when using a shared optimizer, we create accumulators to store gradients from all the workers before - # applying them - if self.distributed_training: - self._create_gradient_accumulators() - - # gradients of the outputs w.r.t. the inputs - self.gradients_wrt_inputs = [{name: tf.gradients(output, input_ph) for name, input_ph in - self.inputs.items()} for output in self.outputs] - self.gradients_weights_ph = [tf.placeholder('float32', self.outputs[i].shape, 'output_gradient_weights') - for i in range(len(self.outputs))] - self.weighted_gradients = [] - for i in range(len(self.outputs)): - unnormalized_gradients = tf.gradients(self.outputs[i], self.weights, self.gradients_weights_ph[i]) - # unnormalized gradients seems to be better at the time. TODO: validate this accross more environments - # self.weighted_gradients.append(list(map(lambda x: tf.div(x, self.network_parameters.batch_size), - # unnormalized_gradients))) - self.weighted_gradients.append(unnormalized_gradients) - - # defining the optimization process (for LBFGS we have less control over the optimizer) - if self.optimizer_type != 'LBFGS' and self.network_is_trainable: - self._create_gradient_applying_ops() - - def _create_gradient_accumulators(self): - if self.network_is_global: - self.shared_accumulated_gradients = [tf.Variable(initial_value=tf.zeros_like(var)) for var in self.weights] - self.accumulate_shared_gradients = [var.assign_add(holder, use_locking=True) for holder, var in - zip(self.weights_placeholders, self.shared_accumulated_gradients)] - self.init_shared_accumulated_gradients = [var.assign(tf.zeros_like(var)) for var in - self.shared_accumulated_gradients] - elif self.network_is_local: - self.accumulate_shared_gradients = self.global_network.accumulate_shared_gradients - self.init_shared_accumulated_gradients = self.global_network.init_shared_accumulated_gradients - - def _create_gradient_clipping_ops(self): - """ - Create tensorflow ops for clipping the gradients according to the given GradientClippingMethod - :return: None - """ - if self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByGlobalNorm: - self.clipped_grads, self.grad_norms = tf.clip_by_global_norm(self.tensor_gradients, - self.network_parameters.clip_gradients) - elif self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByValue: - self.clipped_grads = [tf.clip_by_value(grad, - -self.network_parameters.clip_gradients, - self.network_parameters.clip_gradients) - for grad in self.tensor_gradients] - elif self.network_parameters.gradients_clipping_method == GradientClippingMethod.ClipByNorm: - self.clipped_grads = [tf.clip_by_norm(grad, self.network_parameters.clip_gradients) - for grad in self.tensor_gradients] - - def _create_gradient_applying_ops(self): - """ - Create tensorflow ops for applying the gradients to the network weights according to the training scheme - (distributed training - local or global network, shared optimizer, etc.) - :return: None - """ - if self.network_is_global and self.network_parameters.shared_optimizer and \ - not self.network_parameters.async_training: - # synchronous training with shared optimizer? -> create an operation for applying the gradients - # accumulated in the shared gradients accumulator - self.update_weights_from_shared_gradients = self.optimizer.apply_gradients( - zip(self.shared_accumulated_gradients, self.weights), - global_step=self.global_step) - - elif self.distributed_training and self.network_is_local: - # distributed training but independent optimizer? -> create an operation for applying the gradients - # to the global weights - self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( - zip(self.weights_placeholders, self.global_network.weights), global_step=self.global_step) - - elif self.network_is_trainable: - # not any of the above but is trainable? -> create an operation for applying the gradients to - # this network weights - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name) - - with tf.control_dependencies(update_ops): - self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( - zip(self.weights_placeholders, self.weights), global_step=self.global_step) - - def set_session(self, sess): - self.sess = sess - - task_is_distributed = isinstance(self.ap.task_parameters, DistributedTaskParameters) - # initialize the session parameters in single threaded runs. Otherwise, this is done through the - # MonitoredSession object in the graph manager - if not task_is_distributed: - self.sess.run(self.init_op) + with tf.GradientTape(persistent=True) as tape: - if self.ap.visualization.tensorboard: - # Write the merged summaries to the current experiment directory - if not task_is_distributed: - self.train_writer = tf.summary.FileWriter(self.ap.task_parameters.experiment_path + '/tensorboard') - self.train_writer.add_graph(self.sess.graph) - elif self.network_is_local: - self.train_writer = tf.summary.FileWriter(self.ap.task_parameters.experiment_path + - '/tensorboard/worker{}'.format(self.ap.task_parameters.task_index)) - self.train_writer.add_graph(self.sess.graph) - - # wait for all the workers to set their session - if not self.network_is_local: - self.wait_for_all_workers_barrier() - - def reset_accumulated_gradients(self): - """ - Reset the gradients accumulation placeholder - """ - if self.accumulated_gradients is None: - self.accumulated_gradients = self.sess.run(self.weights) + model_outputs = force_list(self.model(model_inputs)) - for ix, grad in enumerate(self.accumulated_gradients): - self.accumulated_gradients[ix] = grad * 0 + for head_loss, head_output in zip(self.losses, model_outputs): + non_trainable_args = utils.extract_loss_inputs(head_loss.head_idx, inputs, targets) + loss_outputs = head_loss([head_output], non_trainable_args) + fetched_tensors = utils.extract_fetches(loss_outputs, head_loss.head_idx, additional_fetches) - def accumulate_gradients(self, inputs, targets, additional_fetches=None, importance_weights=None, - no_accumulation=False): - """ - Runs a forward pass & backward pass, clips gradients if needed and accumulates them into the accumulation - placeholders - :param additional_fetches: Optional tensors to fetch during gradients calculation - :param inputs: The input batch for the network - :param targets: The targets corresponding to the input batch - :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss - error of this sample. If it is not given, the samples losses won't be scaled - :param no_accumulation: If is set to True, the gradients in the accumulated gradients placeholder will be - replaced by the newely calculated gradients instead of accumulating the new gradients. - This can speed up the function runtime by around 10%. - :return: A list containing the total loss and the individual network heads losses - """ + if LOSS_OUT_TYPE_LOSS in loss_outputs: + losses.extend(loss_outputs[LOSS_OUT_TYPE_LOSS]) + if LOSS_OUT_TYPE_REGULARIZATION in loss_outputs: + regularisations.extend(loss_outputs[LOSS_OUT_TYPE_REGULARIZATION]) - if self.accumulated_gradients is None: - self.reset_accumulated_gradients() + # Total loss is losses and regularization + total_loss_list = losses + regularisations + total_loss = tf.add_n(total_loss_list) + + # Calculate gradients + gradients = tape.gradient(total_loss, self.model.trainable_variables) + norm_unclipped_grads = tf.linalg.global_norm(gradients) + # Gradient clipping + if self.network_parameters.clip_gradients is not None and self.network_parameters.clip_gradients != 0: + gradients, gradients_norm = self.clip_gradients(gradients, + self.network_parameters.gradients_clipping_method, + self.network_parameters.clip_gradients) + # Update self.accumulated_gradients depending on no_accumulation flag + if no_accumulation: + self.accumulated_gradients = gradients.copy() + else: + self.accumulated_gradients += gradients.copy() - # feed inputs - if additional_fetches is None: - additional_fetches = [] - feed_dict = self.create_feed_dict(inputs) + # convert everything to numpy or scalar before returning + result = (total_loss, total_loss_list, norm_unclipped_grads.numpy(), fetched_tensors) + return result - # feed targets - targets = force_list(targets) - for placeholder_idx, target in enumerate(targets): - feed_dict[self.targets[placeholder_idx]] = target - - # feed importance weights - importance_weights = force_list(importance_weights) - for placeholder_idx, target_ph in enumerate(targets): - if len(importance_weights) <= placeholder_idx or importance_weights[placeholder_idx] is None: - importance_weight = np.ones(target_ph.shape[0]) - else: - importance_weight = importance_weights[placeholder_idx] - importance_weight = np.reshape(importance_weight, (-1,) + (1,) * (len(target_ph.shape) - 1)) - - feed_dict[self.importance_weights[placeholder_idx]] = importance_weight - - if self.optimizer_type != 'LBFGS': - - # feed the lstm state if necessary - if self.middleware.__class__.__name__ == 'LSTMMiddleware': - # we can't always assume that we are starting from scratch here can we? - feed_dict[self.middleware.c_in] = self.middleware.c_init - feed_dict[self.middleware.h_in] = self.middleware.h_init - - fetches = self.train_fetches + additional_fetches - if self.ap.visualization.tensorboard: - fetches += [self.merged] - - # get grads - result = self.sess.run(fetches, feed_dict=feed_dict) - if hasattr(self, 'train_writer') and self.train_writer is not None: - self.train_writer.add_summary(result[-1], self.sess.run(self.global_step)) - - # extract the fetches - norm_unclipped_grads, grads, total_loss, losses = result[:4] - if self.middleware.__class__.__name__ == 'LSTMMiddleware': - (self.curr_rnn_c_in, self.curr_rnn_h_in) = result[4] - fetched_tensors = [] - if len(additional_fetches) > 0: - fetched_tensors = result[self.additional_fetches_start_idx:self.additional_fetches_start_idx + - len(additional_fetches)] - - # accumulate the gradients - for idx, grad in enumerate(grads): - if no_accumulation: - self.accumulated_gradients[idx] = grad - else: - self.accumulated_gradients[idx] += grad - - return total_loss, losses, norm_unclipped_grads, fetched_tensors + def apply_gradients(self, gradients: List[np.ndarray], scaler: float=1., additional_inputs=None) -> None: + """ + Applies the given gradients to the network weights + :param gradients: The gradients to use for the update + :param scaler: A scaling factor that allows rescaling the gradients before applying them. + The gradients will be MULTIPLIED by this factor + """ + assert self.optimizer_type != 'LBFGS', 'LBFGS not supported' - else: - self.optimizer.minimize(session=self.sess, feed_dict=feed_dict) - - return [0] - - def create_feed_dict(self, inputs): - feed_dict = {} - for input_name, input_value in inputs.items(): - if isinstance(input_name, str): - if input_name not in self.inputs: - raise ValueError(( - 'input name {input_name} was provided to create a feed ' - 'dictionary, but there is no placeholder with that name. ' - 'placeholder names available include: {placeholder_names}' - ).format( - input_name=input_name, - placeholder_names=', '.join(self.inputs.keys()) - )) - - feed_dict[self.inputs[input_name]] = input_value - elif isinstance(input_name, tf.Tensor) and input_name.op.type == 'Placeholder': - feed_dict[input_name] = input_value - else: - raise ValueError(( - 'input dictionary expects strings or placeholders as keys, ' - 'but found key {key} of type {type}' - ).format( - key=input_name, - type=type(input_name), - )) - - return feed_dict - - def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None): + self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) + + def apply_and_reset_gradients(self, gradients: List[np.ndarray], scaler: float = 1., additional_inputs=None) -> None: """ Applies the given gradients to the network weights and resets the accumulation placeholder :param gradients: The gradients to use for the update @@ -426,221 +226,126 @@ def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None update ops also requires the inputs) """ - self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs) + self.apply_gradients(gradients, scaler) self.reset_accumulated_gradients() - def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False): + def clip_gradients(self, grads: List[np.ndarray], + clip_method: GradientClippingMethod, + clip_val: float) -> List[np.ndarray]: """ - Waits for all the workers to lock a certain lock and then continues - :param lock: the name of the lock to use - :param include_only_training_workers: wait only for training workers or for all the workers? - :return: None + Clip gradient values + :param grads: gradients to be clipped + :param clip_method: clipping method + :param clip_val: clipping value. Interpreted differently depending on clipping method. + :return: clipped gradients """ - if include_only_training_workers: - num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks - else: - num_workers_to_wait_for = self.ap.task_parameters.num_tasks - - # lock - if hasattr(self, '{}_counter'.format(lock)): - self.sess.run(getattr(self, lock)) - while self.sess.run(getattr(self, '{}_counter'.format(lock))) % num_workers_to_wait_for != 0: - time.sleep(0.00001) - # self.sess.run(getattr(self, '{}_init'.format(lock))) + + if clip_method == GradientClippingMethod.ClipByGlobalNorm: + clipped_grads, grad_norms = tf.clip_by_global_norm(grads, clip_val) + + elif clip_method == GradientClippingMethod.ClipByValue: + clipped_grads = [tf.clip_by_value(grad, -clip_val, clip_val) for grad in grads] + elif clip_method == GradientClippingMethod.ClipByNorm: + + clipped_grads = [tf.clip_by_norm(grad, clip_val) for grad in grads] else: - raise ValueError("no counter was defined for the lock {}".format(lock)) + raise KeyError('Unsupported gradient clipping method') + return clipped_grads - def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False): + def _predict(self, inputs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, ...]: """ - A barrier that allows waiting for all the workers to finish a certain block of commands - :param include_only_training_workers: wait only for training workers or for all the workers? - :return: None + Run a forward pass of the network using the given input + :param inputs: The input dictionary for the network. Key is name of the embedder. + :return: The network output per each head + """ - self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers) - self.sess.run(self.lock_init) + assert self.middleware.__class__.__name__ != 'LSTMMiddleware' - # we need to lock again (on a different lock) in order to prevent a situation where one of the workers continue - # and then was able to first increase the lock again by one, only to have a late worker to reset it again. - # so we want to make sure that all workers are done resetting the lock before continuting to reuse that lock. + model_inputs = tuple(inputs[emb] for emb in self.emmbeding_types) + model_outputs = self.model(model_inputs) - self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) - self.sess.run(self.release_init) + distribution_output = list(filter(lambda x: isinstance(x, Distribution), model_outputs)) - def apply_gradients(self, gradients, scaler=1., additional_inputs=None): - """ - Applies the given gradients to the network weights - :param gradients: The gradients to use for the update - :param scaler: A scaling factor that allows rescaling the gradients before applying them. - The gradients will be MULTIPLIED by this factor - :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's - update ops also requires the inputs) - """ + if distribution_output: + output_per_head = [] + distribution_output = distribution_output.pop() + policy_mean = distribution_output.mean().numpy() + policy_stddev = distribution_output.stddev().numpy() + value_output = list(filter(lambda x: not (isinstance(x, Distribution)), model_outputs)).pop() + value_output = value_output.numpy().reshape(-1,) + output_per_head.append(value_output) + output_per_head.append(policy_mean) + output_per_head.append(policy_stddev) + else: + output_per_head = model_outputs.numpy() + + return output_per_head - if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters): - if hasattr(self, 'global_step') and not self.network_is_local: - self.sess.run(self.inc_step) - - if self.optimizer_type != 'LBFGS': - - if self.distributed_training and not self.network_parameters.async_training: - # rescale the gradients so that they average out with the gradients from the other workers - if self.network_parameters.scale_down_gradients_by_number_of_workers_for_sync_training: - scaler /= float(self.ap.task_parameters.num_training_tasks) - - # rescale the gradients - if scaler != 1.: - for gradient in gradients: - gradient *= scaler - - # apply the gradients - feed_dict = dict(zip(self.weights_placeholders, gradients)) - if self.distributed_training and self.network_parameters.shared_optimizer \ - and not self.network_parameters.async_training: - # synchronous distributed training with shared optimizer: - # - each worker adds its gradients to the shared gradients accumulators - # - we wait for all the workers to add their gradients - # - the chief worker (worker with task index = 0) applies the gradients once and resets the accumulators - - self.sess.run(self.accumulate_shared_gradients, feed_dict=feed_dict) - - self.wait_for_all_workers_barrier(include_only_training_workers=True) - - if self.is_chief: - self.sess.run(self.update_weights_from_shared_gradients) - self.sess.run(self.init_shared_accumulated_gradients) - else: - # async distributed training / distributed training with independent optimizer - # / non-distributed training - just apply the gradients - feed_dict = dict(zip(self.weights_placeholders, gradients)) - if additional_inputs is not None: - feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)} - self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict) - - # release barrier - if self.distributed_training and not self.network_parameters.async_training: - self.wait_for_all_workers_barrier(include_only_training_workers=True) - - def predict(self, inputs, outputs=None, squeeze_output=True, initial_feed_dict=None): + def predict(self, + inputs: Dict[str, np.ndarray], + outputs: List[str] = None, + squeeze_output: bool = True, + initial_feed_dict: Dict[str, np.ndarray] = None) -> Tuple[np.ndarray, ...]: """ Run a forward pass of the network using the given input - :param inputs: The input for the network - :param outputs: The output for the network, defaults to self.outputs - :param squeeze_output: call squeeze_list on output - :param initial_feed_dict: a dictionary to use as the initial feed_dict. other inputs will be added to this dict + :param inputs: The input dictionary for the network. Key is name of the embedder. + :param outputs: list of outputs to return. Return all outputs if unspecified (currently not supported) + :param squeeze_output: call squeeze_list on output if True + :param initial_feed_dict: a dictionary of extra inputs for forward pass (currently not supported) :return: The network output WARNING: must only call once per state since each call is assumed by LSTM to be a new time step. """ - feed_dict = self.create_feed_dict(inputs) - if initial_feed_dict: - feed_dict.update(initial_feed_dict) - if outputs is None: - outputs = self.outputs - - if self.middleware.__class__.__name__ == 'LSTMMiddleware': - feed_dict[self.middleware.c_in] = self.curr_rnn_c_in - feed_dict[self.middleware.h_in] = self.curr_rnn_h_in - - output, (self.curr_rnn_c_in, self.curr_rnn_h_in) = self.sess.run([outputs, self.middleware.state_out], - feed_dict=feed_dict) - else: - output = self.sess.run(outputs, feed_dict) - - if squeeze_output: - output = squeeze_list(output) + assert initial_feed_dict is None, "initial_feed_dict must be None" + assert outputs is None, "outputs must be None" + output = self._predict(inputs) return output @staticmethod def parallel_predict(sess: Any, - network_input_tuples: List[Tuple['TensorFlowArchitecture', Dict[str, np.ndarray]]]) ->\ - List[np.ndarray]: + network_input_tuples: List[Tuple['TensorFlowArchitecture', + Dict[str, np.ndarray]]]) -> Tuple[np.ndarray, ...]: """ - :param sess: active session to use for prediction + :param sess: active session to use for prediction (must be None for TF2) :param network_input_tuples: tuple of network and corresponding input - :return: list of outputs from all networks + :return: tuple of outputs from all networks """ - feed_dict = {} - fetches = [] - - for network, input in network_input_tuples: - feed_dict.update(network.create_feed_dict(input)) - fetches += network.outputs - - outputs = sess.run(fetches, feed_dict) - - return outputs + assert sess is None + output = [net._predict(inputs) for net, inputs in network_input_tuples] + return output - def train_on_batch(self, inputs, targets, scaler=1., additional_fetches=None, importance_weights=None): - """ - Given a batch of examples and targets, runs a forward pass & backward pass and then applies the gradients - :param additional_fetches: Optional tensors to fetch during the training process - :param inputs: The input for the network - :param targets: The targets corresponding to the input batch - :param scaler: A scaling factor that allows rescaling the gradients before applying them - :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss - error of this sample. If it is not given, the samples losses won't be scaled - :return: The loss of the network - """ - if additional_fetches is None: - additional_fetches = [] - force_list(additional_fetches) - loss = self.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, - importance_weights=importance_weights) - self.apply_and_reset_gradients(self.accumulated_gradients, scaler) - return loss - - def get_weights(self): + def get_weights(self) -> Dict: """ :return: a list of tensors containing the network weights for each layer """ - return self.weights + return self.model.get_weights() - def set_weights(self, weights, new_rate=1.0): + def set_weights(self, source_weights: Dict, new_rate: float = 1.0) -> None: """ - Sets the network weights from the given list of weights tensors + Updates the target network weights from the given source model weights tensors """ - feed_dict = {} - old_weights, new_weights = self.sess.run([self.get_weights(), weights]) - for placeholder_idx, new_weight in enumerate(new_weights): - feed_dict[self.weights_placeholders[placeholder_idx]]\ - = new_rate * new_weight + (1 - new_rate) * old_weights[placeholder_idx] - self.sess.run(self.update_weights_from_list, feed_dict) - - def get_variable_value(self, variable): - """ - Get the value of a variable from the graph - :param variable: the variable - :return: the value of the variable - """ - return self.sess.run(variable) + updated_target = [] + if new_rate < 0 or new_rate > 1: + raise ValueError('new_rate parameter values should be between 0 to 1.') + target_weights = self.model.get_weights() + for (source_layer, target_layer) in zip(source_weights, target_weights): + updated_target.append(new_rate * source_layer + (1 - new_rate) * target_layer) + self.model.set_weights(updated_target) - def set_variable_value(self, assign_op, value, placeholder=None): - """ - Updates the value of a variable. - This requires having an assign operation for the variable, and a placeholder which will provide the value - :param assign_op: an assign operation for the variable - :param value: a value to set the variable to - :param placeholder: a placeholder to hold the given value for injecting it into the variable - """ - self.sess.run(assign_op, feed_dict={placeholder: value}) - - def set_is_training(self, state: bool): + def set_is_training(self, state: bool) -> None: """ Set the phase of the network between training and testing :param state: The current state (True = Training, False = Testing) :return: None """ - self.set_variable_value(self.assign_is_training, state, self.is_training_placeholder) + self.is_training = state - def reset_internal_memory(self): + def reset_internal_memory(self) -> None: """ Reset any internal memory used by the network. For example, an LSTM internal state :return: None """ - # initialize LSTM hidden states - if self.middleware.__class__.__name__ == 'LSTMMiddleware': - self.curr_rnn_c_in = self.middleware.c_init - self.curr_rnn_h_in = self.middleware.h_init + assert self.middleware.__class__.__name__ != 'LSTMMiddleware', 'LSTM middleware not supported' def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ @@ -649,45 +354,11 @@ def collect_savers(self, parent_path_suffix: str) -> SaverCollection: (e.g. could be name of level manager plus name of agent) :return: checkpoint collection for the network """ + name = self.name.replace('/', '.') savers = SaverCollection() if not self.distributed_training: - savers.add(GlobalVariableSaver(self.name)) + savers.add(TfSaver( + name="{}.{}".format(parent_path_suffix, name), + model=self.model)) return savers - -def save_onnx_graph(input_nodes, output_nodes, checkpoint_save_dir: str) -> None: - """ - Given the input nodes and output nodes of the TF graph, save it as an onnx graph - This requires the TF graph and the weights checkpoint to be stored in the experiment directory. - It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX. - - :param input_nodes: A list of input nodes for the TF graph - :param output_nodes: A list of output nodes for the TF graph - :param checkpoint_save_dir: The directory to save the ONNX graph to - :return: None - """ - import tf2onnx # just to verify that tf2onnx is installed - - # freeze graph - frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb") - freeze_graph_command = [ - "python -m tensorflow.python.tools.freeze_graph", - "--input_graph={}".format(os.path.join(checkpoint_save_dir, "graphdef.pb")), - "--input_binary=true", - "--output_node_names='{}'".format(','.join([o.split(":")[0] for o in output_nodes])), - "--input_checkpoint={}".format(tf.train.latest_checkpoint(checkpoint_save_dir)), - "--output_graph={}".format(frozen_graph_path) - ] - start_shell_command_and_wait(" ".join(freeze_graph_command)) - - # convert graph to onnx - onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx") - convert_to_onnx_command = [ - "python -m tf2onnx.convert", - "--input {}".format(frozen_graph_path), - "--inputs '{}'".format(','.join(input_nodes)), - "--outputs '{}'".format(','.join(output_nodes)), - "--output {}".format(onnx_graph_path), - "--verbose" - ] - start_shell_command_and_wait(" ".join(convert_to_onnx_command)) diff --git a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py index bbbbc0f23..a74b26396 100644 --- a/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py +++ b/rl_coach/architectures/tensorflow_components/distributed_tf_utils.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,13 @@ # limitations under the License. # + from typing import Tuple import tensorflow as tf + def create_cluster_spec(parameters_server: str, workers: str) -> tf.train.ClusterSpec: """ Creates a ClusterSpec object representing the cluster. @@ -36,7 +38,7 @@ def create_cluster_spec(parameters_server: str, workers: str) -> tf.train.Cluste return cluster_spec -def create_and_start_parameters_server(cluster_spec: tf.train.ClusterSpec, config: tf.ConfigProto=None) -> None: +def create_and_start_parameters_server(cluster_spec: tf.train.ClusterSpec, config: tf.compat.v1.ConfigProto=None) -> None: """ Create and start a parameter server :param cluster_spec: the ClusterSpec object representing the cluster @@ -44,14 +46,14 @@ def create_and_start_parameters_server(cluster_spec: tf.train.ClusterSpec, confi :return: None """ # create a server object for the parameter server - server = tf.train.Server(cluster_spec, job_name="ps", task_index=0, config=config) + server = tf.distribute.Server(cluster_spec, job_name="ps", task_index=0, config=config) # wait for the server to finish server.join() def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_index: int, - use_cpu: bool=True, config: tf.ConfigProto=None) -> Tuple[str, tf.device]: + use_cpu: bool=True, config: tf.compat.v1.ConfigProto=None) -> Tuple[str, tf.device]: """ Creates a worker server and a device setter used to assign the workers operations to :param cluster_spec: a ClusterSpec object representing the cluster @@ -61,7 +63,7 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind :return: the target string for the tf.Session and the worker device setter object """ # Create and start a worker - server = tf.train.Server(cluster_spec, job_name="worker", task_index=task_index, config=config) + server = tf.distribute.Server(cluster_spec, job_name="worker", task_index=task_index, config=config) # Assign ops to the local worker worker_device = "/job:worker/task:{}".format(task_index) @@ -69,13 +71,13 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind worker_device += "/cpu:0" else: worker_device += "/device:GPU:0" - device = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster_spec) + device = tf.compat.v1.train.replica_device_setter(worker_device=worker_device, cluster=cluster_spec) return server.target, device -def create_monitored_session(target: tf.train.Server, task_index: int, - checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session: +def create_monitored_session(target: tf.distribute.Server, task_index: int, + checkpoint_dir: str, checkpoint_save_secs: int, config: tf.compat.v1.ConfigProto=None) -> tf.compat.v1.Session: """ Create a monitored session for the worker :param target: the target string for the tf.Session @@ -89,7 +91,7 @@ def create_monitored_session(target: tf.train.Server, task_index: int, is_chief = task_index == 0 # Create the monitored session - sess = tf.train.MonitoredTrainingSession( + sess = tf.compat.v1.train.MonitoredTrainingSession( master=target, is_chief=is_chief, hooks=[], diff --git a/rl_coach/architectures/tensorflow_components/dnn_model.py b/rl_coach/architectures/tensorflow_components/dnn_model.py new file mode 100644 index 000000000..f824ee6ce --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/dnn_model.py @@ -0,0 +1,320 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import numpy as np +import tensorflow as tf +from tensorflow import keras +from typing import List +from types import ModuleType +from tensorflow.keras.layers import Input + +from rl_coach.architectures.tensorflow_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder +from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters +from rl_coach.architectures.tensorflow_components.middlewares import FCMiddleware +from rl_coach.architectures.tensorflow_components.heads import Head#, PPOHead, PPOVHead, VHead, QHead +from rl_coach.architectures.tensorflow_components.heads.ppo_head import continuous_ppo_head +from rl_coach.architectures.tensorflow_components.heads.v_head import value_head +from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters +from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadParameters, QHeadParameters +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.head_parameters import HeadParameters +from rl_coach.architectures.middleware_parameters import MiddlewareParameters +from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType +from rl_coach.base_parameters import NetworkParameters +from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace + + +def _get_input_embedder(name_prefix: str, + spaces: SpacesDefinition, + input_name: str, + embedder_params: InputEmbedderParameters) -> ModuleType: + """ + Given an input embedder parameters class, creates the input embedder and returns it + :param input_name: the name of the input to the embedder (used for retrieving the shape). The input should + be a value within the state or the action. + :param embedder_params: the parameters of the class of the embedder + :return: the embedder instance + """ + embedder_params = copy.copy(embedder_params) + + embedder_params.name = name_prefix + '_input_embedder' + + allowed_inputs = copy.copy(spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(spaces.action) + allowed_inputs["goal"] = copy.copy(spaces.goal) + + if input_name not in allowed_inputs.keys(): + raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" + .format(input_name, allowed_inputs.keys())) + + type = "vector" + if isinstance(allowed_inputs[input_name], TensorObservationSpace): + type = "tensor" + elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): + type = "image" + + if type == 'vector': + + module = VectorEmbedder(input_size=allowed_inputs[input_name].shape, + activation_function=embedder_params.activation_function, + scheme=embedder_params.scheme, + batchnorm=embedder_params.batchnorm, + dropout_rate=embedder_params.dropout_rate, + name=embedder_params.name, + input_rescaling=embedder_params.input_rescaling[type], + input_offset=embedder_params.input_offset[type], + input_clipping=embedder_params.input_clipping, + is_training=embedder_params.is_training) + elif type == 'image': + module = ImageEmbedder(input_size=allowed_inputs[input_name].shape, + activation_function=embedder_params.activation_function, + scheme=embedder_params.scheme, + batchnorm=embedder_params.batchnorm, + dropout_rate=embedder_params.dropout_rate, + name=embedder_params.name, + input_rescaling=embedder_params.input_rescaling[type], + input_offset=embedder_params.input_offset[type], + input_clipping=embedder_params.input_clipping, + is_training=embedder_params.is_training) + + elif type == 'tensor': + module = TensorEmbedder(embedder_params) + else: + raise KeyError('Unsupported embedder type: {}'.format(type)) + return module + + +def _get_middleware(middleware_params: MiddlewareParameters) -> ModuleType: + """ + Given a middleware type, creates the middleware and returns it + :param middleware_params: the paramaeters of the middleware class + :return: the middleware instance + """ + if isinstance(middleware_params, FCMiddlewareParameters): + module = FCMiddleware(activation_function=middleware_params.activation_function, + scheme=middleware_params.scheme, + batchnorm=middleware_params.batchnorm, + dropout_rate=middleware_params.dropout_rate, + name=middleware_params.name, + is_training=middleware_params.is_training, + num_streams=middleware_params.num_streams) + + elif isinstance(middleware_params, LSTMMiddlewareParameters): + raise KeyError('"LSTM middleware not supported": {}'.format(type(middleware_params))) + else: + raise KeyError('Unsupported middleware type: {}'.format(type(middleware_params))) + return module + + +def _get_output_head( + head_params: HeadParameters, + head_input_dim: int, + head_idx: int, + head_type_index: int, + agent_params: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + is_local: bool) -> Head: + """ + Given a head type, creates the head and returns it + :param head_params: the parameters of the head to create + :param head_idx: the head index + :param head_type_index: the head type index (same index if head_param.num_output_head_copies>0) + :param agent_params: agent parameters + :param spaces: state and action space definitions + :param network_name: name of the network + :param is_local: + :return: head name and head block + """ + + if isinstance(head_params, QHeadParameters): + head_output_dim = len(spaces.action.actions) + module = value_head(head_input_dim, head_output_dim) + # module = QHead( + # agent_parameters=agent_params, + # spaces=spaces, + # network_name=network_name, + # head_type_idx=head_type_index, + # loss_weight=head_params.loss_weight, + # is_local=is_local, + # activation_function=head_params.activation_function, + # dense_layer=head_params.dense_layer) + + elif isinstance(head_params, PPOHeadParameters): + head_output_dim = spaces.action.shape[0] + module = continuous_ppo_head(head_input_dim, head_output_dim) + + # module = PPOHead( + # agent_parameters=agent_params, + # spaces=spaces, + # network_name=network_name, + # head_type_idx=head_type_index, + # loss_weight=head_params.loss_weight, + # is_local=is_local, + # activation_function=head_params.activation_function, + # dense_layer=head_params.dense_layer) + + elif isinstance(head_params, VHeadParameters): + head_output_dim = 1 + module = value_head(head_input_dim, head_output_dim) + # module = VHead( + # agent_parameters=agent_params, + # spaces=spaces, + # network_name=network_name, + # head_type_idx=head_type_index, + # loss_weight=head_params.loss_weight, + # is_local=is_local, + # activation_function=head_params.activation_function, + # dense_layer=head_params.dense_layer) + + else: + raise KeyError('Unsupported head type: {}'.format(type(head_params))) + + return module + + +def create_single_network(inputs_shapes, + name: str, + network_is_local: bool, + head_type_idx_start: int, + agent_parameters: AgentParameters, + input_embedders_parameters: {str: InputEmbedderParameters}, + embedding_merger_type: EmbeddingMergerType, + middleware_param: MiddlewareParameters, + head_param_list: [HeadParameters], + spaces: SpacesDefinition): + """ + + Block that connects a single embedder, with middleware and heads + + :param network_is_local: True if network is local + :param name: name of the network + :param agent_parameters: agent parameters + :param input_embedders_parameters: dictionary of embedder name to embedding parameters + :param embedding_merger_type: type of merging output of embedders: concatenate or sum + :param middleware_param: middleware parameters + :param head_param_list: list of head parameters, one per head type + :param head_type_idx_start: start index for head type index counting + :param spaces: state and action space definition + """ + name = name + '_' + head_param_list[0].name.replace('head_params', '_network') + inputs = list(map(lambda x: Input(name=name + '_input', shape=x), inputs_shapes)) + # Get list of input embedders + embedders = [_get_input_embedder(name, spaces, k, v) for k, v in input_embedders_parameters.items()] + # Apply each embbeder on its corresponding input + state_embeddings = [embedder(input_t) for embedder, input_t in zip(embedders, inputs)] + + # Merge embedders outputs + if len(state_embeddings) == 1: + # TODO: change to squeeze + state_embeddings = state_embeddings[0] + else: + if embedding_merger_type == EmbeddingMergerType.Concat: + state_embeddings = tf.keras.layers.Concatenate()(state_embeddings) + elif embedding_merger_type == EmbeddingMergerType.Sum: + state_embeddings = tf.keras.layers.Add()(state_embeddings) + + middleware_output = _get_middleware(middleware_param)(state_embeddings) + + heads_outputs = list() + + for i, head_param in enumerate(head_param_list): + for head_copy_idx in range(head_param.num_output_head_copies): + # create output head and add it to the output heads list + head_idx = (head_type_idx_start + i) * head_param.num_output_head_copies + head_copy_idx + network_head = _get_output_head( + head_input_dim=middleware_output.shape[-1], + head_idx=head_idx, + head_type_index=head_type_idx_start + i, + network_name=name, + spaces=spaces, + is_local=network_is_local, + agent_params=agent_parameters, + head_params=head_param) + + heads_outputs.append(network_head(middleware_output)) + + name = name + '_' + head_param_list[0].name.replace('head_params', '_network') + model = keras.Model(name=name, inputs=inputs, outputs=heads_outputs) + return model + + +def create_full_model(num_networks: int, + num_heads_per_network: int, + network_is_local: bool, + network_name: str, + agent_parameters: AgentParameters, + network_parameters: NetworkParameters, + spaces: SpacesDefinition): + """ + function that creates multiple models. + For example, can be two single models, one for the actor and one for the critic. or online and target networks + :param num_networks: number of networks to create + :param num_heads_per_network: number of heads per network to create + :param network_is_local: True if network is local + :param network_name: name of the network + :param agent_parameters: agent parameters + :param network_parameters: network parameters + :param spaces: state and action space definitions + """ + + input_emmbeders_types = network_parameters.input_embedders_parameters.keys() + input_shapes = get_input_shapes(spaces, input_emmbeders_types) + inputs = list(map(lambda x: Input(name=network_name + '_input', shape=x), input_shapes)) + + outputs = list() + networks = {} + for network_idx in range(num_networks): + head_type_idx_start = network_idx * num_heads_per_network + head_type_idx_end = head_type_idx_start + num_heads_per_network + networks[network_idx] = create_single_network(inputs_shapes=input_shapes, + name=network_name, + network_is_local=network_is_local, + head_type_idx_start=head_type_idx_start, + agent_parameters=agent_parameters, + input_embedders_parameters=network_parameters.input_embedders_parameters, + embedding_merger_type=network_parameters.embedding_merger_type, + middleware_param=network_parameters.middleware_parameters, + head_param_list=network_parameters.heads_parameters[head_type_idx_start:head_type_idx_end], + spaces=spaces) + + outputs.append(networks[network_idx](inputs)) + + model = keras.Model(name=network_name + '_full_model', inputs=inputs, outputs=outputs) + # Dummy batch size 1, therefore [1] + shape + dummy_inputs = tuple(np.zeros(tuple([1] + shape)) for shape in input_shapes) + model(dummy_inputs) + return model + + +def get_input_shapes(spaces, input_emmbeders_types) -> List[List[int]]: + """ + Create a list of input array shapes + :return: type of input shapes + """ + allowed_inputs = copy.copy(spaces.state.sub_spaces) + allowed_inputs["action"] = copy.copy(spaces.action) + allowed_inputs["goal"] = copy.copy(spaces.goal) + return list(allowed_inputs[embedder_type].shape.tolist() for embedder_type in input_emmbeders_types) + + +def squeeze_model_outputs(model_outputs): + if len(model_outputs) == 1: + return model_outputs + else: + return list(map(lambda output: output[0], model_outputs)) + diff --git a/rl_coach/architectures/tensorflow_components/embedders/__init__.py b/rl_coach/architectures/tensorflow_components/embedders/__init__.py index 5091f35c1..93d79f234 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/__init__.py +++ b/rl_coach/architectures/tensorflow_components/embedders/__init__.py @@ -1,5 +1,7 @@ from .image_embedder import ImageEmbedder -from .vector_embedder import VectorEmbedder from .tensor_embedder import TensorEmbedder +from .vector_embedder import VectorEmbedder -__all__ = ['ImageEmbedder', 'VectorEmbedder', 'TensorEmbedder'] +__all__ = ['ImageEmbedder', + 'TensorEmbedder', + 'VectorEmbedder'] diff --git a/rl_coach/architectures/tensorflow_components/embedders/embedder.py b/rl_coach/architectures/tensorflow_components/embedders/embedder.py index 13544c9ac..75eef2507 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/embedder.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,109 +14,76 @@ # limitations under the License. # -from typing import List, Union, Tuple -import copy - +from typing import List, Union import numpy as np import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense -from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters - +from tensorflow import keras, Tensor +from rl_coach.base_parameters import EmbedderScheme from rl_coach.core_types import InputEmbedding -from rl_coach.utils import force_list +from rl_coach.architectures.tensorflow_components.layers import convert_layer -class InputEmbedder(object): +class InputEmbedder(keras.layers.Layer): """ An input embedder is the first part of the network, which takes the input from the state and produces a vector embedding by passing it through a neural network. The embedder will mostly be input type dependent, and there can be multiple embedders in a single network """ - def __init__(self, input_size: List[int], activation_function=tf.nn.relu, - scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, - name: str= "embedder", input_rescaling=1.0, input_offset=0.0, input_clipping=None, dense_layer=Dense, - is_training=False): - self.name = name + def __init__(self, + input_size: List[int], + activation_function=tf.nn.relu, + scheme: EmbedderScheme = None, + batchnorm: bool = False, + dropout_rate: float = 0.0, + name: str = "embedder", + input_rescaling=1.0, + input_offset=0.0, + input_clipping=None, + is_training=False, + **kwargs): + + super(InputEmbedder, self).__init__(name=name)#, trainable=is_training) self.input_size = input_size - self.activation_function = activation_function - self.batchnorm = batchnorm - self.dropout_rate = dropout_rate - self.input = None - self.output = None - self.scheme = scheme self.return_type = InputEmbedding - self.layers_params = [] - self.layers = [] self.input_rescaling = input_rescaling self.input_offset = input_offset self.input_clipping = input_clipping - self.dense_layer = dense_layer - if self.dense_layer is None: - self.dense_layer = Dense - self.is_training = is_training - - # layers order is conv -> batchnorm -> activation -> dropout - if isinstance(self.scheme, EmbedderScheme): - self.layers_params = copy.copy(self.schemes[self.scheme]) - self.layers_params = [convert_layer(l) for l in self.layers_params] - else: - # if scheme is specified directly, convert to TF layer if it's not a callable object - # NOTE: if layer object is callable, it must return a TF tensor when invoked - self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] - - # we allow adding batchnorm, dropout or activation functions after each layer. - # The motivation is to simplify the transition between a network with batchnorm and a network without - # batchnorm to a single flag (the same applies to activation function and dropout) - if self.batchnorm or self.activation_function or self.dropout_rate > 0: - for layer_idx in reversed(range(len(self.layers_params))): - self.layers_params.insert(layer_idx+1, - BatchnormActivationDropout(batchnorm=self.batchnorm, - activation_function=self.activation_function, - dropout_rate=self.dropout_rate)) - - def __call__(self, prev_input_placeholder: tf.placeholder=None) -> Tuple[tf.Tensor, tf.Tensor]: - """ - Wrapper for building the module graph including scoping and loss creation - :param prev_input_placeholder: the input to the graph - :return: the input placeholder and the output of the last layer - """ - with tf.variable_scope(self.get_name()): - if prev_input_placeholder is None: - self.input = tf.placeholder("float", shape=[None] + self.input_size, name=self.get_name()) - else: - self.input = prev_input_placeholder - self._build_module() - - return self.input, self.output + self.embbeder_layers = [] - def _build_module(self) -> None: + if isinstance(scheme, EmbedderScheme): + layers = self.schemes[scheme] + else: + layers = scheme + # Convert layer to TensorFlow layer + layers = [convert_layer(l) for l in layers] + + for layer in layers: + self.embbeder_layers.extend([layer]) + if batchnorm: + self.embbeder_layers.extend([keras.layers.BatchNormalization()]) + if activation_function: + self.embbeder_layers.extend([keras.activations.get(activation_function)]) + if dropout_rate: + self.embbeder_layers.extend([keras.layers.Dropout(rate=dropout_rate)]) + + def call(self, inputs) -> Tensor: """ - Builds the graph of the module - This method is called early on from __call__. It is expected to store the graph - in self.output. - :return: None + Used for forward pass through embedder network. + :param inputs: environment state, where first dimension is batch_size, then dimensions are data type dependent. + :return: embedding of environment state, where shape is (batch_size, channels). """ - # NOTE: for image inputs, we expect the data format to be of type uint8, so to be memory efficient. we chose not - # to implement the rescaling as an input filters.observation.observation_filter, as this would have caused the - # input to the network to be float, which is 4x more expensive in memory. - # thus causing each saved transition in the memory to also be 4x more pricier. - - input_layer = self.input / self.input_rescaling - input_layer -= self.input_offset - # clip input using te given range + inputs = tf.cast(inputs, tf.float32) + x = tf.math.divide(inputs, self.input_rescaling) + x = x - self.input_offset if self.input_clipping is not None: - input_layer = tf.clip_by_value(input_layer, self.input_clipping[0], self.input_clipping[1]) - - self.layers.append(input_layer) + x = tf.clip_by_value(x, self.input_clipping[0], self.input_clipping[1]) - for idx, layer_params in enumerate(self.layers_params): - self.layers.extend(force_list( - layer_params(input_layer=self.layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, idx), - is_training=self.is_training) - )) + for layer in self.embbeder_layers: + x = layer(x) - self.output = tf.contrib.layers.flatten(self.layers[-1]) + # For convolution layer + x = keras.layers.Flatten()(x) + return x @property def input_size(self) -> List[int]: @@ -135,7 +102,14 @@ def input_size(self, value: Union[int, List[int]]): self._input_size = value @property - def schemes(self): + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + InputEmbedder. Should be implemented in child classes, and are used to create Block when InputEmbedder is + initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of Tensorflow layers. + """ raise NotImplementedError("Inheriting embedder must define schemes matching its allowed default " "configurations.") @@ -146,12 +120,17 @@ def get_name(self) -> str: """ return self.name + def get_config(self): + config = super(InputEmbedder, self).get_config() + config.update({'name': self.name}) + return config + def __str__(self): result = ['Input size = {}'.format(self._input_size)] if self.input_rescaling != 1.0 or self.input_offset != 0.0: result.append('Input Normalization (scale = {}, offset = {})'.format(self.input_rescaling, self.input_offset)) - result.extend([str(l) for l in self.layers_params]) - if not self.layers_params: + result.extend([str(l) for l in self.embbeder_layers]) + if not self.embbeder_layers: result.append('No layers') return '\n'.join(result) diff --git a/rl_coach/architectures/tensorflow_components/embedders/image_embedder.py b/rl_coach/architectures/tensorflow_components/embedders/image_embedder.py index b05ec8e03..dcb4e08c5 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/image_embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/image_embedder.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,13 @@ # limitations under the License. # -from typing import List +from typing import List import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from tensorflow import keras +from functools import partial +from typing import Dict +from rl_coach.architectures.layers import Conv2d from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder from rl_coach.base_parameters import EmbedderScheme from rl_coach.core_types import InputImageEmbedding @@ -31,19 +33,42 @@ class ImageEmbedder(InputEmbedder): The embedder also allows custom rescaling of the input prior to the neural network. """ - def __init__(self, input_size: List[int], activation_function=tf.nn.relu, - scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, - name: str= "embedder", input_rescaling: float=255.0, input_offset: float=0.0, input_clipping=None, - dense_layer=Dense, is_training=False): - super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, - input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) + def __init__(self, input_size: List[int], + activation_function=tf.nn.relu, + scheme: EmbedderScheme = EmbedderScheme.Medium, + batchnorm: bool = False, + dropout_rate: float = 0.0, + name: str = "embedder", + input_rescaling: float = 255.0, + input_offset: float = 0.0, + input_clipping=None, + is_training=False): + + super().__init__(input_size, + activation_function, + scheme, + batchnorm, + dropout_rate, + name, + input_rescaling, + input_offset, + input_clipping, + is_training=is_training) + self.return_type = InputImageEmbedding if len(input_size) != 3 and scheme != EmbedderScheme.Empty: raise ValueError("Image embedders expect the input size to have 3 dimensions. The given size is: {}" .format(input_size)) + DefaultConv2D = partial(keras.layers.Conv2D, default_data_format='channels_last', activation=None, padding="SAME") @property - def schemes(self): + def schemes(self) -> Dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when ImageEmbedder is initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of Tensorflow layers. + """ return { EmbedderScheme.Empty: [], @@ -74,5 +99,3 @@ def schemes(self): Conv2d(256, 3, 1) ] } - - diff --git a/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py b/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py index 286442c4e..080f7c950 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,10 @@ # limitations under the License. # -from typing import List import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from typing import List +from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder from rl_coach.base_parameters import EmbedderScheme from rl_coach.core_types import InputTensorEmbedding @@ -38,15 +37,39 @@ class TensorEmbedder(InputEmbedder): activation, and dropout in between as specified in InputEmbedderParameters. """ - def __init__(self, input_size: List[int], activation_function=tf.nn.relu, - scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, - name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, - dense_layer=Dense, is_training=False): - super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, - input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) + def __init__(self, input_size: List[int], + activation_function=tf.nn.relu, + scheme: EmbedderScheme = None, + batchnorm: bool = False, + dropout_rate: float = 0.0, + name: str = "embedder", + input_rescaling: float = 1.0, + input_offset: float = 0.0, + input_clipping=None, + dense_layer=Dense, + is_training=False): + super().__init__(input_size, + activation_function, + scheme, batchnorm, + dropout_rate, + name, + input_rescaling, + input_offset, + input_clipping, + dense_layer=dense_layer, + is_training=is_training) self.return_type = InputTensorEmbedding assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder" @property def schemes(self): + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when InputEmbedder is initialised. + + Note: Tensor embedder doesn't define any pre-defined scheme. User must provide custom scheme in preset. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of TensorFlow layers. + For tensor embedder, this is an empty dictionary. + """ return {} diff --git a/rl_coach/architectures/tensorflow_components/embedders/vector_embedder.py b/rl_coach/architectures/tensorflow_components/embedders/vector_embedder.py index 60b728dbd..eb8c300d9 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/vector_embedder.py +++ b/rl_coach/architectures/tensorflow_components/embedders/vector_embedder.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,9 @@ # limitations under the License. # -from typing import List - import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import Dense +from typing import List, Dict +from rl_coach.architectures.layers import Dense from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder from rl_coach.base_parameters import EmbedderScheme from rl_coach.core_types import InputVectorEmbedding @@ -26,16 +24,26 @@ class VectorEmbedder(InputEmbedder): """ - An input embedder that is intended for inputs that can be represented as vectors. - The embedder flattens the input, applies several dense layers to it and returns the output. + A vector embedder is an input embedder that takes an vector input from the state and produces a vector + embedding by passing it through a neural network. + + :param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function + and dropout properties. """ - def __init__(self, input_size: List[int], activation_function=tf.nn.relu, - scheme: EmbedderScheme=EmbedderScheme.Medium, batchnorm: bool=False, dropout_rate: float=0.0, - name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, - dense_layer=Dense, is_training=False): + def __init__(self, input_size: List[int], + activation_function=tf.nn.relu, + scheme: EmbedderScheme=EmbedderScheme.Medium, + batchnorm: bool=False, + dropout_rate: float=0.0, + name: str = "embedder", + input_rescaling: float=1.0, + input_offset: float=0.0, + input_clipping=None, + is_training=False): + super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, - input_rescaling, input_offset, input_clipping, dense_layer=dense_layer, + input_rescaling, input_offset, input_clipping, is_training=is_training) self.return_type = InputVectorEmbedding @@ -43,27 +51,33 @@ def __init__(self, input_size: List[int], activation_function=tf.nn.relu, raise ValueError("The input size of a vector embedder must contain only a single dimension") @property - def schemes(self): + def schemes(self) -> Dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when VectorEmbedder is initialised. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of Tensorflow layers. + """ return { EmbedderScheme.Empty: [], EmbedderScheme.Shallow: [ - self.dense_layer(128) + Dense(128) ], - # dqn + # Use for DQN EmbedderScheme.Medium: [ - self.dense_layer(256) + Dense(256) ], - # carla + # Use for Carla EmbedderScheme.Deep: \ [ - self.dense_layer(128), - self.dense_layer(128), - self.dense_layer(128) + Dense(128), + Dense(128), + Dense(128) ] } diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index 8821ac6cc..0950e946e 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,32 +14,31 @@ # limitations under the License. # -import copy +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +from typing import List, Union from types import MethodType -from typing import Dict, List, Union +from tensorflow import keras +from tensorflow.keras.losses import Loss, Huber, MeanSquaredError -import numpy as np -import tensorflow as tf - -from rl_coach.architectures.embedder_parameters import InputEmbedderParameters -from rl_coach.architectures.head_parameters import HeadParameters -from rl_coach.architectures.middleware_parameters import MiddlewareParameters +from rl_coach.base_parameters import AgentParameters +from rl_coach.spaces import SpacesDefinition +from rl_coach.base_parameters import NetworkParameters, Device, DeviceType from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture -from rl_coach.architectures.tensorflow_components import utils -from rl_coach.base_parameters import AgentParameters, Device, DeviceType, EmbeddingMergerType -from rl_coach.core_types import PredictionType +from rl_coach.architectures.tensorflow_components.dnn_model import create_full_model +from rl_coach.architectures.tensorflow_components.losses.head_loss import HeadLoss +from rl_coach.architectures.tensorflow_components.losses.q_loss import QLoss +from rl_coach.architectures.tensorflow_components.losses.v_loss import VLoss +from rl_coach.architectures.tensorflow_components.losses.ppo_loss import PPOLoss +from rl_coach.architectures.head_parameters import HeadParameters, PPOHeadParameters, VHeadParameters, QHeadParameters from rl_coach.logger import screen -from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace -from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string +from rl_coach.utils import dynamic_import_and_instantiate_module_from_params class GeneralTensorFlowNetwork(TensorFlowArchitecture): """ - A generalized version of all possible networks implemented using tensorflow. + A generalized version of all possible networks implemented using tensorflow along with the optimizer and loss. """ - # dictionary of variable-scope name to variable-scope object to prevent tensorflow from - # creating a new auxiliary variable scope even when name is properly specified - variable_scopes_dict = dict() @staticmethod def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'GeneralTensorFlowNetwork': @@ -51,27 +50,32 @@ def construct(variable_scope: str, devices: List[str], *args, **kwargs) -> 'Gene :param kwargs: all other keyword arguments for class initializer :return: a GeneralTensorFlowNetwork object """ + + if len(devices) > 1: screen.warning("Tensorflow implementation only support a single device. Using {}".format(devices[0])) - def construct_on_device(): - with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])): - return GeneralTensorFlowNetwork(*args, **kwargs) - - # If variable_scope is in our dictionary, then this is not the first time that this variable_scope - # is being used with construct(). So to avoid TF adding an incrementing number to the end of the - # variable_scope to uniquify it, we have to both pass the previous variable_scope object to the new - # variable_scope() call and also recover the name space using name_scope - if variable_scope in GeneralTensorFlowNetwork.variable_scopes_dict: - variable_scope = GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] - with tf.variable_scope(variable_scope, auxiliary_name_scope=False) as vs: - with tf.name_scope(vs.original_name_scope): - return construct_on_device() - else: - with tf.variable_scope(variable_scope, auxiliary_name_scope=True) as vs: - # Add variable_scope object to dictionary for next call to construct - GeneralTensorFlowNetwork.variable_scopes_dict[variable_scope] = vs - return construct_on_device() + # strategy = tf.distribute.MirroredStrategy() + # with strategy.scope(): + # generalized_network = GeneralTensorFlowNetwork(*args, **kwargs) + + generalized_network = GeneralTensorFlowNetwork(*args, **kwargs) + + # with tf.device(GeneralTensorFlowNetwork._tf_device(devices[0])): + # generalized_network = GeneralTensorFlowNetwork(*args, **kwargs) + + generalized_network.model.summary() + + keras.utils.plot_model(generalized_network.model, + expand_nested=True, + show_shapes=True, + to_file='model_plot_new.png') + img = mpimg.imread('model_plot_new.png') + plt.imshow(img) + plt.show() + + return generalized_network + @staticmethod def _tf_device(device: Union[str, MethodType, Device]) -> str: @@ -93,357 +97,147 @@ def _tf_device(device: Union[str, MethodType, Device]) -> str: else: raise ValueError("Invalid device instance type: {}".format(type(device))) - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, name: str, - global_network=None, network_is_local: bool=True, network_is_trainable: bool=False): + def __init__(self, + agent_parameters: AgentParameters, + spaces: SpacesDefinition, + name: str, + global_network=None, + network_is_local: bool=True, + network_is_trainable: bool=False): """ :param agent_parameters: the agent parameters :param spaces: the spaces definition of the agent + :param devices: list of devices to run the network on :param name: the name of the network :param global_network: the global network replica that is shared between all the workers :param network_is_local: is the network global (shared between workers) or local (dedicated to the worker) :param network_is_trainable: is the network trainable (we can apply gradients on it) """ - self.global_network = global_network - self.network_is_local = network_is_local - self.network_wrapper_name = name.split('/')[0] - self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name] - self.num_heads_per_network = 1 if self.network_parameters.use_separate_networks_per_head else \ - len(self.network_parameters.heads_parameters) - self.num_networks = 1 if not self.network_parameters.use_separate_networks_per_head else \ - len(self.network_parameters.heads_parameters) - self.gradients_from_head_rescalers = [] - self.gradients_from_head_rescalers_placeholders = [] - self.update_head_rescaler_value_ops = [] + super().__init__(agent_parameters, spaces, name, global_network, network_is_local, network_is_trainable) - self.adaptive_learning_rate_scheme = None - self.current_learning_rate = None + self.global_network = global_network - # init network modules containers - self.input_embedders = [] - self.output_heads = [] - super().__init__(agent_parameters, spaces, name, global_network, - network_is_local, network_is_trainable) + self.network_wrapper_name = name.split('/')[0] - self.available_return_types = self._available_return_types() - self.is_training = None + network_name = name.split('/')[0] + '_' + name.split('/')[1] - def _available_return_types(self): - ret_dict = {cls: [] for cls in get_all_subclasses(PredictionType)} + network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name] - components = self.input_embedders + [self.middleware] + self.output_heads - for component in components: - if not hasattr(component, 'return_type'): - raise ValueError(( - "{} has no return_type attribute. Without this, it is " - "unclear how this component should be used." - ).format(component)) + if len(network_parameters.input_embedders_parameters) == 0: + raise ValueError("At least one input type should be defined") - if component.return_type is not None: - ret_dict[component.return_type].append(component) + if len(network_parameters.heads_parameters) == 0: + raise ValueError("At least one output type should be defined") - return ret_dict + if network_parameters.middleware_parameters is None: + raise ValueError("Exactly one middleware type should be defined") - def predict_with_prediction_type(self, states: Dict[str, np.ndarray], - prediction_type: PredictionType) -> Dict[str, np.ndarray]: - """ - Search for a component[s] which has a return_type set to the to the requested PredictionType, and get - predictions for it. + if network_parameters.use_separate_networks_per_head: + num_heads_per_network = 1 + num_networks = len(network_parameters.heads_parameters) + else: + num_heads_per_network = len(network_parameters.heads_parameters) + num_networks = 1 + + self.model = create_full_model(num_networks=num_networks, + num_heads_per_network=num_heads_per_network, + network_is_local=network_is_local, + network_name=network_name, + agent_parameters=agent_parameters, + network_parameters=network_parameters, + spaces=spaces) + + self.losses = list() + for index, loss_params in enumerate(network_parameters.heads_parameters): + loss = self._get_loss(agent_parameters=agent_parameters, + loss_params=loss_params, + network_name=loss_params.name, + num_actions=spaces.action.shape[0], + head_idx=index, + loss_weight=loss_params.loss_weight) + self.losses.append(loss) + + self.optimizer = self._get_optimizer(network_parameters) + self.network_parameters = agent_parameters.network_wrappers[self.network_wrapper_name] - :param states: The input states to the network. - :param prediction_type: The requested PredictionType to look for in the network components - :return: A dictionary with predictions for all components matching the requested prediction type + def _get_optimizer(self, network_parameters: NetworkParameters) -> keras.optimizers: """ + :param network_parameters: class containing the relevant optimizer parameters. + :return: an instantiation of tensorFlow optimizer with the configured parameters + """ + if network_parameters.optimizer_type == 'Adam': + optimizer = keras.optimizers.Adam( + lr=network_parameters.learning_rate, + beta_1=network_parameters.adam_optimizer_beta1, + beta_2=network_parameters.adam_optimizer_beta2, + epsilon=network_parameters.optimizer_epsilon) + + elif network_parameters.optimizer_type == 'RMSProp': + optimizer = keras.optimizers.RMSprop( + lr=network_parameters.learning_rate, + decay=network_parameters.rms_prop_optimizer_decay, + epsilon=network_parameters.optimizer_epsilon) + + elif network_parameters.optimizer_type == 'LBFGS': + raise NotImplementedError(' Could not find updated LBFGS implementation') + else: + raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type)) - ret_dict = {} - for component in self.available_return_types[prediction_type]: - ret_dict[component] = self.predict(inputs=states, outputs=component.output) - - return ret_dict + return optimizer - def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderParameters): - """ - Given an input embedder parameters class, creates the input embedder and returns it - :param input_name: the name of the input to the embedder (used for retrieving the shape). The input should - be a value within the state or the action. - :param embedder_params: the parameters of the class of the embedder - :return: the embedder instance - """ - allowed_inputs = copy.copy(self.spaces.state.sub_spaces) - allowed_inputs["action"] = copy.copy(self.spaces.action) - allowed_inputs["goal"] = copy.copy(self.spaces.goal) - - if input_name not in allowed_inputs.keys(): - raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" - .format(input_name, allowed_inputs.keys())) - - emb_type = "vector" - if isinstance(allowed_inputs[input_name], TensorObservationSpace): - emb_type = "tensor" - elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): - emb_type = "image" - - embedder_path = embedder_params.path(emb_type) - embedder_params_copy = copy.copy(embedder_params) - embedder_params_copy.is_training = self.is_training - embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) - embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] - embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] - embedder_params_copy.name = input_name - module = dynamic_import_and_instantiate_module_from_params(embedder_params_copy, - path=embedder_path, - positional_args=[allowed_inputs[input_name].shape]) - return module - - def get_middleware(self, middleware_params: MiddlewareParameters): - """ - Given a middleware type, creates the middleware and returns it - :param middleware_params: the paramaeters of the middleware class - :return: the middleware instance + def _get_loss(self, + agent_parameters: AgentParameters, + loss_params: HeadParameters, + network_name: str, + num_actions: int, + head_idx: int, + loss_weight: float) -> HeadLoss: """ - mod_name = middleware_params.parameterized_class_name - middleware_path = middleware_params.path - middleware_params_copy = copy.copy(middleware_params) - middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) - middleware_params_copy.is_training = self.is_training - module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) - return module - - def get_output_head(self, head_params: HeadParameters, head_idx: int): - """ - Given a head type, creates the head and returns it - :param head_params: the parameters of the head to create + Given a loss type, creates the loss and returns it + :param loss_params: the parameters of the loss to create :param head_idx: the head index - :return: the head + :param network_name: name of the network + :return: loss block """ - mod_name = head_params.parameterized_class_name - head_path = head_params.path - head_params_copy = copy.copy(head_params) - head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) - head_params_copy.is_training = self.is_training - return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ - 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, - 'head_idx': head_idx, 'is_local': self.network_is_local}) - - def get_model(self) -> List: - # validate the configuration - if len(self.network_parameters.input_embedders_parameters) == 0: - raise ValueError("At least one input type should be defined") - - if len(self.network_parameters.heads_parameters) == 0: - raise ValueError("At least one output type should be defined") - - if self.network_parameters.middleware_parameters is None: - raise ValueError("Exactly one middleware type should be defined") - - # ops for defining the training / testing phase - self.is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) - self.is_training_placeholder = tf.placeholder("bool") - self.assign_is_training = tf.assign(self.is_training, self.is_training_placeholder) - - for network_idx in range(self.num_networks): - with tf.variable_scope('network_{}'.format(network_idx)): - - #################### - # Input Embeddings # - #################### - - state_embedding = [] - for input_name in sorted(self.network_parameters.input_embedders_parameters): - input_type = self.network_parameters.input_embedders_parameters[input_name] - # get the class of the input embedder - input_embedder = self.get_input_embedder(input_name, input_type) - self.input_embedders.append(input_embedder) - - # input placeholders are reused between networks. on the first network, store the placeholders - # generated by the input_embedders in self.inputs. on the rest of the networks, pass - # the existing input_placeholders into the input_embedders. - if network_idx == 0: - input_placeholder, embedding = input_embedder() - self.inputs[input_name] = input_placeholder - else: - input_placeholder, embedding = input_embedder(self.inputs[input_name]) - - state_embedding.append(embedding) - - ########## - # Merger # - ########## - - if len(state_embedding) == 1: - state_embedding = state_embedding[0] - else: - if self.network_parameters.embedding_merger_type == EmbeddingMergerType.Concat: - state_embedding = tf.concat(state_embedding, axis=-1, name="merger") - elif self.network_parameters.embedding_merger_type == EmbeddingMergerType.Sum: - state_embedding = tf.add_n(state_embedding, name="merger") - - ############## - # Middleware # - ############## - - self.middleware = self.get_middleware(self.network_parameters.middleware_parameters) - _, self.state_embedding = self.middleware(state_embedding) - - ################ - # Output Heads # - ################ - - head_count = 0 - for head_idx in range(self.num_heads_per_network): - - if self.network_parameters.use_separate_networks_per_head: - # if we use separate networks per head, then the head type corresponds to the network idx - head_type_idx = network_idx - head_count = network_idx - else: - # if we use a single network with multiple embedders, then the head type is the current head idx - head_type_idx = head_idx - head_params = self.network_parameters.heads_parameters[head_type_idx] - - for head_copy_idx in range(head_params.num_output_head_copies): - # create output head and add it to the output heads list - self.output_heads.append( - self.get_output_head(head_params, - head_idx*head_params.num_output_head_copies + head_copy_idx) - ) - - # rescale the gradients from the head - self.gradients_from_head_rescalers.append( - tf.get_variable('gradients_from_head_{}-{}_rescalers'.format(head_idx, head_copy_idx), - initializer=float(head_params.rescale_gradient_from_head_by_factor), - dtype=tf.float32)) - - self.gradients_from_head_rescalers_placeholders.append( - tf.placeholder('float', - name='gradients_from_head_{}-{}_rescalers'.format(head_type_idx, head_copy_idx))) - - self.update_head_rescaler_value_ops.append(self.gradients_from_head_rescalers[head_count].assign( - self.gradients_from_head_rescalers_placeholders[head_count])) - - head_input = (1-self.gradients_from_head_rescalers[head_count]) * tf.stop_gradient(self.state_embedding) + \ - self.gradients_from_head_rescalers[head_count] * self.state_embedding - - # build the head - if self.network_is_local: - output, target_placeholder, input_placeholders, importance_weight_ph = \ - self.output_heads[-1](head_input) - - self.targets.extend(target_placeholder) - self.importance_weights.extend(importance_weight_ph) - else: - output, input_placeholders = self.output_heads[-1](head_input) - - self.outputs.extend(output) - # TODO: use head names as well - for placeholder_index, input_placeholder in enumerate(input_placeholders): - self.inputs['output_{}_{}'.format(head_type_idx, placeholder_index)] = input_placeholder - - head_count += 1 - - # model weights - if not self.distributed_training or self.network_is_global: - self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if - 'global_step' not in var.name] - else: - self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)] - - # Losses - self.losses = tf.losses.get_losses(self.full_name) - - # L2 regularization - if self.network_parameters.l2_regularization != 0: - self.l2_regularization = tf.add_n([tf.nn.l2_loss(v) for v in self.weights]) \ - * self.network_parameters.l2_regularization - self.losses += self.l2_regularization - - self.total_loss = tf.reduce_sum(self.losses) - # tf.summary.scalar('total_loss', self.total_loss) - - # Learning rate - if self.network_parameters.learning_rate_decay_rate != 0: - self.adaptive_learning_rate_scheme = \ - tf.train.exponential_decay( - self.network_parameters.learning_rate, - self.global_step, - decay_steps=self.network_parameters.learning_rate_decay_steps, - decay_rate=self.network_parameters.learning_rate_decay_rate, - staircase=True) - - self.current_learning_rate = self.adaptive_learning_rate_scheme + if agent_parameters.network_wrappers['main'].replace_mse_with_huber_loss: + loss_type = Huber else: - self.current_learning_rate = self.network_parameters.learning_rate - - # Optimizer - if self.distributed_training and self.network_is_local and self.network_parameters.shared_optimizer: - # distributed training + is a local network + optimizer shared -> take the global optimizer - self.optimizer = self.global_network.optimizer - elif (self.distributed_training and self.network_is_local and not self.network_parameters.shared_optimizer) \ - or self.network_parameters.shared_optimizer or not self.distributed_training: - # distributed training + is a global network + optimizer shared - # OR - # distributed training + is a local network + optimizer not shared - # OR - # non-distributed training - # -> create an optimizer - - if self.network_parameters.optimizer_type == 'Adam': - self.optimizer = tf.train.AdamOptimizer(learning_rate=self.current_learning_rate, - beta1=self.network_parameters.adam_optimizer_beta1, - beta2=self.network_parameters.adam_optimizer_beta2, - epsilon=self.network_parameters.optimizer_epsilon) - elif self.network_parameters.optimizer_type == 'RMSProp': - self.optimizer = tf.train.RMSPropOptimizer(self.current_learning_rate, - decay=self.network_parameters.rms_prop_optimizer_decay, - epsilon=self.network_parameters.optimizer_epsilon) - elif self.network_parameters.optimizer_type == 'LBFGS': - self.optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.total_loss, method='L-BFGS-B', - options={'maxiter': 25}) - else: - raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type)) - - return self.weights + loss_type = MeanSquaredError + + # loss_path = 'rl_coach.architectures.tensorflow_components.losses:QLoss' + # loss_path = loss_params.path + # loss = dynamic_import_and_instantiate_module_from_params(loss_params, path=loss_path) + + if isinstance(loss_params, QHeadParameters): + loss = QLoss(network_name=network_name, + head_idx=head_idx, + loss_type=loss_type, + agent_parameters=agent_parameters, + loss_weight=loss_weight) + + elif isinstance(loss_params, VHeadParameters): + loss = VLoss(network_name=network_name, + head_idx=head_idx, + loss_type=loss_type, + loss_weight=loss_weight) + + elif isinstance(loss_params, PPOHeadParameters): + loss = PPOLoss(network_name=network_name, + head_idx=head_idx, + loss_type=loss_type, + loss_weight=loss_weight, + agent_parameters=agent_parameters, + num_actions=num_actions) - def __str__(self): - result = [] - - for network in range(self.num_networks): - network_structure = [] - - # embedder - for embedder in self.input_embedders: - network_structure.append("Input Embedder: {}".format(embedder.name)) - network_structure.append(indent_string(str(embedder))) + else: + raise KeyError('Unsupported loss type: {}'.format(type(loss_params))) - if len(self.input_embedders) > 1: - network_structure.append("{} ({})".format(self.network_parameters.embedding_merger_type.name, - ", ".join(["{} embedding".format(e.name) for e in self.input_embedders]))) + return loss - # middleware - network_structure.append("Middleware:") - network_structure.append(indent_string(str(self.middleware))) + @property + def output_heads(self): + output_heads = list(map(lambda sub_model: sub_model.layers[-1], self.model.layers[1:])) + return output_heads - # head - if self.network_parameters.use_separate_networks_per_head: - heads = range(network, network+1) - else: - heads = range(0, len(self.output_heads)) - - for head_idx in heads: - head = self.output_heads[head_idx] - head_params = self.network_parameters.heads_parameters[head_idx] - if head_params.num_output_head_copies > 1: - network_structure.append("Output Head: {} (num copies = {})".format(head.name, head_params.num_output_head_copies)) - else: - network_structure.append("Output Head: {}".format(head.name)) - network_structure.append(indent_string(str(head))) - - # finalize network - if self.num_networks > 1: - result.append("Sub-network for head: {}".format(self.output_heads[network].name)) - result.append(indent_string('\n'.join(network_structure))) - else: - result.append('\n'.join(network_structure)) - result = '\n'.join(result) - return result diff --git a/rl_coach/architectures/tensorflow_components/heads/__init__.py b/rl_coach/architectures/tensorflow_components/heads/__init__.py index 03c237a84..300f4f30b 100644 --- a/rl_coach/architectures/tensorflow_components/heads/__init__.py +++ b/rl_coach/architectures/tensorflow_components/heads/__init__.py @@ -1,43 +1,13 @@ -from .q_head import QHead -from .categorical_q_head import CategoricalQHead -from .ddpg_actor_head import DDPGActor -from .dnd_q_head import DNDQHead -from .dueling_q_head import DuelingQHead -from .measurements_prediction_head import MeasurementsPredictionHead -from .naf_head import NAFHead -from .policy_head import PolicyHead -from .ppo_head import PPOHead -from .ppo_v_head import PPOVHead -from .quantile_regression_q_head import QuantileRegressionQHead -from .rainbow_q_head import RainbowQHead -from .v_head import VHead -from .acer_policy_head import ACERPolicyHead -from .sac_head import SACPolicyHead -from .sac_q_head import SACQHead -from .classification_head import ClassificationHead -from .cil_head import RegressionHead -from .td3_v_head import TD3VHead -from .ddpg_v_head import DDPGVHead +from .head import Head#, HeadLoss +#from .q_head import QHead +#from .ppo_head import PPOHead +#from .v_head import VHead + __all__ = [ - 'CategoricalQHead', - 'DDPGActor', - 'DNDQHead', - 'DuelingQHead', - 'MeasurementsPredictionHead', - 'NAFHead', - 'PolicyHead', - 'PPOHead', - 'PPOVHead', - 'QHead', - 'QuantileRegressionQHead', - 'RainbowQHead', - 'VHead', - 'ACERPolicyHead', - 'SACPolicyHead', - 'SACQHead', - 'ClassificationHead', - 'RegressionHead', - 'TD3VHead' - 'DDPGVHead' + 'Head', + #'PPOHead', + #'QHead', + #'VHead' ] + diff --git a/rl_coach/architectures/tensorflow_components/heads/head.py b/rl_coach/architectures/tensorflow_components/heads/head.py index e971889e9..5ee2f45e0 100644 --- a/rl_coach/architectures/tensorflow_components/heads/head.py +++ b/rl_coach/architectures/tensorflow_components/heads/head.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,153 +14,59 @@ # limitations under the License. # -import numpy as np -import tensorflow as tf -from tensorflow.python.ops.losses.losses_impl import Reduction -from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer_class +from tensorflow import keras from rl_coach.base_parameters import AgentParameters from rl_coach.spaces import SpacesDefinition -from rl_coach.utils import force_list -from rl_coach.architectures.tensorflow_components.utils import squeeze_tensor - -# Used to initialize weights for policy and value output layers -def normalized_columns_initializer(std=1.0): - def _initializer(shape, dtype=None, partition_info=None): - out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) - return tf.constant(out) - return _initializer - - -class Head(object): - """ - A head is the final part of the network. It takes the embedding from the middleware embedder and passes it through - a neural network to produce the output of the network. There can be multiple heads in a network, and each one has - an assigned loss function. The heads are algorithm dependent. - """ - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, - head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu', - dense_layer=Dense, is_training=False): - self.head_idx = head_idx +from rl_coach.architectures.tensorflow_components.layers import Dense, convert_layer + + +class Head(keras.layers.Layer): + def __init__(self, agent_parameters: AgentParameters, + spaces: SpacesDefinition, + network_name: str, + head_type_idx: int = 0, + loss_weight: float = 1., + is_local: bool = True, + activation_function: str = 'relu', + dense_layer: None = None): + """ + A head is the final part of the network. It takes the embedding from the middleware embedder and passes it + through a neural network to produce the output of the network. There can be multiple heads in a network, and + each one has an assigned loss function. The heads are algorithm dependent. + + :param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon + and beta_entropy. + :param spaces: containing action spaces used for defining size of network output. + :param network_name: name of head network. currently unused. + :param head_type_idx: index of head network. currently unused. + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param is_local: flag to denote if network is local. currently unused. + :param activation_function: activation function to use between layers. currently unused. + :param dense_layer: type of dense layer to use in network. currently unused. + """ + super(Head, self).__init__() + self.head_type_idx = head_type_idx self.network_name = network_name - self.network_parameters = agent_parameters.network_wrappers[self.network_name] - self.name = "head" - self.output = [] - self.loss = [] - self.loss_type = [] - self.regularizations = [] - self.loss_weight = tf.Variable([float(w) for w in force_list(loss_weight)], - trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) - self.target = [] - self.importance_weight = [] - self.input = [] + self.loss_weight = loss_weight self.is_local = is_local self.ap = agent_parameters self.spaces = spaces self.return_type = None self.activation_function = activation_function self.dense_layer = dense_layer + if self.dense_layer is None: self.dense_layer = Dense else: - self.dense_layer = convert_layer_class(self.dense_layer) - self.is_training = is_training + self.dense_layer = convert_layer(self.dense_layer) - def __call__(self, input_layer): - """ - Wrapper for building the module graph including scoping and loss creation - :param input_layer: the input to the graph - :return: the output of the last layer and the target placeholder - """ + self.num_outputs_ = None - with tf.variable_scope(self.get_name(), initializer=tf.contrib.layers.xavier_initializer()): - self._build_module(squeeze_tensor(input_layer)) + @property + def num_outputs(self): + """ Returns number of outputs that forward() call will return - self.output = force_list(self.output) - self.target = force_list(self.target) - self.input = force_list(self.input) - self.loss_type = force_list(self.loss_type) - self.loss = force_list(self.loss) - self.regularizations = force_list(self.regularizations) - if self.is_local: - self.set_loss() - self._post_build() - - if self.is_local: - return self.output, self.target, self.input, self.importance_weight - else: - return self.output, self.input - - def _build_module(self, input_layer): - """ - Builds the graph of the module - This method is called early on from __call__. It is expected to store the graph - in self.output. - :param input_layer: the input to the graph - :return: None - """ - pass - - def _post_build(self): - """ - Optional function that allows adding any extra definitions after the head has been fully defined - For example, this allows doing additional calculations that are based on the loss - :return: None + :return: """ - pass - - def get_name(self): - """ - Get a formatted name for the module - :return: the formatted name - """ - return '{}_{}'.format(self.name, self.head_idx) - - def set_loss(self): - """ - Creates a target placeholder and loss function for each loss_type and regularization - :param loss_type: a tensorflow loss function - :param scope: the name scope to include the tensors in - :return: None - """ - - # there are heads that define the loss internally, but we need to create additional placeholders for them - for idx in range(len(self.loss)): - importance_weight = tf.placeholder('float', - [None] + [1] * (len(self.target[idx].shape) - 1), - '{}_importance_weight'.format(self.get_name())) - self.importance_weight.append(importance_weight) - - # add losses and target placeholder - for idx in range(len(self.loss_type)): - # create target placeholder - target = tf.placeholder('float', self.output[idx].shape, '{}_target'.format(self.get_name())) - self.target.append(target) - - # create importance sampling weights placeholder - num_target_dims = len(self.target[idx].shape) - importance_weight = tf.placeholder('float', [None] + [1] * (num_target_dims - 1), - '{}_importance_weight'.format(self.get_name())) - self.importance_weight.append(importance_weight) - - # compute the weighted loss. importance_weight weights over the samples in the batch, while self.loss_weight - # weights the specific loss of this head against other losses in this head or in other heads - loss_weight = self.loss_weight[idx]*importance_weight - loss = self.loss_type[idx](self.target[-1], self.output[idx], - scope=self.get_name(), reduction=Reduction.NONE, loss_collection=None) - - # the loss is first summed over each sample in the batch and then the mean over the batch is taken - loss = tf.reduce_mean(loss_weight*tf.reduce_sum(loss, axis=list(range(1, num_target_dims)))) - - # we add the loss to the losses collection and later we will extract it in general_network - tf.losses.add_loss(loss) - self.loss.append(loss) - - # add regularizations - for regularization in self.regularizations: - self.loss.append(regularization) - tf.losses.add_loss(regularization) - - @classmethod - def path(cls): - return cls.__class__.__name__ + assert self.num_outputs_ is not None, 'must call forward() once to configure number of outputs' + return self.num_outputs_ diff --git a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py index 63f95a3ba..a2e7f912c 100644 --- a/rl_coach/architectures/tensorflow_components/heads/ppo_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/ppo_head.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,143 +14,126 @@ # limitations under the License. # +from typing import List, Tuple, Union + +from tensorflow import keras +from tensorflow import Tensor +from tensorflow.keras.layers import Dense, Input, Lambda import numpy as np import tensorflow as tf +import tensorflow_probability as tfp +tfd = tfp.distributions + +LOSS_OUT_TYPE_KL = 'kl_divergence' +LOSS_OUT_TYPE_ENTROPY = 'entropy' +LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio' +LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio' + +# from rl_coach.architectures.tensorflow_components.heads.head import Head +# from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters +# from rl_coach.core_types import ActionProbabilities +# from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace +# from rl_coach.spaces import SpacesDefinition +# from rl_coach.utils import eps + +# class PPOHead(Head): +# def __init__(self, +# agent_parameters: AgentParameters, +# spaces: SpacesDefinition, +# network_name: str, +# head_type_idx: int=0, +# loss_weight: float=1., +# is_local: bool=True, +# activation_function: str='tanh', +# dense_layer: None=None) -> None: +# """ +# Head block for Proximal Policy Optimization, to calculate probabilities for each action given middleware +# representation of the environment state. +# +# :param agent_parameters: containing algorithm parameters such as clip_likelihood_ratio_using_epsilon +# and beta_entropy. +# :param spaces: containing action spaces used for defining size of network output. +# :param network_name: name of head network. currently unused. +# :param head_type_idx: index of head network. currently unused. +# :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). +# :param is_local: flag to denote if network is local. currently unused. +# :param activation_function: activation function to use between layers. currently unused. +# :param dense_layer: type of dense layer to use in network. currently unused. +# """ +# super().__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, is_local, activation_function, +# dense_layer=dense_layer) +# self.return_type = ActionProbabilities +# +# self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon +# self.beta = agent_parameters.algorithm.beta_entropy +# self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization +# if self.use_kl_regularization: +# self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient +# self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence +# self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient +# else: +# self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (None, None, None) +# self._loss = [] +# +# if isinstance(self.spaces.action, BoxActionSpace): +# #self.net = ContinuousPPOHead(num_actions=self.spaces.action.shape[0]) +# head_input_dim = 64 # middleware output dim hard coded, because scheme is hard coded +# head_output_dim = self.spaces.action.shape[0] +# self.net = continuous_ppo_head(head_input_dim, head_output_dim) +# else: +# raise ValueError("Only discrete or continuous action spaces are supported for PPO.") +# +# def call(self, inputs): +# """ +# :param inputs: middleware embedding +# :return: policy parameters/probabilities +# """ +# return self.net(inputs) +# +# @property +# def kl_divergence(self): +# return self.head_type_idx, LOSS_OUT_TYPE_KL +# +# @property +# def entropy(self): +# return self.head_type_idx, LOSS_OUT_TYPE_ENTROPY +# +# @property +# def likelihood_ratio(self): +# return self.head_type_idx, LOSS_OUT_TYPE_LIKELIHOOD_RATIO +# +# @property +# def clipped_likelihood_ratio(self): +# return self.head_type_idx, LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO + + +def normalized_columns_initializer(std=1.0): + """ + Standardizes Root Sum of Squares along the input channel dimension. + Used for Dense layer weight matrices only (ie. do not use on Convolution kernels). + """ + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer + + +def continuous_ppo_head(input_dim, output_dim): + """ + Used for forward pass through Proximal Policy Optimization head Layer. + :param input_dim: middleware state representation, of shape (batch_size, in_channels). + :param output_dim: dimension of the output. + :return: probabilities distribution conditioned on the given middleware + representation of the environment state. + """ + inputs = Input(shape=([input_dim])) + policy_means = Dense(units=output_dim, name="policy_means", kernel_initializer=normalized_columns_initializer(0.01))(inputs) + policy_stds = tfp.layers.VariableLayer(shape=(1, output_dim), dtype=tf.float32)(inputs) + actions_proba = tfp.layers.DistributionLambda( + lambda t: tfd.MultivariateNormalDiag( + loc=t[0], scale_diag=tf.exp(t[1])))([policy_means, policy_stds]) + model = keras.Model(name='continuous_ppo_head', inputs=inputs, outputs=actions_proba) + + return model -from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer -from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters -from rl_coach.core_types import ActionProbabilities -from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace -from rl_coach.spaces import SpacesDefinition -from rl_coach.utils import eps - - -class PPOHead(Head): - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, - head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh', - dense_layer=Dense): - super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, - dense_layer=dense_layer) - self.name = 'ppo_head' - self.return_type = ActionProbabilities - - # used in regular PPO - self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization - if self.use_kl_regularization: - # kl coefficient and its corresponding assignment operation and placeholder - self.kl_coefficient = tf.Variable(agent_parameters.algorithm.initial_kl_coefficient, - trainable=False, name='kl_coefficient') - self.kl_coefficient_ph = tf.placeholder('float', name='kl_coefficient_ph') - self.assign_kl_coefficient = tf.assign(self.kl_coefficient, self.kl_coefficient_ph) - self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence - self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient - - self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon - self.beta = agent_parameters.algorithm.beta_entropy - - def _build_module(self, input_layer): - if isinstance(self.spaces.action, DiscreteActionSpace): - self._build_discrete_net(input_layer, self.spaces.action) - elif isinstance(self.spaces.action, BoxActionSpace): - self._build_continuous_net(input_layer, self.spaces.action) - else: - raise ValueError("only discrete or continuous action spaces are supported for PPO") - - self.action_probs_wrt_policy = self.policy_distribution.log_prob(self.actions) - self.action_probs_wrt_old_policy = self.old_policy_distribution.log_prob(self.actions) - self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) - - # Used by regular PPO only - # add kl divergence regularization - self.kl_divergence = tf.reduce_mean(tf.distributions.kl_divergence(self.old_policy_distribution, self.policy_distribution)) - - if self.use_kl_regularization: - # no clipping => use kl regularization - self.weighted_kl_divergence = tf.multiply(self.kl_coefficient, self.kl_divergence) - self.regularizations += [self.weighted_kl_divergence + self.high_kl_penalty_coefficient * \ - tf.square(tf.maximum(0.0, self.kl_divergence - self.kl_cutoff))] - - # calculate surrogate loss - self.advantages = tf.placeholder(tf.float32, [None], name="advantages") - self.target = self.advantages - # action_probs_wrt_old_policy != 0 because it is e^... - self.likelihood_ratio = tf.exp(self.action_probs_wrt_policy - self.action_probs_wrt_old_policy) - if self.clip_likelihood_ratio_using_epsilon is not None: - self.clip_param_rescaler = tf.placeholder(tf.float32, ()) - self.input.append(self.clip_param_rescaler) - max_value = 1 + self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler - min_value = 1 - self.clip_likelihood_ratio_using_epsilon * self.clip_param_rescaler - self.clipped_likelihood_ratio = tf.clip_by_value(self.likelihood_ratio, min_value, max_value) - self.scaled_advantages = tf.minimum(self.likelihood_ratio * self.advantages, - self.clipped_likelihood_ratio * self.advantages) - else: - self.scaled_advantages = self.likelihood_ratio * self.advantages - # minus sign is in order to set an objective to minimize (we actually strive for maximizing the surrogate loss) - self.surrogate_loss = -tf.reduce_mean(self.scaled_advantages) - if self.is_local: - # add entropy regularization - if self.beta: - self.entropy = tf.reduce_mean(self.policy_distribution.entropy()) - self.regularizations += [-tf.multiply(self.beta, self.entropy, name='entropy_regularization')] - - self.loss = self.surrogate_loss - tf.losses.add_loss(self.loss) - - def _build_discrete_net(self, input_layer, action_space): - num_actions = len(action_space.actions) - self.actions = tf.placeholder(tf.int32, [None], name="actions") - - self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean") - self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std") - - # Policy Head - self.input = [self.actions, self.old_policy_mean] - policy_values = self.dense_layer(num_actions)(input_layer, name='policy_fc') - self.policy_mean = tf.nn.softmax(policy_values, name="policy") - - # define the distributions for the policy and the old policy - self.policy_distribution = tf.contrib.distributions.Categorical(probs=self.policy_mean) - self.old_policy_distribution = tf.contrib.distributions.Categorical(probs=self.old_policy_mean) - - self.output = self.policy_mean - - def _build_continuous_net(self, input_layer, action_space): - num_actions = action_space.shape[0] - self.actions = tf.placeholder(tf.float32, [None, num_actions], name="actions") - - self.old_policy_mean = tf.placeholder(tf.float32, [None, num_actions], "old_policy_mean") - self.old_policy_std = tf.placeholder(tf.float32, [None, num_actions], "old_policy_std") - - self.input = [self.actions, self.old_policy_mean, self.old_policy_std] - self.policy_mean = self.dense_layer(num_actions)(input_layer, name='policy_mean', - kernel_initializer=normalized_columns_initializer(0.01)) - - # for local networks in distributed settings, we need to move variables we create manually to the - # tf.GraphKeys.LOCAL_VARIABLES collection, since the variable scope custom getter which is set in - # Architecture does not apply to them - if self.is_local and isinstance(self.ap.task_parameters, DistributedTaskParameters): - self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', - collections=[tf.GraphKeys.LOCAL_VARIABLES], name="policy_log_std") - else: - self.policy_logstd = tf.Variable(np.zeros((1, num_actions)), dtype='float32', name="policy_log_std") - - self.policy_std = tf.tile(tf.exp(self.policy_logstd), [tf.shape(input_layer)[0], 1], name='policy_std') - - # define the distributions for the policy and the old policy - self.policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.policy_mean, self.policy_std + eps) - self.old_policy_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.old_policy_mean, self.old_policy_std + eps) - - self.output = [self.policy_mean, self.policy_std] - - def __str__(self): - action_head_mean_result = [] - if isinstance(self.spaces.action, DiscreteActionSpace): - # create a discrete action network (softmax probabilities output) - action_head_mean_result.append("Dense (num outputs = {})".format(len(self.spaces.action.actions))) - action_head_mean_result.append("Softmax") - elif isinstance(self.spaces.action, BoxActionSpace): - # create a continuous action network (bounded mean and stdev outputs) - action_head_mean_result.append("Dense (num outputs = {})".format(self.spaces.action.shape)) - - return '\n'.join(action_head_mean_result) diff --git a/rl_coach/architectures/tensorflow_components/heads/q_head.py b/rl_coach/architectures/tensorflow_components/heads/q_head.py index ecc1461a0..1f3dcc008 100644 --- a/rl_coach/architectures/tensorflow_components/heads/q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/q_head.py @@ -1,5 +1,5 @@ -# -# Copyright (c) 2017 Intel Corporation + +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,53 +14,76 @@ # limitations under the License. # -import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head -from rl_coach.base_parameters import AgentParameters -from rl_coach.core_types import QActionStateValue -from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace +from tensorflow import keras +from tensorflow.keras.layers import Dense, Input, Lambda -class QHead(Head): - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, - head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', - dense_layer=Dense): - super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, - dense_layer=dense_layer) - self.name = 'q_values_head' - if isinstance(self.spaces.action, BoxActionSpace): - self.num_actions = 1 - elif isinstance(self.spaces.action, DiscreteActionSpace): - self.num_actions = len(self.spaces.action.actions) - else: - raise ValueError( - 'QHead does not support action spaces of type: {class_name}'.format( - class_name=self.spaces.action.__class__.__name__, - ) - ) - self.return_type = QActionStateValue - if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss: - self.loss_type = tf.losses.huber_loss - else: - self.loss_type = tf.losses.mean_squared_error - - def _build_module(self, input_layer): - # Standard Q Network - self.q_values = self.output = self.dense_layer(self.num_actions)(input_layer, name='output') +#from rl_coach.architectures.layers import Dense +# from rl_coach.architectures.tensorflow_components.heads.head import Head +# from rl_coach.base_parameters import AgentParameters +# from rl_coach.core_types import QActionStateValue +# from rl_coach.spaces import SpacesDefinition, BoxActionSpace, DiscreteActionSpace +# +# +# class QHead(Head): +# def __init__(self, +# agent_parameters: AgentParameters, +# spaces: SpacesDefinition, +# network_name: str, +# head_type_idx: int = 0, +# loss_weight: float = 1., +# is_local: bool = True, +# activation_function: str = 'relu', +# dense_layer: None = None) -> None: +# """ +# Q-Value Head for predicting state-action Q-Values. +# +# :param agent_parameters: containing algorithm parameters, but currently unused. +# :param spaces: containing action spaces used for defining size of network output. +# :param network_name: name of head network. currently unused. +# :param head_type_idx: index of head network. currently unused. +# :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). +# :param is_local: flag to denote if network is local. currently unused. +# :param activation_function: activation function to use between layers. currently unused. +# :param dense_layer: type of dense layer to use in network. currently unused. +# :param loss_type: loss function to use. +# """ +# super(QHead, self).__init__(agent_parameters, spaces, network_name, head_type_idx, loss_weight, +# is_local, activation_function, dense_layer) +# if isinstance(self.spaces.action, BoxActionSpace): +# self.num_actions = 1 +# elif isinstance(self.spaces.action, DiscreteActionSpace): +# self.num_actions = len(self.spaces.action.actions) +# else: +# raise ValueError( +# 'QHead does not support action spaces of type: {class_name}'.format( +# class_name=self.spaces.action.__class__.__name__, +# ) +# ) +# self.return_type = QActionStateValue +# self.dense = Dense(units=self.num_actions) +# +# def call(self, inputs): +# """ +# Used for forward pass through Q-Value head network. +# +# :param inputs: middleware state representation, of shape (batch_size, in_channels). +# :return: predicted state-action q-values, of shape (batch_size, num_actions). +# """ +# q_value = self.dense(inputs) +# return q_value - # used in batch-rl to estimate a probablity distribution over actions - self.softmax = self.add_softmax_with_temperature() - def __str__(self): - result = [ - "Dense (num outputs = {})".format(self.num_actions) - ] - return '\n'.join(result) - def add_softmax_with_temperature(self): - temperature = self.ap.network_wrappers[self.network_name].softmax_temperature - temperature_scaled_outputs = self.q_values / temperature - return tf.nn.softmax(temperature_scaled_outputs, name="softmax") +def q_value_head(input_dim, output_dim): + """ + Used for forward pass through Value head network. + :param input_dim: middleware state representation, of shape (batch_size, in_channels). + :param output_dim: state value representation, of shape (1, in_channels). + :return: predicted svalues, of shape (batch_size, num_actions). + """ + inputs = Input(shape=([input_dim])) + value = Dense(units=output_dim, name="q_value_output")(inputs) + model = keras.Model(name='value_head', inputs=inputs, outputs=value) + return model diff --git a/rl_coach/architectures/tensorflow_components/heads/v_head.py b/rl_coach/architectures/tensorflow_components/heads/v_head.py index 62bfba03b..b5c234ac5 100644 --- a/rl_coach/architectures/tensorflow_components/heads/v_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/v_head.py @@ -1,5 +1,5 @@ -# -# Copyright (c) 2017 Intel Corporation + # +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,41 +14,67 @@ # limitations under the License. # -import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.layers import Dense, Input, Lambda -from rl_coach.architectures.tensorflow_components.layers import Dense -from rl_coach.architectures.tensorflow_components.heads.head import Head, normalized_columns_initializer +from rl_coach.architectures.tensorflow_components.heads.head import Head from rl_coach.base_parameters import AgentParameters from rl_coach.core_types import VStateValue from rl_coach.spaces import SpacesDefinition +# +# class VHead(Head): +# def __init__(self, +# agent_parameters: AgentParameters, +# spaces: SpacesDefinition, +# network_name: str, +# head_type_idx: int=0, +# loss_weight: float=1., +# is_local: bool=True, +# activation_function: str='relu', +# dense_layer: None=None): +# """ +# Value Head for predicting state values. +# :param agent_parameters: containing algorithm parameters, but currently unused. +# :param spaces: containing action spaces, but currently unused. +# :param network_name: name of head network. currently unused. +# :param head_type_idx: index of head network. currently unused. +# :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). +# :param is_local: flag to denote if network is local. currently unused. +# :param activation_function: activation function to use between layers. currently unused. +# :param dense_layer: type of dense layer to use in network. currently unused. +# """ +# super(VHead, self).__init__(agent_parameters, +# spaces, +# network_name, +# head_type_idx, +# loss_weight, +# is_local, +# activation_function, +# dense_layer) +# +# self.return_type = VStateValue +# self.dense = keras.layers.Dense(units=1) +# +# def call(self, inputs): +# """ +# Used for forward pass through Q-Value head network. +# +# :param inputs: middleware state representation, of shape (batch_size, in_channels). +# :return: predicted state-action q-values, of shape (batch_size, num_actions). +# """ +# value = self.dense(inputs) +# return value -class VHead(Head): - def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, - head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='relu', - dense_layer=Dense, initializer='normalized_columns'): - super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, - dense_layer=dense_layer) - self.name = 'v_values_head' - self.return_type = VStateValue - - if agent_parameters.network_wrappers[self.network_name.split('/')[0]].replace_mse_with_huber_loss: - self.loss_type = tf.losses.huber_loss - else: - self.loss_type = tf.losses.mean_squared_error - - self.initializer = initializer - - def _build_module(self, input_layer): - # Standard V Network - if self.initializer == 'normalized_columns': - self.output = self.dense_layer(1)(input_layer, name='output', - kernel_initializer=normalized_columns_initializer(1.0)) - elif self.initializer == 'xavier' or self.initializer is None: - self.output = self.dense_layer(1)(input_layer, name='output') - def __str__(self): - result = [ - "Dense (num outputs = 1)" - ] - return '\n'.join(result) +def value_head(input_dim, output_dim): + """ + Used for forward pass through Value head network. + :param input_dim: middleware state representation, of shape (batch_size, in_channels). + :param output_dim: state value representation, of shape (1, in_channels). + :return: predicted svalues, of shape (batch_size, num_actions). + """ + inputs = Input(shape=([input_dim])) + value = Dense(units=output_dim, name="value_output")(inputs) + model = keras.Model(name='value_head', inputs=inputs, outputs=value) + return model diff --git a/rl_coach/architectures/tensorflow_components/layers.py b/rl_coach/architectures/tensorflow_components/layers.py index eb6326234..196354966 100644 --- a/rl_coach/architectures/tensorflow_components/layers.py +++ b/rl_coach/architectures/tensorflow_components/layers.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,52 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import math from types import FunctionType import tensorflow as tf - from rl_coach.architectures import layers from rl_coach.architectures.tensorflow_components import utils - -def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): - layers = [input_layer] - - # Rationale: passing a bool here will mean that batchnorm and or activation will never activate - assert not isinstance(is_training, bool) - - # batchnorm - if batchnorm: - layers.append( - tf.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training) - ) - - # activation - if activation_function: - if isinstance(activation_function, str): - activation_function = utils.get_activation_function(activation_function) - layers.append( - activation_function(layers[-1], name="{}_activation".format(name)) - ) - - # dropout - if dropout_rate > 0: - layers.append( - tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training) - ) - - # remove the input layer from the layers list - del layers[0] - - return layers - - # define global dictionary for storing layer type to layer implementation mapping tf_layer_dict = dict() -tf_layer_class_dict = dict() def reg_to_tf_instance(layer_type) -> FunctionType: @@ -71,17 +34,6 @@ def reg_impl_decorator(func): return reg_impl_decorator -def reg_to_tf_class(layer_type) -> FunctionType: - """ function decorator that registers layer type - :return: decorated function - """ - def reg_impl_decorator(func): - assert layer_type not in tf_layer_class_dict - tf_layer_class_dict[layer_type] = func - return func - return reg_impl_decorator - - def convert_layer(layer): """ If layer instance is callable (meaning this is already a concrete TF class), return layer, otherwise convert to TF type @@ -93,31 +45,36 @@ def convert_layer(layer): return tf_layer_dict[type(layer)](layer) -def convert_layer_class(layer_class): - """ - If layer instance is callable, return layer, otherwise convert to TF type - :param layer: layer to be converted - :return: converted layer if not callable, otherwise layer itself - """ - if hasattr(layer_class, 'to_tf_instance'): - return layer_class - else: - return tf_layer_class_dict[layer_class]() +class Dense(layers.Dense): + def __init__(self, units: int): + super(Dense, self).__init__(units=units) + + def __call__(self): + """ + returns a tensorflow dense layer + :return: dense layer + """ + return tf.keras.layers.Dense(self.units) + + @staticmethod + @reg_to_tf_instance(layers.Dense) + def to_tf_instance(base: layers.Dense): + return Dense(units=base.units)() class Conv2d(layers.Conv2d): def __init__(self, num_filters: int, kernel_size: int, strides: int): super(Conv2d, self).__init__(num_filters=num_filters, kernel_size=kernel_size, strides=strides) - def __call__(self, input_layer, name: str=None, is_training=None): + def __call__(self): """ returns a tensorflow conv2d layer :param input_layer: previous layer :param name: layer name :return: conv2d layer """ - return tf.layers.conv2d(input_layer, filters=self.num_filters, kernel_size=self.kernel_size, - strides=self.strides, data_format='channels_last', name=name) + return tf.keras.layers.Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, strides=self.strides) + #, data_format='channels_last') @staticmethod @reg_to_tf_instance(layers.Conv2d) @@ -125,12 +82,7 @@ def to_tf_instance(base: layers.Conv2d): return Conv2d( num_filters=base.num_filters, kernel_size=base.kernel_size, - strides=base.strides) - - @staticmethod - @reg_to_tf_class(layers.Conv2d) - def to_tf_class(): - return Conv2d + strides=base.strides)() class BatchnormActivationDropout(layers.BatchnormActivationDropout): @@ -158,103 +110,40 @@ def to_tf_instance(base: layers.BatchnormActivationDropout): activation_function=base.activation_function, dropout_rate=base.dropout_rate) - @staticmethod - @reg_to_tf_class(layers.BatchnormActivationDropout) - def to_tf_class(): - return BatchnormActivationDropout +def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): + # TODO - tf2: remove tf1 compatibility code -class Dense(layers.Dense): - def __init__(self, units: int): - super(Dense, self).__init__(units=units) + layers = [input_layer] - def __call__(self, input_layer, name: str=None, kernel_initializer=None, activation=None, is_training=None): - """ - returns a tensorflow dense layer - :param input_layer: previous layer - :param name: layer name - :return: dense layer - """ - return tf.layers.dense(input_layer, self.units, name=name, kernel_initializer=kernel_initializer, - activation=activation) + # Rationale: passing a bool here will mean that batchnorm and or activation will never activate + assert not isinstance(is_training, bool) - @staticmethod - @reg_to_tf_instance(layers.Dense) - def to_tf_instance(base: layers.Dense): - return Dense(units=base.units) + # batchnorm + if batchnorm: - @staticmethod - @reg_to_tf_class(layers.Dense) - def to_tf_class(): - return Dense + layers.append( + tf.compat.v1.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training) + ) + # activation + if activation_function: + if isinstance(activation_function, str): + activation_function = utils.get_activation_function(activation_function) + layers.append( + activation_function(layers[-1], name="{}_activation".format(name)) + ) -class NoisyNetDense(layers.NoisyNetDense): - """ - A factorized Noisy Net layer + # dropout + if dropout_rate > 0: + layers.append( + tf.compat.v1.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training) + ) - https://arxiv.org/abs/1706.10295. - """ + # remove the input layer from the layers list + del layers[0] - def __init__(self, units: int): - super(NoisyNetDense, self).__init__(units=units) + return layers - def __call__(self, input_layer, name: str, kernel_initializer=None, activation=None, is_training=None): - """ - returns a NoisyNet dense layer - :param input_layer: previous layer - :param name: layer name - :param kernel_initializer: initializer for kernels. Default is to use Gaussian noise that preserves stddev. - :param activation: the activation function - :return: dense layer - """ - #TODO: noise sampling should be externally controlled. DQN is fine with sampling noise for every - # forward (either act or train, both for online and target networks). - # A3C, on the other hand, should sample noise only when policy changes (i.e. after every t_max steps) - - def _f(values): - return tf.sqrt(tf.abs(values)) * tf.sign(values) - - def _factorized_noise(inputs, outputs): - # TODO: use factorized noise only for compute intensive algos (e.g. DQN). - # lighter algos (e.g. DQN) should not use it - noise1 = _f(tf.random_normal((inputs, 1))) - noise2 = _f(tf.random_normal((1, outputs))) - return tf.matmul(noise1, noise2) - - num_inputs = input_layer.get_shape()[-1].value - num_outputs = self.units - - stddev = 1 / math.sqrt(num_inputs) - activation = activation if activation is not None else (lambda x: x) - - if kernel_initializer is None: - kernel_mean_initializer = tf.random_uniform_initializer(-stddev, stddev) - kernel_stddev_initializer = tf.random_uniform_initializer(-stddev * self.sigma0, stddev * self.sigma0) - else: - kernel_mean_initializer = kernel_stddev_initializer = kernel_initializer - with tf.variable_scope(None, default_name=name): - weight_mean = tf.get_variable('weight_mean', shape=(num_inputs, num_outputs), - initializer=kernel_mean_initializer) - bias_mean = tf.get_variable('bias_mean', shape=(num_outputs,), initializer=tf.zeros_initializer()) - - weight_stddev = tf.get_variable('weight_stddev', shape=(num_inputs, num_outputs), - initializer=kernel_stddev_initializer) - bias_stddev = tf.get_variable('bias_stddev', shape=(num_outputs,), - initializer=kernel_stddev_initializer) - bias_noise = _f(tf.random_normal((num_outputs,))) - weight_noise = _factorized_noise(num_inputs, num_outputs) - - bias = bias_mean + bias_stddev * bias_noise - weight = weight_mean + weight_stddev * weight_noise - return activation(tf.matmul(input_layer, weight) + bias) - @staticmethod - @reg_to_tf_instance(layers.NoisyNetDense) - def to_tf_instance(base: layers.NoisyNetDense): - return NoisyNetDense(units=base.units) - @staticmethod - @reg_to_tf_class(layers.NoisyNetDense) - def to_tf_class(): - return NoisyNetDense diff --git a/rl_coach/architectures/tensorflow_components/losses/__init__.py b/rl_coach/architectures/tensorflow_components/losses/__init__.py new file mode 100644 index 000000000..900d272f4 --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/losses/__init__.py @@ -0,0 +1,3 @@ +# from .loss import HeadLoss +from .q_loss import QLoss +from .v_loss import VLoss \ No newline at end of file diff --git a/rl_coach/architectures/tensorflow_components/losses/head_loss.py b/rl_coach/architectures/tensorflow_components/losses/head_loss.py new file mode 100644 index 000000000..bd6ffb82a --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/losses/head_loss.py @@ -0,0 +1,102 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from tensorflow import keras +from typing import Dict, List, Tuple +from tensorflow import Tensor +import numpy as np + +LOSS_OUT_TYPE_LOSS = 'loss' +LOSS_OUT_TYPE_REGULARIZATION = 'regularization' + + +class LossInputSchema(object): + """ + Helper class to contain schema for loss input + """ + def __init__(self, model_outputs: List[str], non_trainable_args: List[str]): + """ + :param model_outputs: list of argument names in call that are the outputs of the head. + The order and number MUST MATCH the output from the head. + :param non_trainable_args: list of argument names that are inputs from the agent. + The order and number MUST MATCH the loss_forward call for this head. + + """ + self._model_outputs = model_outputs + self._non_trainable_args = non_trainable_args + + @property + def model_outputs(self): + return self._model_outputs + + @property + def non_trainable_args(self): + return self._non_trainable_args + + +class HeadLoss(keras.layers.Layer): + """ + ABC for loss functions of each agent. Child class must implement input_schema() and loss_forward() + """ + + def __init__(self, *args, **kwargs): + super(HeadLoss, self).__init__(*args, **kwargs) + + @property + def input_schema(self) -> LossInputSchema: + """ + :return: schema for input of loss forward. Read docstring for LossInputSchema for details. + """ + raise NotImplementedError + + def call(self, model_outputs: List[Tensor], non_trainable_args: List[np.ndarray]) -> List[np.ndarray]: + """ + Extracts and aligns loss arguments and Passes the cal to loss_forward() + :param model_outputs: list of all trainable model_outputs for this loss + :param non_trainable_args: list of all non trainable args for this loss + :return: list of arguments in containing the loss values regularization values and additional fetches. + """ + loss_args = self.align_loss_args(model_outputs, non_trainable_args) + return self.loss_forward(*loss_args) + + def loss_forward(self, *args, **kwargs): + """ + Needs to be implemented by each child class + """ + raise NotImplementedError + + def align_loss_args(self, + model_outputs: List[Tensor], + non_trainable_args: List[np.ndarray]) -> List[np.ndarray]: + """ + Creates a list of arguments from model_outputs and non_trainable_args aligned with parameters of + loss.loss_forward() based on their name in loss input_schema. + :param model_outputs: list of all trainable model_outputs for this loss + :param non_trainable_args: list of all non trainable args for this loss + :return: list of arguments in correct order to be passed to loss + """ + arg_list = list() + schema = self.input_schema + assert len(schema.model_outputs) == len(model_outputs) + assert len(schema.non_trainable_args) == len(non_trainable_args) + + arg_list.extend(model_outputs) + arg_list.extend(non_trainable_args) + return arg_list + + @classmethod + def path(cls): + return cls.__class__.__name__ \ No newline at end of file diff --git a/rl_coach/architectures/tensorflow_components/losses/ppo_loss.py b/rl_coach/architectures/tensorflow_components/losses/ppo_loss.py new file mode 100644 index 000000000..6b1957709 --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/losses/ppo_loss.py @@ -0,0 +1,155 @@ +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +import tensorflow_probability as tfp +from tensorflow import Tensor +from tensorflow.keras.losses import Loss +from typing import Dict + +from rl_coach.base_parameters import AgentParameters +from rl_coach.architectures.tensorflow_components.losses.head_loss import HeadLoss, LossInputSchema,\ + LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION + + +tfd = tfp.distributions +LOSS_OUT_TYPE_KL = 'kl_divergence' +LOSS_OUT_TYPE_ENTROPY = 'entropy' +LOSS_OUT_TYPE_LIKELIHOOD_RATIO = 'likelihood_ratio' +LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO = 'clipped_likelihood_ratio' + + +class PPOLoss(HeadLoss): + def __init__(self, + network_name: str, + head_idx: int, + agent_parameters: AgentParameters, + num_actions: int, + loss_type: Loss, + loss_weight: float = 1.): + + """ + Loss for continuous version of Clipped PPO. + + :param num_actions: number of actions in action space. + :param clip_likelihood_ratio_using_epsilon: epsilon to use for likelihood ratio clipping. + :param beta: loss coefficient applied to entropy + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + :param use_kl_regularization: option to add kl divergence loss + :param initial_kl_coefficient: initial loss coefficient applied kl divergence loss (also see high_kl_penalty_coefficient). + :param kl_cutoff: threshold for using high_kl_penalty_coefficient + :param high_kl_penalty_coefficient: loss coefficient applied to kv divergence above kl_cutoff + :param weight: scalar used to adjust relative weight of loss (if using this loss with others). + :param batch_axis: axis used for mini-batch (default is 0) and excluded from loss aggregation. + """ + super(PPOLoss, self).__init__(name=network_name) + self.head_idx = head_idx + self.weight = loss_weight + self.num_actions = num_actions + self.clip_likelihood_ratio_using_epsilon = agent_parameters.algorithm.clip_likelihood_ratio_using_epsilon + self.beta = agent_parameters.algorithm.beta_entropy + self.use_kl_regularization = agent_parameters.algorithm.use_kl_regularization + + if self.use_kl_regularization: + self.initial_kl_coefficient = agent_parameters.algorithm.initial_kl_coefficient + self.kl_cutoff = 2 * agent_parameters.algorithm.target_kl_divergence + self.high_kl_penalty_coefficient = agent_parameters.algorithm.high_kl_penalty_coefficient + else: + self.initial_kl_coefficient, self.kl_cutoff, self.high_kl_penalty_coefficient = (0.0, None, None) + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + model_outputs=['new_policy_distribution'], + non_trainable_args=['actions', 'old_policy_means', 'old_policy_stds', 'clip_param_rescaler', 'advantages'] + ) + + def loss_forward(self, + new_policy_distribution, + actions, + old_policy_means, + old_policy_stds, + clip_param_rescaler, + advantages) -> Dict[str, Tensor]: + """ + Used for forward pass through loss computations. + Works with batches of data, and optionally time_steps, but be consistent in usage: i.e. if using time_step, + new_policy_means, old_policy_means, actions and advantages all must include a time_step dimension. + :param new_policy_distribution: Used for calculation of the log probability of the actions, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + + :param actions: true actions taken during rollout, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param old_policy_means: action means for previous policy, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param old_policy_stds: action standard deviation returned by head previously, + of shape (batch_size, num_actions) or + of shape (batch_size, time_step, num_actions). + :param clip_param_rescaler: scales epsilon to use for likelihood ratio clipping. + :param advantages: change in state value after taking action (a.k.a advantage) + of shape (batch_size,) or + of shape (batch_size, time_step). + :return: loss, of shape (batch_size). + """ + old_policy_dist = tfd.MultivariateNormalDiag(loc=old_policy_means, scale_diag=old_policy_stds) + action_probs_wrt_old_policy = old_policy_dist.log_prob(actions) + + action_probs_wrt_new_policy = new_policy_distribution.log_prob(actions) + + entropy_loss = - self.beta * tf.reduce_mean(new_policy_distribution.entropy()) + + assert self.use_kl_regularization == False # Not supported yet + kl_div_loss = tf.constant(0, dtype=tf.float32) + # working with log probs, so minus first, then exponential (same as division) + likelihood_ratio = tf.exp(action_probs_wrt_new_policy - action_probs_wrt_old_policy) + # Added when changed to functional + # advantages = np.float32(advantages).reshape(likelihood_ratio.shape) + + if self.clip_likelihood_ratio_using_epsilon is not None: + # clipping of likelihood ratio + min_value = 1 - self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + max_value = 1 + self.clip_likelihood_ratio_using_epsilon * clip_param_rescaler + + clipped_likelihood_ratio = tf.clip_by_value(likelihood_ratio, min_value, max_value) + # lower bound of original, and clipped versions or each scaled advantage + # element-wise min between the two arrays + unclipped_scaled_advantages = likelihood_ratio * advantages + clipped_scaled_advantages = clipped_likelihood_ratio * advantages + scaled_advantages = tf.minimum(unclipped_scaled_advantages, clipped_scaled_advantages) + + else: + scaled_advantages = likelihood_ratio * advantages + clipped_likelihood_ratio = tf.zeros_like(likelihood_ratio) + + # for each batch, calculate expectation of scaled_advantages across time steps, + # but want code to work with data without time step too, so reshape to add timestep if doesn't exist. + expected_scaled_advantages = tf.reduce_mean(scaled_advantages) + # want to maximize expected_scaled_advantages, add minus so can minimize. + surrogate_loss = -expected_scaled_advantages * self.weight + + + return { + LOSS_OUT_TYPE_LOSS: [surrogate_loss], + LOSS_OUT_TYPE_REGULARIZATION: [(entropy_loss + kl_div_loss)], + LOSS_OUT_TYPE_KL: kl_div_loss, + LOSS_OUT_TYPE_ENTROPY: [entropy_loss], + LOSS_OUT_TYPE_LIKELIHOOD_RATIO: [likelihood_ratio], + LOSS_OUT_TYPE_CLIPPED_LIKELIHOOD_RATIO: [clipped_likelihood_ratio], + } + + diff --git a/rl_coach/architectures/tensorflow_components/losses/q_loss.py b/rl_coach/architectures/tensorflow_components/losses/q_loss.py new file mode 100644 index 000000000..d7d2647db --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/losses/q_loss.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.losses import Loss, Huber, MeanSquaredError + +from rl_coach.architectures.tensorflow_components.losses.head_loss import HeadLoss, LossInputSchema +from rl_coach.architectures.tensorflow_components.losses.head_loss import LOSS_OUT_TYPE_LOSS +from rl_coach.base_parameters import AgentParameters + +class QLoss(HeadLoss): + def __init__(self, + network_name: str, + head_idx: int, + agent_parameters: AgentParameters, + loss_type: Loss = MeanSquaredError, + loss_weight: float=1.): + """ + Loss for Q-Value Head. + :param head_idx: the index of the corresponding head. + :param loss_type: loss function with default of mean squared error (i.e. L2Loss). + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + """ + super(QLoss, self).__init__(name=network_name) + self.head_idx = head_idx + assert (loss_type == MeanSquaredError) or (loss_type == Huber), "Only expecting L2Loss or HuberLoss." + #self.loss_fn = keras.losses.mean_squared_error + self.loss_fn = keras.losses.get(loss_type)() + # sample_weight can be used like https://github.com/keras-team/keras/blob/master/keras/losses.py + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + model_outputs=['q_value_pred'], + non_trainable_args=['target'] + ) + + def loss_forward(self, q_value_pred, target): + """ + Used for forward pass through loss computations. + :param q_value_pred: state-action q-values predicted by QHead network, of shape (batch_size, num_actions). + :param target: actual state-action q-values, of shape (batch_size, num_actions). + :return: loss, of shape (batch_size). + """ + # TODO: preferable to return a tensor containing one loss per instance, rather than returning the mean loss. + # This way, Keras can apply class weights or sample weights when requested. + loss = tf.reduce_mean(self.loss_fn(q_value_pred, target)) + return {LOSS_OUT_TYPE_LOSS: [loss]} diff --git a/rl_coach/architectures/tensorflow_components/losses/v_loss.py b/rl_coach/architectures/tensorflow_components/losses/v_loss.py new file mode 100644 index 000000000..af02f2eec --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/losses/v_loss.py @@ -0,0 +1,56 @@ +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.losses import Loss, Huber, MeanSquaredError +from rl_coach.architectures.tensorflow_components.losses.head_loss import HeadLoss, LossInputSchema, LOSS_OUT_TYPE_LOSS + + +class VLoss(HeadLoss): + + def __init__(self, + network_name, + head_idx: int = 0, + loss_type: Loss = MeanSquaredError, + loss_weight: float=1.): + """ + Loss for Value Head. + :param head_idx: the index of the corresponding head. + :param loss_type: loss function with default of mean squared error (i.e. L2Loss). + :param loss_weight: scalar used to adjust relative weight of loss (if using this loss with others). + """ + super(VLoss, self).__init__(name=network_name) + self.head_idx = head_idx + assert (loss_type == MeanSquaredError) or (loss_type == Huber), "Only expecting L2Loss or HuberLoss." + self.loss_fn = keras.losses.get(loss_type)() + + @property + def input_schema(self) -> LossInputSchema: + return LossInputSchema( + model_outputs=['value_prediction'], + non_trainable_args=['target'] + ) + + def loss_forward(self, value_prediction, target): + """ + Used for forward pass through loss computations. + :param value_prediction: state values predicted by VHead network, of shape (batch_size). + :param target: actual state values, of shape (batch_size). + :return: loss, of shape (batch_size). + """ + loss = self.loss_fn(value_prediction, target) + return {LOSS_OUT_TYPE_LOSS: [loss]} diff --git a/rl_coach/architectures/tensorflow_components/middlewares/__init__.py b/rl_coach/architectures/tensorflow_components/middlewares/__init__.py index 481eab0bf..ecef74100 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/__init__.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/__init__.py @@ -1,4 +1,4 @@ from .fc_middleware import FCMiddleware -from .lstm_middleware import LSTMMiddleware +#from .lstm_middleware import LSTMMiddleware -__all__ = ["FCMiddleware", "LSTMMiddleware"] +__all__ = ["FCMiddleware"]#, "LSTMMiddleware"] diff --git a/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py index 4361e171d..e89f8b9a1 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,67 +13,72 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Union, List +from typing import Dict import tensorflow as tf - -from rl_coach.architectures.tensorflow_components.layers import Dense +from rl_coach.architectures.layers import Dense from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware from rl_coach.base_parameters import MiddlewareScheme from rl_coach.core_types import Middleware_FC_Embedding -from rl_coach.utils import force_list + +""" +Module that defines the fully-connected middleware class +""" class FCMiddleware(Middleware): - def __init__(self, activation_function=tf.nn.relu, + """ + FCMiddleware or Fully-Connected Middleware can be used in the middle part of the network. It takes the + embeddings from the input embedders, after they were aggregated in some method (for example, concatenation) + and passes it through a neural network which can be customizable but shared between the heads of the network. + + :param params: parameters object containing batchnorm, activation_function and dropout properties. + """ + def __init__(self, + activation_function=tf.nn.relu, scheme: MiddlewareScheme = MiddlewareScheme.Medium, - batchnorm: bool = False, dropout_rate: float = 0.0, - name="middleware_fc_embedder", dense_layer=Dense, is_training=False, num_streams: int = 1): + batchnorm: bool = False, + dropout_rate: float = 0.0, + name="middleware_fc_embedder", + is_training=False, + num_streams: int = 1): super().__init__(activation_function=activation_function, batchnorm=batchnorm, - dropout_rate=dropout_rate, scheme=scheme, name=name, dense_layer=dense_layer, + dropout_rate=dropout_rate, scheme=scheme, name=name, is_training=is_training) self.return_type = Middleware_FC_Embedding assert(isinstance(num_streams, int) and num_streams >= 1) self.num_streams = num_streams - def _build_module(self): - self.output = [] - - for stream_idx in range(self.num_streams): - layers = [self.input] - - for idx, layer_params in enumerate(self.layers_params): - layers.extend(force_list( - layer_params(layers[-1], name='{}_{}'.format(layer_params.__class__.__name__, - idx + stream_idx * len(self.layers_params)), - is_training=self.is_training) - )) - self.output.append((layers[-1])) - @property - def schemes(self): + def schemes(self) -> Dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used for the + Middleware. Are used to create Block when FCMiddleware is initialised. + + :return: dictionary of schemes, with key of type MiddlewareScheme enum and value being list of Tensorflow layers. + """ return { MiddlewareScheme.Empty: [], - # ppo + # Use for PPO MiddlewareScheme.Shallow: [ - self.dense_layer(64) + Dense(64), ], - # dqn + # Use for DQN MiddlewareScheme.Medium: [ - self.dense_layer(512) + Dense(512), ], MiddlewareScheme.Deep: \ [ - self.dense_layer(128), - self.dense_layer(128), - self.dense_layer(128) + Dense(128), + Dense(128), + Dense(128) ] } diff --git a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py index 64c578fc1..d71d303c5 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/middleware.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,79 +13,68 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import copy -from typing import Union, Tuple -import tensorflow as tf +from typing import Tuple +import tensorflow as tf +from tensorflow import keras from rl_coach.architectures.tensorflow_components.layers import BatchnormActivationDropout, convert_layer, Dense -from rl_coach.base_parameters import MiddlewareScheme, NetworkComponentParameters +from rl_coach.base_parameters import MiddlewareScheme from rl_coach.core_types import MiddlewareEmbedding -class Middleware(object): +class Middleware(keras.layers.Layer): """ A middleware embedder is the middle part of the network. It takes the embeddings from the input embedders, after they were aggregated in some method (for example, concatenation) and passes it through a neural network which can be customizable but shared between the heads of the network """ - def __init__(self, activation_function=tf.nn.relu, + def __init__(self, + activation_function=tf.nn.relu, scheme: MiddlewareScheme = MiddlewareScheme.Medium, - batchnorm: bool = False, dropout_rate: float = 0.0, name="middleware_embedder", dense_layer=Dense, + batchnorm: bool = False, + dropout_rate: float = 0.0, + name="middleware_embedder", + dense_layer=Dense, is_training=False): - self.name = name - self.input = None - self.output = None - self.activation_function = activation_function - self.batchnorm = batchnorm - self.dropout_rate = dropout_rate - self.scheme = scheme + super(Middleware, self).__init__(name=name) + + self.middleware_layers = [] self.return_type = MiddlewareEmbedding - self.dense_layer = dense_layer - if self.dense_layer is None: - self.dense_layer = Dense self.is_training = is_training - # layers order is conv -> batchnorm -> activation -> dropout - if isinstance(self.scheme, MiddlewareScheme): - self.layers_params = copy.copy(self.schemes[self.scheme]) - self.layers_params = [convert_layer(l) for l in self.layers_params] - else: - # if scheme is specified directly, convert to TF layer if it's not a callable object - # NOTE: if layer object is callable, it must return a TF tensor when invoked - self.layers_params = [convert_layer(l) for l in copy.copy(self.scheme)] + # self.dense_layer = dense_layer + # if self.dense_layer is None: + # self.dense_layer = Dense - # we allow adding batchnorm, dropout or activation functions after each layer. - # The motivation is to simplify the transition between a network with batchnorm and a network without - # batchnorm to a single flag (the same applies to activation function and dropout) - if self.batchnorm or self.activation_function or self.dropout_rate > 0: - for layer_idx in reversed(range(len(self.layers_params))): - self.layers_params.insert(layer_idx+1, - BatchnormActivationDropout(batchnorm=self.batchnorm, - activation_function=self.activation_function, - dropout_rate=self.dropout_rate)) + if isinstance(scheme, MiddlewareScheme): + layers = self.schemes[scheme] + else: + layers = scheme - def __call__(self, input_layer: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - """ - Wrapper for building the module graph including scoping and loss creation - :param input_layer: the input to the graph - :return: the input placeholder and the output of the last layer - """ - with tf.variable_scope(self.get_name()): - self.input = input_layer - self._build_module() + # Convert layer to TensorFlow layer + layers = [convert_layer(l) for l in layers] - return self.input, self.output + for layer in layers: + self.middleware_layers.extend([layer]) + if batchnorm: + self.middleware_layers.extend([keras.layers.BatchNormalization()]) + if activation_function: + self.middleware_layers.extend([keras.activations.get(activation_function)]) + if dropout_rate: + self.middleware_layers.extend([keras.layers.Dropout(rate=dropout_rate)]) - def _build_module(self) -> None: + def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """ - Builds the graph of the module - This method is called early on from __call__. It is expected to store the graph - in self.output. - :param input_layer: the input to the graph - :return: None + Used for forward pass through middleware network. + + :param inputs: state embedding, of shape (batch_size, in_channels). + :return: state middleware embedding, where shape is (batch_size, channels). """ - pass + x = inputs + for layer in self.middleware_layers: + x = layer(x) + return x def get_name(self) -> str: """ diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 531c5236a..d64c8127c 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,36 +14,22 @@ # limitations under the License. # -import pickle -from typing import Any, List, Dict - -import tensorflow as tf -import numpy as np - +from typing import Any, List from rl_coach.saver import Saver -class GlobalVariableSaver(Saver): - def __init__(self, name): - self._names = [name] - # if graph is finalized, savers must have already already been added. This happens - # in the case of a MonitoredSession - self._variables = tf.global_variables() - - # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list - # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. - self._variables = [v for v in self._variables if "/target" not in v.name] - - # Using a placeholder to update the variable during restore to avoid memory leak. - # Ref: https://github.com/tensorflow/tensorflow/issues/4151 - self._variable_placeholders = [] - self._variable_update_ops = [] - for v in self._variables: - variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape()) - self._variable_placeholders.append(variable_placeholder) - self._variable_update_ops.append(v.assign(variable_placeholder)) +class TfSaver(Saver): + """ + child class that implements saver for saving tensorflow DNN model. + """ - self._saver = tf.train.Saver(self._variables, max_to_keep=None) + def __init__(self, + name: str, + model): + self._name = name + self.model = model + #self.model._set_inputs(inputs) + self._weights_dict = model.get_weights() @property def path(self): @@ -55,72 +41,24 @@ def path(self): def save(self, sess: None, save_path: str) -> List[str]: """ Save to save_path - :param sess: active session + :param sess: active session for session-based frameworks (TF1 legacy). Must be Non :param save_path: full path to save checkpoint (typically directory plus checkpoint prefix plus self.path) :return: list of all saved paths """ - save_path = self._saver.save(sess, save_path) - return [save_path] - - def to_arrays(self, session: Any) -> Dict[str, np.ndarray]: - """ - Save to dictionary of arrays - :param sess: active session - :return: dictionary of arrays - """ - return { - k.name.split(":")[0]: v for k, v in zip(self._variables, session.run(self._variables)) - } - - def from_arrays(self, session: Any, tensors: Any): - """ - Restore from restore_path - :param sess: active session for session-based frameworks (e.g. TF) - :param tensors: {name: array} - """ - # if variable was saved using global network, re-map it to online - # network - # TODO: Can this be more generic so that `global/` and `online/` are not - # hardcoded here? - if isinstance(tensors, dict): - tensors = tensors.items() - - variables = {k.replace("global/", "online/"): v for k, v in tensors} - - # Assign all variables using placeholder - placeholder_dict = { - ph: variables[v.name.split(":")[0]] - for ph, v in zip(self._variable_placeholders, self._variables) - } - session.run(self._variable_update_ops, placeholder_dict) - - def to_string(self, session: Any) -> str: - """ - Save to byte string - :param session: active session - :return: serialized byte string - """ - return pickle.dumps(self.to_arrays(session), protocol=-1) + assert sess is None + #self.model.save(save_path, save_format="tf") + self.model.save_weights(save_path) - def from_string(self, session: Any, string: str): - """ - Restore from byte string - :param session: active session - :param string: byte string to restore from - """ - self.from_arrays(session, pickle.loads(string)) + # # Save the model weights + # model_weights_path = "{}.{}.h5".format(save_path, 'weights') + # self.model.save_weights(model_weights_path) + # + # # Save the model architecture + # model_architecture_path = "{}.{}.json".format(save_path, 'architecture') + # with open(model_architecture_path, 'w') as f: + # f.write(self.model.to_json()) - def _read_tensors(self, restore_path: str): - """ - Load tensors from a checkpoint - :param restore_path: full path to load checkpoint from. - """ - # We don't use saver.restore() because checkpoint is loaded to online - # network, but if the checkpoint is from the global network, a namespace - # mismatch exists and variable name must be modified before loading. - reader = tf.contrib.framework.load_checkpoint(restore_path) - for var_name, _ in reader.get_variable_to_shape_map().items(): - yield var_name, reader.get_tensor(var_name) + return [save_path] def restore(self, sess: Any, restore_path: str): """ @@ -128,14 +66,17 @@ def restore(self, sess: Any, restore_path: str): :param sess: active session for session-based frameworks (e.g. TF) :param restore_path: full path to load checkpoint from. """ - self.from_arrays(sess, self._read_tensors(restore_path)) + assert sess is None + self.model.load_weights(restore_path) + self._weights_dict = self.model.get_weights() def merge(self, other: "Saver"): """ Merge other saver into this saver :param other: saver to be merged into self """ - assert isinstance(other, GlobalVariableSaver) - self._names.extend(other._names) - # There is nothing else to do because variables must already be part of - # the global collection. + pass + + + + diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index fe805afed..265e8e8cb 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import os import pickle - import numpy as np import tensorflow as tf - from rl_coach.utilities.shared_running_stats import SharedRunningStats +# TODO-tf2 here: remove tf1 compatibility code +tf.compat.v1.disable_resource_variables() class TFSharedRunningStats(SharedRunningStats): def __init__(self, replicated_device=None, epsilon=1e-2, name="", create_ops=True, pubsub_params=None): @@ -42,39 +43,39 @@ def set_params(self, shape=[1], clip_values=None): """ self.clip_values = clip_values - with tf.variable_scope(self.name): - self._sum = tf.get_variable( + with tf.compat.v1.variable_scope(self.name): + self._sum = tf.compat.v1.get_variable( dtype=tf.float64, - initializer=tf.constant_initializer(0.0), + initializer=tf.compat.v1.constant_initializer(0.0), name="running_sum", trainable=False, shape=shape, validate_shape=False, - collections=[tf.GraphKeys.GLOBAL_VARIABLES]) - self._sum_squares = tf.get_variable( + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES]) + self._sum_squares = tf.compat.v1.get_variable( dtype=tf.float64, - initializer=tf.constant_initializer(self.epsilon), + initializer=tf.compat.v1.constant_initializer(self.epsilon), name="running_sum_squares", trainable=False, shape=shape, validate_shape=False, - collections=[tf.GraphKeys.GLOBAL_VARIABLES]) - self._count = tf.get_variable( + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES]) + self._count = tf.compat.v1.get_variable( dtype=tf.float64, shape=(), - initializer=tf.constant_initializer(self.epsilon), - name="count", trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) + initializer=tf.compat.v1.constant_initializer(self.epsilon), + name="count", trainable=False, collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES]) self._shape = None - self._mean = tf.div(self._sum, self._count, name="mean") + self._mean = tf.compat.v1.div(self._sum, self._count, name="mean") self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean)) / tf.maximum(self._count-1, 1), self.epsilon), name="stdev") self.tf_mean = tf.cast(self._mean, 'float32') self.tf_std = tf.cast(self._std, 'float32') - self.new_sum = tf.placeholder(dtype=tf.float64, name='sum') - self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var') - self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') + self.new_sum = tf.compat.v1.placeholder(dtype=tf.float64, name='sum') + self.new_sum_squares = tf.compat.v1.placeholder(dtype=tf.float64, name='var') + self.newcount = tf.compat.v1.placeholder(shape=[], dtype=tf.float64, name='count') - self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True) - self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True) - self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True) + self._inc_sum = tf.compat.v1.assign_add(self._sum, self.new_sum, use_locking=True) + self._inc_sum_squares = tf.compat.v1.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True) + self._inc_count = tf.compat.v1.assign_add(self._count, self.newcount, use_locking=True) - self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs') + self.raw_obs = tf.compat.v1.placeholder(dtype=tf.float64, name='raw_obs') self.normalized_obs = (self.raw_obs - self._mean) / self._std if self.clip_values is not None: self.clipped_obs = tf.clip_by_value(self.normalized_obs, self.clip_values[0], self.clip_values[1]) diff --git a/rl_coach/architectures/tensorflow_components/utils.py b/rl_coach/architectures/tensorflow_components/utils.py index 45f6d01ae..bce9c4b11 100644 --- a/rl_coach/architectures/tensorflow_components/utils.py +++ b/rl_coach/architectures/tensorflow_components/utils.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2017 Intel Corporation +# Copyright (c) 2019 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + +import tensorflow as tf +from tensorflow import keras +import numpy as np +from typing import List +from typing import Union, Any +from tensorflow import Tensor + + """ Module containing utility functions """ -import tensorflow as tf +LOSS_OUT_TYPE_LOSS = 'loss' +LOSS_OUT_TYPE_REGULARIZATION = 'regularization' def get_activation_function(activation_function_string: str): @@ -25,6 +35,7 @@ def get_activation_function(activation_function_string: str): :param activation_function_string: the type of the activation function :return: the tensorflow activation function """ + activation_functions = { 'relu': tf.nn.relu, 'tanh': tf.nn.tanh, @@ -34,14 +45,76 @@ def get_activation_function(activation_function_string: str): 'leaky_relu': tf.nn.leaky_relu, 'none': None } + assert activation_function_string in activation_functions.keys(), \ "Activation function must be one of the following {}. instead it was: {}" \ .format(activation_functions.keys(), activation_function_string) - return activation_functions[activation_function_string] + return keras.activations.get(activation_function_string) def squeeze_tensor(tensor): if tensor.shape[0] == 1: return tensor[0] else: - return tensor \ No newline at end of file + return tensor + + +def to_list(data: Union[tuple, list, Any]): + """ + If input is tuple, it is converted to list. If it's list, it is returned untouched. Otherwise + returns a single-element list of the data. + :return: list-ified data + """ + if isinstance(data, list): + pass + elif isinstance(data, tuple): + data = list(data) + else: + data = [data] + return data + + +def extract_loss_inputs(head_index: int, inputs, targets: List[np.ndarray]) -> List[np.ndarray]: + """ + Creates a list of arguments from model_outputs and non_trainable_args aligned with parameters of + loss.loss_forward() based on their name in loss input_schema. + :param head_index: the head index corresponding to the loss. + :param inputs: environment states (observation, etc.) as well extra inputs required by loss. Shape of ndarray + is (batch_size, observation_space_size) or (batch_size, observation_space_size, stack_size) + :param targets: targets required by loss (e.g. sum of discounted rewards) + :return: list of non trainable arguments in correct order to be passed to loss + """ + arg_list = filter(lambda elem: elem[0].startswith('output_{}_'.format(head_index)), inputs.items()) + arg_list = dict(arg_list) + non_trainable = [] + for key in sorted(arg_list.keys()): + non_trainable.append(arg_list[key]) + + if non_trainable: + non_trainable_args = non_trainable + [targets[head_index]] + else: + non_trainable_args = [targets[head_index]] + + return non_trainable_args + + +def extract_fetches(loss_outputs: List[Tensor], head_index, additional_fetches): + """ + Creates a dictionary for loss output based on the output schema. If two output values have the same + type string in the schema they are concatenated in the same dicrionary item. + :param head_index: the head index corresponding to the loss. + :param loss_outputs: list of output values from the head loss + :param additional_fetches: additional fetches to calculate and return. Each fetch is specified as (int, str) + tuple of head-type-index and fetch-name. The tuple is obtained from each head. + """ + additional_fetches = [(k, None) for k in additional_fetches] + + for i, fetch in enumerate(additional_fetches): + head_type_idx, fetch_name = fetch[0] # fetch key is a tuple of (head_type_index, fetch_name) + if head_index == head_type_idx: + assert fetch[1] is None # sanity check that fetch is None + additional_fetches[i] = (fetch[0], loss_outputs[fetch_name]) + + # result of of additional fetches + fetched_tensors = [fetch[1] for fetch in additional_fetches] + return fetched_tensors diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 7edfbe5b0..dc8eb6392 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -14,7 +14,11 @@ # import sys +import tensorflow as tf sys.path.append('.') +# TODO: Remove. This is added for running the script from command line without rl-coach package installation +from os import sys, path +sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) import copy from configparser import ConfigParser, Error @@ -36,6 +40,7 @@ import subprocess from glob import glob + from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager @@ -771,9 +776,53 @@ def run(self): def main(): + + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + # Restrict TensorFlow to only allocate 2GB of memory on the first GPU + try: + tf.config.experimental.set_virtual_device_configuration( + gpus[0], + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)]) + logical_gpus = tf.config.experimental.list_logical_devices('GPU') + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + except RuntimeError as e: + # Virtual devices must be set before GPUs have been initialized + print(e) + launcher = CoachLauncher() launcher.launch() if __name__ == "__main__": + #os.environ['CUDA_VISIBLE_DEVICES'] = "" + #import tensorflow as tf + + # print("GPU Available: ", tf.test.is_gpu_available()) + # #physical_devices = tf.config.experimental.list_physical_devices('GPU') + # print("Num GPUs Available: ", len(physical_devices)) + # + # print('Device name is: ', tf.test.gpu_device_name()) + + #sys.argv.append('-p') + # sys.argv.append('Atari_DQN') + # sys.argv.extend(['-lvl', 'breakout']) + + #sys.argv.append('CartPole_DQN') + + # sys.argv.append('Mujoco_ClippedPPO') + # sys.argv.extend(['-lvl', 'inverted_pendulum']) + + + #sys.argv.extend(['-f', 'mxnet']) + #sys.argv.extend(['-n', '8']) + + #sys.argv.extend(['-s', '30']) + #CHECKPOINT_RESTORE_DIR = os.path.join('experiments', 'atari', '04_09_2019-20_52', 'checkpoint') + # sys.argv.append('--evaluate') + # CHECKPOINT_RESTORE_DIR = os.path.join('experiments', 'debug', '10_12_2019-10_17', 'checkpoint') + # sys.argv.extend(['-crd', CHECKPOINT_RESTORE_DIR]) + + + #with tf.device("/GPU:0"): main() diff --git a/rl_coach/coach_script_runner.py b/rl_coach/coach_script_runner.py new file mode 100644 index 000000000..d15e9bfee --- /dev/null +++ b/rl_coach/coach_script_runner.py @@ -0,0 +1,54 @@ + +import os +import sys +from rl_coach.coach import main + +import tensorflow as tf + + +# Added for running the script from command line without rl-coach package installation +# from os import sys, path +# sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) + + + + +# +# print("GPU Available: ", tf.test.is_gpu_available()) +# print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) +# print('Device name is: ', tf.test.gpu_device_name()) +# +# # with tf.device('/CPU:0'): +# with tf.device("/device:GPU:0"): +# dan_delete = tf.compat.v1.Variable(FExplain rewardalse, trainable=False, collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES]) + +print(os.getcwd()) +print(sys.executable) + +sys.argv.append('-p') + +#sys.argv.append('CartPole_DQN') + +# sys.argv.append('Atari_DQN') +# sys.argv.extend(['-lvl', 'breakout']) + +sys.argv.append('Mujoco_ClippedPPO') +#sys.argv.extend(['-lvl', 'inverted_pendulum']) +sys.argv.extend(['-lvl', 'humanoid']) + +# sys.argv.extend(['-f', 'mxnet']) +# sys.argv.extend(['-s', '1500']) + +CHECKPOINT_RESTORE_DIR = os.path.join('experiments', 'atari', '04_09_2019-20_52', 'checkpoint') +# sys.argv.extend(['-crd', CHECKPOINT_RESTORE_DIR]) + +# sys.argv.extend('--evaluate') + +print(sys.argv) + + +# main() +# with tf.device("/device:GPU:0"): + +#with tf.device("/GPU:0"): +main() diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index b1ebafbb0..5946aea92 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -142,7 +142,7 @@ def __init__(self, level=None): atari_schedule.improve_steps = EnvironmentSteps(50000000) atari_schedule.steps_between_evaluation_periods = EnvironmentSteps(250000) atari_schedule.evaluation_steps = EnvironmentSteps(135000) -atari_schedule.heatup_steps = EnvironmentSteps(1) +atari_schedule.heatup_steps = EnvironmentSteps(50000) class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper): diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 791b345f1..95f575726 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -52,6 +52,7 @@ def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: :param mode: the arithmetic module to use {'tf' | 'numpy'} :return: None """ + mode = 'numpy' if mode == 'tf': from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False, diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 60afceef3..4626057b7 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -212,7 +212,8 @@ def create_worker_or_parameters_server(task_parameters: DistributedTaskParameter def _create_session_tf(self, task_parameters: TaskParameters): import tensorflow as tf - config = tf.ConfigProto() + + config = tf.compat.v1.ConfigProto() config.allow_soft_placement = True # allow placing ops on cpu if they are not fit for gpu config.gpu_options.allow_growth = True # allow the gpu memory allocated for the worker to grow if needed # config.gpu_options.per_process_gpu_memory_fraction = 0.2 @@ -255,9 +256,16 @@ def _create_session_mx(self): """ self.set_session(sess=None) # Initialize all modules + def _create_session_tf2(self): + """ + Call set_session to initialize parameters and construct checkpoint_saver + """ + self.set_session(sess=None) # Initialize all modules + def create_session(self, task_parameters: TaskParameters): if task_parameters.framework_type == Frameworks.tensorflow: - self._create_session_tf(task_parameters) + + self._create_session_tf2() elif task_parameters.framework_type == Frameworks.mxnet: self._create_session_mx() else: diff --git a/rl_coach/presets/Mujoco_ClippedPPO.py b/rl_coach/presets/Mujoco_ClippedPPO.py index 9d00911c2..f5091ffb1 100644 --- a/rl_coach/presets/Mujoco_ClippedPPO.py +++ b/rl_coach/presets/Mujoco_ClippedPPO.py @@ -1,5 +1,6 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters from rl_coach.architectures.layers import Dense +#from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps from rl_coach.environments.environment import SingleLevelSelection @@ -11,13 +12,17 @@ from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import LinearSchedule +from tensorflow import keras + #################### # Graph Scheduling # #################### schedule_params = ScheduleParameters() schedule_params.improve_steps = TrainingSteps(10000000) +#schedule_params.improve_steps = TrainingSteps(10000) schedule_params.steps_between_evaluation_periods = EnvironmentSteps(2048) +#schedule_params.steps_between_evaluation_periods = EnvironmentSteps(200) schedule_params.evaluation_steps = EnvironmentEpisodes(5) schedule_params.heatup_steps = EnvironmentSteps(0) diff --git a/rl_coach/run_multiple_seeds.py b/rl_coach/run_multiple_seeds.py index 6a6ddde83..b27d3dd17 100644 --- a/rl_coach/run_multiple_seeds.py +++ b/rl_coach/run_multiple_seeds.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +# Added for running the script from command line without rl-coach package installation +from os import sys, path +sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) import sys sys.path.append('.') from subprocess import Popen import argparse from rl_coach.utils import set_gpu, force_list + + """ This script makes it easier to run multiple instances of a given preset. Each instance uses a different seed, and optionally, multiple environment levels can be configured as well. @@ -120,15 +124,17 @@ separator = "/" else: separator = "_" - command.extend(['-e', '{dir_prefix}{preset}_{seed}_{separator}{level}_{num_workers}_workers'.format( + command.extend(['-e', '{dir_prefix}{preset}__{separator}{level}__workers'.format( dir_prefix=dir_prefix, preset=preset, seed=seed, level=level, separator=separator, num_workers=args.num_workers)]) else: - command.extend(['-e', '{dir_prefix}{preset}_{seed}_{num_workers}_workers'.format( + command.extend(['-e', '{dir_prefix}{preset}___workers'.format( dir_prefix=dir_prefix, preset=preset, seed=seed, num_workers=args.num_workers)]) print(command) p = Popen(command) + import time + time.sleep(5) processes.append(p) # for each run, select the next gpu from the available gpus diff --git a/rl_coach/utils.py b/rl_coach/utils.py index f51b02b18..ee4ff1643 100644 --- a/rl_coach/utils.py +++ b/rl_coach/utils.py @@ -475,16 +475,19 @@ def some_worker_is_writing(self): return self.now_writing is True def lock_writing_and_reading(self): + #pass self.writers_lock.acquire() # first things first - block all other writers self.now_writing = True # block new readers who haven't started reading yet while self.some_worker_is_reading(): # let existing readers finish their homework time.sleep(0.05) def release_writing_and_reading(self): + #pass self.now_writing = False # release readers - guarantee no readers starvation self.writers_lock.release() # release writers def lock_writing(self): + #pass while self.now_writing: time.sleep(0.05) @@ -493,6 +496,7 @@ def lock_writing(self): self.num_readers_lock.release() def release_writing(self): + #pass self.num_readers_lock.acquire() self.num_readers -= 1 self.num_readers_lock.release()