Variational Inference with Linear Regression
How to turn point estimates into probability distributions!

Intro

All image generation consists of is being able to accurately approximate the distribution of the images in our training dataset. If we're then able to sample images from this approximated dataset, we have theoretically "generated" an image, i.e. we have a brand new image that seems like it could've come from our training dataset, but in actuality has never been seen before.

Naturally, you might be a little concerned, given that "approximating the distribution" of our training images seems a little complicated. With the correct framing, we see that stable diffusion finds its roots from variational inference, which is what enables us to achieve this seemingly daunting task.

Variational inference is a family of methods that allows us to approximate complicated probability distributions, like the aforementioned "distribution of images." While this specific post won't have us generating any new images today, it'll serve as a stepping stone. We'll cover the very basics of variational inference in the context of linear regression and implement a couple examples from scratch with PyTorch! Future articles will then build upon this to finally yield us our own image generators.

Simple Linear Regression

Let's say we have the following dataset. The code to generate this dataset was borrowed from the TensorFlow Probability documentation linked here.

image

Here's the code. Don't worry too much if you don't understand it.

def generate_data(n=150, b0=5.0, w0=0.125, x_range=(-20, 60)):
    # Noise scale function
    def s(x):
        # g ranges from 0.0 to 1.0 as x spans x_range
        g = (x - x_range[0]) / (x_range[1] - x_range[0])
        return 3.0 * (0.25 + g**2)
 
    # Generate inputs uniformly in the specified range
    x = (x_range[1] - x_range[0]) * torch.rand(n) + x_range[0]  # shape: (n,)
 
    # Generate noise with scale s(x)
    eps = torch.randn(n) * s(x)  # shape: (n,)
 
    # Nonlinear function for y
    y = w0 * x * (1.0 + torch.sin(x)) + b0 + eps  # shape: (n,)
 
    # Reshape x to (n, 1)
    x = x.unsqueeze(-1)
 
    # Create test grid x_tst
    x_tst = torch.linspace(x_range[0], x_range[1], steps=n).unsqueeze(-1)  # shape: (n, 1)
 
    # Normalize y and reshape to (n, 1)
    y = (y - y.mean()) / y.std()
    y = y.unsqueeze(-1)
 
    return x, y, x_tst

Fitting a line to this dataset is pretty trivial with PyTorch. Here's the Model class and the training loop for it below. We could use a Linear layer, but I did things manually to illustrate the simplicity.

class Model(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
 
        # self.linear = torch.nn.Linear(1, 1)
        self.w = torch.nn.Parameter(torch.tensor(0, dtype=torch.float32))
        self.b = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32))
 
    def forward(self, x):
        # return self.linear(x)
        return self.b + self.w * x
 
EPOCHS = 10000
LEARNING_RATE = 0.001
 
model = Model()
 
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.MSELoss()
 
for i in trange(EPOCHS):
    y_pred = model(x)
    loss = loss_fn(y , y_pred)
 
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

The final line looks like this!

image

If all of this wasn't review, take a look at my blog article about linear regression, linked here!

Bayesian linear regression

So now let’s spice things up. Right now, our model just outputs a point prediction y^\hat{y} for any given xx. However, this doesn’t encompass the entirety of the data's characteristics because the points with larger xx values are much more spread out than the points with smaller xx values. So what if instead of point estimates, we have our model output a distribution? If we want to use a Gaussian distribution, that means for any given xx, our model must output two values: a mean and a standard deviation. This will allow our model to not just capture the general trend the y-values exhibit, but also the aleatoric uncertainty or the inherent noise present in the data generation process.

The changes for this are pretty simple.

class Model(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
 
        self.linear = torch.nn.Linear(1, 2) # Outputs 2 values
        self.softplus = torch.nn.Softplus()
 
       
    def forward(self, x):
        x = self.linear(x)
        
        mean = x[..., 0]
        raw_std = x[..., 1]
 
        # std = torch.exp(0.5 * raw_std)
        std = self.softplus(raw_std)
 
        return torch.distributions.normal.Normal(mean, std)

All we need to do is modify our network such that it outputs two values for a prediction rather than one. In our forward function, we use the first column as the prediction mean and the second column as the raw_std. Since we want our standard deviation to always be positive, we pass the this raw value through the Softplus function. Softplus is essentially a smoothed out version of ReLu, seen below.

image

Since PyTorch is tracking all of these operations through the graph it's creating, it'll automatically associate the second column as the input to the Softplus function and optimize the model's parameters appropriately.

Negative Log-Likelihood

The other change we need to make is to our loss function. We're replacing the MSELoss with the negative log-likelihood.

We need to measure if our model's outputted distribution for a given xx is accurate, thus we want to note how "likely" the associated y-value in our training data is under this predicted distribution. This value is known as the "likelihood" and is just the value of the distribution's PDF, parametrized by the model's predicted mean & standard deviation. The higher the likelihood, the better, as it indicates the ground truth y-value is more plausible as a sample from the predicted distribution.

image Assume this is the PDF of our distribution. If our mean is 0 and standard deviation is 1, the likelihood of -1 is 0.241..., or the value of this function evaluated at -1. By having our loss function be the negative log-likelihood, we're telling the optimization to try its best to center this distribution (because the center is where the likelihood is the largest) around the y-values in our training data.

Additionally, it's typically easier to work with logs because of numerical stability (the log makes numbers that are close to zero slightly more spread out), and since the log function is strictly monotonically increasing, it doesn't affect our optimization. Also, multiplication & division become addition & subtraction in log-space, making calculating derivatives a breeze.

Also because our goal is to maximize the likelihood, but PyTorch automatically minimizes any given loss function, we multiply the log-likelihood by -1, hence the name negative log-likelihood!

The log_prob function for any distribution in PyTorch just computes the value of the log of the distribution's PDF. And since our model outputs a Normal distribution object, we can use its built-in log_prob function to evaluate the likelihood of our training data.

def nll(y, y_pred):
    y = y.squeeze()
    return -torch.mean(y_pred.log_prob(y))

Our training loop looks exactly the same, except we switch out the MSELoss with the nll function!

If I plot the results, we get the following! The green lines are two standard deviations away from the mean for any given xx value. What's really amazing is the standard deviation is increasing as our xx values increase! The model is aware of the uncertainty that comes with larger xx values.

image

Instead of plotting the predicted mean and standard deviation, we can also sample from that distribution for any xx value, resulting in the same xx value yielding a different prediction each time! This is what that looks like.

TODO: Add graph about sampling data from our distribution

TODO: Add second graph for showing the spread for x=10,20,30...x=10, 20, 30...

What we're utilizing here is sometimes referred to as the Maximum Likelihood Estimator (MLE), which might help if you check out other resources!

Small Aside on MAP (Maximum a Posteriori)

In the previous example, we incorporated no prior beliefs about the distribution of our y-values. MLE can be handy, but may also fall short with limited data or run the risk of overfitting on our data. Additionally, if we have prior beliefs regarding what we think the distribution of our training y-values might look like, then why not give that to our model? The more information, the better, right?

With MLE, we were optimizing wMLE=argmaxwlogP(Dw)w^{MLE} = \underset{w}{\mathrm{argmax}} \log P(D|w), effectively selecting the set of weights ww maximizes the likelihood of the data DD under the model conditioned on those weights.

If we choose to add a prior over our output distribution P(w)P(w), we can formulate our loss to be:

wMAP=argmaxw[logP(Dw)+logP(w)]=argmaxwlogP(wD)\begin{align*} w^{MAP}= \underset{w}{\mathrm{argmax}} [\log P(D|w) + \log P(w)] = \underset{w}{\mathrm{argmax}} \log P(w|D) \end{align*}

We got this from Bayes Theorem which states:

P(wD)=P(Dw)P(w)P(D)posterior=likelihood×priorevidence\begin{align*} P(w|D) &= \frac{P(D|w)P(w)}{P(D)}\\ \text{posterior} &= \frac{likelihood \times prior}{evidence} \end{align*}

Since our evidence, P(D)P(D), isn't a function of ww it wasn't included in the argmax. This is also known as the Maximum a Posteriori objective, or MAP.

The issue is that in a lot of cases, this posterior distribution, P(wD)P(w|D) is intractable, meaning there's no neat closed-form solution. Solving for it requires many samples ww to estimate, which can get computationally infeasible.

Variational Inference

While it is possible to use MAP on linear regression given the model's simplicity, we'll now dive into variational inference to learn how we'd be able to use it in situations were MAP will not help.

Variational inference is a family of methods where we try to approximate the posterior distribution P(wD)P(w|D) with a variational distribution q(wθ)q(w|\theta), where θ\theta is the learnt parameters of this distribution.

We want our variational posterior to be as closely aligned with the true posterior as possible, which we can measure with the KL divergence! The more accurate our estimate, the smaller our KL divergence will be!

KL Divergence Review

The KL divergence is a measure, with its roots in information theory, of how similar two distributions are. The more similar they are, the smaller the value. For two identical distributions, the KL divergence is 0. The KL divergence is never negative.

Can you tell which pair of distributions has a smaller KL divergence below?

image

Let's check your answer with some simple PyTorch code!

import torch.distributions as dist
 
# Left two distributions
dist1 = dist.Normal(loc=0.0, scale=1.0)
dist2 = dist.Normal(loc=0.5, scale=1.2)
 
# Right two distributions
dist3 = dist.Normal(loc=1.0, scale=0.5)
dist4 = dist.Normal(loc=-1.0,scale=0.5)
 
dist.kl_divergence(dist1, dist2), dist.kl_divergence(dist3, dist4)
# Left KL divergence -> 0.1163
# Right KL divergence -> 8

If you guessed the first pair would have the smaller KL divergence, you'd be correct!

The actual equation for the KL divergence is:

DKL(q(x)p(x))=q(x)log(q(x)p(x))dx\begin{align*} D_{KL}(q(x)||p(x)) = \int q(x)\log\left( \frac{q(x)}{p(x)} \right) dx \end{align*}

Additionally, we can also estimate it by taking repeated samples from q(x)q(x) and subtracting the likelihoods of the two distributions.

DKL(q(x)p(x))=Eq(x)[logq(x)p(x)]=Eq(x)[logq(x)logp(x)]=1ni=1Nlogq(xi)logp(xi)\begin{align*} D_{KL}(q(x)||p(x)) &= \mathbb{E}_{q(x)} \left[ \log \frac{q(x)}{p(x)} \right]\\ &= \mathbb{E}_{q(x)} \left[ \log q(x) - \log p(x) \right]\\ &= \frac{1}{n}\sum_{i=1}^N \log q(x_i) - \log p(x_i) \end{align*}

Back to VI

The KL divergence between our variational posterior and true posterior can be written as DKL(q(wθ)P(wD))D_{\text{KL}}(q(w|\theta) \| P(w|D)). Since we cannot compute our true posterior, let's convert this equation into a form that's easier to manage.

Using the definition of KL divergence we get the following integral.

DKL(q(wθ)P(wD))=q(wθ)log(q(wθ)P(wD))dw\begin{align*} D_{\text{KL}}(q(w|\theta) \| P(w|D)) &= \int q(w|\theta)\log \left(\frac{q(w|\theta)}{P(w|D)}\right)dw\\ \end{align*}

Remember that P(wD)=P(Dw)P(w)P(D)P(w|D)=\frac{P(D|w)P(w)}{P(D)}. We just divide q(wθ)q(w|\theta) by P(Dw)P(w)P(D)\frac{P(D|w)P(w)}{P(D)}.

=q(wθ)log(q(wθ)P(D)P(Dw)P(w))dw=q(wθ)logP(D)dw+q(wθ)logq(wθ)P(w)dwq(wθ)logP(Dw)dw\begin{align*} &= \int q(w|\theta)\log \left(\frac{q(w|\theta)P(D)}{P(D|w)P(w)}\right)dw\\ &= \int q(w|\theta) \log P(D)dw + \int q(w|\theta) \log \frac{q(w|\theta)}{P(w)}dw - \int q(w|\theta) \log P(D|w)dw\\ \end{align*}

On the second line, all we did was use the property of logs, logab=loga+logb\log ab = \log a + \log b and logab=logalogb\log \frac{a}{b} = \log a - \log b, to break up the integral into three parts.

One final algebraic manipulation yields us:

DKL(q(wθ)P(wD))=logP(D)+DKL(q(wθ)P(w))Ewq(wθ)[logP(Dw)]\begin{align*} D_{\text{KL}}(q(w|\theta) \| P(w|D)) &= \log P(D) + D_{\text{KL}}( q(w|\theta)||P(w)) - \mathbb{E}_{w \sim q(w|\theta)} [\log P(D | w)] \end{align*}

For the first part, q(wθ)logP(D)dw\int q(w|\theta) \log P(D)dw just simplifies to logP(D)\log P(D) because q(wθ)dw\int q(w|\theta) dw evaluates to 1 (because it's a probability distribution) and logP(D)\log P(D) isn't a function of ww. However, since P(D)P(D) is just a constant, (we're never optimizing the data itself), we can remove it from our final loss; it doesn't affect our optimization process.

The second component, q(wθ)logq(wθ)P(w)dw\int q(w|\theta) \log \frac{q(w|\theta)}{P(w)}dw yields us another KL divergence term, but this time between our variational posterior and our chosen prior. For the third component, q(wθ)logP(Dw)dw\int q(w|\theta) \log P(D|w)dw, we can rewrite it as an expectation of logP(Dw)\log P(D|w) over q(wθ)q(w|\theta).

This means our loss function that we want to minimize is:

L(θD)=DKL(q(wθ)P(w))Ewq(wθ)[logP(Dw)]\begin{align*} \mathcal{L}(\theta|D) = D_{\text{KL}}( q(w|\theta)||P(w)) - \mathbb{E}_{w \sim q(w|\theta)} [\log P(D | w)] \end{align*}

This term has a neat interpretation. We can see that we are already familiar with the second term. It's the negative log likelihood function that we've already optimized! The only addition here is the KL-Divergence between our variational posterior and prior. You can think of this first term as a regularization term, essentially preventing our variational posterior from straying too far from our chosen prior.

Our loss function is known as the evidence lower bound, or the ELBO. Let's see why. Take this equation again.

DKL(q(wθ)P(wD))=logP(D)+DKL(q(wθ)P(w))Ewq(wθ)[logP(Dw)]\begin{align*} D_{\text{KL}}(q(w|\theta) \| P(w|D)) &= \log P(D) + D_{\text{KL}}( q(w|\theta)||P(w)) - \mathbb{E}_{w \sim q(w|\theta)} [\log P(D | w)]\\ \end{align*}

Moving around a couple terms gives us the following.

logP(D)=Ewq(wθ)[logP(Dw)]DKL(q(wθ)P(w))+DKL(q(wθ)P(wD))\begin{align*} \log P(D) &= \mathbb{E}_{w \sim q(w|\theta)} [\log P(D | w)] - D_{\text{KL}}( q(w|\theta)||P(w)) + D_{KL}(q(w|\theta)||P(w|D))\\ \end{align*}

Again, remember that the whole idea here is that we can't compute P(wD)P(w|D), our true posterior. We don't want it in our equations! Thankfully, since the KL divergence is always zero or positive, we can drop DKL(q(wθ)P(wD))D_{KL}(q(w|\theta)||P(w|D)) to give us a lower bound on the log-evidence, logP(D)\log P(D).

P(D)Ewq(wθ)[logP(Dw)]DKL(q(wθ)P(w))\begin{align*} P(D) &\geq \mathbb{E}_{w \sim q(w|\theta)} [\log P(D | w)] - D_{\text{KL}}( q(w|\theta)||P(w)) \end{align*}

Now hopefully the evidence lower bound name makes sense! It's the core equation we're trying to optimize in variational inference. You'll see it anywhere where we're trying to fit a distribution to some data by maximizing the data's likelihood.

The Code

The only change we need to make to our code is our loss function! In addition to the log likelihood, we also need to add the KL divergence term between our variational posterior & our prior, which in this case we'll assume to be the Normal distribution with mean 0 and standard deviation 1.

def elbo(y_pred, y):
    y = y.squeeze()
    likelihood = y_pred.log_prob(y)
 
    prior = torch.distributions.normal.Normal(0, 1)
    kl_divergence = torch.distributions.kl_divergence(y_pred, prior)
 
    return kl_divergence.mean() - likelihood.mean()

If we run our training loop after making sure we're using the right loss function, we can see how the data is now fit.

image

Comparing to the old version, the regularization force that the KL divergence term plays is quite apparent. The standard deviation lines are much further apart, signifying the model's reluctance to risk overfitting to the dataset.

The Reparameterization Trick

There is one final issue we need to discuss. The current model is returning a Distribution object in PyTorch that we can easily use to calculate the closed form solution to our loss. However, sometimes we may need to use a sampled point from said distribution. This is prevalent in more advanced use cases of VI, such as with variational autoencoders.

Unfortunately, since sampling is not a differentiable operation, this will break our gradient graph & we'll be unable to optimize our model's predictions of the distributional parameters.

Thankfully, we can navigate this with the reparameterization trick. Instead of sampling from our variational posterior distribution, which would break the gradient graph, we can sample a value ϵ\epsilon from the standard normal ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1) and then "morph" this sample into a sample from our variational posterior by multiplying it with our estimated standard deviation and adding it to our estimated mean.

In other words, we turn yN(μθ,σθ)y \sim \mathcal{N}(\mu_\theta, \sigma_\theta) into y=μθ+σθ×ϵy=\mu_\theta + \sigma_\theta \times \epsilon. As shown in the diagram below, this isolates the stochasticity into its own node that doesn't block the path of the gradients to μ\mu and σ\sigma.

recreate picture of reparameterization trick image

While this reparameterization trick isn't necessary for the example in this article, it's quite easy to implement. Let's write a new Model class that outputs a point prediction that represents a sample from our predicted distribution.

class Model(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
 
        self.linear = torch.nn.Linear(1, 2)
        self.softplus = torch.nn.Softplus()
 
    def reparam(self, mu, sigma):
        eps = torch.randn_like(sigma)
        return mu + sigma * eps
 
    def forward(self, x):
        x = self.linear(x)
 
        mean = x[..., 0]
        raw_std = x[..., 1]
        std = self.softplus(raw_std)
 
        return self.reparam(mean, std), mean, std

This means any loss function where we use the reparameterized estimate (the first return value of our forward function), we can easily calculate the loss of with respect to our model's parameters!

PyTorch Distribution objects have an rsample function that use the reparameterization trick to keep sampling from break the gradient graph, but it's always more insightful to do things ourselves!

Now let's modify our elbo function to no longer use the closed-form solution of the KL divergence, but rather the outputted samples from our model.

def loglik_gaussian(y, mu, sigma):
    y = y.squeeze()
    return -0.5 * torch.log(2 * torch.pi * sigma**2) - (1 / (2 * sigma**2))* (y - mu)**2
 
def elbo(y_pred, y, mu, sigma):
	# The likelihood of our original data belonging to this predicted distribution
    likelihood = loglik_gaussian(y, mu, sigma)
 
	# Monte Carlo approx. of the KL divergence
	# y_pred represents samples from our posterior for various x-values
    log_prior = loglik_gaussian(y_pred, torch.tensor(0), torch.tensor(1))
    log_var_post = loglik_gaussian(y_pred, mu, sigma)
 
    kl_divergence_estimate = log_var_post - log_prior
    
    # by taking the mean we approximate the expectation
    return kl_divergence_estimate.mean() - likelihood.mean()

Don't forget that we can approximate the KL divergence DKL(q(wθ)P(w))D_{\text{KL}}( q(w|\theta)||P(w)) with samples from q(wθ)q(w|\theta), which is exactly what we're doing here.

In the above code we've manually coded the log-likelihood function for the Gaussian distribution, which is:

logp(yμ,σ2)=12log(2πσ2)(yμ)22σ2\begin{align*} \log p(y \mid \mu, \sigma^2) = -\frac{1}{2} \log(2\pi \sigma^2) - \frac{(y - \mu)^2}{2\sigma^2} \end{align*}

Try to take the formula for the Gaussian distribution and calculate this yourself!

All Done!

Well guys, that's the basics of Variational Inference! We used it in the context of linear regression, which is like putting a V8 engine in a Toyota Corolla, but we gotta learn how to drive in school zones before we tackle the race track!

I plan on turning this blog post into a YouTube video, and then the next topic we'll tackle is autoencoders. In the meantime, take care and please reach out should there be any questions!