VAE May 2, 2024 VAE(Variational Autoencoder,变分自动编码器)是一种生成模型,它结合了自动编码器(Autoencoder)和变分推断(Variational Inference)的思想。VAE 包括两个主要部分:编码器(Encoder)和解码器(Decoder)。编码器将输入数据映射到潜在空间(latent space)的变量,而解码器将潜在空间的变量映射回原始数据空间。与传统的自动编码器不同,VAE 中潜在空间的变量要学习数据的分布。其大致过程如下: #=> input x import torch import torch.utils.data from torch import nn, optim from torch.nn import functional as F class VAE(nn.Module): def __init__(self, d_in, d_hid, d_mu): super(VAE, self).__init__() self.enc = nn.Sequential( nn.Linear(d_in, d_hid), nn.ReLU()) self.dec = nn.Sequential( nn.Linear(d_mu, d_hid), nn.ReLU(), nn.Linear(d_hid, d_in), nn.Sigmoid()) self.mu = nn.Linear(d_hid, d_mu) self.var = nn.Linear(d_hid, d_mu) def reparameterize(self, mu, logvar): std = torch.exp(0.5*logvar) eps = torch.randn_like(std) return mu + eps*std def forward(self, x): x = torch.flatten(x, start_dim=1) x = self.enc(x) mu, logvar = self.mu(x), self.var(x) z = self.reparameterize(mu, logvar) return self.dec(z), mu, logvar def loss_fn(x_rec, x, mu, logvar): BCE = F.binary_cross_entropy(x_rec, x.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD if __name__ == "__main__": # config d_in, d_hid, d_mu, lr = 784, 128, 16, 1e-4 # data x = (torch.randn(2, 1, 28, 28) > 0) * 1.0 # img # model model = VAE(d_in, d_hid, d_mu) optimizer = optim.Adam(model.parameters(), lr=lr) x_rec, mu, logvar = model(x) loss = loss_fn(x_rec, x, mu, logvar) print(loss.item())