9951 explained code solutions for 126 technologies


python-pytorchHow can I reshape a tensor using Python and Pytorch?


Reshaping a tensor using Python and Pytorch is a simple task. PyTorch provides a torch.view() function to reshape a tensor. The view() function returns a new tensor with the same data as the original tensor but with a different shape.

For example:

import torch
x = torch.randn(4, 4)
y = x.view(16)
print(x.size(), y.size())

# Output: torch.Size([4, 4]) torch.Size([16])

The code above creates a 4x4 tensor x and reshapes it to a single dimension of size 16 using the view() function.

The view() function takes in -1 as an argument to automatically calculate the correct dimension size.

For example:

import torch
x = torch.randn(2, 3, 4)
y = x.view(-1, 4)
print(x.size(), y.size())

# Output: torch.Size([2, 3, 4]) torch.Size([6, 4])

The code above creates a 2x3x4 tensor x and reshapes it to a 6x4 tensor using the view() function with -1 as an argument.

Helpful links

Edit this code on GitHub