Source code for pysgmcmc.progressbar
import logging
try:
from tqdm import tqdm
except ImportError:
logging.warn(
"`tqdm` package could not be found. No progressbar is available!\n"
"Do `pip install tqdm` to display a progressbar."
)
class TrainingProgressbar(object):
""" Dummy progressbar that displays nothing and simply iterates the iterable.
Used as placeholder if `tqdm` is not installed.
"""
def __init__(self, iterable, *args, **kwargs):
self.iterable = iterable
def update(self, *args, **kwargs):
pass
def __iter__(self):
return self
def __next__(self):
return next(self.iterable)
else:
[docs] class TrainingProgressbar(tqdm):
""" Slightly customized `tqdm` progressbar. """
[docs] def __init__(self, losses=None, update_every=100, *args, **kwargs):
""" Set up progressbar to track `losses` and update in a given interval.
Parameters
----------
losses : typing.Iterable[pysgmcmc.torch_typing.TorchLossFunction], optional
Iterable of `torch.nn.modules.loss._Loss` subclasses to display.
Default: `None`, do not display additional loss metrics.
update_every : int, optional
Interval to update this progressbar.
Default: `100`, update every `100` iterations.
"""
super().__init__(*args, **kwargs)
self.losses = losses
self.update_every = update_every
if not losses:
self.losses = dict()
def update(self, predictions, y_batch, epoch):
""" Check this progressbar for update.
Recompute loss values and prettyprints them.
Parameters
----------
predictions: pysgmcmc.torch_typing.Predictions
BNN predictions on current batch.
y_batch: pysgmcmc.torch_typing.Targets
Labels of current batch.
epoch: int
Current epoch count.
"""
if epoch % self.update_every != 0:
return
postfix = tuple(
"{loss}: {value}".format(
loss=loss_name, value=loss_function(input=predictions, target=y_batch)
)
for loss_name, loss_function in self.losses.items()
)
self.set_postfix_str(" - ".join(postfix))