Source code for pysgmcmc.models.losses
import typing
import torch
from torch.nn.modules.loss import _Loss, _assert_no_grad
from pysgmcmc.torch_typing import (
VariancePrior, WeightPrior, Predictions, Targets,
TorchLoss, TorchLossFunction
)
from pysgmcmc.models.priors import log_variance_prior, weight_prior
[docs]class NegativeLogLikelihood(_Loss):
""" Impementation of BNN negative log likelihood for regression problems. """
name = "NLL"
[docs] def __init__(self, parameters: typing.Iterable[torch.Tensor],
num_datapoints: int,
variance_prior: VariancePrior=log_variance_prior,
weight_prior: WeightPrior=weight_prior,
size_average: bool=True, reduce: bool=False) -> None:
""" Instantiate a loss object for given network `parameters`.
Requires `num_datapoints` of the entire regression dataset
for proper scaling.
Parameters
----------
parameters : typing.Iterable[torch.Tensor]
Pytorch variables of BNN parameters.
num_datapoints : int
Total number of datapoints of the entire regression dataset to process.
variance_prior : pysgmcmc.torch_typing.VariancePrior, optional
Prior for BNN variance. Default: `pysgmcmc.models.priors.log_variance_prior`.
weight_prior : pysgmcmc.torch_typing.WeightPrior, optional
Prior for BNN weights. Default: `pysgmcmc.models.priors.weight_prior`.
"""
assert size_average and not reduce
super().__init__()
self.parameters = tuple(parameters)
self.num_datapoints = num_datapoints
self.log_variance_prior = log_variance_prior
self.weight_prior = weight_prior
[docs] def forward(self, input: Predictions, target: Targets) -> torch.Tensor:
""" Compute NLL for 2d-network predictions `input` and (batch) labels `target`.
Parameters
----------
input : pysgmcmc.torch_typing.Predictions
Network predictions.
target : pysgmcmc.torch_typing.Targets
Labels for each datapoint in the current batch.
Returns
----------
nll: torch.Tensor
Scalar value.
NLL of BNN predictions given as `input` with respect to labels `target`.
"""
_assert_no_grad(target)
batch_size = input.size(0)
prediction_mean = input[:, 0].view((-1, 1))
log_prediction_variance = input[:, 1].view((-1, 1))
prediction_variance_inverse = 1. / (torch.exp(log_prediction_variance) + 1e-16)
mean_squared_error = (target.view(-1, 1) - prediction_mean) ** 2
log_likelihood = torch.sum(torch.sum(-mean_squared_error * (0.5 * prediction_variance_inverse) - 0.5 * log_prediction_variance, dim=1))
log_likelihood = log_likelihood / batch_size
log_likelihood += (
self.log_variance_prior(log_prediction_variance) / self.num_datapoints
)
log_likelihood += self.weight_prior(self.parameters) / self.num_datapoints
return -log_likelihood
[docs]def get_loss(loss_cls: TorchLoss, **loss_kwargs) -> TorchLossFunction:
""" Wrapper to use `NegativeLogLikelihood` interchangeably with other pytorch losses.
`loss_kwargs` is expected to be a dict with key `parameters` mapped to
network parameters and key `num_datapoints` mapped to an integer
representing the amount of datapoints in the entire regression dataset.
Parameters
----------
loss_cls : pysgmcmc.torch_typing.TorchLoss
Class type of a loss, e.g. `pysgmcmc.models.losses.NegativeLogLikelihood`.
loss_kwargs : dict
Keyword arguments to be passed to `loss_cls`.
Must contain keys `parameters` for BNN parameters and `num_datapoints`
for the amount of datapoints in the entire regression dataset.
Returns
----------
loss_instance: pysgmcmc.torch_typing.TorchLossFunction
Instance of `loss_cls`.
"""
if loss_cls is NegativeLogLikelihood:
return NegativeLogLikelihood(**loss_kwargs)
loss_kwargs.pop("parameters")
loss_kwargs.pop("num_datapoints")
return loss_cls(**loss_kwargs)
[docs]def to_bayesian_loss(torch_loss):
""" Wrapper to make pytorch losses compatible with our BNN predictions.
BNN predictions are 2-d, with the second dimension representing model variance.
This wrapper essentially passes only the network mean prediction into `torch_loss`, which allows us to evaluate `torch_loss` on our network predictions as normally.
Parameters
----------
torch_loss: pysgmcmc.torch_typing.TochLoss
Class type of a pytorch loss to evaluate on our BNN, e.g. `torch.nn.MSELoss`.
Returns
----------
torch_loss_changed:
Class type that behaves like `torch_loss` but assumes inputs coming from a BNN.
It will evaluate `torch_loss` on the BNN predictions first dimension,
on the mean prediction, only.
"""
class BayesianLoss(torch_loss):
def forward(self, input, target):
return super().forward(input=input[:, 0], target=target)
BayesianLoss.__name__ = torch_loss.__name__
return BayesianLoss