Source code for rl_coach.architectures.network_wrapper

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Tuple

from rl_coach.base_parameters import Frameworks, AgentParameters
from rl_coach.logger import failed_imports
from rl_coach.saver import SaverCollection
from rl_coach.spaces import SpacesDefinition
from rl_coach.utils import force_list


[docs]class NetworkWrapper(object): """ The network wrapper contains multiple copies of the same network, each one with a different set of weights which is updating in a different time scale. The network wrapper will always contain an online network. It will contain an additional slow updating target network if it was requested by the user, and it will contain a global network shared between different workers, if Coach is run in a single-node multi-process distributed mode. The network wrapper contains functionality for managing these networks and syncing between them. """ def __init__(self, agent_parameters: AgentParameters, has_target: bool, has_global: bool, name: str, spaces: SpacesDefinition, replicated_device=None, worker_device=None): self.ap = agent_parameters self.network_parameters = self.ap.network_wrappers[name] self.has_target = has_target self.has_global = has_global self.name = name self.sess = None if self.network_parameters.framework == Frameworks.tensorflow: try: import tensorflow as tf except ImportError: raise Exception('Install tensorflow before using it as framework') from rl_coach.architectures.tensorflow_components.general_network import GeneralTensorFlowNetwork general_network = GeneralTensorFlowNetwork.construct elif self.network_parameters.framework == Frameworks.mxnet: try: import mxnet as mx except ImportError: raise Exception('Install mxnet before using it as framework') from rl_coach.architectures.mxnet_components.general_network import GeneralMxnetNetwork general_network = GeneralMxnetNetwork.construct else: raise Exception("{} Framework is not supported" .format(Frameworks().to_string(self.network_parameters.framework))) variable_scope = "{}/{}".format(self.ap.full_name_id, name) # Global network - the main network shared between threads self.global_network = None if self.has_global: # we assign the parameters of this network on the parameters server self.global_network = general_network(variable_scope=variable_scope, devices=force_list(replicated_device), agent_parameters=agent_parameters, name='{}/global'.format(name), global_network=None, network_is_local=False, spaces=spaces, network_is_trainable=True) # Online network - local copy of the main network used for playing self.online_network = None self.online_network = general_network(variable_scope=variable_scope, devices=force_list(worker_device), agent_parameters=agent_parameters, name='{}/online'.format(name), global_network=self.global_network, network_is_local=True, spaces=spaces, network_is_trainable=True) # Target network - a local, slow updating network used for stabilizing the learning self.target_network = None if self.has_target: self.target_network = general_network(variable_scope=variable_scope, devices=force_list(worker_device), agent_parameters=agent_parameters, name='{}/target'.format(name), global_network=self.global_network, network_is_local=True, spaces=spaces, network_is_trainable=False)
[docs] def sync(self): """ Initializes the weights of the networks to match each other :return: """ self.update_online_network() self.update_target_network()
[docs] def update_target_network(self, rate=1.0): """ Copy weights: online network >>> target network :param rate: the rate of copying the weights - 1 for copying exactly """ if self.target_network: self.target_network.set_weights(self.online_network.get_weights(), rate)
[docs] def update_online_network(self, rate=1.0): """ Copy weights: global network >>> online network :param rate: the rate of copying the weights - 1 for copying exactly """ if self.global_network: self.online_network.set_weights(self.global_network.get_weights(), rate)
[docs] def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None): """ Apply gradients from the online network on the global network :param gradients: optional gradients that will be used instead of teh accumulated gradients :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's update ops also requires the inputs) :return: """ if gradients is None: gradients = self.online_network.accumulated_gradients if self.network_parameters.shared_optimizer: self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs) else: self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
[docs] def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None): """ Apply gradients from the online network on itself :param gradients: optional gradients that will be used instead of teh accumulated gradients :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's update ops also requires the inputs) :return: """ if gradients is None: gradients = self.online_network.accumulated_gradients self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
[docs] def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None, use_inputs_for_apply_gradients=False): """ A generic training function that enables multi-threading training using a global network if necessary. :param inputs: The inputs for the network. :param targets: The targets corresponding to the given inputs :param additional_fetches: Any additional tensor the user wants to fetch :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss error of this sample. If it is not given, the samples losses won't be scaled :param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients (e.g. for incorporating batchnorm update ops) :return: The loss of the training iteration """ result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, importance_weights=importance_weights, no_accumulation=True) if use_inputs_for_apply_gradients: self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs) else: self.apply_gradients_and_sync_networks(reset_gradients=False) return result
[docs] def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None): """ Applies the gradients accumulated in the online network to the global network or to itself and syncs the networks if necessary :param reset_gradients: If set to True, the accumulated gradients wont be reset to 0 after applying them to the network. this is useful when the accumulated gradients are overwritten instead if accumulated by the accumulate_gradients function. this allows reducing time complexity for this function by around 10% :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's update ops also requires the inputs) """ if self.global_network: self.apply_gradients_to_global_network(additional_inputs=additional_inputs) if reset_gradients: self.online_network.reset_accumulated_gradients() self.update_online_network() else: if reset_gradients: self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients, additional_inputs=additional_inputs) else: self.online_network.apply_gradients(self.online_network.accumulated_gradients, additional_inputs=additional_inputs)
[docs] def parallel_prediction(self, network_input_tuples: List[Tuple]): """ Run several network prediction in parallel. Currently this only supports running each of the network once. :param network_input_tuples: a list of tuples where the first element is the network (online_network, target_network or global_network) and the second element is the inputs :return: the outputs of all the networks in the same order as the inputs were given """ return type(self.online_network).parallel_predict(self.sess, network_input_tuples)
[docs] def set_is_training(self, state: bool): """ Set the phase of the network between training and testing :param state: The current state (True = Training, False = Testing) :return: None """ self.online_network.set_is_training(state) if self.has_target: self.target_network.set_is_training(state)
def set_session(self, sess): self.sess = sess self.online_network.set_session(sess) if self.global_network: self.global_network.set_session(sess) if self.target_network: self.target_network.set_session(sess) def __str__(self): sub_networks = [] if self.global_network: sub_networks.append("global network") if self.online_network: sub_networks.append("online network") if self.target_network: sub_networks.append("target network") result = [] result.append("Network: {}, Copies: {} ({})".format(self.name, len(sub_networks), ' | '.join(sub_networks))) result.append("-"*len(result[-1])) result.append(str(self.online_network)) result.append("") return '\n'.join(result)
[docs] def collect_savers(self, parent_path_suffix: str) -> SaverCollection: """ Collect all of network's savers for global or online network Note: global, online, and target network are all copies fo the same network which parameters that are updated at different rates. So we only need to save one of the networks; the one that holds the most recent parameters. target network is created for some agents and used for stabilizing training by updating parameters from online network at a slower rate. As a result, target network never contains the most recent set of parameters. In single-worker training, no global network is created and online network contains the most recent parameters. In vertical distributed training with more than one worker, global network is updated by all workers and contains the most recent parameters. Therefore preference is given to global network if it exists, otherwise online network is used for saving. :param parent_path_suffix: path suffix of the parent of the network wrapper (e.g. could be name of level manager plus name of agent) :return: collection of all checkpoint objects """ if self.global_network: savers = self.global_network.collect_savers(parent_path_suffix) else: savers = self.online_network.collect_savers(parent_path_suffix) return savers