tf.distribute.InputContext

A class wrapping information needed by an input function.

This is a context class that is passed to the user's input function and contains information about the compute replicas and input pipelines. The number of compute replicas (in sync training) helps compute the local batch size from the desired global batch size for each replica. The input pipeline information can be used to return a different subset of the input in each replica (for e.g. shard the input pipeline, use a different input source etc).

Args
num_input_pipelines the number of input pipelines in a cluster.
input_pipeline_id the current input pipeline id, should be an int in [0,num_input_pipelines).
num_replicas_in_sync the number of replicas that are in sync.
Attributes
input_pipeline_id Returns the input pipeline ID.
num_input_pipelines Returns the number of input pipelines.
num_replicas_in_sync Returns the number of compute replicas in sync.

Methods

get_per_replica_batch_size

View source

Returns the per-replica batch size.

Args
global_batch_size the global batch size which should be divisible by num_replicas_in_sync.
Returns
the per-replica batch size.
Raises
ValueError if global_batch_size not divisible by num_replicas_in_sync.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/distribute/InputContext