GAN(生成对抗网络)是一种机器学习模型,由两个网络组成:生成器和判别器。这两个网络相互对抗,使得生成器能够生成逼真的数据,而判别器则试图区分生成的数据与真实数据。GAN 的核心思想是通过竞争训练两个网络,从而使得生成器生成的数据越来越接近真实数据,同时使得判别器更难以区分生成的数据和真实数据。这种竞争过程最终会收敛到一个平衡点,生成器生成的数据无法被判别器区分出来。

import torch 
import torch.nn as nn

class D(nn.Module):
    def __init__(self, d_in, d_hid):
        super(D, self).__init__()

        self.clf = nn.Sequential(
            nn.Linear(d_in, d_hid) , nn.LeakyReLU(0.2),
            nn.Linear(d_hid, d_hid), nn.LeakyReLU(0.2),
            nn.Linear(d_hid, 1), nn.Sigmoid())
    
    def forward(self, x):
        logits = self.clf(x)
        return logits

class G(nn.Module):
    def __init__(self, d_in, d_hid, d_out) -> None:
        super(G, self).__init__()

        self.gen = nn.Sequential(
            nn.Linear(d_in , d_hid), nn.ReLU(),
            nn.Linear(d_hid, d_hid), nn.ReLU(),
            nn.Linear(d_hid, d_out), nn.Tanh())
        
    def forward(self, x):
        fake = self.gen(x)
        return fake
    

def train_D(netD, x_real, x_fake, y):
    netD.zero_grad()
    y_pred = netD(x_real)
    l_real = loss_fn(y_pred, y)
    # l_real.backward()
    y_pred = netD(x_fake.detach())
    l_fake = loss_fn(y_pred, y*0)
    # l_fake.backward()
    loss = l_real + l_fake
    loss.backward()
    optimD.step()
    print(l_real.item(), l_fake.item())


def train_G(netG, x_fake, y):
    netG.zero_grad()
    y_fake = netD(x_fake)
    loss = loss_fn(y_fake, y)
    loss.backward()
    optimG.step()
    print(loss.item())

if __name__ == "__main__":
    d_in, d_hid = 16, 32
    netD = D(d_in, d_hid)
    optimD = torch.optim.SGD(netD.parameters(), lr=0.1)

    d_in, d_hid, d_out = 4, 8, 16
    netG = G(d_in, d_hid, d_out)
    optimG = torch.optim.SGD(netG.parameters(), lr=0.1)
    loss_fn = nn.BCELoss()

    # Training
    noise = torch.randn(2, d_in)
    y = torch.ones(2, 1) # 0: fake, 1: real
    x_real = torch.randn(2, 16)
    x_fake = netG(noise)
    print(x_real.shape, x_fake.shape)

    train_D(netD, x_real, x_fake, y)
    train_G(netG, x_fake, y)