Stein Variational Gradient Descent

This notebook showcases how to use Stein Variational Gradient Descent to sample from the Banana function introduced in the paper Relativistic Monte Carlo.

In [1]:
%matplotlib inline
import sys
import os
sys.path.insert(0, os.path.join(os.path.abspath("."), "..", "..", ".."))

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from pysgmcmc.samplers.svgd import SVGDSampler

from pysgmcmc.diagnostics.objective_functions import (
    banana_log_likelihood,
)

from collections import namedtuple

ObjectiveFunction = namedtuple(
    "ObjectiveFunction", ["function", "dimensionality"]
)

objective_functions = (
    ObjectiveFunction(
        function=banana_log_likelihood, dimensionality=2
    ),
)


def cost_function(log_likelihood_function):
    def wrapped(*args, **kwargs):
        return -log_likelihood_function(*args, **kwargs)
    wrapped.__name__ = log_likelihood_function.__name__
    return wrapped

#  Banana Contour {{{ #
def banana_plot():
    x = np.arange(-25, 25, 0.05)
    y = np.arange(-50, 20, 0.05)
    xx, yy = np.meshgrid(x, y, sparse=True)
    densities = np.asarray([np.exp(banana_log_likelihood((x, y))) for x in xx for y in yy])
    f, ax = plt.subplots(1)
    xdata = [1, 4, 8]
    ydata = [10, 20, 30]
    ax.contour(x, y, densities, 1, label="Banana")
    ax.plot([], [], label="Banana")
    ax.legend()
    ax.grid()
    ax.set_ylim(ymin=-60, ymax=20)
    ax.set_xlim(xmin=-30, xmax=30)

#  }}} Banana Contour #


plot_functions = {
    "banana_log_likelihood": banana_plot,
}


def extract_samples(sampler, n_samples=1000, keep_every=10):
    from itertools import islice
    n_iterations = n_samples * keep_every
    return np.asarray(
        [np.mean(sample, axis=0) for sample, _ in
         islice(sampler, 0, n_iterations, keep_every)]
    )

def plot_samples(sampler, n_samples=5000, keep_every=10):
    samples = extract_samples(
        sampler, n_samples=n_samples, keep_every=keep_every
    )
    plot_functions[sampler.cost_fun.__name__]()

    first_sample = samples[0]
    try:
        sample_dimensionality, = first_sample.shape
    except ValueError:
        plt.scatter(samples, np.exp([-sampler.cost_fun(sample) for sample in samples]))
    else:
        plt.scatter(*[samples[:, i] for i in range(sample_dimensionality)])


n_particles = 10

graph = tf.Graph()

for function, dimensionality in objective_functions:
    tf.reset_default_graph()
    graph = tf.Graph()

    with tf.Session(graph=graph) as session:
        particles = [
            tf.get_variable("particle_{}".format(i), (dimensionality,), initializer=tf.random_normal_initializer())
            for i in range(n_particles)
        ]
        sampler = SVGDSampler(
            particles=particles,
            cost_fun=cost_function(function),
            session=session,
            dtype=tf.float32
        )
        plot_samples(sampler, n_samples=5000)
        plt.show()
../_images/notebooks_SVGD_1_0.png