tf.contrib.distribute.StandardSingleLossStep
A step function that implements a training step for a feed forward network.
Inherits From: StandardInputStep
tf.contrib.distribute.StandardSingleLossStep( dataset_fn, loss_fn, optimizer, distribution, iterations_per_step=1 )
An instance of this class is intended to be used as a callable:
... step = step_fn.StandardSingleLossStep( dataset, loss_fn, optimizer, distribution) # Run a single training step on a given DistributionStrategy: step(distribution) ...
Args | |
---|---|
dataset_fn | a function that returns a tf.data Dataset that produces the input for the model. |
loss_fn | a function that takes a context and inputs as arguments. It returns the loss for those inputs. context is an instance of values.MultiStepContext that will be passed when loss_fn is run. context can be used to specify the outputs to be returned from loss_fn , among other things. |
optimizer | an optimizer that implements an update rule. |
distribution | a DistributionStrategy object. |
Attributes | |
---|---|
distribution |
Methods
initialize
initialize()
__call__
__call__()
Perform one step of this training algorithm.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/distribute/StandardSingleLossStep