MLP May 1, 2024 是一种基本的人工神经网络(ANN)架构。 import torch import torch.nn as nn import torch.nn.functional as F class MLP(nn.Module): def __init__(self, d_in, d_hid, n_class): super(MLP, self).__init__() self.fc1 = nn.Linear(d_in, d_hid) self.fc2 = nn.Linear(d_hid, n_class) self.dropout = nn.Dropout(0.5) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) return x if __name__ == "__main__": d_in, d_hid, n_class = 28*28, 128, 10 x = torch.randn(2, 1, 28, 28) model = MLP(d_in, d_hid, n_class) logits = model(x) print(logits.shape) # (2, 10)