Source code for pysgmcmc.sampling
# vim: foldmethod=marker
from enum import Enum
[docs]class Sampler(Enum):
""" Enumeration type for all samplers we support. """
SGHMC = "SGHMC"
RelativisticSGHMC = "RelativisticSGHMC"
SGLD = "SGLD"
SVGD = "SVGD"
@staticmethod
[docs] def is_burn_in_mcmc(sampling_method):
"""
Static method that returns true if `sampling_method` is a
burn_in sampler (e.g. there is an entry for it in `Sampler` enum).
Examples
----------
Burn-in sampling methods give `True`:
>>> Sampler.is_burn_in_mcmc(Sampler.SGHMC)
True
Other sampling methods give `False`:
>>> Sampler.is_burn_in_mcmc(Sampler.RelativisticSGHMC)
False
Other input types give `False`:
>>> Sampler.is_burn_in_mcmc(0)
False
>>> Sampler.is_burn_in_mcmc("test")
False
"""
return sampling_method in (Sampler.SGHMC, Sampler.SGLD)
@staticmethod
[docs] def is_supported(sampling_method):
"""
Static method that returns true if `sampling_method` is a
supported sampler (e.g. there is an entry for it in `Sampler` enum).
Examples
----------
Supported sampling methods give `True`:
>>> Sampler.is_supported(Sampler.SGHMC)
True
Other input types give `False`:
>>> Sampler.is_supported(0)
False
>>> Sampler.is_supported("test")
False
"""
return sampling_method in (Sampler.SGHMC, Sampler.SGLD)
@classmethod
[docs] def get_sampler(cls, sampling_method, **sampler_args):
""" Return a sampler object for supported `sampling_method`, where all
default values for parameters in keyword dictionary `sampler_args`
are overwritten.
Parameters
----------
sampling_method : Sampler
Enum corresponding to sampling method to return a sampler for.
**sampler_args : dict
Keyword arguments that contain all input arguments to the desired
the constructor of the sampler for the specified `sampling_method`.
Returns
----------
sampler : Subclass of `sampling.MCMCSampler`
A sampler instance that implements the specified `sampling_method`
and is initialized with inputs `sampler_args`.
Examples
----------
We can use this method to construct a sampler for a given
sampling method and override default values by providing them as
keyword arguments:
>>> import tensorflow as tf
>>> params = [tf.Variable(0.)]
>>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function
>>> session=tf.Session()
>>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32)
>>> type(sampler)
<class 'pysgmcmc.samplers.sghmc.SGHMCSampler'>
>>> sampler.dtype
tf.float32
>>> session.close()
Construction of SGLD sampler:
>>> import tensorflow as tf
>>> params = [tf.Variable(0.)]
>>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function
>>> session=tf.Session()
>>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGLD, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32)
>>> type(sampler)
<class 'pysgmcmc.samplers.sgld.SGLDSampler'>
>>> sampler.dtype
tf.float32
>>> session.close()
Construction of Relativistic SGHMC sampler:
>>> import tensorflow as tf
>>> params = [tf.Variable(0.)]
>>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function
>>> session=tf.Session()
>>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.RelativisticSGHMC, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32)
>>> type(sampler)
<class 'pysgmcmc.samplers.relativistic_sghmc.RelativisticSGHMCSampler'>
>>> sampler.dtype
tf.float32
>>> session.close()
Sampler arguments that do not have a default *must* be provided as keyword
argument, otherwise this method will raise an exception:
>>> sampler = Sampler.get_sampler(Sampler.SGHMC, dtype=tf.float32)
Traceback (most recent call last):
...
ValueError: sampling.Sampler.get_sampler: params was not overwritten as sampler argument in `sampler_args` and does not have any default value in SGHMCSampler.__init__Please pass an explicit value for this parameter.
If an **optional** argument is not provided as keyword argument,
the corresponding default value is used.
If we do not provide/overwrite the `dtype` keyword argument,
the samplers default value of `tf.float64` is used:
>>> import tensorflow as tf
>>> params = [tf.Variable(0., dtype=tf.float64)]
>>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function
>>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, session=session, params=params, cost_fun=cost_fun)
>>> sampler.dtype
tf.float64
If a keyword argument that is provided does not represent a valid
parameter of the corresponding `sampling_method`, a `ValueError` is
raised:
>>> import tensorflow as tf
>>> params = [tf.Variable(0., dtype=tf.float64)]
>>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function
>>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, unknown_argument=None, session=session, params=params, cost_fun=cost_fun)
Traceback (most recent call last):
...
ValueError: sampling.Sampler.get_sampler: 'SGHMCSampler' does not take any parameter with name 'unknown_argument' which was specified as argument to this sampler. Please ensure, that you only specify sampler arguments that fit the corresponding sampling method.
For your choice of sampling method ('Sampler.SGHMC'), supported parameters are:
-params
-cost_fun
-batch_generator
-stepsize_schedule
-burn_in_steps
-mdecay
-scale_grad
-session
-dtype
-seed
"""
if sampling_method == cls.SGHMC:
from pysgmcmc.samplers.sghmc import SGHMCSampler as Get_Sampler
elif sampling_method == cls.SGLD:
from pysgmcmc.samplers.sgld import SGLDSampler as Get_Sampler
elif sampling_method == cls.RelativisticSGHMC:
from pysgmcmc.samplers.relativistic_sghmc import (
RelativisticSGHMCSampler as Get_Sampler
)
elif sampling_method == cls.SVGD:
from pysgmcmc.samplers.svgd import (
SVGDSampler as Get_Sampler
)
else:
raise ValueError(
"Sampling method {sampler} is supported, but function "
"'pysgmcmc.sampling.get_sampler' is missing an `import` "
"statement for the corresponding sampler object. "
"Please add an import in the appropriate location."
)
from inspect import signature, _empty
# look up all initializer parameters with their (potential)
# default values
all_sampler_parameters = signature(Get_Sampler.__init__).parameters
# Check if any invalid sampler arguments were passed
# (sampler arguments that are not actually parameters of the specified)
# sampling method
try:
undefined_parameter = next(
parameter_name for parameter_name in sampler_args
if parameter_name not in all_sampler_parameters
)
except StopIteration:
pass
else:
raise ValueError(
"sampling.Sampler.get_sampler: '{sampler_name}' "
"does not take any parameter with name '{parameter}' "
"which was specified as argument to this sampler. "
"Please ensure, that you only specify sampler arguments "
"that fit the corresponding sampling method.\n"
"For your choice of sampling method ('{sampler}'), supported parameters are:\n"
"{valid_parameters}".format(
sampler_name=Get_Sampler.__name__,
sampler=sampling_method,
parameter=undefined_parameter,
valid_parameters="\n".join(
["-{}".format(parameter_name)
for parameter_name in all_sampler_parameters
if parameter_name != "self"]
)
)
)
def parameter_value(parameter_name):
""" Determine the value to assign to the parameter
with name `parameter_name`.
If `parameter_name` is overwritten (if it is a key in
`sampler_args`) use the value provided in `sampler_args`.
Otherwise, fall back to the default value provided in
the samplers `init` method.
Parameters
----------
parameter_name : string
Name of the parameter that we want to determine the value for.
Returns
-------
value : object
Value of sampler parameter with name `parameter_name` that
will be passed to the initializer of the sampler.
"""
default_value = all_sampler_parameters[parameter_name].default
if parameter_name not in sampler_args and default_value is _empty:
raise ValueError(
"sampling.Sampler.get_sampler: "
"{param_name} was not overwritten as sampler argument "
"in `sampler_args` and does not have any default value "
"in {sampler}.__init__"
"Please pass an explicit value for this parameter.".format(
param_name=parameter_name, sampler=Get_Sampler.__name__
)
)
return sampler_args.get(parameter_name, default_value)
sampler_args = {
parameter_name: parameter_value(parameter_name)
for parameter_name in all_sampler_parameters
if parameter_name != "self" # never pass `self` during construction
}
return Get_Sampler(**sampler_args)