GAN May 2, 2024 GAN(生成对抗网络)是一种机器学习模型,由两个网络组成:生成器和判别器。这两个网络相互对抗,使得生成器能够生成逼真的数据,而判别器则试图区分生成的数据与真实数据。GAN 的核心思想是通过竞争训练两个网络,从而使得生成器生成的数据越来越接近真实数据,同时使得判别器更难以区分生成的数据和真实数据。这种竞争过程最终会收敛到一个平衡点,生成器生成的数据无法被判别器区分出来。 import torch import torch.nn as nn class D(nn.Module): def __init__(self, d_in, d_hid): super(D, self).__init__() self.clf = nn.Sequential( nn.Linear(d_in, d_hid) , nn.LeakyReLU(0.2), nn.Linear(d_hid, d_hid), nn.LeakyReLU(0.2), nn.Linear(d_hid, 1), nn.Sigmoid()) def forward(self, x): logits = self.clf(x) return logits class G(nn.Module): def __init__(self, d_in, d_hid, d_out) -> None: super(G, self).__init__() self.gen = nn.Sequential( nn.Linear(d_in , d_hid), nn.ReLU(), nn.Linear(d_hid, d_hid), nn.ReLU(), nn.Linear(d_hid, d_out), nn.Tanh()) def forward(self, x): fake = self.gen(x) return fake def train_D(netD, x_real, x_fake, y): netD.zero_grad() y_pred = netD(x_real) l_real = loss_fn(y_pred, y) # l_real.backward() y_pred = netD(x_fake.detach()) l_fake = loss_fn(y_pred, y*0) # l_fake.backward() loss = l_real + l_fake loss.backward() optimD.step() print(l_real.item(), l_fake.item()) def train_G(netG, x_fake, y): netG.zero_grad() y_fake = netD(x_fake) loss = loss_fn(y_fake, y) loss.backward() optimG.step() print(loss.item()) if __name__ == "__main__": d_in, d_hid = 16, 32 netD = D(d_in, d_hid) optimD = torch.optim.SGD(netD.parameters(), lr=0.1) d_in, d_hid, d_out = 4, 8, 16 netG = G(d_in, d_hid, d_out) optimG = torch.optim.SGD(netG.parameters(), lr=0.1) loss_fn = nn.BCELoss() # Training noise = torch.randn(2, d_in) y = torch.ones(2, 1) # 0: fake, 1: real x_real = torch.randn(2, 16) x_fake = netG(noise) print(x_real.shape, x_fake.shape) train_D(netD, x_real, x_fake, y) train_G(netG, x_fake, y)