평범한 필기장

[평범한 청강생의 논문 맛보기] Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance (PAG) 본문

Experience/DAVIAN Lab Computer Vision Study

[평범한 청강생의 논문 맛보기] Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance (PAG)

junseok-rh 2024. 5. 10. 00:51

https://arxiv.org/abs/2403.17377

 

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance

Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques

arxiv.org

1. Introduction

 Diffusion Model들은 CG(Classifier Guidance)나 CFG(Classifier Free Guidance)를 이용해서 high-fidelity이고 다양한 샘플들을 생성했다. 여기서 CG와 CFG는 더 높은 퀄리티의 이미지들을 생성하도록 하는데에 중요한 역할을 했다. 하지만 이러한 guidance 기법들은 다음과 같은 단점들을 지닌다.

  • 추가적인 training이나 외부 module의 통합을 필요로 한다.
  • Output sample들의 다양성을 해친다.
  • Unconditional generation에서 사용할 수 없다.

Unconditional Generation은 다음과 같은 중요한 실용적인 이점을 제공한다.

  • 데이터 생성의 근본적인 원칙과 내재된 구조를 이해하는 것을 돕는다.
  • Unconditional 기술들의 진보는 conditional generation을 강화한다.
  • Class label, text와 같은 human annotation에 대한 필요를 없앤다.
  • 강련한 general priors를 제공한다.

 본 논문에서는 추가적인 훈련과 외부 module의 통합 없이 uncondition과 condition에서 둘 다 diffusion sample quality를 향상시키는 Perturbed-Attention Guidance (PAG)라는 새로운 sampling guidance를 제안한다. 실험을 통해 본 논문의 guidance는 다양한 downstream task에서 diffusion model의 성능이 향상되는 것을 보였다.

2. PAG : Perturbed-Attention Guidance

2.1 Self-rectifying sampling with implicit discriminator

 본 논문에서는 diffusion process동안 실제 데이터 분포를 따르는 desirable sample과 undesirable sample을 구별짓는 implicit discriminator $\mathcal{D}$를 도입한다. 이 implicit disciminator $\mathcal{D}$는 sample들이 desirable 분포로 향하고 undesirable 분포로는 멀리 하도록 guide한다. Implicit discriminator는 다음과 같이 정의된다.

여기서 $y$와 $\hat{y}$는 각각 desirable dist.와 undesirable dist.의 가상 label이다.

 WGAN과 유사하게 본 논문은 implicit discriminator의 generator loss를 다음과 같이 셋팅하고 다음과 같이 나타낼 수 있다.

위 식을 이용해서 다음과 같은 새로운 diffusion sampling을 정의한다.

 Diffusion models는 이미 desired 분포를 이미 학습했기 때문에, 본 논문은 $-\sigma \Delta_{x_t}{\rm log}p(x_t|y)$의 근사로 pretrain된 score estimation network $\epsilon_\theta(x_t)$를 사용한다. Undesirable label $\hat{y}$에 대한 score를 위해, 본 논문은 pretrained network의 forward pass를 perturbing시킴으로써 근사하고 이를 $\hat{\epsilon}_\theta(x_t)$라고 쓴다. $\hat{\epsilon}_\theta(x)$는 input이나 내부의 representation에 적용된 perturbation을 포함해서 epsilon 예측 프로세스동안 어떤 형태의 perturbation으로 구현될 수 있다.

Connections to CFG.

 식 10은 CFG와 닮았는데, 실제로 CFG는 본 논문의 수식의 특정한 사례로 고려될 수 있다. 수식 10은 다음과 같이 class-conditional diffusion model로 정의될 수 있다.

CFG에서 $\hat{\epsilon}_\theta(x_t,c)$는 class label을 없앤 $\epsilon_\theta(x_t, \emptyset)$로 나타낼 수 있고 이를 본 논문에서는 perturbed forward pass로 정의한다. 본 논문에서 unconditional diffusion models에도 적용할 수 있는 perturbed forward pass의 개념을 확장한다.

2.2 Perturbing self-attention of U-Net diffusion model

 Input 이미지나 condition을 직접 perturbing하는 것은 out-of-distribution 문제를 야기하고, diffusion model이 부정확한 guidance signal들을 생성하도록 하고 diffusion sampling을 잘못된 방향으로 조종할 수 있다. 이전의 연구에서 attention map을 수정하는 것이 모델의 그럴듯한 이미지 생성하는 능력에 최소한의 영향을 준다는 것을 밝혔다. 그래서 본 논문에서는 perturbing 전략을 세우기 위해, conditional과 unconditional 모델 모두에 적용할 수 있도록 denoising U-Net의 self-attention 메커니즘에 집중한다. 또 다른 기준으로는 사진 2에서 볼 수 있듯이 guidance 없이 생성된 이미지는 collapsed structure를 종종 보여준다. 이를 해결하기 위해 desired guidance는 collapsed structure를 보이는 sample과 먼 denoising 궤적으로 조종해야 한다. 이는 CFG에서 null prompt가 class conditioning을 강화하기 위해 사용된 것과 유사하다. 최근의 연구들에서 attention map은 patch들 사이에 구조적인 정보와 의미적 대응을 포함한다는 것을 보였다. 그러므로 self-attention map을 perturbing하는 것은 collapsed structure를 가진 sample을 생성할 수 있다.

Perturbed self-attention

 최근 연구들은 diffusion U-Net의 self-attention module이 structure에 대한 query-key 유사성과 appearance에 대한 value라는 두 개의 다른 역할을 하는 path를 가진다고 한다. 이 모듈의 output은 다음과 같이 정의 된다.

여기서 structure 부분은 self-attention map이라고 흔히 불린다.

 이를 이용해 본 논문은 original sample로 부터 과한 탈선을 최소화하기 위해 self-attention map만 perturbing하는 것에 집중한다. 이 관점은 신경망 input에 대한 out-of-distribution 문제를 다루는 관점으로 이해될 수 있다. 직접 appearance component $V_t$을 perturbing하는 것은 그 후의 MLP가 이전에 본 적 없는 input을 직면하는 문제를 야기한다. 결국 이로 인해 뒤틀린 sample을 야기한다.

 그러나 value를 유지하고 self-attention map을 identity matrix로 사용하는 것은 $V_t$를 직접 perturbing하는 것보다 domain을 더 유지하는 것으로 보인다. 그러므로 본 논문에서는 structural component $A_t = Softmax(Q_t K_t^T/ \sqrt{d}) \in \mathbb{R}^{hw \times hw}$만 perturb한다. 그렇게 함으로써 appearance 정보는 유지하면서 structural 정보는 제거한다. 이 간단한 접근법은  다음과 같이 정의된다.

본 논문에서 이를 perturbed self-attention (PSA)라고 부른다.

 SA와 PSA를 사용해서 본 논문에서 $\epsilon_\theta(x_t)$와 $\hat{\epsilon}_\theta(x_t)$를 각각 구현한다. Input image $x_t$는 $\epsilon_\theta(\cdot)$와 $\hat{\epsilon}_\theta(\cdot)$에 들어가고 두 output은 선형결합돼서 $\tilde{\epsilon}_\theta(x_t)$를 얻게 된다. 아래는 본 논문의 method의 전체적인 파이프라인과 슈도코드이다.

2.3 Analysis on PAG

 위 이미지에서 본 논문에서 제안한 guidance term이 semantic cues를 어떻게 제공하는지를 보여준다. 위 이미지에서 빨간 사각형 부분을 보면 perturbed prediction ($\hat{\epsilon}_\theta(x_t)$)이 눈,코,혀와 같은 두드러진 특징을 놓치고 있다는 것이 명백하다. 이러한 생략은 perturbed perdiction이 global structure에 대한 이해가 부족하기 때문이다. 결국 $\hat{\Delta}_t$가 이러한 두드러진 포인트들에 집중하는 경향이 있다. 이 $\hat{\Delta}_t$를 original prediction $\epsilon_\theta(x_t)$에 추가함으로써 sample의 structure가 강화된다. Timestep이 0으로 갈수록 $\hat{\Delta}_t$가 점진적으로 더 세세한 디테일을 capture한다는 것도 주목해야한다. 이는 PAG가 효과적으로 모든 timestep에 따라 coarse 구조에서 세세한 디테일로 진화하면서 well-defined shape을 향해 샘플들을 가이드한다는 것을 제안한다.

3. Experiments

3.1 Experimental and Implementation Details

 본 논문은 pretrained model ADM, Stable Diffusion을 base로 한다. 실험과 구현 디테일들은 appendix에서 다룬다.

3.2 Pixel-Level Diffusion Models

 위 실험 결과를 통해 PAG를 사용한 ADM이 큰 폭으로 IS와 FID에서 좋은 성능을 보였다. Improved Recall과 Precision에서의 대조적인 패턴은 fidelity와 diversity사이의 trade-off에서 기인한다.

 이러한 trade-off에도 불구하고 사진 5를 보면 realistic하게 보이고 의미적으로 그럴듯한 구조를 보인다. 이는 PAG의 perturbed self-attention을 활용한 diffusion sampling path를 바로잡는 능력을 보여준다.  사진 6에서는 PAG와 SAG사이의 샘플링 결과를 비교한다.

3.3 Latent Diffusion Models

Unconditional generation on Stable Diffusion

 PAG를 사용한 결과 IS와 FID가 더 좋았다.

Text-to-Image synthesis on Stable Diffusion

Text-to-Image task에서는 CFG도 쓰일 수 있기에 표 2에서와 같이 4가지 경우에 대해 실험을 진행한다. 결과에서 볼 수 있듯이 PAG와 CFG를 적절히 결합했을 때 FID에서 상당한 향상이 일어났다.

위 이미지 결과를 보면 CFG와 PAG를 같이 썼을 때의 결과가 좋은 것을 볼 수 있다. 본 논문에서는 CFG의 text-image alignment에 대한 능력과 PAG의 구조적인 정보의 강화의 시너지가 이 둘을 합쳤을 때 더 시각적으로 좋은 이미지를 낸다고 말한다.

본 논문에서는 per-prompt diversity를 "프롬프트가 주어졌을 때 얼마나 다양한 이미지를 생성하는 능력"이라고 정의한다. 그래서 하나의 프롬프트로 얼마나 다양한 이미지를 생성하는지 실험한다. 위 표를 통해 그 결과를 IS와 LPIPS로 비교한다.

3.4 Downstream Tasks

Inverse problems

 Inverse problem은 unconditional generation에서 major task들 중 하나이다. 이는 noisy measurement $y = \mathcal{A}(x) + n$로 부터 $x$를 복원하는 것이 목적이다. 본 task에서는 text prompt가 없어야하기 때문에 본 논문에서의 방식은 적절히 작동하지만 다른 방식은 사용할 수 없다.

위 결과는 PSLD모델에 PAG를 썼을 때의 결과를 비교한 것으로 FID는 모든 task에 대해, LPIPS는 하나 제외하고 모든 task에 대해 더 좋은 결과를 보인다.

ControlNet

 ControlNet은 종종 pose condition과 같은 공간 control signal이 희박한 경우 특히 unconditional generation scenarios 아래에서 높은 퀄리티의 샘플을 생성하는데 애먹는다. PAG는 이러한 예시에서 sample quality를 강화한다.

3.5 Ablation Studies

Self-attention perturbation strategy

Guidance scale

 Guidance scale이 1일 때 가장 좋은 FID를 얻었고 2일 때 가장 좋은 IS를 얻었다.

Computational cost

 PAG는 CFG와 마찬가지로 Diffusion U-Net의 input을 복제함으로써 두개의 denoising pass들을 병렬화할 수 있기 때문에 computational cost는 CFG와 거의 동일하다.

4 Conclusion

 PAG는 추가 학습이나 외부 module들 없이 conditional과 unconditional 모두에서 우월한 샘플 퀄리티를 달성했다. 본 논문의 연구는 diffusion model을 text prompt와 CFG에 대한 의존에서 해방시키면서 unconditional diffusion models의 적용 가능성을 비춘다.

 

Appendix

 Appendix에서 정리할만한 부분들만 따로 정리하려고 한다.

Combination of CFG and PAG

 Text-to-Image 합성에서 CFG와 PAG를 같이 적용시키기위해 다음과 같은 방정식을 사용해서 $\tilde{\epsilon}_\theta(x_t,c)$를 생성한다.

여기서 $w$와 $s$는 guidance scale이다. 이러한 estimation은 CFG와 PAG의 델타를 더한 것을 포함한다. 이를 얻기 위해 본 논문에서는 denoising U-Net에서 동시에 $\epsilon_\theta(x_t,c), \epsilon_\theta(x_t,\emptyset), \hat{\epsilon}_\theta(x_t,c)$를 계산한다.

CFG with separately trained models

 CFG 논문에서 unconditional model과 conditional model을 jointly하게 학습한다. 그러고 각각 모델을 따로 학습시킴으로써 구현될 수 있다고도 말한다. 하지만 실제로는 그렇지 않다. 본 논문에서 그렇게 해서 낸 실험 결과는 다음과 같았다.

위 결과는 CFG가 단순히 diversity와 교환해서 이미지 퀄리티를 강화하는 것이 아닌 어떤 다른 key factor에 의해 작동한다는 것을 암시한다. 이 비결은 아마도 사진 2에서 분석한 것처럼 원래의 conditional prediction으로부터 중요한 feature들을 놓친 샘플들을 예측하고 나서 그러한 중요한 feature들을 강화하기 위해 차이를 추가한다는 것이다. CFG는 jointly하게 unconditional model을 학습하는 추가 비용으로 feature들을 놓치는 perturbed 경로를 생성하는 반면, PAG에 PSA는 추가 training이나 외부 모듈 없이 가능하다.'

Limitation

  • 높은 guidance weight에서 결과는 over-saturation을 보인다.
  • Computational overhead
  • High Resource