MLP
是一种基本的人工神经网络(ANN)架构。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, d_in, d_hid, n_class):
super(MLP, self).__init__()
self.fc1 = nn.Linear(d_in, d_hid)
self.fc2 = nn.Linear(d_hid, n_class)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
if __name__ == "__main__":
d_in, d_hid, n_class = 28*28, 128, 10
x = torch.randn(2, 1, 28, 28)
model = MLP(d_in, d_hid, n_class)
logits = model(x)
print(logits.shape) # (2, 10)