SVGD¶
-
class
pysgmcmc.samplers.svgd.
SVGDSampler
(particles, cost_fun, batch_generator=None, stepsize_schedule=<pysgmcmc.stepsize_schedules.ConstantStepsizeSchedule object>, alpha=0.9, fudge_factor=1e-06, session=None, dtype=tf.float64, seed=None)[source]¶ Stein Variational Gradient Descent Sampler.
See [1] for more details on stein variational gradient descent.
- [1] Q. Liu, D. Wang
In Advances in Neural Information Processing Systems 29 (2016).
Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm.
-
__init__
(particles, cost_fun, batch_generator=None, stepsize_schedule=<pysgmcmc.stepsize_schedules.ConstantStepsizeSchedule object>, alpha=0.9, fudge_factor=1e-06, session=None, dtype=tf.float64, seed=None)[source]¶ - Initialize the sampler parameters and set up a tensorflow.Graph
- for later queries.
Parameters: - particles (List[tensorflow.Variable]) – List of particles each representing a (different) guess of the target parameters of this sampler.
- cost_fun (callable) – Function that takes params of one particle as input and returns a 1-d tensorflow.Tensor that contains the cost-value. Frequently denoted with U in literature.
- batch_generator (iterable, optional) – Iterable which returns dictionaries to feed into tensorflow.Session.run() calls to evaluate the cost function. Defaults to None which indicates that no batches shall be fed.
- stepsize_schedule (pysgmcmc.stepsize_schedules.StepsizeSchedule) – Iterator class that produces a stream of stepsize values that we can use in our samplers. See also: pysgmcmc.stepsize_schedules
- alpha (float, optional) – TODO DOKU Defaults to 0.9.
- fudge_factor (float, optional) – TODO DOKU Defaults to 1e-6.
- session (tensorflow.Session, optional) – Session object which knows about the external part of the graph (which defines Cost, and possibly batches). Used internally to evaluate (burn-in/sample) the sampler.
- dtype (tensorflow.DType, optional) – Type of elements of tensorflow.Tensor objects used in this sampler. Defaults to tensorflow.float64.
- seed (int, optional) – Random seed to use. Defaults to None.
See also
pysgmcmc.sampling.MCMCSampler()
- Base class for SteinVariationalGradientDescentSampler that specifies how actual sampling is performed (using iterator protocol, e.g. next(sampler)).