How to use TensorFlow get_shape Function (get tensor shape)

In this TensorFlow tutorial, I will explain how to use the TensorFlow get_shape function. This function returns the shape of the given tensor.

In my project, I had to process the image with a dimension of 4. However, I had to validate the dimensions of the images, such as batch size, height, width, and color channels.

The function get_shape() helped me validate the dimension of the image being processed. In this tutorial, I have shown how I used the TensorFlow get_shape function in my project to get the shape of the data or dataset.

Apart from this TensorFlow get_shape, I have also explained an equivalent function tf.shape() and a property called tensor.shape to find the tensor’s shape.

What is the TensorFlow get_shape function?

TensorFlow has a built-in function get_shape() that returns the tensor shape on which it is called. In simpler words, if you have created a tensor containing values and want to know its dimension or size, call the TensorFlow get_shape function on that created tensor.

The syntax for using TensorFlow get_shape() is given below.

tensor_name.get_shape()

Where,

  • tensor_name.get_shape(): Calling get_shape() on the tensor_name whose shape you want to know.

For example, let’s create a new tensor and get its shape.

Import the tensorFlow library and create a tensor named tensor_data, as shown below.

import tensorflow as tf

tensor_data = tf.constant([[34, 56, 8], [56, 77, 25]])

To know the shape of the tensor_data, call the get_shape() function on it, as shown below.

print(tensor_data.get_shape())
tensorflow get_shape

Look at the output. When you call TensorFlow get_shape on the tensor_data, it returns the tuple values (2, 3), the tensor’s shape.

READ:  Python Turtle Cheat Sheet

Let’s understand the returned shape in tuple (2, 3). The tuple (2, 3) denotes the shape of a 2-dimensional tensor or matrix in TensorFlow.

To determine dimension, look at the number of elements in the tuple; in this case, it has two elements in a tuple (2, 3), so it is a 2-dimensional tensor. It would be a three-dimensional tensor if it contains three elements like (2, 3, 4).

To know the tensor’s size, each element in a tuple denotes its size in that dimension.

  • The first element, ‘2’, denotes the size of the first dimension; from the matrix perspective, it represents the number of rows, which are 2 rows in this tensor. The second element, ‘3’, denotes the size of the second dimension, which is the number of columns in a matrix; in this case, the tensor has 3 columns.

To learn more about shape and dimension, visit this tutorial, Understanding the Dimensions and Shape of Tensor in TensorFlow.

Let’s take another example: check the shape of the higher-dimension tensor.

Create a new tensor, as shown in the below code.

tensor_data = tf.constant([[[2, 3, 4]],[[7, 2, 8]], [[4, 7, 9]], [[1, 4, 8]]])

print(tensor_data.get_shape())
get_shape tensorflow

Look at the shape of the tensor_data (4, 1, 3); let’s interpret it. So, what is the dimension of this tuple? It is a 3-dimensional tensor because a tuple contains three elements.

Now, in the tuple (4, 1, 3), element 4 is the first dimension of the tensor with size of 4. Next, element 1 is the second dimension with the size of 1, which means that for each item in the first dimension, there is only a single sub-element.

READ:  Horizontal line matplotlib

Lastly, element 3 is the third dimension with a size of 3, meaning that each sub-element of the tensor in the second dimension contains 3 values.

This is how to use the TensorFlow get_shape() function to retrieve the tensor’s shape and interpret them.

You are familiar with get_shape() in TensorFlow, but TensorFlow provides another function called tf.shape equivalent to get_shape().

The complete syntax of tf.shape is given below.

tf.shape
        (
         input,
         out_type=tf.dtypes.int32,
         name=None
        )

Where parameters are:

  • input: This parameter indicates the input tensor whose shape you want to know
  • out_type= By default, it takes tf.dtypes.int32 value. This is an optional parameter and defines the output type.
  • name: This parameter indicates the name of the operation.

For example, execute the below code.

import tensorflow as tf

tensor = tf.constant([[[15, 67, 89], [34, 27, 89]], 
                [[45, 89, 189], [68, 91, 46]]])

result=tf.shape(tensor)

print(result)
tf.shape

When you execute the above code, it returns the shape as ‘tf.Tensor([2 2 3], shape=(3,), dtype=int32)’. However, look at part [2 2 3]; these values are obtained when you use the get_shape() function.

You already know how to interpret [2, 2, 3], so don’t worry if it is in list form; [2 2 3] and (2 2 3) represent the same thing.

Before ending this tutorial, I want to show you one more way to get the shape of the tensor. You can call a property named tensor.shape on the tensor to get the shape.

For example, execute the code below.

import tensorflow as tf
tensor_value = tf.constant([[3, 5], [1, 8]])
print(tensor_value.shape)
python tensor shape

In the above code, tensor_value is a tensor, and then the property shape is called on it like tensor_value.shape; then, it returns the shape as (2,2).

READ:  PyTorch Dataloader + Examples

Either you call tensor_value.get_shape() or tf.shape(tensor_value) or tensor_value.shape, you get the same shape. To verify that, execute the code below.

print('using tensor_value.shape', tensor_value.shape)
print('using tf.shape(tensor_value)', tf.shape(tensor_value))
print('using tensor_value.get_shape()', tensor_value.get_shape())
shape in tensorflow

Look, you get the same shape (2, 2) of tensor_value.

So, you can use TensorFlow get_shape, tf.shape(), and tensor.shape to find the shape of any tensor.

One important thing to remember is that interpreting the tensor’s shape depends on the context. For example, machine learning and image processing can interpret the shape differently.

This tutorial does not explain the shape of a tensor according to context. However, you can find more information on TensorFlow documentation.

Conclusion

In this TensorFlow tutorial, you learned how to use the TensorFlow get_shape() function to get tensor shape.

Additionally, you learned how to use the tf.shape() function and a property tensor.shape to get shape of tensor.

You may like to read: