LSTM(Long Short-Term Memory,长短期记忆)是一种循环神经网络(RNN)的变体,专门设计用来解决传统 RNN 在处理长序列数据时的梯度消失或梯度爆炸等问题。
LSTM 通过引入一种称为“门”的机制来控制信息的流动,从而更有效地处理长序列数据并捕获序列中的长期依赖关系。这些门是一种学习的机制,允许 LSTM 在每个时间步骤上选择性地记住、遗忘或输出信息。

#=> input x

import torch 
import torch.nn as nn 
from torch.optim import SGD

class RNNLM(nn.Module):
    def __init__(self, vocab_size, d_emb, d_hid, n_layer):
        super(RNNLM, self).__init__()
        self.emb = nn.Embedding(vocab_size, d_emb)
        self.lstm = nn.LSTM(d_emb, d_hid, n_layer, batch_first=True)
        self.clf = nn.Linear(d_hid, vocab_size)
        
    def forward(self, x, hc=None):
        # word2vec
        x = self.emb(x) # (bsz, seq_len, d_emb)
        bsz, seq_len, d_emb = x.shape
        
        # hs: (bsz    , seq_len, d_hid)
        # h : (n_layer, bsz    , d_hid)
        # hs[:, -1] = h[-1]
        hs, (h, c) = self.lstm(x, hc)
        
        hs = hs.reshape(bsz * seq_len, hs.size(2))
        logits = self.clf(hs)
        return logits, (h, c)
    
if __name__ == "__main__":
    # config
    vocab_size, d_emb, d_hid, n_layer = 16, 8, 4, 2
    lr, bsz = 1e-4, 2
    # data
    x = torch.Tensor([[1, 2, 3], [7, 8, 9]]).long()
    y = torch.Tensor([[2, 3, 4], [8, 9, 10]]).long()
    # model
    model = RNNLM(vocab_size, d_emb, d_hid, n_layer)
    optimizer = SGD(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    # training
    hc = [torch.zeros(n_layer, 2, d_hid) for _ in range(2)]
    optimizer.zero_grad()
    logits, (h, c) = model(x, hc)
    loss = loss_fn(logits, y.reshape(-1))
    loss.backward()
    optimizer.step()
    print(loss.item())
    y_pred = logits.argmax(dim=-1).reshape(2, -1)
    print(x, "\n==>\n", y_pred)