tf.keras.mixed_precision.set_global_policy

Sets the global dtype policy.

The global policy is the default tf.keras.mixed_precision.Policy used for layers, if no policy is passed to the layer constructor.

tf.keras.mixed_precision.set_global_policy('mixed_float16')
tf.keras.mixed_precision.global_policy()
<Policy "mixed_float16">
tf.keras.layers.Dense(10).dtype_policy
<Policy "mixed_float16">
# Global policy is not used if a policy is directly passed to constructor
tf.keras.layers.Dense(10, dtype='float64').dtype_policy
<Policy "float64">
tf.keras.mixed_precision.set_global_policy('float32')

If no global policy is set, layers will instead default to a Policy constructed from tf.keras.backend.floatx().

To use mixed precision, the global policy should be set to 'mixed_float16' or 'mixed_bfloat16', so that every layer uses a 16-bit compute dtype and float32 variable dtype by default.

Only floating point policies can be set as the global policy, such as 'float32' and 'mixed_float16'. Non-floating point policies such as 'int32' and 'complex64' cannot be set as the global policy because most layers do not support such policies.

See tf.keras.mixed_precision.Policy for more information.

Args
policy A Policy, or a string that will be converted to a Policy. Can also be None, in which case the global policy will be constructed from tf.keras.backend.floatx()

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/keras/mixed_precision/set_global_policy