tf.distribute.experimental.partitioners.Partitioner

Partitioner base class: all partitiners inherit from this class.

Partitioners should implement a __call__ method with the following signature:

def __call__(self, shape, dtype, axis=0):
  # Partitions the given `shape` and returns the partition results.
  # See docstring of `__call__` method for the format of partition results.

Methods

__call__

View source

Partitions the given shape and returns the partition results.

Examples of a partitioner that allocates a fixed number of shards:

partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]
Args
shape a tf.TensorShape, the shape to partition.
dtype a tf.dtypes.Dtype indicating the type of the partition value.
axis The axis to partition along. Default: outermost axis.
Returns
A list of integers representing the number of partitions on each axis, where i-th value correponds to i-th axis.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/distribute/experimental/partitioners/Partitioner