Source code for pysgmcmc.optimizers.sgld
import torch
from torch.optim import Optimizer
# Pytorch Port of a previous tensorflow implementation in `tensorflow_probability`:
# https://github.com/tensorflow/probability/blob/master/tensorflow_probability/g3doc/api_docs/python/tfp/optimizer/StochasticGradientLangevinDynamics.md
[docs]class SGLD(Optimizer):
""" Stochastic Gradient Langevin Dynamics Sampler with preconditioning.
Optimization variable is viewed as a posterior sample under Stochastic
Gradient Langevin Dynamics with noise rescaled in eaach dimension
according to RMSProp.
"""
[docs] def __init__(self,
params,
lr=1e-2,
precondition_decay_rate=0.95,
num_pseudo_batches=1,
num_burn_in_steps=3000,
diagonal_bias=1e-8) -> None:
""" Set up a SGLD Optimizer.
Parameters
----------
params : iterable
Parameters serving as optimization variable.
lr : float, optional
Base learning rate for this optimizer.
Must be tuned to the specific function being minimized.
Default: `1e-2`.
precondition_decay_rate : float, optional
Exponential decay rate of the rescaling of the preconditioner (RMSprop).
Should be smaller than but nearly `1` to approximate sampling from the posterior.
Default: `0.95`
num_pseudo_batches : int, optional
Effective number of minibatches in the data set.
Trades off noise and prior with the SGD likelihood term.
Note: Assumes loss is taken as mean over a minibatch.
Otherwise, if the sum was taken, divide this number by the batch size.
Default: `1`.
num_burn_in_steps : int, optional
Number of iterations to collect gradient statistics to update the
preconditioner before starting to draw noisy samples.
Default: `3000`.
diagonal_bias : float, optional
Term added to the diagonal of the preconditioner to prevent it from
degenerating.
Default: `1e-8`.
"""
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if num_burn_in_steps < 0:
raise ValueError("Invalid num_burn_in_steps: {}".format(num_burn_in_steps))
defaults = dict(
lr=lr, precondition_decay_rate=precondition_decay_rate,
num_pseudo_batches=num_pseudo_batches,
num_burn_in_steps=num_burn_in_steps,
diagonal_bias=1e-8,
)
super().__init__(params, defaults)
[docs] def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for parameter in group["params"]:
if parameter.grad is None:
continue
state = self.state[parameter]
lr = group["lr"]
num_pseudo_batches = group["num_pseudo_batches"]
precondition_decay_rate = group["precondition_decay_rate"]
gradient = parameter.grad.data
# State initialization {{{ #
if len(state) == 0:
state["iteration"] = 0
state["momentum"] = torch.ones_like(parameter)
# }}} State initialization #
state["iteration"] += 1
momentum = state["momentum"]
# Momentum update {{{ #
momentum.add_(
(1.0 - precondition_decay_rate) * ((gradient ** 2) - momentum)
)
# }}} Momentum update #
if state["iteration"] > group["num_burn_in_steps"]:
sigma = 1. / torch.sqrt(torch.tensor(lr))
else:
sigma = torch.zeros_like(parameter)
preconditioner = (
1. / torch.sqrt(momentum + group["diagonal_bias"])
)
scaled_grad = (
0.5 * preconditioner * gradient * num_pseudo_batches +
torch.normal(
mean=torch.zeros_like(gradient),
std=torch.ones_like(gradient)
) * sigma * torch.sqrt(preconditioner)
)
parameter.data.add_(-lr * scaled_grad)
return loss