Mean Flow
References
Main Idea
为了解决源分布 \(p_1\)(往往是标准高斯分布)到目标分布 \(p_0\)(往往是目标数据分布)的转移问题,flow matching 旨在建模一个瞬时速度场 \(v(z_t, t)\),使得 \(z_1\sim p_1\) 通过解如下 ODE 可以得到 \(z_0\sim p_0\):
为了能够单步生成,Mean Flow 的基本思想很简单:建模平均速度 \(u(z_t, r, t)\),使得有
这样就可以通过 \(u(z_1, 0, 1)\) 从 \(z_1\) 得到 \(z_0\),注意到实际上就有
MeanFlow Properties
Consistency
Consistency Model 在概率流 ODE 的轨迹解 \(\{x_t\}_{t\in[0, 1]}\) 上定义了一致性函数 (consistency function) \(f(x_t, t)\),要求其具有自一致性 (self-consistency) 的条件:
Consistency Model 中所建模的 \(f_{\theta}(x_t, t)\) 要求有 \(f_{\theta}(x_t, t)\equiv x_0\)。
我们会发现,建模的平均速度 \(u(z_t, r, t)\) 也具有很相似的一致性性质,即
这个一致性等式是通过积分可加性,在建模时就定义好的,实际上并不作为学习的目标;但如果学得比较好的话,应当是要满足这个一致性等式的。
MeanFlow Identity
接下来的推导将是 Mean Flow 训练的关键。
对于上式,考虑 \(r\) 与 \(t\) 无关,则两边对 \(t\) 求导有
如果 \(r\) 与 \(t\) 有关,则右边会多出一个 \(-v(z_r, r)\frac{\mathrm{d}r}{\mathrm{d}t}\) 的项。
由此我们就得到了用于获得 \(u(z_t, r, t)\) 训练时监督信号的 MeanFlow Identity:
Algorithm
Training
我们对 \(u(z_t, r, t)\) 进行建模,利用 MeanFlow Identity 获得监督信号。注意到现有 Flow Matching 的实践算法一般会预先定义好轨迹簇,例如 Rectified Flow 直接使用直线轨迹,那么 \(v(z_t, t)\) 在给定 \(r, t\) 以及 \(z_t\) 的采样时是已知的。所以关键问题就是如何给出 \(\mathrm{d}u(z_t, r, t)/\mathrm{d}t\) 的监督信号,我们直接有
依然考虑与 \(t\) 无关的 \(r\),且注意到其实有 \(\mathrm{d}z_t/\mathrm{d}t = v(z_t, t)\),那么就有
实践中,可以直接用 Jacobian-vector product(JVP,在 torch 和 jax 中已经有了成熟的实现)来计算这个全导数,即
即计算 Jacobian matrix 和 vector 的向量内积
但是实现细节上,JVP 的结果同样是可以被梯度追踪的,如果同时允许其梯度反向传播有可能会导致显存爆炸、梯度混乱等问题,因此 Mean Flow 利用了 Stop Gradient 算子 \(\operatorname{sg}\) 阻止梯度通过 JVP 反向传播:
Sampling
前面已经提到,单步生成可以通过
实现,如果希望利用多步提高生存质量,也可以
在给定的 scheduling 下进行多步迭代生成。