tf.distribute.DistributedValues
Base class for representing distributed values.
tf.distribute.DistributedValues( values )
A subclass instance of tf.distribute.DistributedValues
is created when creating variables within a distribution strategy, iterating a tf.distribute.DistributedDataset
or through tf.distribute.Strategy.run
. This base class should never be instantiated directly. tf.distribute.DistributedValues
contains a value per replica. Depending on the subclass, the values could either be synced on update, synced on demand, or never synced.
tf.distribute.DistributedValues
can be reduced to obtain single value across replicas, as input into tf.distribute.Strategy.run
or the per-replica values inspected using tf.distribute.Strategy.experimental_local_results
.
Example usage:
- Created from a
tf.distribute.DistributedDataset
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) distributed_values = next(dataset_iterator)
- Returned by
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) @tf.function def run(): ctx = tf.distribute.get_replica_context() return ctx.replica_id_in_sync_group distributed_values = strategy.run(run)
- As input into
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) distributed_values = next(dataset_iterator) @tf.function def run(input): return input + 1.0 updated_value = strategy.run(run, args=(distributed_values,))
- Reduce value:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) distributed_values = next(dataset_iterator) reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM, distributed_values, axis = 0)
- Inspect local replica values:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) per_replica_values = strategy.experimental_local_results( distributed_values) per_replica_values (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/distribute/DistributedValues