tf.contrib.opt.extend_with_decoupled_weight_decay
Factory function returning an optimizer class with decoupled weight decay.
tf.contrib.opt.extend_with_decoupled_weight_decay( base_optimizer )
Returns an optimizer class. An instance of the returned class computes the update step of base_optimizer
and additionally decays the weights. E.g., the class returned by extend_with_decoupled_weight_decay(tf.compat.v1.train.AdamOptimizer)
is equivalent to tf.contrib.opt.AdamWOptimizer
.
The API of the new optimizer class slightly differs from the API of the base optimizer:
- The first argument to the constructor is the weight decay rate.
-
minimize
andapply_gradients
accept the optional keyword argumentdecay_var_list
, which specifies the variables that should be decayed. IfNone
, all variables that are optimized are decayed.
Usage example:
# MyAdamW is a new class MyAdamW = extend_with_decoupled_weight_decay(tf.compat.v1.train.AdamOptimizer) # Create a MyAdamW object optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step!
Args | |
---|---|
base_optimizer | An optimizer class that inherits from tf.train.Optimizer. |
Returns | |
---|---|
A new optimizer class that inherits from DecoupledWeightDecayExtension and base_optimizer. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/opt/extend_with_decoupled_weight_decay