Source code for nlp_architect.nn.tensorflow.python.keras.utils.layer_utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import pickle
import tempfile

from tensorflow import keras


[docs]def save_model(model: keras.models.Model, topology: dict, filepath: str) -> None: """ Save a model to a file (tf.keras models only) The method save the model topology, as given as a Args: model: model object topology (dict): a dictionary of topology elements and their values filepath (str): path to save model """ with tempfile.NamedTemporaryFile(suffix=".h5", delete=True) as fd: model.save_weights(fd.name) model_weights = fd.read() data = {"model_weights": model_weights, "model_topology": topology} with open(filepath, "wb") as fp: pickle.dump(data, fp)
[docs]def load_model(filepath, model) -> None: """ Load a model (tf.keras) from disk, create topology from loaded values and load weights. Args: filepath (str): path to model model: model object to load """ with open(filepath, "rb") as fp: model_data = pickle.load(fp) topology = model_data["model_topology"] model.build(**topology) with tempfile.NamedTemporaryFile(suffix=".h5", delete=True) as fd: fd.write(model_data["model_weights"]) fd.flush() model.model.load_weights(fd.name)