LSTM
LSTM(Long Short-Term Memory,长短期记忆)是一种循环神经网络(RNN)的变体,专门设计用来解决传统 RNN 在处理长序列数据时的梯度消失或梯度爆炸等问题。
LSTM 通过引入一种称为“门”的机制来控制信息的流动,从而更有效地处理长序列数据并捕获序列中的长期依赖关系。这些门是一种学习的机制,允许 LSTM 在每个时间步骤上选择性地记住、遗忘或输出信息。
#=> input x
import torch
import torch.nn as nn
from torch.optim import SGD
class RNNLM(nn.Module):
def __init__(self, vocab_size, d_emb, d_hid, n_layer):
super(RNNLM, self).__init__()
self.emb = nn.Embedding(vocab_size, d_emb)
self.lstm = nn.LSTM(d_emb, d_hid, n_layer, batch_first=True)
self.clf = nn.Linear(d_hid, vocab_size)
def forward(self, x, hc=None):
# word2vec
x = self.emb(x) # (bsz, seq_len, d_emb)
bsz, seq_len, d_emb = x.shape
# hs: (bsz , seq_len, d_hid)
# h : (n_layer, bsz , d_hid)
# hs[:, -1] = h[-1]
hs, (h, c) = self.lstm(x, hc)
hs = hs.reshape(bsz * seq_len, hs.size(2))
logits = self.clf(hs)
return logits, (h, c)
if __name__ == "__main__":
# config
vocab_size, d_emb, d_hid, n_layer = 16, 8, 4, 2
lr, bsz = 1e-4, 2
# data
x = torch.Tensor([[1, 2, 3], [7, 8, 9]]).long()
y = torch.Tensor([[2, 3, 4], [8, 9, 10]]).long()
# model
model = RNNLM(vocab_size, d_emb, d_hid, n_layer)
optimizer = SGD(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# training
hc = [torch.zeros(n_layer, 2, d_hid) for _ in range(2)]
optimizer.zero_grad()
logits, (h, c) = model(x, hc)
loss = loss_fn(logits, y.reshape(-1))
loss.backward()
optimizer.step()
print(loss.item())
y_pred = logits.argmax(dim=-1).reshape(2, -1)
print(x, "\n==>\n", y_pred)