tf.contrib.opt.ModelAverageCustomGetter
Custom_getter class is used to do.
tf.contrib.opt.ModelAverageCustomGetter( worker_device )
- Change trainable variables to local collection and place them at worker device
- Generate global variables Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device. Besides, use 'tf.compat.v1.get_variable' instead of 'tf.Variable' to use this custom getter.
For example, ma_custom_getter = ModelAverageCustomGetter(worker_device) with tf.device( tf.compat.v1.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)), tf.compat.v1.variable_scope('',custom_getter=ma_custom_getter): hid_w = tf.compat.v1.get_variable( initializer=tf.random.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.compat.v1.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), name="hid_b")
Args | |
---|---|
worker_device | String. Name of the worker job. |
Methods
__call__
__call__( getter, name, trainable, collections, *args, **kwargs )
Call self as a function.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/opt/ModelAverageCustomGetter