tf.ensure_shape
| View source on GitHub |
Updates the shape of a tensor and checks at runtime that the shape holds.
tf.ensure_shape(
x, shape, name=None
)
With eager execution this is a shape assertion, that returns the input:
x = tf.constant([1,2,3]) print(x.shape) (3,) x = tf.ensure_shape(x, [3]) x = tf.ensure_shape(x, [5]) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape]
Inside a tf.function or v1.Graph context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape which only checks the buildtime shape.
Note: This differs fromtf.Tensor.set_shapein that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape.tf.Tensor.set_shapesets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors.
For example, of loading images of a known size:
@tf.function
def decode_image(png):
image = tf.image.decode_png(png, channels=3)
# the `print` executes during tracing.
print("Initial shape: ", image.shape)
image = tf.ensure_shape(image,[28, 28, 3])
print("Final shape: ", image.shape)
return image
When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.
concrete_decode = decode_image.get_concrete_function(
tf.TensorSpec([], dtype=tf.string))
Initial shape: (None, None, 3)
Final shape: (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) image = tf.cast(image,tf.uint8) png = tf.image.encode_png(image) image2 = concrete_decode(png) print(image2.shape) (28, 28, 3)
image = tf.concat([image,image], axis=0) print(image.shape) (56, 28, 3) png = tf.image.encode_png(image) image2 = concrete_decode(png) Traceback (most recent call last): tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3].
@tf.function
def bad_decode_image(png):
image = tf.image.decode_png(png, channels=3)
# the `print` executes during tracing.
print("Initial shape: ", image.shape)
# BAD: forgot to use the returned tensor.
tf.ensure_shape(image,[28, 28, 3])
print("Final shape: ", image.shape)
return image
image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) print(image.shape) (56, 28, 3)
| Args | |
|---|---|
x | A Tensor. |
shape | A TensorShape representing the shape of this tensor, a TensorShapeProto, a list, a tuple, or None. |
name | A name for this operation (optional). Defaults to "EnsureShape". |
| Returns | |
|---|---|
A Tensor. Has the same type and contents as x. At runtime, raises a tf.errors.InvalidArgumentError if shape is incompatible with the shape of x. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/ensure_shape