Source code for pysgmcmc.samplers.mixin

import typing

import numpy as np
import torch

[docs]class SamplerMixin(object): """ Mixin class that turns a `torch.nn.optim.Optimizer` into a MCMC sampler."""
[docs] def __init__(self, negative_log_likelihood, params, *args, **kwargs): """ Instantiate a sampler object. (Initial) parameters are passed as iterable `params`, `negative_log_likelihood` is a function mapping parameters to a NLL value and `*args` and `**kwargs` allow specifying additional arguments to pass to a sampler, e.g. `lr` or `mdecay`. Parameters ---------- negative_log_likelihood : typing.Callable[[typing.Iterable[torch.Tensor]], torch.Tensor] Callable mapping parameters to a NLL value. params : iterable Iterable of parameters used to construct samples. See also ---------- pysgmcmc.samplers.sghmc.SGHMC: SGHMC sampler that uses this mixin. """ self.negative_log_likelihood = negative_log_likelihood assert callable(self.negative_log_likelihood) self.params = tuple(params) super().__init__(params=self.params, *args, **kwargs)
@property def parameters(self) -> typing.Tuple[np.ndarray, ...]: """ Return last sample as tuple of numpy arrays. Returns ---------- current_parameters: typing.Tuple[numpy.ndarray, ...] Tuple of numpy arrays containing last sampled values. """ return tuple( np.asarray(torch.tensor( for parameter in self.params )
[docs] def sample_step(self): """ Perform a single step with the sampler. Returns ---------- parameters: typing.Tuple[numpy.ndarray, ...] Current parameters. cost: torch.Tensor NLL value associated with `parameters`. next_parameters: typing.Tuple[numpy.ndarray, ...] Parameters to evaluate on a subsequent call. """ self.zero_grad() last_parameters = self.parameters last_loss = self.negative_log_likelihood(self.params) last_loss.backward() self.step() return last_parameters, last_loss, self.parameters
[docs] def __next__(self): """ Perform a step of this sampler and return parameters with costs. Together with `__iter__`, this allows using samplers as iterables. Returns ---------- parameters: typing.Tuple[numpy.ndarray, ...] Current parameters. cost: torch.Tensor NLL value associated with `parameters`. """ parameters, cost, _ = self.sample_step() return parameters, cost
def __iter__(self): return self