Source code for rl_coach.filters.action.partial_discrete_action_space_map

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List

from rl_coach.core_types import ActionType
from rl_coach.filters.action.action_filter import ActionFilter
from rl_coach.spaces import DiscreteActionSpace, ActionSpace


[docs]class PartialDiscreteActionSpaceMap(ActionFilter): """ Partial map of two countable action spaces. For example, consider an environment with a MultiSelect action space (select multiple actions at the same time, such as jump and go right), with 8 actual MultiSelect actions. If we want the agent to be able to select only 5 of those actions by their index (0-4), we can map a discrete action space with 5 actions into the 5 selected MultiSelect actions. This will both allow the agent to use regular discrete actions, and mask 3 of the actions from the agent. """ def __init__(self, target_actions: List[ActionType]=None, descriptions: List[str]=None): """ :param target_actions: A partial list of actions from the target space to map to. :param descriptions: a list of descriptions of each of the actions """ self.target_actions = target_actions self.descriptions = descriptions super().__init__() def validate_output_action_space(self, output_action_space: ActionSpace): if not self.target_actions: raise ValueError("The target actions were not set") for v in self.target_actions: if not output_action_space.contains(v): raise ValueError("The values in the output actions ({}) do not match the output action " "space definition ({})".format(v, output_action_space)) def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace: self.output_action_space = output_action_space self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions, filtered_action_space=output_action_space) return self.input_action_space def filter(self, action: ActionType) -> ActionType: return self.target_actions[action] def reverse_filter(self, action: ActionType) -> ActionType: return [(action == x).all() for x in self.target_actions].index(True)