Abstract Base Classes¶
-
class
pysgmcmc.sampling.
Sampler
[source]¶ Enumeration type for all samplers we support.
-
classmethod
get_sampler
(sampling_method, **sampler_args)[source]¶ - Return a sampler object for supported sampling_method, where all
- default values for parameters in keyword dictionary sampler_args are overwritten.
Parameters: Returns: sampler – A sampler instance that implements the specified sampling_method and is initialized with inputs sampler_args.
Return type: Subclass of sampling.MCMCSampler
Examples
We can use this method to construct a sampler for a given sampling method and override default values by providing them as keyword arguments:
>>> import tensorflow as tf >>> params = [tf.Variable(0.)] >>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function >>> session=tf.Session() >>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32) >>> type(sampler) <class 'pysgmcmc.samplers.sghmc.SGHMCSampler'> >>> sampler.dtype tf.float32 >>> session.close()
Construction of SGLD sampler:
>>> import tensorflow as tf >>> params = [tf.Variable(0.)] >>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function >>> session=tf.Session() >>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGLD, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32) >>> type(sampler) <class 'pysgmcmc.samplers.sgld.SGLDSampler'> >>> sampler.dtype tf.float32 >>> session.close()
Construction of Relativistic SGHMC sampler:
>>> import tensorflow as tf >>> params = [tf.Variable(0.)] >>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function >>> session=tf.Session() >>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.RelativisticSGHMC, session=session, params=params, cost_fun=cost_fun, dtype=tf.float32) >>> type(sampler) <class 'pysgmcmc.samplers.relativistic_sghmc.RelativisticSGHMCSampler'> >>> sampler.dtype tf.float32 >>> session.close()
Sampler arguments that do not have a default must be provided as keyword argument, otherwise this method will raise an exception:
>>> sampler = Sampler.get_sampler(Sampler.SGHMC, dtype=tf.float32) Traceback (most recent call last): ... ValueError: sampling.Sampler.get_sampler: params was not overwritten as sampler argument in `sampler_args` and does not have any default value in SGHMCSampler.__init__Please pass an explicit value for this parameter.
If an optional argument is not provided as keyword argument, the corresponding default value is used. If we do not provide/overwrite the dtype keyword argument, the samplers default value of tf.float64 is used:
>>> import tensorflow as tf >>> params = [tf.Variable(0., dtype=tf.float64)] >>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function >>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, session=session, params=params, cost_fun=cost_fun) >>> sampler.dtype tf.float64
If a keyword argument that is provided does not represent a valid parameter of the corresponding sampling_method, a ValueError is raised:
>>> import tensorflow as tf >>> params = [tf.Variable(0., dtype=tf.float64)] >>> cost_fun = lambda params: tf.reduce_sum(params) # dummy cost function >>> with tf.Session() as session: sampler = Sampler.get_sampler(Sampler.SGHMC, unknown_argument=None, session=session, params=params, cost_fun=cost_fun) Traceback (most recent call last): ... ValueError: sampling.Sampler.get_sampler: 'SGHMCSampler' does not take any parameter with name 'unknown_argument' which was specified as argument to this sampler. Please ensure, that you only specify sampler arguments that fit the corresponding sampling method. For your choice of sampling method ('Sampler.SGHMC'), supported parameters are: -params -cost_fun -batch_generator -stepsize_schedule -burn_in_steps -mdecay -scale_grad -session -dtype -seed
-
static
is_burn_in_mcmc
(sampling_method)[source]¶ Static method that returns true if sampling_method is a burn_in sampler (e.g. there is an entry for it in Sampler enum).
Examples
Burn-in sampling methods give True:
>>> Sampler.is_burn_in_mcmc(Sampler.SGHMC) True
Other sampling methods give False:
>>> Sampler.is_burn_in_mcmc(Sampler.RelativisticSGHMC) False
Other input types give False:
>>> Sampler.is_burn_in_mcmc(0) False >>> Sampler.is_burn_in_mcmc("test") False
-
static
is_supported
(sampling_method)[source]¶ Static method that returns true if sampling_method is a supported sampler (e.g. there is an entry for it in Sampler enum).
Examples
Supported sampling methods give True:
>>> Sampler.is_supported(Sampler.SGHMC) True
Other input types give False:
>>> Sampler.is_supported(0) False >>> Sampler.is_supported("test") False
-
classmethod