【Transformer】DeiT-Training data-efficient image transformers & distillation
- 논문 : Training data-efficient image transformers & distillation through attention
- 분류 : Transformer
- 읽는 배경 : 매우 핫함. Facebook AI
- 느낀점 :
- 참고 사이트 : Youtube 강의 링크
- 목차
1. DeiT 논문 핵심 정리
- ViT의 문제점 : JFT, ImageNet 매우매우 많은 데이터셋을 사용해서 학습시켜야, ImageNet SOTA를 찍을 수 있다. Transformer를 사용해서 좋은 성능이 나오려면, 일단 무조건 데이터셋이 아주아주아주 많으면 많을 수록 좋다.
- Motivation : ImageNet 데이터셋만 가지고 학습시켜서 좋은 성능이 나오게 할 수 없을까?
- 해결방법 : – Knowledge Distillation 을 사용하자!! Knowledge Distillation을 사용하면 Data Agumentation하기 이전의 class를 그대로 예측해야 할 필요가 없다. Teacher 또한 data augmentation한 결과를 보고 예측한 결과를 Stutent가 따라하면 되기때문에, 잘못된 모델학습이 되는 것을 막아준다. (PPT 2번 참조)
- Transformer 만의 특장점: inductive biases가 매우 낮다. (inductive biases: [모집단(=전세계 모든)의 이미지의 centroids를 중심으로한 분포(+ decision boundary) VS 표본이미지(ImageNet)의 centroids와 분포]의 차이를 이야기 한다.) Conv와 fc에서는 weight가 Fixed되어 있지만, self-attention에서는 들어온 이미지에 대해서 서로 다른 weight[softmax(query x key) ]값을 가진다. (하지만 ImageNet에 대해서 좋은 Validation 결과가 나오려면.. inductive biases도 어느정도 필요하다.)
- DeiT는 Hard Distillation을 사용해서 적은 데이터 만으로 빠르게 성능을 수렴시켰다.
- 추가 기법 : Distillation Token (이거 추가해서 성능이 더 높아졌다. 아래 보충 내용 참조하기)
- 아니그러면! Distillation Token 추가하면 Test 할 때, 뭐를 이용해서 Prediction하는가? Distillation Embeding, Class Embeding 값을 적절히 융합해서 Prediction한다. (PPT 7번 참조)
- 실험결과를 보아 Convolution Network에 의해서 나오는 softmax결과를 Teacher로 이용했을때 가장 좋은 성능 나왔다, (inductive biases가 들어갔다!)
- Transformer는 fc를 주력으로 이용한다. conv는 파라미터를 이용하면서 연산해야하므로, 병렬연산이 fc보단 좋지 않다. 따라서 파라미터를 BeiT가 더 많이 사용하기는 하지만, 더 빠른 인퍼런스 시간이 나온다.
- 추가 기법들! bag of tricks :
- Transformer는 Initialization이 매우매우 중요하다.
“The effect of initialization and architecture”, NIPS 2018
논문 내용을 사용했다. - Transformer는 위에서 말했듯이, 데이터셋이 많으면 많을수록 좋기 떄문에, 여러여러여러가 data augmentation 기법을 사용했다. (그 기법들은 PPT 9번 참조)
- 이미지 resoltuion을 갑자기 늘려서 조금만 Fine-tuning 해주면 성능이 높아진다. resoltuion을 키우고 patch사이즈를 그대로 하면, patch의 갯수가 늘어나고, positional embeding의 갯수가 늘어나야한다. 이때 기존의 positional embedding을 interpolates으로 늘려주는데, bicubic 을 사용해야 그나마 성능이 늘어난다.
- Transformer는 Initialization이 매우매우 중요하다.
2. DeiT 보충 내용
Distillation token : Class token과 같은 역할을 한다. 하지만,
the (hard) label predicted by the teacher
만으로 학습된다. (Its target objective is given by the distillation component of the loss) = 아래 loss함수의 (Fun) + (Fun)에서 오른쪽 항 만을 사용해서 학습시킨다.Distillation token으로 만들어 지는 Distillation Embeding의 결과도 Classfication Prediction 결과가 나와야 하는데, Teacher모델이 예측하는 결과가 나오면 된다. (이미지 자체 Label에 대해서는 학습되지 않는다.)
3. DeiT Youtube 필기 내용