Stein Variational Gradient Descent¶
This notebook showcases how to use Stein Variational Gradient Descent to sample from the Banana function introduced in the paper Relativistic Monte Carlo.
In [1]:
%matplotlib inline
import sys
import os
sys.path.insert(0, os.path.join(os.path.abspath("."), "..", "..", ".."))
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from pysgmcmc.samplers.svgd import SVGDSampler
from pysgmcmc.diagnostics.objective_functions import (
banana_log_likelihood,
)
from collections import namedtuple
ObjectiveFunction = namedtuple(
"ObjectiveFunction", ["function", "dimensionality"]
)
objective_functions = (
ObjectiveFunction(
function=banana_log_likelihood, dimensionality=2
),
)
def cost_function(log_likelihood_function):
def wrapped(*args, **kwargs):
return -log_likelihood_function(*args, **kwargs)
wrapped.__name__ = log_likelihood_function.__name__
return wrapped
# Banana Contour {{{ #
def banana_plot():
x = np.arange(-25, 25, 0.05)
y = np.arange(-50, 20, 0.05)
xx, yy = np.meshgrid(x, y, sparse=True)
densities = np.asarray([np.exp(banana_log_likelihood((x, y))) for x in xx for y in yy])
f, ax = plt.subplots(1)
xdata = [1, 4, 8]
ydata = [10, 20, 30]
ax.contour(x, y, densities, 1, label="Banana")
ax.plot([], [], label="Banana")
ax.legend()
ax.grid()
ax.set_ylim(ymin=-60, ymax=20)
ax.set_xlim(xmin=-30, xmax=30)
# }}} Banana Contour #
plot_functions = {
"banana_log_likelihood": banana_plot,
}
def extract_samples(sampler, n_samples=1000, keep_every=10):
from itertools import islice
n_iterations = n_samples * keep_every
return np.asarray(
[np.mean(sample, axis=0) for sample, _ in
islice(sampler, 0, n_iterations, keep_every)]
)
def plot_samples(sampler, n_samples=5000, keep_every=10):
samples = extract_samples(
sampler, n_samples=n_samples, keep_every=keep_every
)
plot_functions[sampler.cost_fun.__name__]()
first_sample = samples[0]
try:
sample_dimensionality, = first_sample.shape
except ValueError:
plt.scatter(samples, np.exp([-sampler.cost_fun(sample) for sample in samples]))
else:
plt.scatter(*[samples[:, i] for i in range(sample_dimensionality)])
n_particles = 10
graph = tf.Graph()
for function, dimensionality in objective_functions:
tf.reset_default_graph()
graph = tf.Graph()
with tf.Session(graph=graph) as session:
particles = [
tf.get_variable("particle_{}".format(i), (dimensionality,), initializer=tf.random_normal_initializer())
for i in range(n_particles)
]
sampler = SVGDSampler(
particles=particles,
cost_fun=cost_function(function),
session=session,
dtype=tf.float32
)
plot_samples(sampler, n_samples=5000)
plt.show()