python-pytorchHow can I use a Long Short-Term Memory (LSTM) network in Python with PyTorch?
To use a Long Short-Term Memory (LSTM) network in Python with PyTorch, you need to import the necessary packages and define the network architecture.
import torch
import torch.nn as nn
class LSTMNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
# Define the LSTM layer
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
# Define the output layer
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Initialize hidden state
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# Initialize cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
# One time step
out, (hn, cn) = self.lstm(x, (h0, c0))
# Index hidden state of last time step
# out.size() --> 100, 28, 100
# out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
out = self.fc(out[:, -1, :])
# out.size() --> 100, 10
return out
input_dim = 28
hidden_dim = 100
layer_dim = 2
output_dim = 10
model = LSTMNetwork(input_dim, hidden_dim, layer_dim, output_dim)
The code above defines the architecture of an LSTM network with two layers, a hidden dimension of 100, and an output dimension of 10. The forward method performs one time step of the LSTM and returns the output of the last time step.
Code explanation
import torch
andimport torch.nn as nn
: Import the necessary packages.class LSTMNetwork(nn.Module):
: Define the network architecture as a class.self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
: Define the LSTM layer.self.fc = nn.Linear(hidden_dim, output_dim)
: Define the output layer.h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
andc0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
: Initialize the hidden state and cell state.out, (hn, cn) = self.lstm(x, (h0, c0))
: Perform one time step of the LSTM.out = self.fc(out[:, -1, :])
: Index the hidden state of the last time step.
Helpful links
More of Python Pytorch
- How can I use Python, PyTorch, and YOLOv5 to build an object detection model?
- How do I use Pytorch with Python 3.11 on Windows?
- How can I use Python and PyTorch to create a Zoom application?
- How can I use Python and PyTorch to parse XML files?
- How can I use Yolov5 with PyTorch?
- What is the most compatible version of Python to use with PyTorch?
- How can I compare Python PyTorch and Torch for software development?
- How do I update PyTorch using Python?
- How do I check the version of Python and PyTorch I am using?
- How do I save a PyTorch tensor to a file using Python?
See more codes...