9951 explained code solutions for 126 technologies


python-pytorchHow do I use the argmax function in Python PyTorch?


The argmax function in PyTorch is used to retrieve the index of the maximum value of a given tensor along a particular dimension.

For example, if x is a tensor with the values [1, 2, 3, 4], then torch.argmax(x) will return 3, because 4 is the maximum value in the tensor.

Here is an example code block that uses argmax:

import torch
x = torch.tensor([1, 2, 3, 4])
torch.argmax(x)

The output of this code is 3.

Code explanation

  • import torch: This is used to import the PyTorch library.
  • x = torch.tensor([1, 2, 3, 4]): This creates a tensor x with the values [1, 2, 3, 4].
  • torch.argmax(x): This applies the argmax function to the tensor x and returns the index of the maximum value in the tensor.

For more information on the argmax function in PyTorch, please see the following links:

Edit this code on GitHub