tf.contrib.training.create_train_op
Creates an Operation
that evaluates the gradients and returns the loss.
tf.contrib.training.create_train_op( total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops=None, variables_to_train=None, transform_grads_fn=None, summarize_gradients=False, gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, check_numerics=True )
Args | |
---|---|
total_loss | A Tensor representing the total loss. |
optimizer | A tf.Optimizer to use for computing the gradients. |
global_step | A Tensor representing the global step variable. If left as _USE_GLOBAL_STEP , then tf.contrib.framework.global_step() is used. |
update_ops | An optional list of updates to execute. If update_ops is None , then the update ops are set to the contents of the tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None , but it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS , a warning will be displayed. |
variables_to_train | an optional list of variables to train. If None, it will default to all tf.compat.v1.trainable_variables(). |
transform_grads_fn | A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list. |
summarize_gradients | Whether or not add summaries for each gradient. |
gate_gradients | How to gate the computation of gradients. See tf.Optimizer. |
aggregation_method | Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod . |
colocate_gradients_with_ops | Whether or not to try colocating the gradients with the ops that generated them. |
check_numerics | Whether or not we apply check_numerics. |
Returns | |
---|---|
A Tensor that when evaluated, computes the gradients and returns the total loss value. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/training/create_train_op