pi-vae-pytorch

A Pytorch implementation of Poisson Identifiable VAE (pi-VAE).

View the Project on GitHub

Poisson Identifiable VAE (pi-VAE) 2.0

This is a Pytorch implementation of Poisson Identifiable Variational Autoencoder (pi-VAE), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).

A special thank you to Zhongxuan Wu who helped in the design and testing of this implementation.

Code status

GitHub Actions Test Status GitHub Actions Publish Status Python Version License

PyPI status

PyPI - Version PyPI - Downloads PyPI - Wheel

conda-forge status

Conda - Version Conda - Downloads Conda - Platform

Model Versions

pi-VAE 1.0 and 2.0 differ solely in their loss function, specifically how the Kullback–Leibler divergence component of the loss is computed. Additional information is available in the Loss Function - ELBOLoss section of this documentation.

Version 2.0

Version 1.0

Installation

It is possible to install this project using pip:

pip install pi-vae-pytorch

or conda, using the conda-forge channel:

conda install -c conda-forge pi-vae-pytorch

It is also possible to clone this repo and install it using pip:

git clone https://github.com/mmcinnestaylor/pi-vae-pytorch.git
cd pi-vae-pytorch
pip install -e .

Model Architecture

pi-VAE is comprised of three main components: the encoder, the label prior estimator, and the decoder.

MLP Structure

The Multi Layer Perceptron (MLP) is the primary building block of the aforementioned components. Each MLP used in this implementation is configurable by specifying the appropriate parameters when PiVAE is initialized:

Encoder

The model’s encoder is comprised of a single MLP, which learns to approximate the distribution q(z | x).

Label Prior Estimator

The model’s label prior estimator learns to approximate the distribution p(z | u). In the discrete label regime this module is comprised of two nn.Embedding submodules, while in the continuous label regime the module is comprised of a single MLP.

Decoder

The model’s decoder learns to map a latent sample z to its predicted firing rate in the model’s observation space. Inputs to the decoder are passed through the following submodules:

Initialization

pi_vae_pytorch.PiVAE(
    x_dim,
    u_dim,
    z_dim,
    discrete_labels=True,
    encoder_n_hidden_layers=2,
    encoder_hidden_layer_dim=128,
    encoder_hidden_layer_activation=nn.Tanh,
    decoder_n_gin_blocks=2,
    decoder_gin_block_depth=2,
    decoder_affine_input_layer_slice_dim=None,
    decoder_affine_n_hidden_layers=2,
    decoder_affine_hidden_layer_dim=None,
    decoder_affine_hidden_layer_activation=nn.ReLU,
    decoder_nflow_n_hidden_layers=2,
    decoder_nflow_hidden_layer_dim=None,
    decoder_nflow_hidden_layer_activation=nn.ReLU,
    decoder_observation_model='poisson',
    decoder_fr_clamp_min=1E-7,
    decoder_fr_clamp_max=1E7,
    label_prior_n_hidden_layers=2,
    label_prior_hidden_layer_dim=32,
    label_prior_hidden_layer_activation=nn.Tanh)

Attributes

Basic operation

For every observation space sample x and associated label u provided to pi-VAE’s forward method, the encoder and label statistics (mean & log of variance) are obtained from the encoder and label prior modules. These values are used to obtain the same statistics from the posterior q(z | x,u).

The reparameterization trick is performed with the resulting mean & log of variance to obtain the sample’s representation in the model’s latent space. This latent representation is then passed through the model’s decoder module, which generates the predicted firing rate in the model’s observation space.

Inputs

Outputs

A dict with the following items:

Inference Mode

A dict with the following items:

Examples

Continuous Labels

import torch
from pi_vae_pytorch import PiVAE

model = PiVAE(
    x_dim = 100,
    u_dim = 3,
    z_dim = 2,
    discrete_labels=False
)

x = torch.randn(1, 100) # Size([n_samples, x_dim])

u = torch.randn(1, 3) # Size([n_samples, u_dim])

outputs = model(x, u) # dict

Discrete Labels

import torch
from pi_vae_pytorch import PiVAE

model = PiVAE(
    x_dim = 100,
    u_dim = 3,
    z_dim = 2,
    discrete_labels=True
)

x = torch.randn(1, 100) # Size([n_samples, x_dim])

u = torch.randint(u_dim, (1,)) # Size([n_samples])

outputs = model(x, u) # dict

Class Methods

Static Methods

Loss Function - ELBOLoss

pi-VAE learns the deep generative model and the approximate posterior q(z | x, u) of the true posterior p(z | x, u) by maximizing the evidence lower bound (ELBO) of p(x | u). This loss function is implemented in the included ELBOLoss class.

Initialization

pi_vae_pytorch.ELBOLoss(
    version=2,
    alpha=0.5,
    observation_model='poisson',
    device=None)

Inputs

Outputs

Static Methods

Examples

Poisson observation model

from pi_vae_pytorch import ELBOLoss

loss_fn = ELBOLoss()

outputs = model(x, u) # Initialized with decoder_observation_model='poisson'

loss = loss_fn(
    x=x,
    posterior_firing_rate=outputs['posterior_firing_rate'],
    posterior_mean=outputs['posterior_mean'],
    posterior_log_variance=outputs['posterior_log_variance'],
    label_mean=outputs['label_mean'],
    label_log_variance=outputs['label_log_variance'],
    encoder_mean=outputs['encoder_mean'],
    encoder_log_variance=outputs['encoder_log_variance']
)

loss.backward()

Gaussian observation model

import torch
from pi_vae_pytorch import ELBOLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device) # Initialized with decoder_observation_model='gaussian'

loss_fn = ELBOLoss(observation_model='gaussian', device=device)

outputs = model(x, u) 

loss = loss_fn(
    x=x,
    posterior_firing_rate=outputs['posterior_firing_rate'],
    posterior_mean=outputs['posterior_mean'],
    posterior_log_variance=outputs['posterior_log_variance'],
    label_mean=outputs['label_mean'],
    label_log_variance=outputs['label_log_variance'],
    encoder_mean=outputs['encoder_mean'],
    encoder_log_variance=outputs['encoder_log_variance']
    observation_noise_model=model.observation_noise_model
)

loss.backward()

Citation

@misc{zhou2020learning,
    title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE}, 
    author={Ding Zhou and Xue-Xin Wei},
    year={2020},
    eprint={2011.04798},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}