Source code for rl_coach.data_stores.s3_data_store

#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from rl_coach.data_stores.data_store import DataStoreParameters
from rl_coach.data_stores.checkpoint_data_store import CheckpointDataStore
from minio import Minio
from minio.error import ResponseError
from configparser import ConfigParser, Error
from rl_coach.checkpoint import CheckpointStateFile
from rl_coach.data_stores.data_store import SyncFiles

import os
import time
import io


class S3DataStoreParameters(DataStoreParameters):
    def __init__(self, ds_params, creds_file: str = None, end_point: str = None, bucket_name: str = None,
                 checkpoint_dir: str = None, expt_dir: str = None):

        super().__init__(ds_params.store_type, ds_params.orchestrator_type, ds_params.orchestrator_params)
        self.creds_file = creds_file
        self.end_point = end_point
        self.bucket_name = bucket_name
        self.checkpoint_dir = checkpoint_dir
        self.expt_dir = expt_dir


[docs]class S3DataStore(CheckpointDataStore): """ An implementation of the data store using S3 for storing policy checkpoints when using Coach in distributed mode. The policy checkpoints are written by the trainer and read by the rollout worker. """ def __init__(self, params: S3DataStoreParameters): """ :param params: The parameters required to use the S3 data store. """ super(S3DataStore, self).__init__(params) self.params = params access_key = None secret_key = None if params.creds_file: config = ConfigParser() config.read(params.creds_file) try: access_key = config.get('default', 'aws_access_key_id') secret_key = config.get('default', 'aws_secret_access_key') except Error as e: print("Error when reading S3 credentials file: %s", e) else: access_key = os.environ.get('ACCESS_KEY_ID') secret_key = os.environ.get('SECRET_ACCESS_KEY') self.mc = Minio(self.params.end_point, access_key=access_key, secret_key=secret_key) def deploy(self) -> bool: return True def get_info(self): return "s3://{}/{}".format(self.params.bucket_name) def undeploy(self) -> bool: return True def save_to_store(self): self._save_to_store(self.params.checkpoint_dir) def _save_to_store(self, checkpoint_dir): """ save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode. """ try: # remove lock file if it exists self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) # Acquire lock self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0) state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir)) if state_file.exists(): ckpt_state = state_file.read() checkpoint_file = None for root, dirs, files in os.walk(checkpoint_dir): for filename in files: if filename == CheckpointStateFile.checkpoint_state_filename: checkpoint_file = (root, filename) continue if filename.startswith(ckpt_state.name): abs_name = os.path.abspath(os.path.join(root, filename)) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1])) rel_name = os.path.relpath(abs_name, checkpoint_dir) self.mc.fput_object(self.params.bucket_name, rel_name, abs_name) # upload Finished if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0) # upload Ready if present if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)): self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0) # release lock self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value) if self.params.expt_dir and os.path.exists(self.params.expt_dir): for filename in os.listdir(self.params.expt_dir): if filename.endswith((".csv", ".json")): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename)) if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')): for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')): self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename)) except ResponseError as e: print("Got exception: %s\n while saving to S3", e) def load_from_store(self): """ load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used by the rollout workers when using Coach in distributed mode. """ try: state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir)) # wait until lock is removed while True: objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value) if next(objects, None) is None: try: # fetch checkpoint state file from S3 self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path) except Exception as e: continue break time.sleep(10) # Check if there's a finished file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.FINISHED.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value)) ) except Exception as e: pass # Check if there's a ready file objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value) if next(objects, None) is not None: try: self.mc.fget_object( self.params.bucket_name, SyncFiles.TRAINER_READY.value, os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value)) ) except Exception as e: pass checkpoint_state = state_file.read() if checkpoint_state is not None: objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True) for obj in objects: filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name)) if not os.path.exists(filename): self.mc.fget_object(obj.bucket_name, obj.object_name, filename) except ResponseError as e: print("Got exception: %s\n while loading from S3", e) def setup_checkpoint_dir(self, crd=None): if crd: self._save_to_store(crd)