일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- plug-and-play
- image2image translation
- Programmers
- magdiff
- ddim inversion
- image editing
- 코딩테스트
- 논문리뷰
- DP
- prompt2prompt
- 3d generation
- style align
- ddpm inversion
- 3d editing
- transformer
- video editing
- 프로그래머스
- Vit
- diffusion models
- ami lab
- diffusion
- image generation
- BOJ
- 네이버 부스트캠프 ai tech 6기
- 코테
- video generation
- research intern
- VirtualTryON
- Python
- visiontransformer
- Today
- Total
평범한 필기장
[평범한 학부생이 하는 논문 리뷰] Generative Adversarial Nets (GAN) 본문
3-2학기 학부연구생을 하면서 처음으로 읽게 된 논문이 바로 이 Generative Adversarial Nets라는 논문인데, 딥러닝 기초만 조금 봐본 제가 공부하면서 정리하는 느낌으로 하는 리뷰이니 틀린 부분, 잘못 이해한 부분이 많을 수 있다는 점 말씀드리고 리뷰 시작해보도록 하겠습니다!
0. Abstract
일단 이 논문의 abstract에서는 저자들이 두 개의 모델을 적대적인 process로 학습시켜서 생성 모델을 평가하는 프레임 워크를 목적으로 했다고 밝혔다. 여기서 두 개의 모델은 생성 모델인 G와 판별 모델인 D로 설명을 했다.
- generative model G : 데이터의 분포를 모사해 D가 실수할 확률을 최대화하는 방향으로 학습
- discriminative model D : sample data가 G로 부터 생성된 것이 아니라 실제 데이터로 왔을 확률을 추정
- 이 프레임 워크는 minimax two-player game에 상응한다.
- 임의의 G와 D 함수 공간에서, G가 training data 분포를 찾아내 D가 어디서든 1/2로 동일한 유일한 해가 존재한다.
- G와 D가 MLP(multi-layer perceptron)로 정의되면, 전체 system은 역전파를 통해 학습될 수 있다.
- Markov Chain과 unrolled approximate inference가 학습과 생성동안 필요하지 않다.
Abstract에 나온 내용이 처음 읽을 때는 이해가 전혀 가지 않았었는데, 뒤에 설명을 읽고나니 이게 무슨 말을 하는건지 이해가 됐어요. 이 블로그를 읽고 있으시다면 이 부분에서 이해가 되지 않는다고 걱정할 필요가 없이 뒷 부분을 읽으면 자연스럽게 이해가 될거예요. 그러니 너무 걱정마세요^^
1. Introduction
첫 문단에서 뭐 natural image, audio waveforms containing speech이런 얘기를 하면서 어렵게 써놓은 것 같지만 deep learning에서 지금까지 두드러진 성공은 discriminative model을 포함해 왔다는 것을 얘기하고 싶어하는 것 같아요. 그리고 이 성공은 역전파와 drop out 알고리즘을 기초로 했다고 하네요. 그런데 generative model 은 여러 어려움 (다루기 어려운 확률적 계산을 근사하는 데에서 오는 어려움, linear unit의 이점을 활용하는 것에 대한 어려움) 으로 인해 적은 impact를 가져왔다고 합니다. 그래서 이 논문에서는 generative model이 이러한 어려움을 피하게 하려 하네요.
두번 째 문단에서는 gan을 공부할 때 제일 많이 듣는 비유인데 바로 화폐 위조자와 경찰입니다. G는 화폐 위조자이고 D는 경찰이라고 하면 화폐 위조자는 경찰을 속이기 위해 더 고도의 위조 화폐를 만들게 되고, 경찰은 속지 않기 위해 더 정교하게 화폐를 판별할 것입니다. 그래서 서로 속이기 위해, 속지 않기 위해 더 발전한다는 내용인데, gan이란 모델에서도 이러한 방향으로 G와 D가 서로 발전, 향상해 나간다는 얘기를 써놓았네요.
세번 째 문단에서는 이 프레임 워크가 많은 종류의 모델을 위해 특정한 학습 알고리즘들과 optimization 알고리즘을 양산할 수 있다. 그리고 abstract에 말한 것 처럼 MLP로 G와 D를 정의하면 다른 방식 없이 역전파와 dropout 알고리즘만으로 두 모델을 학습시킬 수 있다라고 설명하고 있습니다.
3. Adversarial nets
GAN의 목적 함수 :
- 이 목적 함수를 보면 D는 이 목적함수를 최대화하려 하고, G는 이 목적함수를 최소화 하려한다.
- D(x)에서 input인 x가 실제 데이터와 가까우면 output으로 1에 가까운 값이 나오고 가짜와 가까우면 output이 0에 가깝게 되는데 목적함수의 왼쪽 부분을 보면 실제 데이터로 부터 왔기 때문에 logD(x)의 기댓값은 0에 가까워 져야하고 오른쪽 부분에서는 G(z)는 생성모델이 생성한 가짜 데이터이기 때문에 D(x)의 값이 0에 가까워져 log(1-D(G(x)))의 기댓값이 0에 가까워 지게 돼서 위 식이 최대화되게 된다.
- G 입장에서 보면 오른쪽 식에 영향을 끼치게 되는데 G가 학습이 잘 되어서 실제와 가까운 데이터를 생성한다면 D는 이를 판별을 제대로 하지 못해 D(G(x)) 값이 1에 가까워져 log(1-D(G(x)))의 기댓값이 -(무한대)로 가서 목적 함수가 최소화 되게 된다.
G와 D가 학습되는 과정
- k번 만큼 D를 학습을 시키고 G를 한번 학습 시키는 과정을 반복하면서 전체 모델을 학습 시킨다.
- 이렇게 함으로써 G가 천천히 변하는 한 D가 optimal한 해 근방에서 유지되는 결과를 가져온다.
- 위 그림에서 D(G(z))가 1에 가깝게 가는 방향으로 학습을 해야하는데 G가 생성하는 이미지가 안좋을수록 기울기가 평평해 Gradient값이 작아 학습이 잘 되지 않는다
- 그래서 log(1-D(G(z)))를 최소화하는 방향으로 G를 학습시키기 보단 log(D(G(z))를 최대화 하는 방향으로 G를 학습 시키면 학습 초기에 strong한 gradient를 가져서 학습이 더 잘된다.
- 파란 점선 : D의 분포
- 검정 점선 : 실제 이미지 분포
- 초록 실선 : 생성되는 이미지의 분포
- (a) : random noise인 z에서 G를 통해 x를 생성하면 이 x는 초록 실선의 형태의 분포를 띈다. 이 분포는 실제 이미지 분포와 많은 차이를 보이며 D의 분포 또한 매우 불안정한 것을 볼 수 있다.
- (b) : G를 고정하고 D를 학습하면 D의 분포가 안정해지면서 실제 이미지와 생성되는 이미지를 잘 구분하게 된다.
- (c) : D를 고정하고 G를 학습하면 G의 분포가 실제 이미지의 분포와 유사해진다.
- (d) : 학습이 반복이 되다보면 실제 이미지 분포와 생성된 이미지 분포가 동일해지고, D가 실제 이미지와 생성된 이미지를 구분하지 못하게되어 1/2라는 답만 내놓게 된다. -> 완벽하게 학습됨, 이상적인 상황
4. Theoretical Results
4.1 Global Optimality of pg = pdata
4.2 Convergence of Algorithm 1
이 부분은 제 설명이 많이 부족한 것 같습니다ㅜㅜ
저도 공부를 하면서 완벽하게 이해한 느낌이 아니네요
5. Experiments
- MNIST, TFD, CIFAR-10를 포함한 데이터들로 학습
- generator nets (G)는 rectifier linear activations와 sigmoid activations를 혼합
- discriminator nets (D)는 maxout activations를 사용
- D를 학습시킬때 drop out을 적용
- G의 제일 하위 layer에만 input으로 noise를 사용
G를 통해 생성된 sample들에 Gaussian Parzen window를 fitting하고 이 분포 하의 log-likelihood를 reporting함으로써 p_g하에서 test set data의 확률을 측정하였다. Gaussian의 σ는 validation set에 대한 cross validation을 통해 얻었다. 이 방식은 다양한 생성 모델에서 likelihood가 계산할 수 없을 경우에 사용된다. likelihood를 추정하는 이 방식은 분산이 크고 고차원 공간에서 잘 수행되지는 않지만 저자들의 지식에서는 가능한 최선의 방법이라고 한다. 그리고 likelihood를 직접 추정하지 못하는 생성모델의 발전은 이러한 모델들을 어떻게 평가할지에 대한 앞으로의 research를 motivate한다. G로 생성한 sample들은 더 좋은 생성모델들에 최소한 경쟁력을 지닌다고 저자들은 믿으며, adversarial framework의 잠재력을 강조한다고 한다.
6. Advantages and disadvantages
단점
- pg(x)가 명시적으로 존재하지 않는다.
- G와 D가 균형있게 학습되어야 한다. (G는 D가 학습되는 동안 과도하게 학습되면 안된다. 'Helvetica scenario'를 피하기 위해서라는데 내용은 정확히 이해는 못했지만 잘 알려진 GAN의 단점인 mode collapse를 설명하는 거지 않을까라고 생각합니다,,,)
+ 제가 다른 논문 리뷰나 발표 영상들을 보면서 공부한 단점으론 평가할 수 있는 방식이 없고 Black-Box 방식이다 정도가 더 있다고 하네요
장점
- Markov Chain이 필요없이 역전파로만 gradients를 얻는다.
- 학습 동안 inference가 필요 없다.
- 모델에 다양한 함수가 포함될 수 있다.
7. Conclusions and future work
- G와 D에 c를 추가함으로써 conditional generative model p(x|c)를 얻을 수 있다.
- auxiliary network를 x가 주어졌을 때 z를 예측하도록 학습시킴으로써 Learned approximate inference가 수행될 수 있다. wake-sleep 알고리즘에 의해 inference net이 학습되어지는 것과 유사하지만, generator net이 학습이 완료되고 난 후 고정된 generator net를 위해 inference net이 학습 되어질 수 있다는 이점을 가지고 있다.
- parameters를 공유하는 conditionals model를 학습함으로써 다른 conditionals models을 근사적으로 모델링할 수 있다. 특히 MP-DBM의 stochastic extension의 구현에 대부분의 네트워크를 사용할 수 있다.
- Semi_supervised learning : discriminator 혹은 inference net으로부터의 feature들이 labeled data가 제한될 때 classifier들의 성능을 향상시킬 수 있다.
- Efficiency improvements : G와 D를 조정하는 더 좋은 방식들을 고안하거나 학습동안 sample z에 대한 더 좋은 분포를 결정함으로써 학습은 가속화될 수 있다.
처음으로 읽어본 논문이고 블로그에도 처음으로 올리는 논문이라 많이 부족할 수 있다는 점 양해 부탁드립니다!
긴 글, 깔끔하지 않은 리뷰를 읽어주셔서 감사합니다. 다른 논문 리뷰에서는 좀 더 신경써서 쓸 예정이니 읽어주시면 감사하겠습니당 ㅎㅎ