diff --git a/_posts/2024-01-20-flow-matching.md b/_posts/2024-01-20-flow-matching.md index cebbb94..af1a456 100644 --- a/_posts/2024-01-20-flow-matching.md +++ b/_posts/2024-01-20-flow-matching.md @@ -96,6 +96,10 @@ draft: true color: red; text-decoration: line-through; } + +main .image-container .caption { + text-align: center; +} @@ -139,7 +143,13 @@ $$ \nonumber $$ -# Flow matching +# Table of contents +{:.no_toc} + +1. placeholder +{:toc} + +# Introduction @@ -147,6 +157,7 @@ $$ ## Generative Modelling +{:.no_toc} Let's assume we have data samples $x_1, x_2, \ldots, x_n$ from a distribution of interest $q_1(x)$, which density is unknown. We're interested in using these samples to learn a probabilistic model approximating $q_1$. In particular, we want efficient generation of new samples (approximately ) distributed from $q_1$. This task is referred to as **generative modelling**. @@ -186,10 +197,11 @@ The advancement in generative modelling methods over the past decade has been no ## Outline +{:.no_toc} Flow Matching (FM) models are in nature most closely related to (Continuous) Normalising Flows (CNFs). Therefore, we start this blogpost by briefly recapping the core concepts behind CNFs. We then continue by discussing the difficulties of CNFs and how FM models address them. -## Basics: Normalising Flows +# Normalising Flows Let $\phi: \mathbb{R}^d \rightarrow \mathbb{R}^d$ be a continuously differentiable function which transforms elements of $\mathbb{R}^d$, with a continously differentiable inverse $\phi^{-1}: \mathbb{R}^d \to \mathbb{R}^d$. Let $q_0(x)$ be a density on $\mathbb{R}^d$ and let $p_1(\cdot)$ be the density induced by the following sampling procedure @@ -208,12 +220,13 @@ $$ \begin{align} \label{eq:changevar} p_1(y) &= q_0(\phi^{-1}(y)) \abs{\det\left[\frac{\partial \phi^{-1}}{\partial y}(y)\right]} \\ +\label{eq:changevar-alt} &= \frac{q_0(x)}{\abs{\det\left[\frac{\partial \phi}{\partial x}(x)\right]}} \quad \text{with } x = \phi^{-1}(y) \end{align} $$ where the last equality can be seen from the fact that $\phi \circ \phi^{-1} = \Id$ and a simple application of the chain rule[^chainrule]. The quantity $\frac{\partial \phi^{-1}}{\partial y}$ is the Jacobian of the inverse map. It is a matrix of size $d\times d$ containing $J_{ij} = \frac{d\phi^{-1}_i}{dx_j}$. -Depending on the task at hand, evaluation of likelihood or sampling, one of the two formulation of $\eqref{eq:changevar}$ is preferred. +Depending on the task at hand, evaluation of likelihood or sampling, the formulation in $\eqref{eq:changevar}$ or $\eqref{eq:changevar-alt}$ is preferred (Friedman, 1987; Chen & Gopinath, 2000).
#### Example: Transformation of 1D Gaussian variables by linear map @@ -399,6 +412,7 @@ $$ #### Continuous change-in-variables +{:.no_toc} Of course, this only defines the map $\phi_t(x)$; for this to be a useful normalising flow, we still need to compute the log-abs-determinant of the Jacobian! @@ -412,7 +426,7 @@ $$ $$ -This statement on the time-evolution of $p_t$ is generally known as the *Continuity Equation* or *Transport Equation*. We refer to $p_t$ as the probability path induced by $u_t$. +This statement on the time-evolution of $p_t$ is generally known as the *Transport Equation*. We refer to $p_t$ as the probability path induced by $u_t$. Computing the *total* derivative (as $x_t$ also depends on $t$) in log-space yields[^log_pdf] @@ -437,14 +451,16 @@ $$ \log p_\theta(x) \triangleq \log p_1(x) = \log p_0(x_0) - \int_0^1 (\nabla \cdot u_\theta)(x_t) \dd t. $$ -In practice, both the time evolution of $x_t$ and its log density $\log p_t$ are solved jointly +In practice, to compute $\log p_t$ one can either solve both the time evolution of $x_t$ and its log density $\log p_t$ jointly $$ \begin{equation} -\frac{\dd}{\dd t} \Biggl( \begin{aligned} x_t \ \quad \\ \log p(x_t) \end{aligned} \Biggr) = \Biggl( \begin{aligned} u_\theta(t, x_t) \quad \\ - \div u_\theta(t, x_t) \end{aligned} \Biggr). +\frac{\dd}{\dd t} \Biggl( \begin{aligned} x_t \ \quad \\ \log p_t(x_t) \end{aligned} \Biggr) = \Biggl( \begin{aligned} u_\theta(t, x_t) \quad \\ - \div u_\theta(t, x_t) \end{aligned} \Biggr), \end{equation} $$ +or solve only for $x_t$ and then use quadrature methods to estimate $\log p_t(x_t)$. + Feeding this (joint) vector field to an adaptive step-size ODE solver allows us to control both the error in the sample $x_t$ and the error in the $\log p_t(x)$. One may legitimately wonder why should we bother with such *time-continuous* flows versus *discrete* residual flows. There are a couple of benefits: @@ -459,6 +475,7 @@ Now that you know why CNFs are cool, let's have a look at what such a flow would
#### A simple example: $u_t$ from a Gaussian to a Gaussian +{:.no_toc} Let's come back to our earlier example of mapping a 1D Gaussian to another one with different mean. In contrast to previously where we derived a 'one-shot' (i.e. *discrete*) flow bridging between the two Gaussians, we now aim to derive a time-*continuous flow* $\phi_t$ which would correspond to the time integrating a vector field $u_t$. @@ -571,6 +588,8 @@ We could of course have gone the other way, i.e. define the $u_t$ such that $p_0 #### Training CNFs +{:.no_toc} + Similarly to any flows, CNFs can be trained by maximum log-likelihood $$ \mathcal{L}(\theta) = \mathbb{E}_{x\sim q_1} [\log p_1(x)], @@ -587,7 +606,7 @@ CNFs are very expressive as they parametrise a large class of flows, and therefo -## Flow matching +# Flow matching And that is exactly where Flow Matching (FM) comes in! @@ -673,7 +692,7 @@ $$ -->
{% include image.html - ref="forward_samples-one-color-3.png" + ref="forward_samples_ot-one-color.png" src="flow-matching/forward_samples-one-color-3.png" width=400 %} @@ -682,7 +701,7 @@ $$ -->
{% include image.html - ref="forward_samples_ot-one-color.png" + ref="forward_samples-one-color-3.png" src="flow-matching/forward_samples_ot-one-color.png" width=400 %} @@ -693,7 +712,7 @@ $$ -->

-Figure : *Different paths with the same endpoints marginals[^interpolation].* +Figure 7: *Different paths with the same endpoints marginals[^interpolation].*

@@ -708,7 +727,7 @@ Figure : *Different paths with the same endpoints marginals[^interpolation].* ### Conditional Flows -First, let's remind that the *continuity equation* relates a vector field $u_t$ to (the time evolution of) a probability path $p_t$ +First, let's remind ourselves that the transport equation relates a vector field $u_t$ to (the time evolution of) a probability path $p_t$ $$ \begin{equation} \pdv{p_t(x)}{t} = - \nabla \cdot \big( u_t(x) p_t(x) \big), @@ -742,16 +761,16 @@ with $\sigmamin > 0$ small, and for whatever reference $p_0$ we choose, typicall
{% include image.html - name="Figure" - alt="Two probability path $p_t(x|x_1)$." + name="Figure 8" + alt="Two conditional flows $\phi_t(x \mid x_1)$ for two univariate Gaussians." ref="heatmap_with_cond_traj-v3" src="flow-matching/heatmap_with_cond_traj-v3.png" - width=400 + width=600 %}
-The conditional probability path also satisfies the continuity equation with the **conditional vector field** $u_t(x \mid x_1)$: +The conditional probability path also satisfies the transport equation with the **conditional vector field** $u_t(x \mid x_1)$: $$ \begin{equation} @@ -760,7 +779,7 @@ $$ \end{equation} $$ - -To see why this $u_t$ the same the vector field as the one defined earlier, i.e. the one generating the (marginal) pribability path $p_t$, we need to show that the expression above for the marginal vector field $u_t(x)$ satisfies the continuity equation +To see why this $u_t$ the same the vector field as the one defined earlier, i.e. the one generating the (marginal) pribability path $p_t$, we need to show that the expression above for the marginal vector field $u_t(x)$ satisfies the transport equation $$ \begin{equation} @@ -809,6 +828,40 @@ $$ where in the $\hlone{\text{first highlighted step}}$ we used \eqref{eq:continuity-cond} and in the $\hltwo{\text{last highlighted step}}$ we used the expression of $u_t(x)$ in \eqref{eq:cf-from-cond-vf}. +The relation between $\phi_t(x_0)$, $\phi_t(x_0 \mid x_1)$ and their induced densities are illustrated in the [Figure 9](#figure-flow-matching-diagram) below. And since $\phi_t(x_0)$ and $\phi_t(x_0 \mid x_1)$ are solutions corresponding to the vector fields $u_t(x)$ and $u_t(x \mid x_1)$ with $x(0) = x_0$, [Figure 9](#figure-flow-matching-diagram) is equivalent to [Figure 10](#figure-flow-matching-diagram-2). + +
+
+
+ + +{% include image.html + name="Figure 9" + alt="Diagram illustrating the relation between the paths $\phi_t(x_0)$, $\phi_t(x_0 \mid x_1)$, and their induced marginal and conditional densities." + ref="flow-matching-diagram" + src="flow-matching/flow-matching-diagram.png" + width=500 +%} + +
+
+
+ +
+
+
+ +{% include image.html + name="Figure 10" + alt="Diagram illustrating the relation between the vector fields $u_t(x_0)$, $u_t(x_0 \mid x_1)$, and their induced marginal and conditional densities." + ref="flow-matching-diagram-2" + src="flow-matching/flow-matching-diagram-2.png" + width=700 +%} + +
+
+
@@ -838,15 +891,34 @@ p_0 = \mathcal{N}([-\mu, 0], I) \quad & \text{and} \quad p_1 = \mathcal{N}([+\mu \end{split} \end{equation} $$ -for some $\mu > 0$. We're effectively transforming a Gaussian to another Gaussian using a simple time-linear map, as illustrated in the following figure. +with $\mu = 10$ unless otherwise specified. We're effectively transforming a Gaussian to another Gaussian using a simple time-linear map, as illustrated in the following figure. + +
+
+
+ +{% include image.html + name="Figure 11" + alt="Example conditional paths $\phi_t(x_0 \mid x_1)$ of \eqref{eq:g2g} with $\mu = 10$." + ref="g2g-cond-paths-one-color" + src="flow-matching/g2g-cond-paths-one-color.png" + width=400 +%} + +
+
+
+ +In the end, we're really just interested in learning the *marginal* paths $\phi_t(x_0)$ for initial points $x_0$ that are probable under $p_0$, which we can then use to generate samples $x_1 = \phi_1(x_0)$. In this simple example, we can obain closed-form expressions for $\phi_t(x_0)$ corresponding to the conditional paths $\phi_t(x_0 \mid x_1)$ of \eqref{eq:g2g}, as visualised below.
{% include image.html - name="Figure" - ref="g2g-forward_samples-one-color.png" + name="Figure 12" + alt="Example marginal paths $\phi_t(x_0)$ of \eqref{eq:g2g} with $\mu = 10$." + ref="g2g-marginal-paths-one-color" src="flow-matching/g2g-forward_samples-one-color.png" width=400 %} @@ -855,7 +927,7 @@ for some $\mu > 0$. We're effectively transforming a Gaussian to another Gaussia
-In the end, we're really just interested in learning the marginal paths $\phi_t(x_0)$ for initial points $x_0$ that are probable under $p_0$. With that in mind, let's pick a random initial point $x_0$ from $p_0$, and then compare a MC estimator for $u_t(x_0)$ at different values of $t$ along the path $\phi_t(x_0)$, i.e. we'll be looking at +With that in mind, let's pick a random initial point $x_0$ from $p_0$, and then compare a MC estimator for $u_t(x_0)$ at different values of $t$ along the path $\phi_t(x_0)$, i.e. we'll be looking at @@ -866,7 +938,7 @@ u_t \big( \phi_t(x_0) \big) &\approx \frac{1}{n} \sum_{i = 1}^n u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big) \ \text{with } x_1^{(i)} \sim p_{1|t}(x_1 \mid \phi_t(x_0)). \end{align} $$ -In practice we don't have access to the posterior $p_{1|t}(x_1|x_t)$, but in this specific setting we do have closed-form expressions for everything (Albergo & Boffi, 2023), and so we can visualise the marginal vector field $u_t\big( \phi_t(x_0)\big)$ and the conditional vector fields $u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big)$ for all our "data" samples $x_1^{(i)}$ and see how they compare. This is shown in the figure below. +In practice we don't have access to the posterior $p_{1|t}(x_1|x_t)$, but in this specific setting we do have closed-form expressions for everything (Albergo et al., 2023), and so we can visualise the marginal vector field $u_t\big( \phi_t(x_0)\big)$ and the conditional vector fields $u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big)$ for all our "data" samples $x_1^{(i)}$ and see how they compare. This is shown in the figure below.
@@ -898,7 +970,7 @@ In practice we don't have access to the posterior $p_{1|t}(x_1|x_t)$, but in thi

-Figure 7: *Marginal vector field $u_t(x)$ vs. conditional vector field $u_t(x \mid x_1)$ for samples $x_1 \sim p_1$. Here $p_0 = p_1 = \mathcal{N}(0, 1)$ and the two trajectories are according to the marginal vector field $u_t(x)$. Samples $x_1$ transparency is given by the IS weight $p_t(x \mid x_1) / p_t(x)$.* +Figure 13: Marginal vector field $u_t(x)$ vs. conditional vector field $u_t(x \mid x_1)$ for samples $x_1 \sim p_1$. Here $p_0 = p_1 = \mathcal{N}(0, 1)$ and the two trajectories are according to the marginal vector field $u_t(x)$. Samples $x_1$ transparency is given by the IS weight $p_t(x \mid x_1) / p_t(x)$.

@@ -935,11 +1007,9 @@ $$ \nabla_\theta \mathcal{L}_{\mathrm{FM}}(\theta) = \nabla_\theta \mathcal{L}_{\mathrm{CFM}}(\theta), \end{equation} $$ -which implies that we can use ${\mathcal{L}}_{\text{CFM}}$ - -instead to train the parametric vector field $u_{\theta}$. +which implies that we can use $${\mathcal{L}}_{\text{CFM}}$$ instead to train the parametric vector field $u_{\theta}$. The defer the full proof to the footnote[^CFM], but show the key idea below. -By developing the squared norm in both losses, we can easily show that the squared terms are equal or independant of $\theta$. +By developing the squared norm in both losses, we can easily show that the squared terms are equal or independent of $\theta$. Let's develop inner product term for ${\mathcal{L}}_{\text{FM}}$ and show that it is equal to the inner product of ${\mathcal{L}}_{\text{CFM}}$: @@ -985,7 +1055,7 @@ As a result both endpoints constraint are satisfied since ones recovers
We have defined a probability path $p_t$ in terms of conditional probability path $p_t(\cdot|x_1)$, yet how do we define the latter? -We know that the continuity equation $\frac{\partial}{\partial_t} p_t(x_t) = - (\nabla \cdot (u_t p_t))(x_t)$ relates a vector field (i.e. vector field) to a propability path $p_t$ (given an initial value $p_{t=0} = q_0$). +We know that the transport equation $\frac{\partial}{\partial_t} p_t(x_t) = - (\nabla \cdot (u_t p_t))(x_t)$ relates a vector field (i.e. vector field) to a propability path $p_t$ (given an initial value $p_{t=0} = q_0$). As such it is sufficient to construct a _conditional vector field_ $u_t(\cdot|x_1)$ which induces a conditional probability path $p_t(\cdot|x_1)$ with the right boundary conditions. --> @@ -1087,6 +1157,7 @@ The simplest solution to the above is then just
#### Example: Linear interpolation +{:.no_toc} A simple choice for the mean $\mu_t(x_1)$ and std. $\sigma_t(x_1)$ is the linear interpolation for both, i.e. @@ -1101,7 +1172,7 @@ so that $$ \begin{equation} -\big( \hlone{\mu_0(x_1)} + \hlthree{\sigma_0(x_1)} x_1 \big) \sim p_0 \quad \text{and} \quad \big( \hlone{\mu_1(x_1)} + \hlthree{\sigma_1(x_1)} x_1 \big) \sim \mathcal{N}(x_1, \sigmamin^2 I) +\big( {\hlone{\mu_0(x_1)}} + {\hlthree{\sigma_0(x_1)}} x_1 \big) \sim p_0 \quad \text{and} \quad \big( {\hlone{\mu_1(x_1)}} + {\hlthree{\sigma_1(x_1)}} x_1 \big) \sim \mathcal{N}(x_1, \sigmamin^2 I) \end{equation} $$ @@ -1133,7 +1204,7 @@ Below you can see the difference between $\phi_t(x_0)$ (top figure) and $\phi_t(
{% include image.html - name="Figure 8" + name="Figure 14" alt="Realizations of paths from $p_0$ to $p_1$ following conditional vector fields $u_t(x \mid x_1)$. Paths are highlighted by the sign of the 2nd vector component at time $t=1$." ref="g2g-vector-field-samples-cond.png" src="flow-matching/g2g-vector-field-samples-cond.png" @@ -1145,7 +1216,7 @@ Below you can see the difference between $\phi_t(x_0)$ (top figure) and $\phi_t(
{% include image.html - name="Figure 9" + name="Figure 15" alt="Paths from $p_0$ to $p_1$ following the true marginal vector field $u_t(x)$. Paths are highlighted by the sign of the 2nd vector component." ref="g2g-forward_samples.png" src="flow-matching/g2g-forward_samples.png" @@ -1225,10 +1296,10 @@ result in paths that are quite different from the marginal paths as illustrated
{% include image.html - name="Figure 10" + name="Figure 16" alt="Realizations of conditional paths from $p_0 = p_1 = \mathcal{N}(0, 1)$ for two different $x_1^{(i)}, x_1^{(2)} \sim q$ with conditional vector field given by $u_t(x \mid x_1) = (1 - t) x + t x_1$." ref=".png" - src="flow-matching/g2g-forward_samples.png" + src="flow-matching/g2g-vector-field-samples-cond.png" width=400 %} @@ -1240,7 +1311,7 @@ result in paths that are quite different from the marginal paths as illustrated
{% include image.html - name="Figure 11" + name="Figure 17" alt="Paths from $p_0$ to $p_1$ following the true marginal vector field $u_t(x)$. Paths are highlighted by the sign of the 2nd vector component." ref=".png" src="flow-matching/g2g-forward_samples.png" @@ -1268,8 +1339,19 @@ where $t \sim \mathcal{U}[0, 1]$, $\hlone{x_1^{(1)}}, \hlthree{x_1^{(2)}} \sim q In such a scenario, we're attempting to align $u_{\theta}(t, x)$ with two different vector fields whose corresponding paths are impossible under the marginal vector field $u(t, x)$ that we're trying to learn! This fact can lead to increased variance in the gradient estimate, and thus slower convergence. -In slightly more complex scenarios, the situation becomes even more striking. Below we see a nice example from Liu et al. (2022) where our reference and target are two different mixture of Gaussians in 2D. Here we see that marginal paths (bottom figure) end up looking *very* different from the conditional paths (top figure). Indeed, at training time paths may intersect, whilst at sampling time they cannot (due to the uniqueness of the ODE solution). As such we see on the bottom plot that some (marginal) paths are quite curved and would therefore require a greater number of discretisation steps from the ODE solver during inference. - +In slightly more complex scenarios, the situation becomes even more striking. Below we see a nice example from Liu et al. (2022) where our reference and target are two different mixture of Gaussians in 2D differing only by the sign of the mean in the x-component. Specifically, +$$ +\begin{equation} +\tag{MoG-to-MoG} +\label{eq:mog2mog} +\begin{split} +p_{\hlone{0}} &= (1 / 2)\mathcal{N}([{\hlone{-\mu}}, -\mu], I) + (1 / 2) \mathcal{N}([{\hlone{-\mu}}, +\mu], I) \\ +\text{and} \quad p_{\hltwo{1}} &= (1 / 2) \mathcal{N}([{\hltwo{+\mu}}, -\mu], I) + (1 / 2) \mathcal{N}([{\hltwo{+\mu}}, +\mu], I) \\ +\text{with} \quad \phi_t(x_0 \mid x_1) &= (1 - t) x_0 + t x_1 +\end{split} +\end{equation} +$$ +where we set $\mu = 10$, unless otherwise specified.
@@ -1278,8 +1360,8 @@ In slightly more complex scenarios, the situation becomes even more striking. Be
{% include image.html - name="Figure 12" - alt="Realizations of conditional paths from $p_0 = \mathcal{N}([-\mu, 0], I)$ to $p_1 = \mathcal{N}([\mu, 0], I)$ following realizations of the conditional vector field $u_t(x \mid x_1)$. Paths are highlighted by the sign of the 2nd vector component." + name="Figure 18" + alt="Realizations of conditional paths following conditional vector field $u_t(x \mid x_1)$ from \eqref{eq:mog2mog}. Paths are highlighted by the sign of the 2nd vector component." ref="vector-field-samples-cond.png" src="flow-matching/vector-field-samples-cond.png" width=400 @@ -1290,8 +1372,8 @@ In slightly more complex scenarios, the situation becomes even more striking. Be
{% include image.html - name="Figure 13" - alt="Paths from $p_0 = \mathcal{N}([-\mu, 0], I)$ to $p_1 = \mathcal{N}([\mu, 0], I)$ following the true marginal vector field $u_t(x)$. Paths are highlighted by the sign of the 2nd vector component." + name="Figure 19" + alt="Realizations of marginal paths following the marginal vector field $u_t(x)$ from \eqref{eq:mog2mog}. Paths are highlighted by the sign of the 2nd vector component." ref="vector-field-samples-marginal.png" src="flow-matching/vector-field-samples-marginal.png" width=400 @@ -1303,6 +1385,7 @@ In slightly more complex scenarios, the situation becomes even more striking. Be
+Here we see that marginal paths (bottom figure) end up looking *very* different from the conditional paths (top figure). Indeed, at training time paths may intersect, whilst at sampling time they cannot (due to the uniqueness of the ODE solution). As such we see on the bottom plot that some (marginal) paths are quite curved and would therefore require a greater number of discretisation steps from the ODE solver during inference. We can also see how this leads to a significant variance of the CFM loss estimate for $t \approx 0.5$ in the figure below. More generally, samples from the reference distribution which are arbitrarily close to eachothers can be associated with either target modes, leading to high variance in the vector field regression loss. @@ -1315,8 +1398,8 @@ More generally, samples from the reference distribution which are arbitrarily cl
{% include image.html - name="Figure 14" - alt="Realizations of conditional paths from $p_0 = \mathcal{N}([-\mu, 0], I)$ to $p_1 = \mathcal{N}([\mu, 0], I)$ following realizations of the conditional vector field $u_t(x \mid x_1)$." + name="Figure 20" + alt="Realizations of conditional paths $\phi_t(x_0 \mid x_1)$ following the conditional vector field $u_t(x \mid x_1)$ for \eqref{eq:mog2mog}." ref="vector-field-samples-with-traj.png" src="flow-matching/vector-field-samples-with-traj.png" width=400 @@ -1328,8 +1411,8 @@ More generally, samples from the reference distribution which are arbitrarily cl
{% include image.html - name="Figure 15" - alt="Variance of conditional vector field over $p_{1|t}$ for both blue and red trajectories." + name="Figure 21" + alt="Variance of conditional vector field over $p_{1|t}$ for both blue and red trajectories for \eqref{eq:mog2mog}." ref="variance_cond_vector_field.png" src="flow-matching/variance_cond_vector_field.png" width=400 @@ -1406,8 +1489,8 @@ e.g. $\ p(x_t | x_1) = \mathrm{N}(x_t|tx_1, (1-t)^2)$.
{% include image.html - name="Figure 16" - alt="One sided interpolation." + name="Figure 22" + alt="One sided interpolation. Source: Figure (2) in Albergo et al. (2023)." ref="albergo_one_sided.jpg" src="flow-matching/albergo_one_sided.jpg" width=800 @@ -1438,8 +1521,8 @@ For instance, a deterministic linear interpolation gives $p(x_t \mid x_0, x_1) =
{% include image.html - name="Figure 17" - alt="Two sided interpolation." + name="Figure 23" + alt="Two sided interpolation. Source: Figure (2) in Albergo et al. (2023)." ref="albergo_two_sided.jpg" src="flow-matching/albergo_two_sided.jpg" width=800 @@ -1463,12 +1546,16 @@ $$ --> #### Optimal Transport (OT) coupling +{:.no_toc} -Now let's go back to the idea of *not* using an independant coupling (i.e. pairing) but instead to correlate pairs $(x_1, x_0)$ with a joint $q(x_1, x_0) \neq q_1(x_1) q_0(x_0)$. +Now let's go back to the idea of *not* using an independent coupling (i.e. pairing) but instead to correlate pairs $(x_1, x_0)$ with a joint $q(x_1, x_0) \neq q_1(x_1) q_0(x_0)$. Tong et al. (2023) and Pooladian et al. (2023) suggest using the *optimal transport coupling* + $$ \begin{equation} +\tag{OT} +\label{eq:ot} q(x_1, x_0) = \pi(x_1, x_0) \in \arg\inf_{\pi \in \Pi} \int \|x_1 - x_0\|_2^2 \mathrm{d} \pi(x_1, x_0) \end{equation} $$ @@ -1477,7 +1564,7 @@ which minimises the optimal transport (i.e. Wasserstein) cost (Monge, 1781, Peyr The OT coupling $\pi$ associates samples $x_0$ and $x_1$ such that the total distance is minimised. -This OT coupling is illustrated in the right hand side of the figure below. In contrast to the middle figure which an independent coupling, the OT one does not have paths that cross. This leads to lower training variance and faster sampling[^OT]. +This OT coupling is illustrated in the right hand side of the figure below, adapted from Tong et al. (2023). In contrast to the middle figure which an independent coupling, the OT one does not have paths that cross. This leads to lower training variance and faster sampling[^OT].
@@ -1487,7 +1574,7 @@ This OT coupling is illustrated in the right hand side of the figure below. In c
{% include image.html - name="Figure 18" + name="Figure 24" alt="One-sided conditioning (Lipman et al., 2022)" ref="trajectory-marginals-vertical.png" src="flow-matching/trajectory-marginals-vertical.png" @@ -1498,7 +1585,7 @@ This OT coupling is illustrated in the right hand side of the figure below. In c
{% include image.html - name="Figure 19" + name="Figure 25" alt="Two-sided conditioning (Tong et al., 2023)" ref="trajectory-marginals-vertical-cond.png" src="flow-matching/trajectory-marginals-vertical-cond.png" @@ -1509,7 +1596,7 @@ This OT coupling is illustrated in the right hand side of the figure below. In c
{% include image.html - name="Figure 20" + name="Figure 26" alt="OT coupling (Tong et al., 2023)" ref="trajectory-marginals-vertical-ot.png" src="flow-matching/trajectory-marginals-vertical-ot.png" @@ -1533,17 +1620,97 @@ This OT coupling is illustrated in the right hand side of the figure below. In c In practice, we cannot compute the optimal coupling $\pi$ between $x_1 \sim q_1$ and $x_0 \sim q_0$, as algorithms solving this problem are only known for finite distributions. In fact, finding a map from $q_0$ to $q_1$ is the generative modelling problem that we are trying to solve in the first place! -Tong et al. (2023) and Pooladian et al. (2023) propose to approximate the OT coupling $\pi$ by computing such optimal coupling only over each mini-batch of data and noise samples, coined **mini-batch OT**. This is scalable as for finite collection of samples the OT problem can be computed with quadratic complexity via the Sinkhorn algorithm (Peyre and Cuturi, 2020). This algorithm returns a permutation $\sigma$ of a random pairing between samples -$$ -\{x_0^{(i)}\}_{i=1,\dots,B} \quad\text{and}\quad \{x_1^{(i)}\}_{i=1,\dots,B}, -$$ -such that $\sum_{i,j} \|x_1^{(i)} - x_0^{(\sigma(j))}\|^2$ is minimised (over all potential permutations). +Tong et al. (2023) and Pooladian et al. (2023) propose to approximate the OT coupling $\pi$ by computing such optimal coupling only over each mini-batch of data and noise samples, coined **mini-batch OT** (Fatras et al., 2020). This is scalable as for finite collection of samples the OT problem can be computed with quadratic complexity via the Sinkhorn algorithm (Peyre and Cuturi, 2020). This results in a *joint* distribution $\gamma(i, j)$ over "inputs" $$\big(x_0^{(i)}\big)_{i=1,\dots,B}$$ and "outputs" $$\big(x_1^{(j)}\big)_{j=1,\dots,B}$$ such that the expected distance is (approximately) minimised. Finally, to construct a mini-batch from this $\gamma$ which we can subsequently use for training, we can either compute the expectation wrt. $\gamma(i, j)$ by considering all $n^2$ pairs (in practice, this can often boil down to only needing to consider $n$ disjoint pairs[^mini-batch-ot-deterministic-vs-stochastic]) or sample a new collection of training pairs $(x_0^{(i')}, x_1^{(j')})$ with $(i', j') \sim \gamma$[^mini-batch-ot-sampling-size]. + +For example, we can apply this to the \eqref{eq:g2g} example from before, which almost completely removes the crossing paths behaviour described earlier, as can be seen in the figure below. + +
+
+ +
+ +
+ +{% include image.html + ref="g2g-cond-paths-one-color--ot" + src="flow-matching/g2g-cond-paths-one-color.png" + width=400 +%} + +
+ +
+ +{% include image.html + ref="g2g-cond-paths-one-color-ot--ot" + src="flow-matching/g2g-cond-paths-one-color-ot.png" + width=400 +%} + +
+ +
+ +
+ +

+Figure 27: \eqref{eq:g2g} with uniformly sampled pairings (left) and with OT pairings (right). +

+ +
+ +
+
+ +We also observe similar behavior when applying this the more complex example \eqref{eq:mog2mog}, as can be seen in the figure below. + +
+
+ +
+ +
+ +{% include image.html + ref="g2g-cond-paths-one-color--ot" + src="flow-matching/mog2mog-cond-paths-one-color.png" + width=400 +%} + +
+ +
+ +{% include image.html + ref="g2g-cond-paths-one-color-ot--ot" + src="flow-matching/mog2mog-cond-paths-one-color-ot.png" + width=400 +%} + +
+ +
+ +
+ +

+Figure 28: \eqref{eq:mog2mog} with uniformly sampled pairings (left) and with OT pairings (right). +

+ +
+ +
+
+ +All in all, making use of mini-batch OT seems to be a strict improvement over the uniform sampling approach to constructing the mini-batch in the above examples and has been shown to improve practical performance in a wide range of applications (Tong et al., 2023; Klein et al., 2023). + +It's worth noting that in \eqref{eq:ot} we only considered choosing the coupling $\gamma(i, j)$ such that we minimize the expected squared Euclidean distance. This works well in the examples \eqref{eq:g2g} and \eqref{eq:mog2mog}, but we could also replace squared Euclidean distance with some other distance metric when constructing the coupling $\gamma(i, j)$. For example, if we were modeling molecules using CNFs, it might also make sense to pick $(i, j)$ such that $x_0^{(i)}$ and $x_1^{(j)}$ are also rotationally aligned as is done in the work of Klein et al. (2023). -## Quick Summary +# Quick Summary @@ -1556,7 +1723,7 @@ Similarly to CNFs, sampled can be obtained at inference time by solving the ODE -## Citation +# Citation Please cite us as: @@ -1567,51 +1734,53 @@ Please cite us as: journal = "https://mlg.eng.cam.ac.uk/blog/", year = "2024", month = "January", - url = "https://mlg.eng.cam.ac.uk/blog/TODO" + url = "https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html" } ``` # References -- Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt (2022). [Flow Matching for Generative Modeling](http://arxiv.org/abs/2210.02747). - - Albergo, Michael S. and Boffi, Nicholas M. and Vanden-Eijnden, Eric (2023). [Stochastic Interpolants: A Unifying Framework for Flows and Diffusions](http://arxiv.org/abs/2303.08797). -- Chen & Lipman (2023). [Riemannian Flow Matching on General Geometries](http://arxiv.org/abs/2302.03660v2). - -- De Bortoli, Mathieu & Hutchinson et al. (2022). [Riemannian Score-Based Generative Modelling](http://arxiv.org/abs/2202.02763v3). - -- Dupont, Doucet & Teh (2019). [Augmented Neural Odes](http://arxiv.org/abs/1904.01681v3). +- Behrmann, Jens and Grathwohl, Will and Chen, Ricky T. Q. and Duvenaud, David and Jacobsen, Joern-Henrik (2019). [Invertible Residual Networks](https://proceedings.mlr.press/v97/behrmann19a.html). -- Klein, Krämer & Noé (2023). [Equivariant Flow Matching](http://arxiv.org/abs/2306.15030v2). +- Betker, James, Gabriel Goh, Li Jing, TimBrooks, Jianfeng Wang, Linjie Li, LongOuyang, JuntangZhuang, JoyceLee, YufeiGuo, WesamManassra, PrafullaDhariwal, CaseyChu, YunxinJiao and Aditya Ramesh (2023). [Improving Image Generation with Better Captions](https://cdn.openai.com/papers/dall-e-3.pdf). -- Tong, Malkin & Huguet et al. (2023). [Improving and Generalizing Flow-Based Generative Models With Minibatch Optimal Transport](http://arxiv.org/abs/2302.00482v2). +- Chen & Gopinath (2000) Gaussianization, Advances in Neural Information Processing Systems. -- Pooladian, Aram-Alexandre and {Ben-Hamu}, Heli and {Domingo-Enrich}, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky T. Q. (2023). [Multisample Flow Matching: Straightening Flows With Minibatch Couplings](http://arxiv.org/abs/2304.14772). +- Chen & Lipman (2023). [Riemannian Flow Matching on General Geometries](http://arxiv.org/abs/2302.03660v2). -- Liu, Xingchao and Gong, Chengyue and Liu, Qiang (2022). [Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](http://arxiv.org/abs/2209.03003). +- Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David K and Jacobsen, Joern-Henrik (2019). [Residual flows for invertible generative modeling](http://arxiv.org/abs/1906.02735). -- Song, Sohl-Dickstein & Kingma et al. (2020). [Score-Based Generative Modeling Through Stochastic Differential Equations](http://arxiv.org/abs/2011.13456v2). +- De Bortoli, Mathieu & Hutchinson et al. (2022). [Riemannian Score-Based Generative Modelling](http://arxiv.org/abs/2202.02763v3). -- Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua (2023). [Simulation-Free Schrodinger Bridges via Score and Flow Matching](http://arxiv.org/abs/2307.03672). +- Dupont, Doucet & Teh (2019). [Augmented Neural Odes](http://arxiv.org/abs/1904.01681v3). -- Behrmann, Jens and Grathwohl, Will and Chen, Ricky T. Q. and Duvenaud, David and Jacobsen, Joern-Henrik (2019). [Invertible Residual Networks](https://proceedings.mlr.press/v97/behrmann19a.html). +- Friedman (1987) Exploratory projection pursuit, Journal of the American statistical association. -- Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David K and Jacobsen, Joern-Henrik (2019). [Residual flows for invertible generative modeling](http://arxiv.org/abs/1906.02735). +- George Papamakarios, Theo Pavlakou, Iain Murray (2018). [Masked Autoregressive Flow for Density Estimation](https://proceedings.neurips.cc/paper/2017/file/6c1da886822c67822bcf3679d04369fa-Paper.pdf). - Huang, Chin-Wei and Krueger, David and Lacoste, Alexandre and Courville, Aaron (2018). [Neural Autoregressive Flows](http://arxiv.org/abs/1804.00779). -- George Papamakarios, Theo Pavlakou, Iain Murray (2018). [Masked Autoregressive Flow for Density Estimation](https://proceedings.neurips.cc/paper/2017/file/6c1da886822c67822bcf3679d04369fa-Paper.pdf). +- Klein, Krämer & Noé (2023). [Equivariant Flow Matching](http://arxiv.org/abs/2306.15030v2). -- Watson, Joseph L. and Juergens, David and Bennett, Nathaniel R. and Trippe, Brian L. and Yim, Jason and Eisenach, Helen E. and Ahern, Woody and Borst, Andrew J. and Ragotte, Robert J. and Milles, Lukas F. and Wicky, Basile I. M. and Hanikel, Nikita and Pellock, Samuel J. and Courbet, Alexis and Sheffler, William and Wang, Jue and Venkatesh, Preetham and Sappington, Isaac and Torres, Susana V{\'a}zquez and Lauko, Anna and De Bortoli, Valentin and Mathieu, Emile and Ovchinnikov, Sergey and Barzilay, Regina and Jaakkola, Tommi S. and DiMaio, Frank and Baek, Minkyung and Baker, David (2023). [De Novo Design of Protein Structure and Function with RFdiffusion](https://www.nature.com/articles/s41586-023-06415-8). +- Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt (2022). [Flow Matching for Generative Modeling](http://arxiv.org/abs/2210.02747). + +- Liu, Xingchao and Gong, Chengyue and Liu, Qiang (2022). [Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](http://arxiv.org/abs/2209.03003). - Monge, Gaspard (1781). Mémoire Sur La Théorie Des Déblais et Des Remblais. - Peyré, Gabriel and Cuturi, Marco (2020). [Computational Optimal Transport](http://arxiv.org/abs/1803.00567). -- Betker, James, Gabriel Goh, Li Jing, TimBrooks, Jianfeng Wang, Linjie Li, LongOuyang, JuntangZhuang, JoyceLee, YufeiGuo, WesamManassra, PrafullaDhariwal, CaseyChu, YunxinJiao and Aditya Ramesh (2023). [Improving Image Generation with Better Captions](https://cdn.openai.com/papers/dall-e-3.pdf). +- Pooladian, Aram-Alexandre and {Ben-Hamu}, Heli and {Domingo-Enrich}, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky T. Q. (2023). [Multisample Flow Matching: Straightening Flows With Minibatch Couplings](http://arxiv.org/abs/2304.14772). +- Song, Sohl-Dickstein & Kingma et al. (2020). [Score-Based Generative Modeling Through Stochastic Differential Equations](http://arxiv.org/abs/2011.13456v2). +- Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua (2023). [Simulation-Free Schrodinger Bridges via Score and Flow Matching](http://arxiv.org/abs/2307.03672). + +- Tong, Malkin & Huguet et al. (2023). [Improving and Generalizing Flow-Based Generative Models With Minibatch Optimal Transport](http://arxiv.org/abs/2302.00482v2). + +- Watson, Joseph L. and Juergens, David and Bennett, Nathaniel R. and Trippe, Brian L. and Yim, Jason and Eisenach, Helen E. and Ahern, Woody and Borst, Andrew J. and Ragotte, Robert J. and Milles, Lukas F. and Wicky, Basile I. M. and Hanikel, Nikita and Pellock, Samuel J. and Courbet, Alexis and Sheffler, William and Wang, Jue and Venkatesh, Preetham and Sappington, Isaac and Torres, Susana V{\'a}zquez and Lauko, Anna and De Bortoli, Valentin and Mathieu, Emile and Ovchinnikov, Sergey and Barzilay, Regina and Jaakkola, Tommi S. and DiMaio, Frank and Baek, Minkyung and Baker, David (2023). [De Novo Design of Protein Structure and Function with RFdiffusion](https://www.nature.com/articles/s41586-023-06415-8). [^chainrule]: The property $\phi \circ \phi^{-1} = \Id$ implies, by the chain rule, @@ -1636,7 +1805,7 @@ Please cite us as: [^residual_flow]: A sufficient condition for $\phi_k$ to be invertible is for $u_k$ to be $1/h$-Lipschitz [Behrmann et al., 2019]. The inverse $\phi_k^{-1}$ can be approximated via fixed-point iteration (Chen et al., 2019). -[^log_pdf]: Expanding the divergence in the _continuity equation_ we have: +[^log_pdf]: Expanding the divergence in the _transport equation_ we have: $$ \begin{equation} \frac{\partial}{\partial_t} p_t(x_t) @@ -1678,7 +1847,7 @@ Please cite us as: [^ODE_conditions]: A sufficient condition for $\phi_t$ to be invertible is for $u_t$ to be Lipschitz and continuous by Picard–Lindelöf theorem. -[^FPE]: The _Fokker–Planck equation_ gives the time evolution of the density induced by a stochastic process. For ODEs where the diffusion term is zero, one recovers the continuity equation. +[^FPE]: The _Fokker–Planck equation_ gives the time evolution of the density induced by a stochastic process. For ODEs where the diffusion term is zero, one recovers the transport equation. [^hutchinson]: The Skilling-Hutchinson trace estimator is given by $\Tr(A) = \E[v^\top A v]$ with $v \sim p$ isotropic and centred. In our setting we are interested in $\div(u_t)(x) = \Tr(\frac{\partial u_t(x)}{\partial x}) = \E[v^\top \frac{\partial u_t(x)}{\partial x} v]$ which can be approximated with a Monte-Carlo estimator, where the integrand is computed via automatic forward or backward differentiation. @@ -1690,3 +1859,9 @@ Please cite us as: [^1]: There's also the difference that in (standard) score-based diffusion models, we don't have "exact endpoints" in the sense that our $p_0$ is actually the reference we use during inference. Instead, we "just hope" that the chosen integration time $T$ is sufficiently large so that $p_0 \approx q_0$. [^2]: We can of course just compute $\nabla \log p_t(x)$ of the $p_t$ induced by $u_t$, but this will generally be ridiculoulsly expensive. + +[^interpolation]: The top row is with reference $p_0 = \mathcal{N}([-a, 0], I)$ and target $p_1 = (1/2) \mathcal{N}([a, -10], I) + (1 / 2) \mathcal{N}([a, 10], I)$, and the bottom row is the \ref{eq:g2g} example. The left column shows the straight-line solutions for the *marginals* and the right column shows the marginal solutions induced by considering the straight-line *conditional* interpolants. + +[^mini-batch-ot-sampling-size]: Note the size of the resulting mini-batch sampled from $\gamma(i, j)$ does not necessarily have to be of the same size as the mini-batch size used to construct the mini-batch OT approximation as we can sample from $\gamma$ with replacement, but using the same size is typically done in practice, e.g. Tong et al. (2023). + +[^mini-batch-ot-deterministic-vs-stochastic]: In mini-batch OT, we only work with the empirical distributions over $x_0^{(i)}$ and $x_1^{(j)}$, i.e. they all have weights $1 / n$, where $n$ is the size of the mini-batch. This means that we can find a $\gamma$ matching the $\inf$ in \eqref{eq:ot} by solving what's referred to as a [linear assignment problem](https://en.wikipedia.org/wiki/Assignment_problem). This results in a sparse matrix with exactly $n$ entries, each then with a weight of $1 / n$. In such a scenario, computing the expectation over the joint $\gamma(i, j)$, which has $n^2$ entries but in this case only $n$ non-zero entries, can be done by only considering $n$ training pairs where every $i$ is involved in exactly one pair and similarly for every $j$. This is usally what's done in practice. When solving the assignment problem is too computationally intensive, using Sinkhorn and a sampling from the coupling might be the preferable approach. diff --git a/assets/images/flow-matching/flow-matching-diagram-2.png b/assets/images/flow-matching/flow-matching-diagram-2.png new file mode 100644 index 0000000..ba97f2e Binary files /dev/null and b/assets/images/flow-matching/flow-matching-diagram-2.png differ diff --git a/assets/images/flow-matching/flow-matching-diagram.png b/assets/images/flow-matching/flow-matching-diagram.png new file mode 100644 index 0000000..3d42703 Binary files /dev/null and b/assets/images/flow-matching/flow-matching-diagram.png differ diff --git a/assets/images/flow-matching/g2g-cond-paths-one-color-ot.png b/assets/images/flow-matching/g2g-cond-paths-one-color-ot.png new file mode 100644 index 0000000..24a15d8 Binary files /dev/null and b/assets/images/flow-matching/g2g-cond-paths-one-color-ot.png differ diff --git a/assets/images/flow-matching/g2g-cond-paths-one-color.png b/assets/images/flow-matching/g2g-cond-paths-one-color.png new file mode 100644 index 0000000..8f217fd Binary files /dev/null and b/assets/images/flow-matching/g2g-cond-paths-one-color.png differ diff --git a/assets/images/flow-matching/mog2mog-cond-paths-one-color-ot.png b/assets/images/flow-matching/mog2mog-cond-paths-one-color-ot.png new file mode 100644 index 0000000..72395bd Binary files /dev/null and b/assets/images/flow-matching/mog2mog-cond-paths-one-color-ot.png differ diff --git a/assets/images/flow-matching/mog2mog-cond-paths-one-color.png b/assets/images/flow-matching/mog2mog-cond-paths-one-color.png new file mode 100644 index 0000000..36587c6 Binary files /dev/null and b/assets/images/flow-matching/mog2mog-cond-paths-one-color.png differ