LSTM May 2, 2024 LSTM(Long Short-Term Memory,长短期记忆)是一种循环神经网络(RNN)的变体,专门设计用来解决传统 RNN 在处理长序列数据时的梯度消失或梯度爆炸等问题。 LSTM 通过引入一种称为“门”的机制来控制信息的流动,从而更有效地处理长序列数据并捕获序列中的长期依赖关系。这些门是一种学习的机制,允许 LSTM 在每个时间步骤上选择性地记住、遗忘或输出信息。 #=> input x import torch import torch.nn as nn from torch.optim import SGD class RNNLM(nn.Module): def __init__(self, vocab_size, d_emb, d_hid, n_layer): super(RNNLM, self).__init__() self.emb = nn.Embedding(vocab_size, d_emb) self.lstm = nn.LSTM(d_emb, d_hid, n_layer, batch_first=True) self.clf = nn.Linear(d_hid, vocab_size) def forward(self, x, hc=None): # word2vec x = self.emb(x) # (bsz, seq_len, d_emb) bsz, seq_len, d_emb = x.shape # hs: (bsz , seq_len, d_hid) # h : (n_layer, bsz , d_hid) # hs[:, -1] = h[-1] hs, (h, c) = self.lstm(x, hc) hs = hs.reshape(bsz * seq_len, hs.size(2)) logits = self.clf(hs) return logits, (h, c) if __name__ == "__main__": # config vocab_size, d_emb, d_hid, n_layer = 16, 8, 4, 2 lr, bsz = 1e-4, 2 # data x = torch.Tensor([[1, 2, 3], [7, 8, 9]]).long() y = torch.Tensor([[2, 3, 4], [8, 9, 10]]).long() # model model = RNNLM(vocab_size, d_emb, d_hid, n_layer) optimizer = SGD(model.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() # training hc = [torch.zeros(n_layer, 2, d_hid) for _ in range(2)] optimizer.zero_grad() logits, (h, c) = model(x, hc) loss = loss_fn(logits, y.reshape(-1)) loss.backward() optimizer.step() print(loss.item()) y_pred = logits.argmax(dim=-1).reshape(2, -1) print(x, "\n==>\n", y_pred)