Here I try to solve a Empirical Bayes Poisson Mean problem using VAE (detail described in https://www.overleaf.com/project/5bd084d90a33772e7a7f99a2 EBVAE.tex
).
The model is very simple and can only model the prior family of truncated normal
.
I need to run $10000$ iteraions to get a relatively good result. (Before I only run $2000$ iterations and it is not enough).
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from matplotlib import pyplot as plt
import pdb
class Map(dict):
def __init__(self, *args, **kwargs):
super(Map, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(Map, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(Map, self).__delitem__(key)
del self.__dict__[key]
def sample_trunc_normal(mu, sd = 1):
x = np.random.normal(loc = mu, scale = sd, size = 1).item()
if x > 0:
return x
else:
return 0
def sample_poisson_trunc_normal(n, mu, sd, seed = 123):
np.random.seed(seed)
lam = [sample_trunc_normal(mu, sd) for i in range(n)]
print(len(lam))
x = np.random.poisson(lam = lam)
return {"lam":lam, "x":x}
args
class. ## simulate data
n = 5000
mu = 10
sd = 5
sim = sample_poisson_trunc_normal(n, mu,sd, seed = 123)
plt.hist(sim["x"], bins=100)
plt.ylabel("x")
plt.title("histogram for data x")
plt.show()
## model arguments
args = Map(batch_size = n, epochs = 10000, no_cuda = 1, seed = 123, log_interval= 100, L = 20)
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
data = torch.Tensor(sim["x"].reshape(-1,1))
train_loader = torch.utils.data.DataLoader(data,
batch_size=args.batch_size, shuffle=False, **kwargs)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
mid_len = 10
self.fc1 = nn.Linear(1, mid_len)
self.fc21 = nn.Linear(mid_len, 1)
self.fc22 = nn.Linear(mid_len, 1)
self.fc3 = nn.Linear(1, 1)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar, L = None):
if L is None: ## L is the number of samples we take to compute E_q ...
L = args.L
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std.unsqueeze(-1).expand(-1,L, -1)) ## [batch_size, args.L, 1]
return mu.unsqueeze(-1) + eps*std.unsqueeze(-1) ## [batch_size, args.L, 1]
def decode(self, z):
h3 = self.fc3(z).clamp(min = 1e-20)
return h3 ## [batch_size, args.L, 1]
def forward(self, x, L = None):
mu, logvar = self.encode(x.view(-1, 1)) ## mu, logvar: [batch_size, 1]
z = self.reparameterize(mu, logvar, L) ## z: [batch_size, args.L, 1]
theta = self.decode(z) ## theta: [batch_size, args.L, 1]
return theta, mu, logvar
# Expected (under q) Poisson loss + KL divergence losses summed over all elements and batch
def loss_function(theta, x, mu, logvar):
poisson_ll = (- theta + x.view(-1,1,1) * torch.log(theta)).sum()/args.L
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
neg_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return - poisson_ll + neg_KLD
def train(epoch):
model.train()
train_loss = 0
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
theta_batch, mu, logvar = model.forward(data)
loss = loss_function(theta_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if epoch % args.log_interval == 0:
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
return data
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
data = train(epoch)
Let's see how our posterior mean of $\lambda$ compare with truth
lam, mu, logvar = model.forward(data, L = 1000)
plt.scatter(sim["x"], lam.mean(dim = 1)[:,0].tolist())
plt.xlabel("x")
plt.ylabel("lam_posterior")
plt.show()
plt.scatter(sim["lam"], lam.mean(dim = 1)[:,0].tolist())
plt.xlabel("lam")
plt.ylabel("lam_posterior")
plt.show()
Let's take a look at $g_{\beta}(\theta)$.
We sample from $g_{\beta}(\theta)$ this way:
Sample $z_i \sim p0(.) = N(.;0,1)$
Let $\theta_i = f_{\beta}(z_i)$
Compare with the true $g(theta)$, our model has too many weights on 0s. But the mean seems right.
N = len(sim["lam"])
z = torch.randn(N,1)
theta = model.decode(z)
bins = np.linspace(0, 25, 100)
plt.hist(theta[:,0].tolist(), bins, density = True, alpha=0.5, label='g_beta')
plt.hist(sim["lam"], bins, density = True, alpha=0.5, label='g_true')
plt.legend(loc='upper right')
plt.show()
Save result for comparison with ebpm
.
import pickle
out = {"x":sim["x"], "lam":np.array(sim["lam"]),"prior_vae" : np.array(theta[:,0].tolist()),"posterior_vae":np.array(lam.mean(dim = 1)[:,0].tolist())}
with open("../data/poisson-normal.pkl", "wb") as f:
pickle.dump(out, f)