python-pytorchHow can I use a Long Short-Term Memory (LSTM) network in Python with PyTorch?
To use a Long Short-Term Memory (LSTM) network in Python with PyTorch, you need to import the necessary packages and define the network architecture.
import torch
import torch.nn as nn
class LSTMNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
# Define the LSTM layer
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
# Define the output layer
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Initialize hidden state
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# Initialize cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# One time step
out, (hn, cn) = self.lstm(x, (h0, c0))
# Index hidden state of last time step
# out.size() --> 100, 28, 100
# out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
out = self.fc(out[:, -1, :])
# out.size() --> 100, 10
return out
input_dim = 28
hidden_dim = 100
layer_dim = 2
output_dim = 10
model = LSTMNetwork(input_dim, hidden_dim, layer_dim, output_dim)
The code above defines the architecture of an LSTM network with two layers, a hidden dimension of 100, and an output dimension of 10. The forward method performs one time step of the LSTM and returns the output of the last time step.
Code explanation
import torch
andimport torch.nn as nn
: Import the necessary packages.class LSTMNetwork(nn.Module):
: Define the network architecture as a class.self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
: Define the LSTM layer.self.fc = nn.Linear(hidden_dim, output_dim)
: Define the output layer.h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
andc0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
: Initialize the hidden state and cell state.out, (hn, cn) = self.lstm(x, (h0, c0))
: Perform one time step of the LSTM.out = self.fc(out[:, -1, :])
: Index the hidden state of the last time step.
Helpful links
More of Python Pytorch
- How do I use Pytorch with Python 3.11 on Windows?
- How can I use Python and PyTorch to parse XML files?
- How can I use Python and PyTorch to create a Zoom application?
- How can I use Yolov5 with PyTorch?
- How can I use Python, PyTorch, and YOLOv5 to build an object detection model?
- How do I install PyTorch on a Windows computer?
- How can I use Python and PyTorch to create a U-Net architecture?
- What is the most compatible version of Python to use with PyTorch?
- How do I check which versions of Python are supported by PyTorch?
- How do I install a Python PyTorch .whl file?
See more codes...