RNN(Recurrent Neural Network,循环神经网络)是一类用于处理序列数据的神经网络结构。与传统的前馈神经网络(Feedforward Neural Networks)不同,RNN 具有循环连接。RNN 的每个时间步都会接收输入和前一个时间步的隐藏状态,并输出当前时间步的隐藏状态。这种结构使得 RNN 能够捕捉序列数据中的时间相关性。

#=> input x

import torch
import torch.nn as nn
from torch.optim import Adam

class RNN(nn.Module):
    def __init__(self, d_emb, d_hid, vocab_size, n_layer, dropout=0.5):
        super(RNN, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.emb = nn.Embedding(vocab_size, d_emb)
        self.enc = nn.RNN(d_emb, d_hid, n_layer, dropout=dropout)
        self.dec = nn.Linear(d_hid, vocab_size)

        self.init_weights()

    def init_weights(self):
        a = 0.1
        nn.init.uniform_(self.emb.weight, -a, a)
        nn.init.uniform_(self.dec.weight, -a, a)
        nn.init.zeros_(self.dec.bias)

    def forward(self, x, hid):
        emb = self.drop(self.emb(x))
        seq_len, bsz, d_emb = emb.shape
        # hs: (seq_len, bsz, d_hid)
        # h : (n_layer, bsz, d_hid)
        # hs[-1] = h[-1]
        hs, h = self.enc(emb, hid)
        hs = self.drop(hs)
        logits = self.dec(hs)
        logits = logits.view(seq_len * bsz, -1)
        return logits, h


if __name__ == "__main__":
    # config
    d_emb, d_hid, vocab_size, n_layer = 8, 4, 16, 5
    bsz, lr = 2, 1e-4
    # data
    x = torch.Tensor([[1, 7],
                      [2, 8],
                      [3, 9]]).long()
    y = torch.Tensor([[2, 8],
                      [3, 9],
                      [4, 10]]).long()
    # model
    model = RNN(d_emb, d_hid, vocab_size, n_layer)
    optimizer = Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    # training
    h = torch.zeros(n_layer, bsz, d_hid)
    optimizer.zero_grad()
    logits, h = model(x, h)
    loss = loss_fn(logits, y.reshape(-1))
    loss.backward()
    optimizer.step()
    print(loss.item())