#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import namedtuple
import copy
import math
from enum import Enum
from random import shuffle
from typing import List, Union, Dict, Any, Type
import numpy as np
from rl_coach.utils import force_list
ActionType = Union[int, float, np.ndarray, List]
GoalType = Union[None, np.ndarray]
ObservationType = np.ndarray
RewardType = Union[int, float, np.ndarray]
StateType = Dict[str, np.ndarray]
class GoalTypes(Enum):
Embedding = 1
EmbeddingChange = 2
Observation = 3
Measurements = 4
Record = namedtuple('Record', ['name', 'label'])
class TimeTypes(Enum):
EpisodeNumber = Record(name='Episode #', label='Episode #')
TrainingIteration = Record(name='Training Iter', label='Training Iteration')
EnvironmentSteps = Record(name='Total steps', label='Total steps (per worker)')
WallClockTime = Record(name='Wall-Clock Time', label='Wall-Clock Time (minutes)')
Epoch = Record(name='Epoch', label='Epoch #')
# step methods
class StepMethod(object):
def __init__(self, num_steps: int):
self._num_steps = self.num_steps = num_steps
@property
def num_steps(self) -> int:
return self._num_steps
@num_steps.setter
def num_steps(self, val: int) -> None:
self._num_steps = val
def __eq__(self, other):
return self.num_steps == other.num_steps
def __truediv__(self, other):
"""
divide this step method with other. If other is an integer, returns an object of the same
type as self. If other is the same type of self, returns an integer. In either case, any
floating point value is rounded up under the assumption that if we are dividing Steps, we
would rather overestimate than underestimate.
"""
if isinstance(other, type(self)):
return math.ceil(self.num_steps / other.num_steps)
elif isinstance(other, int):
return type(self)(math.ceil(self.num_steps / other))
else:
raise TypeError("cannot divide {} by {}".format(type(self), type(other)))
def __rtruediv__(self, other):
"""
divide this step method with other. If other is an integer, returns an object of the same
type as self. If other is the same type of self, returns an integer. In either case, any
floating point value is rounded up under the assumption that if we are dividing Steps, we
would rather overestimate than underestimate.
"""
if isinstance(other, type(self)):
return math.ceil(other.num_steps / self.num_steps)
elif isinstance(other, int):
return type(self)(math.ceil(other / self.num_steps))
else:
raise TypeError("cannot divide {} by {}".format(type(other), type(self)))
class Frames(StepMethod):
def __init__(self, num_steps):
super().__init__(num_steps)
class EnvironmentSteps(StepMethod):
def __init__(self, num_steps):
super().__init__(num_steps)
class EnvironmentEpisodes(StepMethod):
def __init__(self, num_steps):
super().__init__(num_steps)
class TrainingSteps(StepMethod):
def __init__(self, num_steps):
super().__init__(num_steps)
def __truediv__(self, other):
if isinstance(other, EnvironmentSteps):
return math.ceil(self.num_steps / other.num_steps)
else:
super().__truediv__(self, other)
class Time(StepMethod):
def __init__(self, num_steps):
super().__init__(num_steps)
class PredictionType(object):
pass
class VStateValue(PredictionType):
pass
class QActionStateValue(PredictionType):
pass
class ActionProbabilities(PredictionType):
pass
class Embedding(PredictionType):
pass
class Measurements(PredictionType):
pass
class InputEmbedding(Embedding):
pass
class MiddlewareEmbedding(Embedding):
pass
class InputImageEmbedding(InputEmbedding):
pass
class InputVectorEmbedding(InputEmbedding):
pass
class InputTensorEmbedding(InputEmbedding):
pass
class Middleware_FC_Embedding(MiddlewareEmbedding):
pass
class Middleware_LSTM_Embedding(MiddlewareEmbedding):
pass
PlayingStepsType = Union[EnvironmentSteps, EnvironmentEpisodes, Frames]
# run phases
class RunPhase(Enum):
HEATUP = "Heatup"
TRAIN = "Training"
TEST = "Testing"
UNDEFINED = "Undefined"
# transitions
[docs]class Transition(object):
def __init__(self, state: Dict[str, np.ndarray]=None, action: ActionType=None, reward: RewardType=None,
next_state: Dict[str, np.ndarray]=None, game_over: bool=None, info: Dict=None):
"""
A transition is a tuple containing the information of a single step of interaction
between the agent and the environment. The most basic version should contain the following values:
(current state, action, reward, next state, game over)
For imitation learning algorithms, if the reward, next state or game over is not known,
it is sufficient to store the current state and action taken by the expert.
:param state: The current state. Assumed to be a dictionary where the observation
is located at state['observation']
:param action: The current action that was taken
:param reward: The reward received from the environment
:param next_state: The next state of the environment after applying the action.
The next state should be similar to the state in its structure.
:param game_over: A boolean which should be True if the episode terminated after
the execution of the action.
:param info: A dictionary containing any additional information to be stored in the transition
"""
self._state = self.state = state
self._action = self.action = action
self._reward = self.reward = reward
self._n_step_discounted_rewards = self.n_step_discounted_rewards = None
if not next_state:
next_state = state
self._next_state = self._next_state = next_state
self._game_over = self.game_over = game_over
if info is None:
self.info = {}
else:
self.info = info
def __repr__(self):
return str(self.__dict__)
@property
def state(self):
if self._state is None:
raise Exception("The state was not filled by any of the modules between the environment and the agent")
return self._state
@state.setter
def state(self, val):
self._state = val
@property
def action(self):
if self._action is None:
raise Exception("The action was not filled by any of the modules between the environment and the agent")
return self._action
@action.setter
def action(self, val):
self._action = val
@property
def reward(self):
if self._reward is None:
raise Exception("The reward was not filled by any of the modules between the environment and the agent")
return self._reward
@reward.setter
def reward(self, val):
self._reward = val
@property
def n_step_discounted_rewards(self):
if self._n_step_discounted_rewards is None:
raise Exception("The n_step_discounted_rewards were not filled by any of the modules between the "
"environment and the agent. Make sure that you are using an episodic experience replay.")
return self._n_step_discounted_rewards
@n_step_discounted_rewards.setter
def n_step_discounted_rewards(self, val):
self._n_step_discounted_rewards = val
@property
def game_over(self):
if self._game_over is None:
raise Exception("The done flag was not filled by any of the modules between the environment and the agent")
return self._game_over
@game_over.setter
def game_over(self, val):
self._game_over = val
@property
def next_state(self):
if self._next_state is None:
raise Exception("The next state was not filled by any of the modules between the environment and the agent")
return self._next_state
@next_state.setter
def next_state(self, val):
self._next_state = val
def add_info(self, new_info: Dict[str, Any]) -> None:
if not new_info.keys().isdisjoint(self.info.keys()):
raise ValueError("The new info dictionary can not be appended to the existing info dictionary since there "
"are overlapping keys between the two. old keys: {}, new keys: {}"
.format(self.info.keys(), new_info.keys()))
self.info.update(new_info)
def update_info(self, new_info: Dict[str, Any]) -> None:
self.info.update(new_info)
def __copy__(self):
new_transition = type(self)()
new_transition.__dict__.update(self.__dict__)
new_transition.state = copy.copy(new_transition.state)
new_transition.next_state = copy.copy(new_transition.next_state)
new_transition.info = copy.copy(new_transition.info)
return new_transition
[docs]class EnvResponse(object):
def __init__(self, next_state: Dict[str, ObservationType], reward: RewardType, game_over: bool, info: Dict=None,
goal: ObservationType=None):
"""
An env response is a collection containing the information returning from the environment after a single action
has been performed on it.
:param next_state: The new state that the environment has transitioned into. Assumed to be a dictionary where the
observation is located at state['observation']
:param reward: The reward received from the environment
:param game_over: A boolean which should be True if the episode terminated after
the execution of the action.
:param info: any additional info from the environment
:param goal: a goal defined by the environment
"""
self._next_state = self.next_state = next_state
self._reward = self.reward = reward
self._game_over = self.game_over = game_over
self._goal = self.goal = goal
if info is None:
self.info = {}
else:
self.info = info
def __repr__(self):
return str(self.__dict__)
@property
def next_state(self):
return self._next_state
@next_state.setter
def next_state(self, val):
self._next_state = val
@property
def reward(self):
return self._reward
@reward.setter
def reward(self, val):
self._reward = val
@property
def game_over(self):
return self._game_over
@game_over.setter
def game_over(self, val):
self._game_over = val
@property
def goal(self):
return self._goal
@goal.setter
def goal(self, val):
self._goal = val
def add_info(self, info: Dict[str, Any]) -> None:
if info.keys().isdisjoint(self.info.keys()):
raise ValueError("The new info dictionary can not be appended to the existing info dictionary since there"
"are overlapping keys between the two")
self.info.update(info)
[docs]class ActionInfo(object):
"""
Action info is a class that holds an action and various additional information details about it
"""
def __init__(self, action: ActionType, all_action_probabilities: float=0,
action_value: float=0., state_value: float=0., max_action_value: float=None):
"""
:param action: the action
:param all_action_probabilities: the probability that the action was given when selecting it
:param action_value: the state-action value (Q value) of the action
:param state_value: the state value (V value) of the state where the action was taken
:param max_action_value: in case this is an action that was selected randomly, this is the value of the action
that received the maximum value. if no value is given, the action is assumed to be the
action with the maximum value
"""
self.action = action
self.all_action_probabilities = all_action_probabilities
self.action_value = action_value
self.state_value = state_value
if not max_action_value:
self.max_action_value = action_value
else:
self.max_action_value = max_action_value
[docs]class Batch(object):
"""
A wrapper around a list of transitions that helps extracting batches of parameters from it.
For example, one can extract a list of states corresponding to the list of transitions.
The class uses lazy evaluation in order to return each of the available parameters.
"""
def __init__(self, transitions: List[Transition]):
"""
:param transitions: a list of transitions to extract the batch from
"""
self.transitions = transitions
self._states = {}
self._actions = None
self._rewards = None
self._n_step_discounted_rewards = None
self._game_overs = None
self._next_states = {}
self._goals = None
self._info = {}
[docs] def slice(self, start, end) -> None:
"""
Keep a slice from the batch and discard the rest of the batch
:param start: the start index in the slice
:param end: the end index in the slice
:return: None
"""
self.transitions = self.transitions[start:end]
for k, v in self._states.items():
self._states[k] = v[start:end]
if self._actions is not None:
self._actions = self._actions[start:end]
if self._rewards is not None:
self._rewards = self._rewards[start:end]
if self._n_step_discounted_rewards is not None:
self._n_step_discounted_rewards = self._n_step_discounted_rewards[start:end]
if self._game_overs is not None:
self._game_overs = self._game_overs[start:end]
for k, v in self._next_states.items():
self._next_states[k] = v[start:end]
if self._goals is not None:
self._goals = self._goals[start:end]
for k, v in self._info.items():
self._info[k] = v[start:end]
[docs] def shuffle(self) -> None:
"""
Shuffle all the transitions in the batch
:return: None
"""
batch_order = list(range(self.size))
shuffle(batch_order)
self.transitions = [self.transitions[i] for i in batch_order]
self._states = {}
self._actions = None
self._rewards = None
self._n_step_discounted_rewards = None
self._game_overs = None
self._next_states = {}
self._goals = None
self._info = {}
# This seems to be slower
# for k, v in self._states.items():
# self._states[k] = [v[i] for i in batch_order]
# if self._actions is not None:
# self._actions = [self._actions[i] for i in batch_order]
# if self._rewards is not None:
# self._rewards = [self._rewards[i] for i in batch_order]
# if self._total_returns is not None:
# self._total_returns = [self._total_returns[i] for i in batch_order]
# if self._game_overs is not None:
# self._game_overs = [self._game_overs[i] for i in batch_order]
# for k, v in self._next_states.items():
# self._next_states[k] = [v[i] for i in batch_order]
# if self._goals is not None:
# self._goals = [self._goals[i] for i in batch_order]
# for k, v in self._info.items():
# self._info[k] = [v[i] for i in batch_order]
[docs] def states(self, fetches: List[str], expand_dims=False) -> Dict[str, np.ndarray]:
"""
follow the keys in fetches to extract the corresponding items from the states in the batch
if these keys were not already extracted before. return only the values corresponding to those keys
:param fetches: the keys of the state dictionary to extract
:param expand_dims: add an extra dimension to each of the value batches
:return: a dictionary containing a batch of values correponding to each of the given fetches keys
"""
current_states = {}
# there are cases (e.g. ddpg) where the state does not contain all the information needed for running
# through the network and this has to be added externally (e.g. ddpg where the action needs to be given in
# addition to the current_state, so that all the inputs of the network will be filled)
for key in set(fetches).intersection(self.transitions[0].state.keys()):
if key not in self._states.keys():
self._states[key] = np.array([np.array(transition.state[key]) for transition in self.transitions])
if expand_dims:
current_states[key] = np.expand_dims(self._states[key], -1)
else:
current_states[key] = self._states[key]
return current_states
[docs] def actions(self, expand_dims=False) -> np.ndarray:
"""
if the actions were not converted to a batch before, extract them to a batch and then return the batch
:param expand_dims: add an extra dimension to the actions batch
:return: a numpy array containing all the actions of the batch
"""
if self._actions is None:
self._actions = np.array([transition.action for transition in self.transitions])
if expand_dims:
return np.expand_dims(self._actions, -1)
return self._actions
[docs] def rewards(self, expand_dims=False) -> np.ndarray:
"""
if the rewards were not converted to a batch before, extract them to a batch and then return the batch
:param expand_dims: add an extra dimension to the rewards batch
:return: a numpy array containing all the rewards of the batch
"""
if self._rewards is None:
self._rewards = np.array([transition.reward for transition in self.transitions])
if expand_dims:
return np.expand_dims(self._rewards, -1)
return self._rewards
[docs] def n_step_discounted_rewards(self, expand_dims=False) -> np.ndarray:
"""
if the n_step_discounted_rewards were not converted to a batch before, extract them to a batch and then return
the batch
if the n step discounted rewards were not filled, this will raise an exception
:param expand_dims: add an extra dimension to the total_returns batch
:return: a numpy array containing all the total return values of the batch
"""
if self._n_step_discounted_rewards is None:
self._n_step_discounted_rewards = np.array([transition.n_step_discounted_rewards for transition in
self.transitions])
if expand_dims:
return np.expand_dims(self._n_step_discounted_rewards, -1)
return self._n_step_discounted_rewards
[docs] def game_overs(self, expand_dims=False) -> np.ndarray:
"""
if the game_overs were not converted to a batch before, extract them to a batch and then return the batch
:param expand_dims: add an extra dimension to the game_overs batch
:return: a numpy array containing all the game over flags of the batch
"""
if self._game_overs is None:
self._game_overs = np.array([transition.game_over for transition in self.transitions])
if expand_dims:
return np.expand_dims(self._game_overs, -1)
return self._game_overs
[docs] def next_states(self, fetches: List[str], expand_dims=False) -> Dict[str, np.ndarray]:
"""
follow the keys in fetches to extract the corresponding items from the next states in the batch
if these keys were not already extracted before. return only the values corresponding to those keys
:param fetches: the keys of the state dictionary to extract
:param expand_dims: add an extra dimension to each of the value batches
:return: a dictionary containing a batch of values correponding to each of the given fetches keys
"""
next_states = {}
# there are cases (e.g. ddpg) where the state does not contain all the information needed for running
# through the network and this has to be added externally (e.g. ddpg where the action needs to be given in
# addition to the current_state, so that all the inputs of the network will be filled)
for key in set(fetches).intersection(self.transitions[0].next_state.keys()):
if key not in self._next_states.keys():
self._next_states[key] = np.array(
[np.array(transition.next_state[key]) for transition in self.transitions])
if expand_dims:
next_states[key] = np.expand_dims(self._next_states[key], -1)
else:
next_states[key] = self._next_states[key]
return next_states
[docs] def goals(self, expand_dims=False) -> np.ndarray:
"""
if the goals were not converted to a batch before, extract them to a batch and then return the batch
if the goal was not filled, this will raise an exception
:param expand_dims: add an extra dimension to the goals batch
:return: a numpy array containing all the goals of the batch
"""
if self._goals is None:
self._goals = np.array([transition.goal for transition in self.transitions])
if expand_dims:
return np.expand_dims(self._goals, -1)
return self._goals
[docs] def info_as_list(self, key) -> list:
"""
get the info and store it internally as a list, if wasn't stored before. return it as a list
:param expand_dims: add an extra dimension to the info batch
:return: a list containing all the info values of the batch corresponding to the given key
"""
if key not in self._info.keys():
self._info[key] = [transition.info[key] for transition in self.transitions]
return self._info[key]
[docs] def info(self, key, expand_dims=False) -> np.ndarray:
"""
if the given info dictionary key was not converted to a batch before, extract it to a batch and then return the
batch. if the key is not part of the keys in the info dictionary, this will raise an exception
:param expand_dims: add an extra dimension to the info batch
:return: a numpy array containing all the info values of the batch corresponding to the given key
"""
info_list = self.info_as_list(key)
if expand_dims:
return np.expand_dims(info_list, -1)
return np.array(info_list)
@property
def size(self) -> int:
"""
:return: the size of the batch
"""
return len(self.transitions)
def __getitem__(self, key):
"""
get an item from the transitions list
:param key: index of the transition in the batch
:return: the transition corresponding to the given index
"""
return self.transitions[key]
def __setitem__(self, key, item):
"""
set an item in the transition list
:param key: index of the transition in the batch
:param item: the transition to place in the given index
:return: None
"""
self.transitions[key] = item
class TotalStepsCounter(object):
"""
A wrapper around a dictionary counting different StepMethods steps done.
"""
def __init__(self):
self.counters = {
EnvironmentEpisodes: 0,
EnvironmentSteps: 0,
TrainingSteps: 0
}
def __getitem__(self, key: Type[StepMethod]) -> int:
"""
get counter value
:param key: counter type
:return: the counter value
"""
return self.counters[key]
def __setitem__(self, key: StepMethod, item: int) -> None:
"""
set an item in the transition list
:param key: counter type
:param item: an integer representing the new counter value
:return: None
"""
self.counters[key] = item
def __add__(self, other: Type[StepMethod]) -> Type[StepMethod]:
return other.__class__(self.counters[other.__class__] + other.num_steps)
def __lt__(self, other: Type[StepMethod]):
return self.counters[other.__class__] < other.num_steps
class GradientClippingMethod(Enum):
ClipByGlobalNorm = 0
ClipByNorm = 1
ClipByValue = 2
[docs]class Episode(object):
"""
An Episode represents a set of sequential transitions, that end with a terminal state.
"""
def __init__(self, discount: float=0.99, bootstrap_total_return_from_old_policy: bool=False, n_step: int=-1):
"""
:param discount: the discount factor to use when calculating total returns
:param bootstrap_total_return_from_old_policy: should the total return be bootstrapped from the values in the
memory
:param n_step: the number of future steps to sum the reward over before bootstrapping
"""
self.transitions = []
self._length = 0
self.discount = discount
self.bootstrap_total_return_from_old_policy = bootstrap_total_return_from_old_policy
self.n_step = n_step
self.is_complete = False
[docs] def insert(self, transition: Transition) -> None:
"""
Insert a new transition to the episode. If the game_over flag in the transition is set to True,
the episode will be marked as complete.
:param transition: The new transition to insert to the episode
:return: None
"""
self.transitions.append(transition)
self._length += 1
if transition.game_over:
self.is_complete = True
[docs] def is_empty(self) -> bool:
"""
Check if the episode is empty
:return: A boolean value determining if the episode is empty or not
"""
return self.length() == 0
[docs] def length(self) -> int:
"""
Return the length of the episode, which is the number of transitions it holds.
:return: The number of transitions in the episode
"""
return self._length
def __len__(self):
return self.length()
[docs] def get_transition(self, transition_idx: int) -> Transition:
"""
Get a specific transition by its index.
:param transition_idx: The index of the transition to get
:return: The transition which is stored in the given index
"""
return self.transitions[transition_idx]
[docs] def get_last_transition(self) -> Transition:
"""
Get the last transition in the episode, or None if there are no transition available
:return: The last transition in the episode
"""
return self.get_transition(-1) if self.length() > 0 else None
[docs] def get_first_transition(self) -> Transition:
"""
Get the first transition in the episode, or None if there are no transitions available
:return: The first transition in the episode
"""
return self.get_transition(0) if self.length() > 0 else None
[docs] def update_discounted_rewards(self):
"""
Update the discounted returns for all the transitions in the episode.
The returns will be calculated according to the rewards of each transition, together with the number of steps
to bootstrap from and the discount factor, as defined by n_step and discount respectively when initializing
the episode.
:return: None
"""
if self.n_step == -1 or self.n_step > self.length():
curr_n_step = self.length()
else:
curr_n_step = self.n_step
rewards = np.array([t.reward for t in self.transitions])
rewards = rewards.astype('float')
discounted_rewards = rewards.copy()
current_discount = self.discount
for i in range(1, curr_n_step):
discounted_rewards += current_discount * np.pad(rewards[i:], (0, i), 'constant', constant_values=0)
current_discount *= self.discount
# calculate the bootstrapped returns
if self.bootstrap_total_return_from_old_policy:
bootstraps = np.array([np.squeeze(t.info['max_action_value']) for t in self.transitions[curr_n_step:]])
bootstrapped_return = discounted_rewards + current_discount * np.pad(bootstraps, (0, curr_n_step),
'constant', constant_values=0)
discounted_rewards = bootstrapped_return
for transition_idx in range(self.length()):
self.transitions[transition_idx].n_step_discounted_rewards = discounted_rewards[transition_idx]
def update_transitions_rewards_and_bootstrap_data(self):
if not isinstance(self.n_step, int) or (self.n_step < 1 and self.n_step != -1):
raise ValueError("n-step should be an integer with value >= 1, or set to -1 for always setting to episode"
" length.")
elif self.n_step > 1:
curr_n_step = self.n_step if self.n_step < self.length() else self.length()
for idx, transition in enumerate(self.transitions):
next_n_step_transition_idx = (idx + curr_n_step)
if next_n_step_transition_idx < len(self.transitions):
# next state will now point to the n-step next state
transition.next_state = self.transitions[next_n_step_transition_idx].state
transition.info['should_bootstrap_next_state'] = True
else:
transition.next_state = self.transitions[-1].next_state
transition.info['should_bootstrap_next_state'] = False
self.update_discounted_rewards()
[docs] def get_transitions_attribute(self, attribute_name: str) -> List[Any]:
"""
Get the values for some transition attribute from all the transitions in the episode.
For example, this allows getting the rewards for all the transitions as a list by calling
get_transitions_attribute('reward')
:param attribute_name: The name of the attribute to extract from all the transitions
:return: A list of values from all the transitions according to the attribute given in attribute_name
"""
if len(self.transitions) > 0 and hasattr(self.transitions[0], attribute_name):
return [getattr(t, attribute_name) for t in self.transitions]
elif len(self.transitions) == 0:
return []
else:
raise ValueError("The transitions have no such attribute name")
def __getitem__(self, sliced):
return self.transitions[sliced]
"""
Video Dumping Methods
"""
class VideoDumpFilter(object):
"""
Method used to decide when to dump videos
"""
def should_dump(self, episode_terminated=False, **kwargs):
raise NotImplementedError("")
class AlwaysDumpFilter(VideoDumpFilter):
"""
Dump video for every episode
"""
def __init__(self):
super().__init__()
def should_dump(self, episode_terminated=False, **kwargs):
return True
class MaxDumpFilter(VideoDumpFilter):
"""
Dump video every time a new max total reward has been achieved
"""
def __init__(self):
super().__init__()
self.max_reward_achieved = -np.inf
def should_dump(self, episode_terminated=False, **kwargs):
# if the episode has not finished yet we want to be prepared for dumping a video
if not episode_terminated:
return True
if kwargs['total_reward_in_current_episode'] > self.max_reward_achieved:
self.max_reward_achieved = kwargs['total_reward_in_current_episode']
return True
else:
return False
class EveryNEpisodesDumpFilter(object):
"""
Dump videos once in every N episodes
"""
def __init__(self, num_episodes_between_dumps: int):
super().__init__()
self.num_episodes_between_dumps = num_episodes_between_dumps
self.last_dumped_episode = 0
if num_episodes_between_dumps < 1:
raise ValueError("the number of episodes between dumps should be a positive number")
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['episode_idx'] >= self.last_dumped_episode + self.num_episodes_between_dumps - 1:
self.last_dumped_episode = kwargs['episode_idx']
return True
else:
return False
class SelectedPhaseOnlyDumpFilter(object):
"""
Dump videos when the phase of the environment matches a predefined phase
"""
def __init__(self, run_phases: Union[RunPhase, List[RunPhase]]):
self.run_phases = force_list(run_phases)
def should_dump(self, episode_terminated=False, **kwargs):
if kwargs['_phase'] in self.run_phases:
return True
else:
return False
# TODO move to a NamedTuple, once we move to Python3.6
# https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple/34269877
class CsvDataset(object):
def __init__(self, filepath: str, is_episodic: bool = True):
self.filepath = filepath
self.is_episodic = is_episodic
class PickledReplayBuffer(object):
def __init__(self, filepath: str):
self.filepath = filepath