Source code for rl_coach.filters.observation.observation_normalization_filter

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pickle
from typing import List

import numpy as np

from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace
from rl_coach.utilities.shared_running_stats import NumpySharedRunningStats, NumpySharedRunningStats


[docs]class ObservationNormalizationFilter(ObservationFilter): """ Normalizes the observation values with a running mean and standard deviation of all the observations seen so far. The normalization is performed element-wise. Additionally, when working with multiple workers, the statistics used for the normalization operation are accumulated over all the workers. """ def __init__(self, clip_min: float=-5.0, clip_max: float=5.0, name='observation_stats'): """ :param clip_min: The minimum value to allow after normalizing the observation :param clip_max: The maximum value to allow after normalizing the observation """ super().__init__() self.clip_min = clip_min self.clip_max = clip_max self.running_observation_stats = None self.name = name self.supports_batching = True self.observation_space = None def set_device(self, device, memory_backend_params=None, mode='numpy') -> None: """ An optional function that allows the filter to get the device if it is required to use tensorflow ops :param device: the device to use :memory_backend_params: if not None, holds params for a memory backend for sharing data (e.g. Redis) :param mode: the arithmetic module to use {'tf' | 'numpy'} :return: None """ if mode == 'tf': from rl_coach.architectures.tensorflow_components.shared_variables import TFSharedRunningStats self.running_observation_stats = TFSharedRunningStats(device, name=self.name, create_ops=False, pubsub_params=memory_backend_params) elif mode == 'numpy': self.running_observation_stats = NumpySharedRunningStats(name=self.name, pubsub_params=memory_backend_params) def set_session(self, sess) -> None: """ An optional function that allows the filter to get the session if it is required to use tensorflow ops :param sess: the session :return: None """ self.running_observation_stats.set_session(sess) def filter(self, observations: List[ObservationType], update_internal_state: bool=True) -> ObservationType: observations = np.array(observations) if update_internal_state: self.running_observation_stats.push(observations) self.last_mean = self.running_observation_stats.mean self.last_stdev = self.running_observation_stats.std return self.running_observation_stats.normalize(observations) def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace: self.running_observation_stats.set_params(shape=input_observation_space.shape, clip_values=(self.clip_min, self.clip_max)) return input_observation_space def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)