평범한 필기장

[평범한 학부생이 하는 논문 리뷰] Variational Rectified Flow Matching 본문

AI/Generative Models

[평범한 학부생이 하는 논문 리뷰] Variational Rectified Flow Matching

junseok-rh 2025. 5. 19. 23:04

Paper : https://arxiv.org/abs/2502.09616

 

Variational Rectified Flow Matching

We study Variational Rectified Flow Matching, a framework that enhances classic rectified flow matching by modeling multi-modal velocity vector-fields. At inference time, classic rectified flow matching 'moves' samples from a source distribution to the tar

arxiv.org

Abstract

 본 논문은 multi-modal velocity vector-fields를 모델링함으로써 classic rectified flow matching을 강화한 framework인 Variational Rectified Flow Matching을 연구한다.

1. Introduction

 Rectified flow에서 velocity vector-field를 학습할 때, source와 target distribution에서 sampling해서 임의의 pair를 만들어 사용한다. 이는 Figure 1(a)에서와 같이 straight line으로 connect된다. 이는 결국 multi-modality/ambiguity를 야기한다. Classic rectified flow matching은 predicted velocity vector-field와 constructed velocity vector-field를 비교하는 standard squared-norm loss를 사용하기 때문에, 이러한 multi-modality를 capture하지 못한다. Figure 1(b)에서처럼 source와 target 분포를 대안적인 방식으로 matching한다.

 이 문제를 해결하기 위해서, 본 논문에서는 variational rectified flow matching을 연구한다. 직관적으로 variational rectified flow matching은 latent variable을 도입하는데 이는 data-domain-time-domain에서 각 위치에서 multi-modal/ambiguity flow direction을 disentangle하는 것을 허용한다. Figure 1(c)에서처럼, variational rectified flow matching은 intersect하는 flow trajectory를 모델링할 수 있다.

2. Variational Rectified Flow Matching

 본 논문의 목표는 "ground-truth" velocity vector-fields에 내재된 multi-modality를 capture하는 것이다. Multi-modality를 capture하는 것의 어려움은 더 곡선이고 그래서 inference time에서 integrate하기 어려운 velocity vector field를 야기한다. 결국 data와 잘 fit하지 않은 distribution을 야기한다.

 본 논문은 이를 위해서 rectified flow matching과 variational auto-encoder를 결합한다.

2.1 Objective

 Probability density $p_0, p_1$와 velocity vector-field $v_\theta$는 transport problem을 통해 연관된다.

혹은 integral form으로 연관된다.

Equation (2)의 partial differential equation을 푸는 것은 어렵다.

 그러나, probability density function을 Gaussian으로 가정하고 velocity vector-field를 constant로 가정하면, analytic solution을 얻을 수 있다.

두 개의 Gaussian probability density function 대신에, 본 논문은 주어진 data point $x_0,x_1$에 center된 Gaussian으로 구성된 source와 target data에 대한 real probability density function을 가정한다. 게다가, data-domain-time-domain location $(x_t,t)$에서 velocity vector field $v_\theta(x_t,t)$를 다음과 같은 uni-modal standard Gaussian으로 구체화한다.

Empirical "velocity data"의 log-likelihood를 최대화하는 것은 다음 objective와 동일하다.

이 objective는 classic rectified flow matching과 동일하다.

 이 유도는 key point를 강조한다 : vector field는 각 data-domain-time-domain location에서 Gaussian으로 parameterize되고 이 Gaussian은 uni-modal이기 때문에, multi-modality는 capture될 수 없다. 이런 이유로, classic rectified flow matching은 "ground-truth" velocity들을 평균낸다.

 이는 sub-optimal할 수 있다. Multi-modality를 capture하기 위해서, 본 논문은 각 data-domain-time-domain location에서 velocity들에 대한 mixture model의 사용을 연구한다. 이를 위해, 본 논문은 unobserved continuous random variable $z$가 velocity vector field의 conditional distribution의 mean을 결정한다고 가정한다.

이 model은 $p(v|x_t, t) = \int p(v|x_t, t,z)p(z)dz$가 Gaussian mixture이기 때문에 multi-modality를 capture한다.

 Random variable $z$는 관찰되지 않기 때문에, training time에 encoder라고 불리는 recognition model $q_\phi(z|x_0,x_1,x_t,t)$를 도입한다. 이는 $\phi$로 parameterize되고 intractable true posterior를 approximate한다.

 이 셋업을 이용해서, 개별 data point의 marginal likelihood는 다음과 같이 lower-bound될 수 있다.

Equation (3)에서 Gaussian의 log-probability를 equation (4)에서 주어진 lower bound로 대체하는 것은 variational rectified flow matching objective를 이끈다.

2.2 Training

 본 논문은 prior를 $p(z) = \mathcal{N}(z;0,I)$로 두고 approximate posterior를 $q_\phi(z|x_0,x_1,x_t,t) = \mathcal{N}(z;\mu_\phi(x_0,x_1,x_t,t),\sigma_\phi(x_0,x_1,x_t,t))$로 둔다. 이는 equation (5)에서 KL-divergence의 analytic computation을 가능하게 한다. 또한 objective의 optimization을 가능하게 하기 위해서 re-parameterization trick을 사용한다.

2.3 Inference

3. Experiments

 본 논문의 실험은 variational rectified flow matching이 data-domain-time-domain에서 multi-modal velocity를 capture할 수 있다는 것을 보인다. 게다가, 본 논문은 conditional latent $z$를 통해서 명시적으로 multi-modality를 모델링하는 것은 flow matching model의 interpretability를 강화한다. 이는 controllability를 이끈다.

3.1 Synthetic 1D data

Source 분포는 평균이 0이가 분산이 1인 Gaussian이고, target 분포는 -1과 1이 중심인 bimodal이다.

3.2 Synthetic 2D data

3.3 MNIST

 추론 동안에 unit square에 대해서 linearly spaced coordinate를 샘플링하고, 이를 Gaussian의 inverse CDF를 통해 변환해서 latents $z$를 얻는다. 이러한 latents를 이용해서 본 논문은 ODE solver로 샘플링한다. Source 분포 sample $x_0$와 latents $z$의 효과를 보이기 위해서, 본 논문은 random하게 샘플링된 두개의 $x_0$에 대한 학습된 MNIST manifold를 시각화한다. 그 결과는 latent space $z$는 2D manifold내에서 다른 digits사이의 부드러운 interpolation을 보인다.

3.4 CIFAR-10

 본 논문은 conditioning mechanism으로 두 개의 효과적인 방식을 확인했다.

  • Adaptive normalization : shift와 offset 파라미터를 계산하기 전에 time embedding에 $z$가 더해짐
  • Bottleneck sum : upsampling전에 weighted sum을 사용해서 가장 낮은 해상도에서 latent와 중간 activations를 혼합함

본 논문은 fusion mechanism과 posterior model $q_\phi$의 input, KL loss weighting에 대한 4가지 model variation 실험을 진행했다.

 MNIST의 결과와 유사하게, 본 논문은 생성된 샘플 $x_1$에 대한 색과 content에서의 명백한 패턴을 관찰했다. 이는 controllability의 정도를 증명한다.

동일한 latent에 대해서 condition된 image는 동일한 color pattern을 보이고 동일한 $x_0$에 대한 이미지는 유사한 content를 보인다.

3.5 ImageNet