tf.saved_model.load_v2
Load a SavedModel from export_dir
.
tf.saved_model.load_v2( export_dir, tags=None )
Signatures associated with the SavedModel are available as functions:
imported = tf.saved_model.load(path) f = imported.signatures["serving_default"] print(f(x=tf.constant([[1.]])))
Objects exported with tf.saved_model.save
additionally have trackable objects and functions assigned to attributes:
exported = tf.train.Checkpoint(v=tf.Variable(3.)) exported.f = tf.function( lambda x: exported.v * x, input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) tf.saved_model.save(exported, path) imported = tf.saved_model.load(path) assert 3. == imported.v.numpy() assert 6. == imported.f(x=tf.constant(2.)).numpy()
Loading Keras models
Keras models are trackable, so they can be saved to SavedModel. The object returned by tf.saved_model.load
is not a Keras object (i.e. doesn't have .fit
, .predict
, etc. methods). A few attributes and functions are still available: .variables
, .trainable_variables
and .__call__
.
model = tf.keras.Model(...) tf.saved_model.save(model, path) imported = tf.saved_model.load(path) outputs = imported(inputs)
Use tf.keras.models.load_model
to restore the Keras model.
Importing SavedModels from TensorFlow 1.x
SavedModels from tf.estimator.Estimator
or 1.x SavedModel APIs have a flat graph instead of tf.function
objects. These SavedModels will have functions corresponding to their signatures in the .signatures
attribute, but also have a .prune
method which allows you to extract functions for new subgraphs. This is equivalent to importing the SavedModel and naming feeds and fetches in a Session from TensorFlow 1.x.
imported = tf.saved_model.load(path_to_v1_saved_model) pruned = imported.prune("x:0", "out:0") pruned(tf.ones([]))
See tf.compat.v1.wrap_function
for details. These SavedModels also have a .variables
attribute containing imported variables, and a .graph
attribute representing the whole imported graph. For SavedModels exported from tf.saved_model.save
, variables are instead assigned to whichever attributes they were assigned before export.
Args | |
---|---|
export_dir | The SavedModel directory to load from. |
tags | A tag or sequence of tags identifying the MetaGraph to load. Optional if the SavedModel contains a single MetaGraph, as for those exported from tf.saved_model.load . |
Returns | |
---|---|
A trackable object with a signatures attribute mapping from signature keys to functions. If the SavedModel was exported by tf.saved_model.load , it also points to trackable objects and functions which were attached to the exported object. |
Raises | |
---|---|
ValueError | If tags don't match a MetaGraph in the SavedModel. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/saved_model/load_v2