tf.estimator.TrainSpec

Configuration for the "train" part for the train_and_evaluate call.

TrainSpec determines the input data for the training, as well as the duration. Optional hooks run at various stages of training.

Usage:

train_spec = tf.estimator.TrainSpec(
   input_fn=lambda: 1,
   max_steps=100,
   hooks=[_StopAtSecsHook(stop_after_secs=10)],
   saving_listeners=[_NewCheckpointListenerForEvaluate(None, 20, None)])
train_spec.saving_listeners[0]._eval_throttle_secs
20
train_spec.hooks[0]._stop_after_secs
10
train_spec.max_steps
100
Args
input_fn A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following:
  • A 'tf.data.Dataset' object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below.
  • A tuple (features, labels): Where features is a Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor.
max_steps Int. Positive number of total steps for which to train model. If None, train forever. The training input_fn is not expected to generate OutOfRangeError or StopIteration exceptions. See the train_and_evaluate stop condition section for details.
hooks Iterable of tf.train.SessionRunHook objects to run on all workers (including chief) during training.
saving_listeners Iterable of tf.estimator.CheckpointSaverListener objects to run on chief during training.
Raises
ValueError If any of the input arguments is invalid.
TypeError If any of the arguments is not of the expected type.
Attributes
input_fn
max_steps
hooks
saving_listeners

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/estimator/TrainSpec