VAE
VAE(Variational Autoencoder,变分自动编码器)是一种生成模型,它结合了自动编码器(Autoencoder)和变分推断(Variational Inference)的思想。VAE 包括两个主要部分:编码器(Encoder)和解码器(Decoder)。编码器将输入数据映射到潜在空间(latent space)的变量,而解码器将潜在空间的变量映射回原始数据空间。与传统的自动编码器不同,VAE 中潜在空间的变量要学习数据的分布。其大致过程如下:
#=> input x
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
class VAE(nn.Module):
def __init__(self, d_in, d_hid, d_mu):
super(VAE, self).__init__()
self.enc = nn.Sequential(
nn.Linear(d_in, d_hid),
nn.ReLU())
self.dec = nn.Sequential(
nn.Linear(d_mu, d_hid),
nn.ReLU(),
nn.Linear(d_hid, d_in),
nn.Sigmoid())
self.mu = nn.Linear(d_hid, d_mu)
self.var = nn.Linear(d_hid, d_mu)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = self.enc(x)
mu, logvar = self.mu(x), self.var(x)
z = self.reparameterize(mu, logvar)
return self.dec(z), mu, logvar
def loss_fn(x_rec, x, mu, logvar):
BCE = F.binary_cross_entropy(x_rec, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
if __name__ == "__main__":
# config
d_in, d_hid, d_mu, lr = 784, 128, 16, 1e-4
# data
x = (torch.randn(2, 1, 28, 28) > 0) * 1.0 # img
# model
model = VAE(d_in, d_hid, d_mu)
optimizer = optim.Adam(model.parameters(), lr=lr)
x_rec, mu, logvar = model(x)
loss = loss_fn(x_rec, x, mu, logvar)
print(loss.item())