GAN
GAN(生成对抗网络)是一种机器学习模型,由两个网络组成:生成器和判别器。这两个网络相互对抗,使得生成器能够生成逼真的数据,而判别器则试图区分生成的数据与真实数据。GAN 的核心思想是通过竞争训练两个网络,从而使得生成器生成的数据越来越接近真实数据,同时使得判别器更难以区分生成的数据和真实数据。这种竞争过程最终会收敛到一个平衡点,生成器生成的数据无法被判别器区分出来。
import torch
import torch.nn as nn
class D(nn.Module):
def __init__(self, d_in, d_hid):
super(D, self).__init__()
self.clf = nn.Sequential(
nn.Linear(d_in, d_hid) , nn.LeakyReLU(0.2),
nn.Linear(d_hid, d_hid), nn.LeakyReLU(0.2),
nn.Linear(d_hid, 1), nn.Sigmoid())
def forward(self, x):
logits = self.clf(x)
return logits
class G(nn.Module):
def __init__(self, d_in, d_hid, d_out) -> None:
super(G, self).__init__()
self.gen = nn.Sequential(
nn.Linear(d_in , d_hid), nn.ReLU(),
nn.Linear(d_hid, d_hid), nn.ReLU(),
nn.Linear(d_hid, d_out), nn.Tanh())
def forward(self, x):
fake = self.gen(x)
return fake
def train_D(netD, x_real, x_fake, y):
netD.zero_grad()
y_pred = netD(x_real)
l_real = loss_fn(y_pred, y)
# l_real.backward()
y_pred = netD(x_fake.detach())
l_fake = loss_fn(y_pred, y*0)
# l_fake.backward()
loss = l_real + l_fake
loss.backward()
optimD.step()
print(l_real.item(), l_fake.item())
def train_G(netG, x_fake, y):
netG.zero_grad()
y_fake = netD(x_fake)
loss = loss_fn(y_fake, y)
loss.backward()
optimG.step()
print(loss.item())
if __name__ == "__main__":
d_in, d_hid = 16, 32
netD = D(d_in, d_hid)
optimD = torch.optim.SGD(netD.parameters(), lr=0.1)
d_in, d_hid, d_out = 4, 8, 16
netG = G(d_in, d_hid, d_out)
optimG = torch.optim.SGD(netG.parameters(), lr=0.1)
loss_fn = nn.BCELoss()
# Training
noise = torch.randn(2, d_in)
y = torch.ones(2, 1) # 0: fake, 1: real
x_real = torch.randn(2, 16)
x_fake = netG(noise)
print(x_real.shape, x_fake.shape)
train_D(netD, x_real, x_fake, y)
train_G(netG, x_fake, y)