tf.clip_by_norm
View source on GitHub |
Clips tensor values to a maximum L2-norm.
tf.clip_by_norm( t, clip_norm, axes=None, name=None )
Given a tensor t
, and a maximum clip value clip_norm
, this operation normalizes t
so that its L2-norm is less than or equal to clip_norm
, along the dimensions given in axes
. Specifically, in the default case where all dimensions are used for calculation, if the L2-norm of t
is already less than or equal to clip_norm
, then t
is not modified. If the L2-norm is greater than clip_norm
, then this operation returns a tensor of the same type and shape as t
with its values set to:
t * clip_norm / l2norm(t)
In this case, the L2-norm of the output tensor is clip_norm
.
As another example, if t
is a matrix and axes == [1]
, then each row of the output will have L2-norm less than or equal to clip_norm
. If axes == [0]
instead, each column of the output will be clipped.
Code example:
some_nums = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.float32) tf.clip_by_norm(some_nums, 2.0).numpy() array([[0.26967996, 0.5393599 , 0.80903983, 1.0787199 , 1.3483998 ]], dtype=float32)
This operation is typically used to clip gradients before applying them with an optimizer. Most gradient data is a collection of different shaped tensors for different parts of the model. Thus, this is a common usage:
# Get your gradients after training loss_value, grads = grad(model, features, labels) # Apply some clipping grads = [tf.clip_by_norm(g, norm) for g in grads] # Continue on with training optimizer.apply_gradients(grads)
Args | |
---|---|
t | A Tensor or IndexedSlices . This must be a floating point type. |
clip_norm | A 0-D (scalar) Tensor > 0. A maximum clipping value, also floating point |
axes | A 1-D (vector) Tensor of type int32 containing the dimensions to use for computing the L2-norm. If None (the default), uses all dimensions. |
name | A name for the operation (optional). |
Returns | |
---|---|
A clipped Tensor or IndexedSlices . |
Raises | |
---|---|
ValueError | If the clip_norm tensor is not a 0-D scalar tensor. |
TypeError | If dtype of the input is not a floating point or complex type. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/clip_by_norm