tf.compat.v2.sparse.split

Split a SparseTensor into num_split tensors along axis.

If the sp_input.dense_shape[axis] is not an integer multiple of num_split each slice starting from 0:shape[axis] % num_split gets extra one dimension. For example, if axis = 1 and num_split = 2 and the input is:

input_tensor = shape = [2, 7]
[    a   d e  ]
[b c          ]

Graphically the output tensors are:

output_tensor[0] =
[    a ]
[b c   ]

output_tensor[1] =
[ d e  ]
[      ]
Args
sp_input The SparseTensor to split.
num_split A Python integer. The number of ways to split.
axis A 0-D int32 Tensor. The dimension along which to split.
name A name for the operation (optional).
Returns
num_split SparseTensor objects resulting from splitting value.
Raises
TypeError If sp_input is not a SparseTensor.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/compat/v2/sparse/split