Source code for rl_coach.filters.observation.observation_reduction_by_sub_parts_name_filter

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from enum import Enum
from typing import List

import numpy as np

from rl_coach.core_types import ObservationType
from rl_coach.filters.observation.observation_filter import ObservationFilter
from rl_coach.spaces import ObservationSpace, VectorObservationSpace


[docs]class ObservationReductionBySubPartsNameFilter(ObservationFilter): """ Allows keeping only parts of the observation, by specifying their name. This is useful when the environment has a measurements vector as observation which includes several different measurements, but you want the agent to only see some of the measurements and not all. For example, the CARLA environment extracts multiple measurements that can be used by the agent, such as speed and location. If we want to only use the speed, it can be done using this filter. This will currently work only for VectorObservationSpace observations """ class ReductionMethod(Enum): Keep = 0 Discard = 1 def __init__(self, part_names: List[str], reduction_method: ReductionMethod): """ :param part_names: A list of part names to reduce :param reduction_method: A reduction method to use - keep or discard the given parts """ super().__init__() self.part_names = part_names self.reduction_method = reduction_method self.measurement_names = None self.indices_to_keep = None def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType: if not isinstance(observation, np.ndarray): raise ValueError("All the state values are expected to be numpy arrays") if self.indices_to_keep is None: raise ValueError("To use ObservationReductionBySubPartsNameFilter, the get_filtered_observation_space " "function should be called before filtering an observation") observation = observation[..., self.indices_to_keep] return observation def validate_input_observation_space(self, input_observation_space: ObservationSpace): if not isinstance(input_observation_space, VectorObservationSpace): raise ValueError("The ObservationReductionBySubPartsNameFilter support only VectorObservationSpace " "observations. The given observation space was: {}" .format(input_observation_space.__class__)) def get_filtered_observation_space(self, input_observation_space: VectorObservationSpace) -> ObservationSpace: self.measurement_names = copy.copy(input_observation_space.measurements_names) if self.reduction_method == self.ReductionMethod.Keep: input_observation_space.shape[-1] = len(self.part_names) self.indices_to_keep = [idx for idx, val in enumerate(self.measurement_names) if val in self.part_names] input_observation_space.measurements_names = copy.copy(self.part_names) elif self.reduction_method == self.ReductionMethod.Discard: input_observation_space.shape[-1] -= len(self.part_names) self.indices_to_keep = [idx for idx, val in enumerate(self.measurement_names) if val not in self.part_names] input_observation_space.measurements_names = [val for val in input_observation_space.measurements_names if val not in self.part_names] else: raise ValueError("The given reduction method is not supported") return input_observation_space