双向长短期记忆网络(Bidirectional Long Short-Term Memory,Bi-LSTM)是 LSTM 的一个扩展版本,它在处理序列数据时同时考虑了过去和未来的信息。与传统的单向 LSTM 只考虑从过去到未来的信息流不同,Bi-LSTM 通过引入两个独立的 LSTM 层,分别从序列的前向和后向进行处理,从而能够捕捉到更全面的序列信息。在 Bi-LSTM 中,输入序列首先从前向 LSTM 层传递,然后再从后向 LSTM 层传递。最终,两个方向上的隐藏状态会被联合起来,形成整个序列的表示。

#=> input x

import torch
import torch.nn as nn
from torch.optim import Adam

class GRU(nn.Module):
    def __init__(self, d_in, d_hid, n_layer, n_class):
        super(GRU, self).__init__()
        self.enc = nn.GRU(d_in, d_hid, n_layer, batch_first=True)
        self.dec = nn.Linear(d_hid, n_class)
    
    def forward(self, x, h0):
        hs, h = self.enc(x, h0)  # (bsz, seq_len, d_hid)
        logits = self.dec(hs[:, -1, :])
        return logits

if __name__ == "__main__":
    # config
    d_in, d_hid, n_layer, n_class = 28, 4, 2, 10
    bsz, lr = 2, 1e-4
    # data
    x = torch.randn(bsz, 1, 28, 28) # img
    y = torch.Tensor([1, 2]).long()
    # model
    model = GRU(d_in, d_hid, n_layer, n_class)
    optimizer = Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    # training
    h0 = torch.zeros(n_layer, bsz, d_hid)
    optimizer.zero_grad()
    logits = model(x.view(2, 28, 28), h0)
    loss = loss_fn(logits, y)
    loss.backward()
    optimizer.step()
    print(loss.item())