tf.custom_gradient
View source on GitHub |
Decorator to define a function with a custom gradient.
tf.custom_gradient( f=None )
This decorator allows fine grained control over the gradients of a sequence for operations. This may be useful for multiple reasons, including providing a more efficient or numerically stable gradient for a sequence of operations.
For example, consider the following function that commonly occurs in the computation of cross entropy and log likelihoods:
def log1pexp(x): return tf.math.log(1 + tf.exp(x))
Due to numerical instability, the gradient of this function evaluated at x=100 is NaN. For example:
x = tf.constant(100.) y = log1pexp(x) dy = tf.gradients(y, x) # Will be NaN when evaluated.
The gradient expression can be analytically simplified to provide numerical stability:
@tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.math.log(1 + e), grad
With this definition, the gradient at x=100 will be correctly evaluated as 1.0.
The variable dy
is defined as the upstream gradient. i.e. the gradient from all the layers or functions originating from this layer.
By chain rule we know that dy/dx = dy/x_0 * dx_0/dx_1 * ... * dx_i/dx_i+1 * ... * dx_n/dx
In this case the gradient of our current function defined as dx_i/dx_i+1 = (1 - 1 / (1 + e))
. The upstream gradient dy
would be dx_i+1/dx_i+2 * dx_i+2/dx_i+3 * ... * dx_n/dx
. The upstream gradient multiplied by the current gradient is then passed downstream.
In case the function takes multiple variables as input, the grad
function must also return the same number of variables. We take the function z = x * y
as an example.
@tf.custom_gradient def bar(x, y): def grad(upstream): dz_dx = y dz_dy = x return upstream * dz_dx, upstream * dz_dy z = x * y return z, grad x = tf.constant(2.0, dtype=tf.float32) y = tf.constant(3.0, dtype=tf.float32) with tf.GradientTape(persistent=True) as tape: tape.watch(x) tape.watch(y) z = bar(x, y) z <tf.Tensor: shape=(), dtype=float32, numpy=6.0> tape.gradient(z, x) <tf.Tensor: shape=(), dtype=float32, numpy=3.0> tape.gradient(z, y) <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
Nesting custom gradients can lead to unintuitive results. The default behavior does not correspond to n-th order derivatives. For example
@tf.custom_gradient def op(x): y = op1(x) @tf.custom_gradient def grad_fn(dy): gdy = op2(x, y, dy) def grad_grad_fn(ddy): # Not the 2nd order gradient of op w.r.t. x. return op3(x, y, dy, ddy) return gdy, grad_grad_fn return y, grad_fn
The function grad_grad_fn
will be calculating the first order gradient of grad_fn
with respect to dy
, which is used to generate forward-mode gradient graphs from backward-mode gradient graphs, but is not the same as the second order gradient of op
with respect to x
.
Instead, wrap nested @tf.custom_gradients
in another function:
@tf.custom_gradient def op_with_fused_backprop(x): y, x_grad = fused_op(x) def first_order_gradient(dy): @tf.custom_gradient def first_order_custom(unused_x): def second_order_and_transpose(ddy): return second_order_for_x(...), gradient_wrt_dy(...) return x_grad, second_order_and_transpose return dy * first_order_custom(x) return y, first_order_gradient
Additional arguments to the inner @tf.custom_gradient
-decorated function control the expected return values of the innermost function.
See also tf.RegisterGradient
which registers a gradient function for a primitive TensorFlow operation. tf.custom_gradient
on the other hand allows for fine grained control over the gradient computation of a sequence of operations.
Note that if the decorated function uses Variable
s, the enclosing variable scope must be using ResourceVariable
s.
Args | |
---|---|
f | function f(*x) that returns a tuple (y, grad_fn) where:
In a pure mathematical sense, a vector-argument vector-valued function If |
Returns | |
---|---|
A function h(x) which returns the same value as f(x)[0] and whose gradient (as calculated by tf.gradients ) is determined by f(x)[1] . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/custom_gradient