是一种基本的人工神经网络(ANN)架构。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, d_in, d_hid, n_class):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(d_in, d_hid)
        self.fc2 = nn.Linear(d_hid, n_class)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
if __name__ == "__main__":
    d_in, d_hid, n_class = 28*28, 128, 10
    x = torch.randn(2, 1, 28, 28)
    model = MLP(d_in, d_hid, n_class)
    logits = model(x)
    print(logits.shape) # (2, 10)