VAE(Variational Autoencoder,变分自动编码器)是一种生成模型,它结合了自动编码器(Autoencoder)和变分推断(Variational Inference)的思想。VAE 包括两个主要部分:编码器(Encoder)和解码器(Decoder)。编码器将输入数据映射到潜在空间(latent space)的变量,而解码器将潜在空间的变量映射回原始数据空间。与传统的自动编码器不同,VAE 中潜在空间的变量要学习数据的分布。其大致过程如下:

#=> input x

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F

class VAE(nn.Module):
    def __init__(self, d_in, d_hid, d_mu):
        super(VAE, self).__init__()
        self.enc = nn.Sequential(
            nn.Linear(d_in, d_hid),
            nn.ReLU())
        
        self.dec = nn.Sequential(
            nn.Linear(d_mu, d_hid),
            nn.ReLU(),
            nn.Linear(d_hid, d_in),
            nn.Sigmoid())
        
        self.mu = nn.Linear(d_hid, d_mu)
        self.var = nn.Linear(d_hid, d_mu)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.enc(x)
        mu, logvar = self.mu(x), self.var(x)
        z = self.reparameterize(mu, logvar)
        return self.dec(z), mu, logvar
    
def loss_fn(x_rec, x, mu, logvar):
    BCE = F.binary_cross_entropy(x_rec, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

if __name__ == "__main__":
    # config
    d_in, d_hid, d_mu, lr = 784, 128, 16, 1e-4
    # data
    x = (torch.randn(2, 1, 28, 28) > 0) * 1.0 # img
    # model
    model = VAE(d_in, d_hid, d_mu)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    x_rec, mu, logvar = model(x)
    loss = loss_fn(x_rec, x, mu, logvar)
    print(loss.item())