【Transformer】DeiT-Training data-efficient image transformers & distillation



1. DeiT 논문 핵심 정리

  • ViT의 문제점 : JFT, ImageNet 매우매우 많은 데이터셋을 사용해서 학습시켜야, ImageNet SOTA를 찍을 수 있다. Transformer를 사용해서 좋은 성능이 나오려면, 일단 무조건 데이터셋이 아주아주아주 많으면 많을 수록 좋다.
  • Motivation : ImageNet 데이터셋만 가지고 학습시켜서 좋은 성능이 나오게 할 수 없을까?
  • 해결방법 : – Knowledge Distillation 을 사용하자!! Knowledge Distillation을 사용하면 Data Agumentation하기 이전의 class를 그대로 예측해야 할 필요가 없다. Teacher 또한 data augmentation한 결과를 보고 예측한 결과를 Stutent가 따라하면 되기때문에, 잘못된 모델학습이 되는 것을 막아준다. (PPT 2번 참조)
  • Transformer 만의 특장점: inductive biases가 매우 낮다. (inductive biases: [모집단(=전세계 모든)의 이미지의 centroids를 중심으로한 분포(+ decision boundary) VS 표본이미지(ImageNet)의 centroids와 분포]의 차이를 이야기 한다.) Conv와 fc에서는 weight가 Fixed되어 있지만, self-attention에서는 들어온 이미지에 대해서 서로 다른 weight[softmax(query x key) ]값을 가진다. (하지만 ImageNet에 대해서 좋은 Validation 결과가 나오려면.. inductive biases도 어느정도 필요하다.)
  • DeiT는 Hard Distillation을 사용해서 적은 데이터 만으로 빠르게 성능을 수렴시켰다.
  • 추가 기법 : Distillation Token (이거 추가해서 성능이 더 높아졌다. 아래 보충 내용 참조하기)
  • 아니그러면! Distillation Token 추가하면 Test 할 때, 뭐를 이용해서 Prediction하는가? Distillation Embeding, Class Embeding 값을 적절히 융합해서 Prediction한다. (PPT 7번 참조)
  • 실험결과를 보아 Convolution Network에 의해서 나오는 softmax결과를 Teacher로 이용했을때 가장 좋은 성능 나왔다, (inductive biases가 들어갔다!)
  • Transformer는 fc를 주력으로 이용한다. conv는 파라미터를 이용하면서 연산해야하므로, 병렬연산이 fc보단 좋지 않다. 따라서 파라미터를 BeiT가 더 많이 사용하기는 하지만, 더 빠른 인퍼런스 시간이 나온다.
  • 추가 기법들! bag of tricks :
    1. Transformer는 Initialization이 매우매우 중요하다. “The effect of initialization and architecture”, NIPS 2018 논문 내용을 사용했다.
    2. Transformer는 위에서 말했듯이, 데이터셋이 많으면 많을수록 좋기 떄문에, 여러여러여러가 data augmentation 기법을 사용했다. (그 기법들은 PPT 9번 참조)
    3. 이미지 resoltuion을 갑자기 늘려서 조금만 Fine-tuning 해주면 성능이 높아진다. resoltuion을 키우고 patch사이즈를 그대로 하면, patch의 갯수가 늘어나고, positional embeding의 갯수가 늘어나야한다. 이때 기존의 positional embedding을 interpolates으로 늘려주는데, bicubic 을 사용해야 그나마 성능이 늘어난다.

2. DeiT 보충 내용

  • Distillation token : Class token과 같은 역할을 한다. 하지만, the (hard) label predicted by the teacher 만으로 학습된다. (Its target objective is given by the distillation component of the loss) = 아래 loss함수의 (Fun) + (Fun)에서 오른쪽 항 만을 사용해서 학습시킨다.

  • Distillation token으로 만들어 지는 Distillation Embeding의 결과도 Classfication Prediction 결과가 나와야 하는데, Teacher모델이 예측하는 결과가 나오면 된다. (이미지 자체 Label에 대해서는 학습되지 않는다.)
    image-20210415204857822


3. DeiT Youtube 필기 내용

img01 img02 img03 img04 img05 img06 img07 img08 img09 img10 img11 img12 img13 img14 img15 img16 img17 img18 img19 img20 img21 img22 img23 img24 img25 img26 img27 img28 img29 img30 img31 img32


© All rights reserved By Junha Song.