Here I try to solve a Empirical Bayes Poisson Mean problem using VAE (detail described in https://www.overleaf.com/project/5bd084d90a33772e7a7f99a2 EBVAE.tex).

The model is very simple and can only model the prior family of truncated normal.

I need to run $10000$ iteraions to get a relatively good result. (Before I only run $2000$ iterations and it is not enough).

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import numpy as np
from matplotlib import pyplot as plt

import pdb
In [2]:
class Map(dict):
    def __init__(self, *args, **kwargs):
        super(Map, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v

        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(Map, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(Map, self).__delitem__(key)
        del self.__dict__[key]
In [3]:
def sample_trunc_normal(mu, sd = 1):
    x = np.random.normal(loc = mu, scale = sd, size = 1).item()
    if x > 0:
        return x
    else:
        return 0
    
def sample_poisson_trunc_normal(n, mu, sd, seed = 123):
    np.random.seed(seed)
    lam = [sample_trunc_normal(mu, sd) for i in range(n)]
    print(len(lam))
    x = np.random.poisson(lam = lam)
    return {"lam":lam, "x":x}
  • Simulate data: $\lambda$s are from truncated normal; †hen $x \sim Pois(\lambda)$. I plot the histogram of data below.
  • Model: specified in the args class.
In [4]:
## simulate data
n = 5000
mu = 10
sd = 5
sim = sample_poisson_trunc_normal(n, mu,sd, seed = 123)
5000
In [5]:
plt.hist(sim["x"], bins=100)
plt.ylabel("x")
plt.title("histogram for data x")
plt.show()
In [6]:
## model arguments
args = Map(batch_size = n, epochs = 10000, no_cuda = 1, seed = 123, log_interval= 100, L = 20)
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

data = torch.Tensor(sim["x"].reshape(-1,1))
train_loader = torch.utils.data.DataLoader(data, 
                                           batch_size=args.batch_size, shuffle=False, **kwargs)
In [7]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        mid_len = 10
        self.fc1 = nn.Linear(1, mid_len)
        self.fc21 = nn.Linear(mid_len, 1)
        self.fc22 = nn.Linear(mid_len, 1)
        self.fc3 = nn.Linear(1, 1)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparameterize(self, mu, logvar, L = None):
        if L is None: ## L is the number of samples we take to compute E_q ...
            L = args.L
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std.unsqueeze(-1).expand(-1,L, -1)) ## [batch_size, args.L, 1]
        return mu.unsqueeze(-1) + eps*std.unsqueeze(-1) ## [batch_size, args.L, 1]

    def decode(self, z):
        h3 = self.fc3(z).clamp(min = 1e-20)
        return h3 ## [batch_size, args.L, 1]

    def forward(self, x, L = None):
        mu, logvar = self.encode(x.view(-1, 1)) ## mu, logvar: [batch_size, 1]
        z = self.reparameterize(mu, logvar, L) ## z: [batch_size, args.L, 1]
        theta = self.decode(z) ## theta: [batch_size, args.L, 1]
        return theta, mu, logvar
    

# Expected (under q) Poisson loss + KL divergence losses summed over all elements and batch
def loss_function(theta, x, mu, logvar):    
    poisson_ll = (- theta + x.view(-1,1,1) * torch.log(theta)).sum()/args.L
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    neg_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return - poisson_ll + neg_KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        theta_batch, mu, logvar = model.forward(data)
        loss = loss_function(theta_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    if epoch % args.log_interval == 0:
        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, train_loss / len(train_loader.dataset)))
    
    return data

      
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
if __name__ == "__main__":
    for epoch in range(1, args.epochs + 1):
        data = train(epoch)
====> Epoch: 100 Average loss: -0.0074
====> Epoch: 200 Average loss: -1.7422
====> Epoch: 300 Average loss: -3.0664
====> Epoch: 400 Average loss: -4.1306
====> Epoch: 500 Average loss: -5.0390
====> Epoch: 600 Average loss: -5.8160
====> Epoch: 700 Average loss: -6.5002
====> Epoch: 800 Average loss: -7.1114
====> Epoch: 900 Average loss: -7.6601
====> Epoch: 1000 Average loss: -8.1486
====> Epoch: 1100 Average loss: -8.5819
====> Epoch: 1200 Average loss: -8.9880
====> Epoch: 1300 Average loss: -9.3480
====> Epoch: 1400 Average loss: -9.6796
====> Epoch: 1500 Average loss: -9.9907
====> Epoch: 1600 Average loss: -10.2608
====> Epoch: 1700 Average loss: -10.5398
====> Epoch: 1800 Average loss: -10.7647
====> Epoch: 1900 Average loss: -10.9893
====> Epoch: 2000 Average loss: -11.1824
====> Epoch: 2100 Average loss: -11.3477
====> Epoch: 2200 Average loss: -11.5563
====> Epoch: 2300 Average loss: -11.6966
====> Epoch: 2400 Average loss: -11.8118
====> Epoch: 2500 Average loss: -11.9771
====> Epoch: 2600 Average loss: -12.1144
====> Epoch: 2700 Average loss: -12.2207
====> Epoch: 2800 Average loss: -12.3347
====> Epoch: 2900 Average loss: -12.4470
====> Epoch: 3000 Average loss: -12.5348
====> Epoch: 3100 Average loss: -12.6192
====> Epoch: 3200 Average loss: -12.6994
====> Epoch: 3300 Average loss: -12.7650
====> Epoch: 3400 Average loss: -12.8402
====> Epoch: 3500 Average loss: -12.8978
====> Epoch: 3600 Average loss: -12.9593
====> Epoch: 3700 Average loss: -13.0075
====> Epoch: 3800 Average loss: -13.0148
====> Epoch: 3900 Average loss: -13.0733
====> Epoch: 4000 Average loss: -13.1366
====> Epoch: 4100 Average loss: -13.1594
====> Epoch: 4200 Average loss: -13.2000
====> Epoch: 4300 Average loss: -13.2481
====> Epoch: 4400 Average loss: -13.2602
====> Epoch: 4500 Average loss: -13.3010
====> Epoch: 4600 Average loss: -13.3461
====> Epoch: 4700 Average loss: -13.3708
====> Epoch: 4800 Average loss: -13.4186
====> Epoch: 4900 Average loss: -13.4256
====> Epoch: 5000 Average loss: -13.4209
====> Epoch: 5100 Average loss: -13.4588
====> Epoch: 5200 Average loss: -13.4928
====> Epoch: 5300 Average loss: -13.5179
====> Epoch: 5400 Average loss: -13.5233
====> Epoch: 5500 Average loss: -13.5515
====> Epoch: 5600 Average loss: -13.5651
====> Epoch: 5700 Average loss: -13.5830
====> Epoch: 5800 Average loss: -13.5922
====> Epoch: 5900 Average loss: -13.5917
====> Epoch: 6000 Average loss: -13.6155
====> Epoch: 6100 Average loss: -13.6376
====> Epoch: 6200 Average loss: -13.6479
====> Epoch: 6300 Average loss: -13.6474
====> Epoch: 6400 Average loss: -13.6627
====> Epoch: 6500 Average loss: -13.6651
====> Epoch: 6600 Average loss: -13.6763
====> Epoch: 6700 Average loss: -13.7074
====> Epoch: 6800 Average loss: -13.6985
====> Epoch: 6900 Average loss: -13.7036
====> Epoch: 7000 Average loss: -13.7178
====> Epoch: 7100 Average loss: -13.7113
====> Epoch: 7200 Average loss: -13.7311
====> Epoch: 7300 Average loss: -13.7408
====> Epoch: 7400 Average loss: -13.7633
====> Epoch: 7500 Average loss: -13.7543
====> Epoch: 7600 Average loss: -13.7683
====> Epoch: 7700 Average loss: -13.7718
====> Epoch: 7800 Average loss: -13.7906
====> Epoch: 7900 Average loss: -13.7877
====> Epoch: 8000 Average loss: -13.7883
====> Epoch: 8100 Average loss: -13.7846
====> Epoch: 8200 Average loss: -13.7998
====> Epoch: 8300 Average loss: -13.7658
====> Epoch: 8400 Average loss: -13.8039
====> Epoch: 8500 Average loss: -13.8108
====> Epoch: 8600 Average loss: -13.7875
====> Epoch: 8700 Average loss: -13.7951
====> Epoch: 8800 Average loss: -13.8043
====> Epoch: 8900 Average loss: -13.8078
====> Epoch: 9000 Average loss: -13.8250
====> Epoch: 9100 Average loss: -13.8243
====> Epoch: 9200 Average loss: -13.8233
====> Epoch: 9300 Average loss: -13.8242
====> Epoch: 9400 Average loss: -13.8263
====> Epoch: 9500 Average loss: -13.8329
====> Epoch: 9600 Average loss: -13.8161
====> Epoch: 9700 Average loss: -13.8499
====> Epoch: 9800 Average loss: -13.8256
====> Epoch: 9900 Average loss: -13.8614
====> Epoch: 10000 Average loss: -13.8578

Let's see how our posterior mean of $\lambda$ compare with truth

In [8]:
lam, mu, logvar = model.forward(data, L = 1000)
plt.scatter(sim["x"], lam.mean(dim = 1)[:,0].tolist())
plt.xlabel("x")
plt.ylabel("lam_posterior")
plt.show()
In [9]:
plt.scatter(sim["lam"], lam.mean(dim = 1)[:,0].tolist())
plt.xlabel("lam")
plt.ylabel("lam_posterior")
plt.show()

Let's take a look at $g_{\beta}(\theta)$.

We sample from $g_{\beta}(\theta)$ this way:

  • Sample $z_i \sim p0(.) = N(.;0,1)$

  • Let $\theta_i = f_{\beta}(z_i)$

Compare with the true $g(theta)$, our model has too many weights on 0s. But the mean seems right.

In [10]:
N = len(sim["lam"])
z = torch.randn(N,1)
theta = model.decode(z)

bins = np.linspace(0, 25, 100)
plt.hist(theta[:,0].tolist(), bins, density = True, alpha=0.5, label='g_beta')
plt.hist(sim["lam"], bins, density = True, alpha=0.5, label='g_true')
plt.legend(loc='upper right')
plt.show()

Save result for comparison with ebpm.

In [12]:
import pickle
out = {"x":sim["x"], "lam":np.array(sim["lam"]),"prior_vae" : np.array(theta[:,0].tolist()),"posterior_vae":np.array(lam.mean(dim = 1)[:,0].tolist())}
with open("../data/poisson-normal.pkl", "wb") as f:
    pickle.dump(out, f)