#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
from rl_coach.core_types import ActionType
from rl_coach.filters.action.action_filter import ActionFilter
from rl_coach.spaces import DiscreteActionSpace, ActionSpace
[docs]class PartialDiscreteActionSpaceMap(ActionFilter):
"""
Partial map of two countable action spaces. For example, consider an environment
with a MultiSelect action space (select multiple actions at the same time, such as jump and go right), with 8 actual
MultiSelect actions. If we want the agent to be able to select only 5 of those actions by their index (0-4), we can
map a discrete action space with 5 actions into the 5 selected MultiSelect actions. This will both allow the agent to
use regular discrete actions, and mask 3 of the actions from the agent.
"""
def __init__(self, target_actions: List[ActionType]=None, descriptions: List[str]=None):
"""
:param target_actions: A partial list of actions from the target space to map to.
:param descriptions: a list of descriptions of each of the actions
"""
self.target_actions = target_actions
self.descriptions = descriptions
super().__init__()
def validate_output_action_space(self, output_action_space: ActionSpace):
if not self.target_actions:
raise ValueError("The target actions were not set")
for v in self.target_actions:
if not output_action_space.contains(v):
raise ValueError("The values in the output actions ({}) do not match the output action "
"space definition ({})".format(v, output_action_space))
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
self.output_action_space = output_action_space
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions,
filtered_action_space=output_action_space)
return self.input_action_space
def filter(self, action: ActionType) -> ActionType:
return self.target_actions[action]
def reverse_filter(self, action: ActionType) -> ActionType:
return [(action == x).all() for x in self.target_actions].index(True)