[Paper Review] ๐ฆฉ Flamingo: a Visual Language Model for Few-Shot Learning
[๋ ผ๋ฌธ ๋ฆฌ๋ทฐ]๐ฆฉ Flamingo: a Visual Language Model for Few-Shot Learning
๐ฆฉ Flamingo: a Visual Language Model for Few-Shot Learning
Jean-Baptiste Alayrac et al
NeurIPS 2022
[arXiv]
๊ตฌ๊ธ DeepMind์์ ๊ฐ๋ฐํ Vision-Language Model์ด๋ค.
Background
Multimodal learning์ ์ด๋ฏธ์ง์ ํ ์คํธ๋ฅผ ๋์์ ์ดํดํ์ฌ VQA(Vision Question Anwsering), Image Captioning ๋ฑ์ task๋ฅผ ์ํํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค.
๊ธฐ์กด์ VLM๋ค์ ์ฃผ๋ก ๋๋์ Image-Text ๋ฐ์ดํฐ๋ก pretrainํ ํ, ๊ฐ downstream task๋ณ๋ก fine-tuningํ๋ Supervised Learning์ ์ฌ์ฉํ๋ค.
โ task๋ง๋ค ์์ฒ ๊ฐ์ ๋ผ๋ฒจ๋ง๋ ๋ฐ์ดํฐ๋ฅผ ์์งํ๊ณ ๋ชจ๋ธ์ ์ฌํ์ตํด์ผ ํ๋ค๋ ์ ์์ ํ์ฅ์ฑ์ ํ๊ณ
์๋ฅผ ๋ค์ด, CLIP๊ณผ ๊ฐ์ contrastive learning ๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ ์น์ผ๋ก๋ถํฐ ๋๊ท๋ชจ Image-Text๋ฅผ ํ์ตํ์ฌ zero-shot Classification์ ๋ณด์ฌ์ฃผ์์ง๋ง, ์ถ๋ ฅ์ด Image-Text similarity score ํํ์ ํ์ ๋์ด ์์ด Generation์ด ํ์ํ ๊ฐ๋ฐฉํ ์ง๋ฌธ(์: captioning, Q&A)์๋ ๊ทธ๋๋ก ์ ์ฉํ๊ธฐ ์ด๋ ต๋ค.
ViLT ๋ฑ ์ด๋ฏธ์ง์ ํ ์คํธ๋ฅผ ํจ๊ป ์ ๋ ฅ์ผ๋ก ๋ฐ์ Transformer๋ก ์ฒ๋ฆฌํ๋ ์์ฑํ VLM์ ๋ค์๊ณผ ๊ฐ์ ํ๊ณ๊ฐ ์กด์ฌํ๋ค.
- ํฌ๊ธฐ๊ฐ ์ ํ์
- ์๋์ ์์๋ง์ผ๋ก ์๋ก์ด task์ generalization ๋ถ๊ฐ
์์ฝํ๋ฉด, ๊ธฐ์กด VLM๋ค์ ๊ฑฐ๋ํ์ง๋ง ๋น์์ฑ์ ๋ชจ๋ธ(์: CLIP) ๋๋ ์์ฑ์ ๊ฐ๋ฅํด๋ ํ์คํฌ ์ ์์ ์ํด ์ถ๊ฐ ํ์ต์ด ํ์ํ ๋ชจ๋ธ(์: ViLT ๋ฑ)๋ก ๊ตฌ๋ถ๋๋ฉฐ, ์ฌ๋์ฒ๋ผ ๋ช ๊ฐ์ง ์์๋ง ๋ณด๊ณ ๋ ์๋ก์ด ์๊ฐ ์ธ์ด task๋ฅผ ์ํํ ์ ์๋ ๋ฒ์ฉ ๋ชจ๋ธ์ ๋ถ์ฌํ๋ค.
Contrastive Learning \(L_{\text{contrastive:txt2im}} = -\frac1N\sum_{i=1}^N \log\frac{\exp\bigl(L_i^\top V_i\beta\bigr)} {\sum_{j=1}^N\exp\bigl(L_i^\top V_j\beta\bigr)}\)
\[L_{\text{contrastive:I2T}} = -\frac1N\sum_{i=1}^N \log\frac{\exp\bigl(V_i^\top L_i\beta\bigr)} {\sum_{j=1}^N\exp\bigl(V_i^\top L_j\beta\bigr)}\]- ๋ถ์์ ๋ค์ด๊ฐ๋ $\exp(L_i^\top V_i/\tau)$ ํญ์ positive ์์ ์ ์ฌ๋๋ฅผ ๋์
- ๋ถ๋ชจ์ ์๋ $\sum_j \exp(L_i^\top V_j/\tau)$ ํญ์ โnegative ์๊ณผ์ ์ ์ฌ๋๋ฅผ ์๋์ ์ผ๋ก ๋ฎ์ถค
โ์ฌ๋ฐ๋ฅธ(pair) ์๋ฒ ๋ฉ๋ง ๊ณจ๋ผ๋ด๋๋กโ ํฌ๋ก์ค์ํธ๋กํผ๋ฅผ ์ต์ํํ๊ณ , ๋ค๋ฅธ ๋ชจ๋ ์๋ฒ ๋ฉ์๋ ๊ฑฐ๋ฆฌ๋ฅผ ๋ฒ๋ฆฌ๋๋ก ํ์ตํ๋ Loss
Introduction
Flamingo๋ Few-Shot ํ์ต์ ํนํ๋ VLM๋ก์, ๋ช ๊ฐ์ง ์์(Text-Image pair)๋ง์ผ๋ก๋ ์๋ก์ด Multimodal task๋ฅผ ํด๊ฒฐํ ์ ์๋๋ก ์ค๊ณ๋์๋ค. Flamingo๋ ์ด๋ฏธ์ง/๋น๋์ค์ ํ ์คํธ๊ฐ ์์ธ ์ํ์ค๋ฅผ ์ ๋ ฅ ๋ฐ์ ๋ค์์ ์ด์ด์ง ํ ์คํธ๋ฅผ ์์ฑํจ์ผ๋ก์จ VQA, ์ด๋ฏธ์ง ์บก์ ์์ฑ, ๋น๋์ค ์ค๋ช ๋ฑ ๋ค์ํ ์์ ์ ํ๋์ ๋ชจ๋ธ๋ก ์ํํ๋ค.
Flamingo์ ํต์ฌ ์์ด๋์ด๋ ๊ฐ๋ ฅํ pretrained (Vision, LLM)์ ์ฐ๊ฒฐํ๋ ๊ฒ์ด๋ค.
Flamingo์์๋ ์ด ๋์ ํจ๊ณผ์ ์ผ๋ก ๊ฒฐํฉํ๊ธฐ ์ํด ์๋ก์ด ์ํคํ ์ฒ ์์๋ค์ ๋์ ํ๋ค.
- Pretrained Vision and LLM์ ์ฐ๊ฒฐ module
- Image-Text๊ฐ ์์๋ก ๋ฐฐ์น๋ ์ํ์ค ์ฒ๋ฆฌ ๋ฉ์ปค๋์ฆ
- ์ด๋ฏธ์ง๋ฟ ์๋๋ผ ๋์์๊น์ง Input
์ด๋ฌํ ๊ตฌ์กฐ ๋๋ถ์ Flamingo๋ ์ธํฐ๋ท์์ ์์งํ ๋๊ท๋ชจ ์ด๋ฏธ์ง/ํ ์คํธ ํผํฉ ๋ฐ์ดํฐ๋ก ๋ฉํฐ๋ชจ๋ฌ ์ฌ์ ํ์ต์ ์ํํ ์ ์์๋ค.
๊ทธ ๊ฒฐ๊ณผ ๋ณ๋ fine-tuning ์์ด ๋ค์ํ Vision-Language task์์ ๋ฐ์ด๋ few-shot ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ์๋ค .
Flamingo๋ LLM์ ๋ฒ์ฉ์ ์ดํด๋ ฅ๊ณผ Vision ๋ชจ๋ธ์ ์ธ์ง ๋ฅ๋ ฅ์ ๊ฒฐํฉํ์ฌ, ์ ์ ์์๋ง์ผ๋ก ์ multimodal task๋ค์ ํ์ตํ ์ ์๋ ๋ฒ์ฉ VLM์ด๋ค.
์ข : ์ฌ๋ฌ ๋ฐ์ดํฐ์ ์์ ์ด์ ๊น์ง SOTA ์ฑ๋ฅ๊ณผ Flamingo์ ์ฑ๋ฅ ๋น๊ต ๊ทธ๋ํ
์ฐ : Flamingo ๋ชจ๋ธ ์ฌ์ด์ฆ์ shot (๋ฐ์ดํฐ) ์์ ๋ฐ๋ฅธ ์ฑ๋ฅ ๊ทธ๋ํ
Goal
few-shot๋ง์ผ๋ก ์๋ก์ด Vision-Language task๋ฅผ ์ถ๊ฐ ํ์ต ์์ด ์ํํ ์ ์๋ ๋ฒ์ฉ VLM์ ๋ง๋๋ ๊ฒ.
Motivation
๊ธฐ์กด VLM๋ค์
- Contrastive learning(CLIP)
- zero-shot classification์ ๊ฐ๋ฅ, text generation ๋ฅ๋ ฅ ๋ถ์ฌ
- Generative ๋ชจ๋ธ(ViLT)
- text generation์ ๊ฐ๋ฅ, task๋ณ fine-tuning ํ์
๋ ๋ฐฉ์ ๋ชจ๋ ์๋ก์ด task๋ง๋ค ๋๋์ labeled ๋ฐ์ดํฐ์ ๋ชจ๋ธ ์ฌํ์ต์ด ํ์ํ๋ฏ๋ก, โ์ฌ๋์ฒ๋ผ ๋ช ๊ฐ์ง ์์๋ง ๋ณด๊ณ ๋ ์ฆ์ ์ ์โํ๋ Few-Shot ๋ฉํฐ๋ชจ๋ฌ ํ์ต์ ๊ตฌํํ ํ์๊ฐ ์์๋ค.
Contributions
Flamingo ์ํคํ ์ฒ ์ ์
๊ฐ๋ ฅํ pretrain Vision ๋ชจ๋ธ(NFNet-F6)๊ณผ Language ๋ชจ๋ธ(Chinchilla)์ freezeํ ์ฑ
๋ ๋ชจ๋ธ ์ฌ์ด์ Perceiver Resampler์ ๊ฒ์ดํธ๋ cross-attention-dense(XATTN-Dense) layer ์ฝ์
โ์ด๋ฏธ์ง/๋น๋์ค โ ํ ์คํธโ๋ฅผ ์์๋ก ๊ต์ฐจํ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๊ณ ์์ ํ ํ ์คํธ๋ฅผ ์์ฑ
In-context Few-Shot learning
- ๋ณ๋ fine-tuning ์์ด, 4~32๊ฐ์ ์์(Image text pair)๋ฅผ ํ๋กฌํํธ๋ก ์ ๊ณตํ๋ฉด ์๋ก์ด task์ ์ฆ์ ์ ์
- VQA, Captioning, Video QA ๋ฑ 16๊ฐ ๋ฒค์น๋งํฌ์์ ๊ธฐ์กด zero/few-shot SOTA ๋ฌ์ฑ
- 6๊ฐ task์์๋ ํ์ธํ๋๋ SOTA๋ ๋ฅ๊ฐ
In-context learning
- ๋ชจ๋ธ์๊ฒ **์ถ๊ฐ fine-tuning **์์ด, ์ ๋ ฅ ํ๋กฌํํธ(๋ฌธ์ ์ค๋ช +์์)๋ง ๋ณด์ฌ ์ค์ผ๋ก์จ ์๋ก์ด task๋ฅผ ์ํํ๋๋ก ํ๋ ๋ฐฉ์
Few-shot learning
- โ์์(๋ช ๊ฐ)์ labeled ์์๋ง์ผ๋กโ ์๋ก์ด ํ์คํฌ๋ฅผ ํ์ตํ๊ฑฐ๋ ์ ์ํ๋ ๋ฅ๋ ฅ
- ์ผ๋ฐ์ ์ผ๋ก 1~32๊ฐ ์ ๋์ ์์(shots)๋ฅผ ํ๋กฌํํธ์ ํฌํจ์์ผ ๋ชจ๋ธ์ด ํจํด์ ํ์ ํ๋๋ก ํ๋ค.
- ๋ณ๋ fine-tuning ์์ด, 4~32๊ฐ์ ์์(Image text pair)๋ฅผ ํ๋กฌํํธ๋ก ์ ๊ณตํ๋ฉด ์๋ก์ด task์ ์ฆ์ ์ ์
ํจ์จ์ Pretrain multi-modal ์ ๋ต
- unlabeling ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉ
- ์นํ์ด์ง์์ ์ถ์ถํ โimage-text interleavedโ ๋๊ท๋ชจ ์ฝํผ์ค(M3W)
- ์์ญ์ต ์์ Image-text ๋ฐ Video-text pair
- ์ด๋ฅผ ํผํฉ ํ์ตํ์ฌ, ๋ฉํฐ๋ชจ๋ฌ ๋ฌธ๋งฅ์์์ ๋ค์ token ์์ธก ๋ฅ๋ ฅ์ ํ๋ณด
- unlabeling ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉ
ํ์ ์ ๊ธฐ์ ์์
- Perceiver Resampler: ๊ฐ๋ณ ๊ธธ์ด Vision ํน์ง์ 64๊ฐ token์ผ๋ก ์์ถ
- ๊ฒ์ดํธ๋ XATTN-Dense: tanh ๊ฒ์ดํธ๋ก ์์ ํ๋ cross-attention ์ฝ์ ๊ธฐ๋ฒ
- ์ด๋ฏธ์ง๋ณ ์ดํ ์ ๋ง์คํน: ๊ฐ ํ ์คํธ token์ด ์ง์ ์ด๋ฏธ์ง token๋ง ์ฐธ์กฐํ๋๋ก ์ ํ
Method
Flamingo๋ ์ด๋ฏธ์ง/๋น๋์ค์ ํ ์คํธ๊ฐ ๊ต์ฐจ๋ ์ํ์ค๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ Autoregressive text generation์ ์ํํ๋ ๋ชจ๋ธ์ด๋ค.
- (frozen) Vision encoder๊ฐ ํฝ์ ์ด๋ฏธ์ง๋ฅผ ๊ณ ์ฐจ์ feature์ผ๋ก ๋ณํ
- Perceiver Resampler๊ฐ ์ด ๊ฐ๋ณ ๊ธธ์ด์ visoin feature๋ค์ ๊ณ ์ ๊ธธ์ด์ token์ผ๋ก ์์ฝ
- (frozen) LLM ๋ด๋ถ์ cross-attention layer๋ค์ ์ฝ์
- ์ด๋ฏธ์ง/๋น๋์ค๋ก๋ถํฐ ์ป์ ์ ๋ณด๋ฅผ ํ ์คํธ ์์ฑ์ ํ์ฉ
Perceiver Resampler์ Cross-Attention ๋ชจ๋๋ง ํ์ต
Flamingo๋ ์ฝ์ ๋ ์ด๋ฏธ์ง ๋ฐ ๋น๋์ค $๐ฅ$์ ์กฐ๊ฑด๋ถ๋ก ํ ์คํธ $๐ฆ$์ ํ๋ฅ ์ ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ๋งํ๋ค.
\[p(y|x) = \prod^{L}_{\ell = 1} p(y_\ell | y_{<\ell}, x_{ \leq \ell})\]- $y_\ell$: ์ ๋ ฅ text์ $\ell$ ๋ฒ์งธ token
- $y<\ell$: ์ด์ text token ์งํฉ
- $x \leq \ell$: ์ด์ ์ ์์นํ ์ด๋ฏธ์ง/๋น๋์ค ์งํฉ
์ด๋ฅผ ํตํด ์๋ ์ฌ์ ํ์ต ๋ชจ๋ธ๋ค์ด ์ง๋ ์ง์๊ณผ ๋ฅ๋ ฅ์ ์ต๋ํ ๋ณด์กดํ๋ฉด์ ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ๋ฅผ ๊ฐ๋ฅํ๊ฒํ๋ค.
Visual processing
Vision Encoder : NFNet-F6
Vision Encoder๋ Normalizer-Free ResNet (NFNet) ๊ณ์ด์ F6 ๋ชจ๋ธ์ ์ฌ์ฉํ๋ค.
ImageNet ๋ฑ์์ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฐ๋ ฅํ ์ด๋ฏธ์ง feauter extractor
NFNet-F6 Encoder๋ Flamingo ํ์ต ์ด์ ์ ๋ณ๋๋ก ๋๊ท๋ชจ Image-text pair ๋ฐ์ดํฐ์ ๋ํด contrastive learning์ผ๋ก pretrain๋์์ผ๋ฉฐ freezeํ์ฌ ์ฌ์ฉํ๋ค.
CLIP๊ณผ ์ ์ฌํ๋ฉฐ, ํ์ต ๊ณผ์ ์์ Image/Video์ feature๋ง ์ ๊ณต
๋ณ๋ ํ์ต ๊ณผ์ ์์ BERT ์ธ์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ ๋์กฐํ์ตํจ โ โ๋ฌธ์ฅ ์ ์ฒด โ ํ๋์ ๊ณ ์ ๊ธธ์ด ๋ฒกํฐโ๋ฅผ ์ ๋ฝ์๋ด๋ ค๋ฉด ์๋ฐฉํฅ ์ปจํ ์คํธ๊ฐ ์ค์
Vision ์ธ์ฝ๋๋ฅผ contrastive ๋ฐฉ์์ผ๋ก ํ๋ํ์ฌ ์ข์ Vision ์๋ฒ ๋ฉ์ ์ป๋ ๊ฒโ๊ณผ โ์ป์ Vision ์๋ฒ ๋ฉ์ ๋์ค์ Chinchilla ๊ธฐ๋ฐ ์์ฑ ๋ชจ๋ธ์ ์ฃผ์ ํ๋ ๊ฒโ์ ์์ ํ ๋ถ๋ฆฌ
Flamingo๋ NFNet-F6์ ๋ง์ง๋ง layer ์ถ๋ ฅ์ feature map์ผ๋ก ๋ฐ์๋ค์ธ๋ค.
๊ตฌ์ฒด์ ์ผ๋ก,
- ์ด๋ฏธ์ง
- NFNet์ผ๋ก $HรW$ ํฌ๊ธฐ์ 2D feature map์ผ๋ก ๋ณํ, flatten โ ํ๋์ 1D token ์ํ์ค
- ๋์์
- 1์ด๋น 1 frame์ ์ํ๋งํ์ฌ ๊ฐ๊ฐ NFNet์ผ๋ก ์ธ์ฝ๋ฉ
- ์๊ฐ ์ถ๊น์ง ํฌํจํ 3D feature map์ ์ป๊ณ ์ฌ๊ธฐ์ ํ์ต๋ time embedding์ ๋ํด์ค๋ค.
- ๋ชจ๋ frame์ ๊ณต๊ฐ-์๊ฐ feautre๋ค์ flattenํ์ฌ 1D ์ํ์ค๋ก ๋ณํ
์ด๋ ๊ฒ ๋์จ 1D ์ํ์ค๋ Perceiver Resampler์ ์ ๋ฌ๋๋ค.
Perceiver Resampler
Vision Encoder๊ฐ ์ถ๋ ฅํ๋ ๊ฐ๋ณ ๊ธธ์ด์ ๋ฐฉ๋ํ feature ์ํ์ค๋ฅผ ์ผ์ ํ ๊ธธ์ด๋ก ์์ฝํ๋ ์ญํ ์ ํ๋ค.
- Input : NFNet์์ ์ถ์ถ๋ visual feature
- Output: 64๊ฐ์ visual token
์ด๋ฏธ์ง/๋์์ ํฌ๊ธฐ๊ฐ ์ด๋ป๋ , Perceiver Resampler๋ 64๊ฐ์ token์ผ๋ก ๊ทธ ์ ๋ณด๋ฅผ ์์ถํด์ ์ธ์ด ๋ชจ๋ธ ์ชฝ์ผ๋ก ์ ๋ฌํ๋ค.
LM๊ณผ visual feature๋ฅผ cross-attentionํ ๋ ๋น์ฉ์ด ํฌ๊ฒ ์ ๊ฐ
Perceiver Resampler์ ๊ตฌ์กฐ๋ Transformer encoder์ ํ ๋ธ๋ก ์ ๋๋ก ๋ณผ ์ ์๋ค.
64๊ฐ์ learnableํ latent query ๋ฒกํฐ๋ฅผ ์ค๋นํ๊ณ , ์ด๋ฅผ visual feature๋ค์ ๋ํด cross-attention ์ํค๋ ๋ฐฉ์์ด๋ค .
์ฝ๊ฒ ๋งํด, ์๋ฐฑ ๊ฐ์ ์ด๋ฅด๋ Image feautre ์ค ๊ฐ์ฅ ์ค์ํ ์ ๋ณด๋ง ๋ฝ์ 64๊ฐ ๋ฒกํฐ์ ๋ด๋ ์ญํ
์ ์๋ค์ ์ด๋ ๊ฒ Resampler ์ ์ฉ ๋ชจ๋์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋จ์ํ flattened feature๋ฅผ ๋ฐ๋ก ์ค์ด๊ธฐ ์ํ MLP๋ Transformer๋ฅผ ์ฐ๋ ๊ฒ๋ณด๋ค ์ฑ๋ฅ์ด ์ฐ์ํ๋ค๊ณ ํ๋ค
๊ฒฐ๊ณผ์ ์ผ๋ก, Resampler๋ฅผ ํตํด ์์ถ๋ 64๊ฐ์ token์ ์ดํ ์ธ์ด ๋ชจ๋ธ์ด ์ฐธ๊ณ ํ ์๊ฐ์ context๋ก ํ์ฉ๋๋ค.
Conditioning frozen LM on visual representations
Flamingo์ ์ธ์ด ์ดํด ๋ฐ ์์ฑ ๋ฅ๋ ฅ์ DeepMind๊ฐ ๊ฐ๋ฐํ LLM์ธ Chinchilla๋ก๋ถํฐ ๋์จ๋ค.
Flamingo์์๋ Chinchilla ๋ชจ๋ธ์ ํฌ๊ธฐ๋ณ๋ก 3๊ฐ์ง ์ฌ์ฉํ๋ค:
- Flamingo-3B
- Flamingo-9B
- Flamingo-80B
LM ์ญ์ ํ์ต ์ frozen๋์ด, ์๋์ ์ธ์ด ์ง์์ ๊ทธ๋๋ก ๊ฐ์งํ ์ฑ ์ฌ์ฉ๋๋ค.
๋์ Flamingo๋ ์ด LM ๋ด๋ถ์ ์๊ฐ ์ ๋ณด๋ฅผ ๋ผ์๋ฃ์ ์ ์๋ ์๋ก์ด layer๋ค์ ์ถ๊ฐํ๋ค.
Gated XATTN-Dense layers
Flamingo์์๋ learnableํ cross-attention ๋ธ๋ก๋ค์ pretrained LM์ ์ค๊ฐ์ ์ฝ์ ํ์ฌ, LM์ด ์์ฑ ๊ณผ์ ์์ Visual token์ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ด๋๋ก ํ๋ค.
์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด Cross Attention layer์ feed-forward layer๋ก ๊ตฌ์ฑ๋๋ฉฐ, tanh ๊ฒ์ดํธ๊ฐ ๊ณฑํด์ง๋ค.
- Cross Attention layer
- Query : LM์ ์ค๊ฐ hidden state
- Key, Value : 64๊ฐ visual token
- ์ธ์ด ๋ชจ๋ธ์ ํ์ฌ๊น์ง ์์ฑ๋ ํ ์คํธ ๋งฅ๋ฝ์ ๋ง์ถ์ด ์๊ฐ ์ ๋ณด์ ์ง์(query)๋ฅผ ๋ณด๋ด ํ์ํ ๋ด์ฉ์ ์ป์ด์ฌ ์ ์๋ค.
- Feed-Forward(Dense) layer :์๊ฐ ์ ๋ณด๋ฅผ ๋ฐ์ํ ํํ์ ๊ฐ ์์น๋ณ๋ก ๋ณํ
์ด๋ ๊ฒ ์ถ๊ฐ๋ ๋ ์ด์ด๋ค์ ์ถ๋ ฅ์ tanh ๊ฒ์ดํธ๋ฅผ ํตํด ์ค์ผ์ผ์ด ์กฐ์ ๋ ํ, ์๋ ์ธ์ด ๋ชจ๋ธ์ ๋ ์ด์ด ์ถ๋ ฅ๊ณผ ํฉ์ณ์ง๋ค.
gate์์๋ ํ์ต ๊ฐ๋ฅํ ์ค์นผ๋ผ $\alpha$๋ฅผ ํตํด $\tanh(\alpha)$๋งํผ ์ค์ผ์ผ์ ์กฐ์ ํ๋ค.
์ฒ์์ $\alpha=0$์ผ๋ก ์ค์ ํ์ฌ $\tanh(0)=0$, ์ฆ ํ์ต ์ด๋ฐ์๋ ์๋ก์ด ๋ ์ด์ด๊ฐ ์ธ์ด ๋ชจ๋ธ์ ์ํฅ์ ์ฃผ์ง ์๋๋ค.
์๊ฐ์ด ์ง๋จ์ ๋ฐ๋ผ $\alpha$๊ฐ ํ์ต๋๋ฉด์ gate๊ฐ ์ด๋ ค ์๊ฐ ์ ๋ณด๊ฐ ์ ์ง์ ์ผ๋ก ํตํฉ๋๋ค.
์ ์ด๋ฏธ์ง๋ tahn ๊ฒ์ดํธ์ ๊ฐ์ด ํ์ต ๊ณผ์ ์์ ์ด๋ป๊ฒ ๋ณํ๋์ง ๋ํ๋ธ ๊ทธ๋ํ์ด๋ค.
Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def gated_xattn_dense(y, x, alpha_xattn, alpha_dense):
# y: ์ธ์ด ๋ชจ๋ธ์ ์ค๊ฐ ํ๋ (queries)
# x: Vision token(key, value) โ Perceiver Resampler ์ถ๋ ฅ
# alpha_xattn, alpha_dense: ํ์ต ๊ฐ๋ฅํ ์ค์นผ๋ผ(๊ฒ์ดํธ ํ๋ผ๋ฏธํฐ)
# 1) Cross-Attention with tanh gating
y = y + tanh(alpha_xattn) * Attention(q=y, kv=x)
# 2) Feed-Forward Dense layer with tanh gating
y = y + tanh(alpha_dense) * FeedForward(y)
# 3) ๊ธฐ์กด ์ธ์ด ๋ชจ๋ธ์ Self-Attention + FFN (๋๊ฒฐ๋ ํ๋ผ๋ฏธํฐ)
y = y + FrozenSelfAttention(q=y, kv=y)
y = y + FrozenFeedForward(y)
return y
- Attention(
q=y, kv=x
): ์ฟผ๋ฆฌy
๊ฐ Vision tokenx
์ cross-attention์ ์ํ - FeedForward(
y
): Transformer์ ์ผ๋ฐ์ ์ธ position-wise FFN(Dense) - ๋ค์ FrozenSelfAttention๊ณผ FrozenFeedForward๋ ์ฌ์ ํ์ต๋ ์ธ์ด ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ฉฐ, ํ์ต๋์ง ์์ ์ฑ frozen๋์ด ์๋ค .
ํจ์จ์ฑ๊ณผ ํํ๋ ฅ ์ฌ์ด trade-off๋ฅผ ๋ง์ถ๊ธฐ ์ํด ๋ช ๊ฐ layer๋ง๋ค ํ ๋ฒ์ฉ ์ฝ์ ํ๋ ์ ๋ต์ ํํ๋ค.
์ด๋ ๊ฒ ํจ์ผ๋ก์จ ํ๋ผ๋ฏธํฐ ์์ ๊ณ์ฐ๋์ ๋๋ฆฌ์ง ์์ผ๋ฉด์๋ ์๊ฐ ์ ๋ณด๊ฐ ์ถฉ๋ถํ ์ธ์ด ๋ชจ๋ธ์ ์ฃผ์ ๋๋ค.
Multi-visual input support
per-image attention masking์ ํตํด ํ๋์ ๋ํ ๋ด์ ์ฌ๋ฌ ์ด๋ฏธ์ง/๋น๋์ค๊ฐ interleaved๋ ๊ฒฝ์ฐ๋ ์์ฐ์ค๋ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์๊ฒ ๋๋ค.
{์ด๋ฏธ์ง A - ์ง๋ฌธ X - ์ด๋ฏธ์ง B - ์ง๋ฌธ Y} ์ฒ๋ผ ํ ์คํธ์ ์ด๋ฏธ์ง๊ฐ ๋ฒ๊ฐ์ ์ฌ๋ฌ ๊ฐ ๋ฑ์ฅํด๋ ์ผ๊ด์ฑ ์๊ฒ ์ดํดํ๊ณ ๋๋ตํ ์ ์๋ค.
Cross-Attention ๋จ๊ณ์์, ๊ฐ text token์ด ๋ณผ ์ ์๋ Vision token์ ์ ํ์ ๊ฑฐ๋ ๊ฒ์ด๋ค.
๊ตฌ์ฒด์ ์ผ๋ก Flamingo๋ ํน์ ์์น์ ํ ์คํธ token์ด ์ค์ง ์ง์ ์ ๋ฑ์ฅํ ํ ์ฅ์ ์ด๋ฏธ์ง๋ก๋ถํฐ ๋์จ Vision token๋ค๋ง ์ฐธ์กฐํ๋๋ก ์ดํ ์ ๋ง์คํฌ๋ฅผ ์ ์ฉํ๋ค .
์๋ฅผ ๋ค์ด ๋ํ ํ๋กฌํํธ๊ฐ <Image1><Question1><Image2><Question2> ๋ผ๋ฉด, <Question2>์ ๋ํ ์์ฑ ์์๋ Image2 token์๋ง attention์ ์ ์ฉ
์ด๋ ๊ฒ ํ๋ฉด ๊ฐ ์ง๋ฌธ์ ํด๋นํ๋ ์ด๋ฏธ์ง ์ ๋ณด๋ง ์ง์คํ์ฌ ๋ต์ ์์ฑํ๊ฒ ๋๊ณ , ์ฌ๋ฌ ์ด๋ฏธ์ง๊ฐ ์๋ ๊ฒฝ์ฐ์๋ ๋ฌธ๋งฅ์ด ์์ฌ ํผ๋๋๋ ๊ฒ์ ๋ง์ ์ ์๋ค.
LM ๋ด๋ถ์ self-attention์ ๋ชจ๋ token๋ฅผ ํ์ฉํ๋ฏ๋ก ๊ฐ์ ์ ์ผ๋ก ๋ค๋ฅธ token๋ค๊ณผ ์ด์ด์ง ์ ์์ง๋ง, ์ง์ ์ ์ธ cross-attention ์ฐ๊ฒฐ์ ํ ์ด๋ฏธ์ง์ฉ์ผ๋ก ์ ํํจ์ผ๋ก์จ, Flamingo๋ ํ๋ จ ์ ์ฌ์ฉ๋ ์ด๋ฏธ์ง ๊ฐ์๋ณด๋ค ๋ ๋ง์ ์ด๋ฏธ์ง๊ฐ ๋ค์ด์๋ ์ ๋์ํ ์ ์๊ฒ ๋์๋ค .
์ด ๋๋ถ์ ํ์ต ๋ ํ ์ํ์ค๋น ์ต๋ 5์ฅ์ ์ด๋ฏธ์ง๋ง ์ฌ์ฉํ์ง๋ง, Test ์์๋ ์ต๋ 32๊ฐ์ Image-Text pair๊น์ง ๋ค๋ฃฐ ์ ์์๋ค.
์ฐธ๊ณ ๋ก, ์ ์๋ค์ ๋์ ์คํ์ผ๋ก โํ ์คํธ๊ฐ ๋ชจ๋ ์ด์ ์ด๋ฏธ์ง๋ค์ cross-attendํ๋๋กโ ํด๋ณธ ๊ฒฝ์ฐ ์คํ๋ ค ์ฑ๋ฅ์ด ๋จ์ด์ก๋ค๊ณ ๋ณด๊ณ ํ๋ค . ์ด๋ ํ๊บผ๋ฒ์ ์ฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ ์ฐธ๊ณ ํ๋ฉด ๋ชจ๋ธ์ด ์ด๋ ์ ๋ณด๋ฅผ ์ด๋์ ์จ์ผ ํ ์ง ํผ๋์ค๋ฌ์ํ๊ธฐ ๋๋ฌธ์ผ๋ก ์ถ์ ๋๋ค.
๊ฒฐ๊ตญ ์ด๋ฏธ์ง-์ธ๊ณผ์ (image-causal) ์ดํ ์ ๋ง์คํน์ ์ ์ฉํ ํ์ฌ์ ๋ฐฉ์์ด ๊ฐ์ฅ ํจ๊ณผ์ ์ด์๋ค๊ณ ํ๋ค.
Training on a mixture of vision and language datasets
Flamingo์ ์ฌ์ ํ์ต(Pretraining)์๋ ๋๊ท๋ชจ unlabeled multi-modal ์น ๋ฐ์ดํฐ๊ฐ ํ์ฉ๋์๋ค.
๋ฐ์ดํฐ๋ 3๊ฐ์ง๋ก ๊ตฌ์ฑ๋์ด์๋ค.
- M3W (MultiModal MassiveWeb)
- ์นํ์ด์ง๋ก๋ถํฐ ์ถ์ถํ ์ด๋ฏธ์ง-ํ ์คํธ ํผํฉ ์ํ์ค ๋ฐ์ดํฐ์ ์ด๋ค.
- ๊ฐ ํ์ด์ง์์ ์ด๋ฏธ์ง๊ฐ ๋ฑ์ฅํ ์๋ฆฌ์ <image> token์ ์ฝ์ ํ๊ณ , ๋ฌธ์์ ์น์ ์ด ๋๋ ๋ <EOC> token์ ๋ฃ๋ ์์ผ๋ก ์ํ์คํ
- Image-Text Pairs
- ALIGN ๋ฐ์ดํฐ์ + ์ถ๊ฐ๋ก ์์งํ LTIP (Long Text Image Pairs) ๋ฐ์ดํฐ์
- Flamingo ์ ๋ ฅ ํ์์ ๋ง์ถ๊ธฐ ์ํด ์บก์ ์์๋ <image> token์, ๋์๋ <EOC> token์ ๋ถ์ฌ ๊ตฌ์ฑ
- Video-Text Pairs
- ์์ฒด ์์งํ ๋์์ ์ค๋ช ๋ฐ์ดํฐ
- ํ๊ท 22์ด ๋ถ๋ ๋์์ - ํ ๋ฌธ์ฅ์ง๋ฆฌ ์ค๋ช
- ๋ง์ฐฌ๊ฐ์ง๋ก <image> (๋๋ ๋์์ ํ๋ ์์ placeholder)์ <EOC> token์ ํ์ฉํด ์ ๋ ฅ ์ํ์ค๋ฅผ ๊ตฌ์ฑ
์ด ์ธ ๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ์์ด ๋ชจ๋ธ์ ํ์ตํ ๋, ๋จ์ํ ํ ๋ฐ์ดํฐ์ ์ฉ ๋ฒ๊ฐ์ ํ๋ จํ๋ ๊ฒ๋ณด๋ค weight๋ฅผ ์ฃผ๋ฉฐ batch์ ์์ด gradient๋ฅผ ๋์ ํ๋ ๋ฐฉ์์ด ๋ ํจ๊ณผ์ ์ด์๋ค๊ณ ํ๋ค.
Training strategy
Flamingo ๋ชจ๋ธ์ ์ ๊ฑฐ๋ํ ์น ๋ฉํฐ๋ชจ๋ฌ corpora์์ ๋ค์ ๋จ์ด ์์ธก task๋ฅผ ์ํํ๋ฉฐ ์ฌ์ ํ์ต๋์๋ค.
์ฆ, ์ฃผ์ด์ง ์ํ์ค์์ ํ token์ฉ autoregressiveํ๊ฒ ์์ธกํ๋ ๋ฐฉ์์ผ๋ก, ์ด๋ฏธ์ง/ํ ์คํธ ํผํฉ ๋ฌธ๋งฅ์์ ํ ์คํธ ์์ฑ ํ๋ฅ $P(\text{ํ ์คํธ} \mid \text{์ด์ ํ ์คํธ+์ด๋ฏธ์ง})$ ์ ์ต๋ํํ๋๋ก ํ๋ จ๋์๋ค . ํ๋ จ ๊ฒฐ๊ณผ, Flamingo๋ ํ ์คํธ์ ์ด๋ฏธ์ง๊ฐ ์์ธ ๊ธด ์ํ์ค๋ฅผ ๋ณด๊ณ ๋ ๋ค์์ ์ฌ ๋จ์ด๋ฅผ ์์ฐ์ค๋ฝ๊ฒ ๋ง๋ค์ด๋ด๋ ๋ฅ๋ ฅ์ ํ๋ํ๋ค. ์ด๋ฌํ ๋ฅ๋ ฅ์ ๋ฐํ์ผ๋ก, ๋ชจ๋ธ์ด ํ์ต์ ์ฌ์ฉ๋์ง ์์ ์๋ก์ด ๋ฒค์น๋งํฌ์ ๋ํด few-shot ์ค์ ์ผ๋ก ๋น ๋ฅด๊ฒ ์ ์ํ ์ ์์์ ๋ ผ๋ฌธ์์ ์คํ์ผ๋ก ์ฆ๋ช ํ์๋ค.
Loss
- ์ธ ๋ฐ์ดํฐ์ $\mathcal{D}_m $ ๊ฐ๊ฐ์ ๋ํ negative log-likelihood ์์ค์ ๊ฐ์ค์น$\lambda_m$๋ก ํฉ์ฐ:
- ์ฌ๊ธฐ์ $x$๋ Vision ์ ๋ ฅ(์ด๋ฏธ์ง/๋น๋์ค), $y$๋ ํ ์คํธ token
- ๊ฐ์ค์น ํ๋์ ํตํด, ๋ชจ๋ธ์ด ์ธ ์์ค ๋ชจ๋์์ ๊ณ ๋ฅธ ์ฑ๋ฅ์ ๋ด๋๋ก ๊ท ํ์ ๋ง์ถ๋ค.
Gradient ๋์
- ๊ฐ step๋ง๋ค ๋ชจ๋ ๋ฐ์ดํฐ์ ์ ๊ฑธ์ณ gradient๋ฅผ ๋์ (accumulate)ํ๊ณ ํ ๋ฒ์ ์ ๋ฐ์ดํธ
Task adaption with few-shot in-context learning
์ฌ์ ํ์ต์ ๋ง์น Flamingo๋, task๋ณ fine-tuning ์์ด โํ๋กฌํํธ์ ๋ช ๊ฐ์ ์์๋ง ๋ณด์ฌ ์ค์ผ๋ก์จโ ์๋ก์ด task์ ์ฆ์ ์ ์ฉํ ์ ์๋ค.
In-context learning
Prompt ๊ตฌ์ฑ
- Support set: ํด๋น task์ ${(x_1,y_1),\dots,(x_K,y_K)}$ ์์ $K$๊ฐ(๋ณดํต $K=4,8,16,32$).
- $x$ : Image or Video
- Query: โ์๋ก์ด ์ ๋ ฅโ $x_{q}$ (์ด๋ฏธ์งยท์์)
- ์ต์ข ์ ๋ ฅ ์ํ์ค
์์ ์๋ต ์์ โOutput:โ์ ์ถ๊ฐ
์๊ฐ์ QA task์๋
โQuestion: {question} Answer: {answer}โ
ํ์์ผ๋ก ํ๋กฌํํธ๋ฅผ ๊ตฌ์ฑ
Decoding method
Open-ended task(์บก์ ยท์์ ์๋ต): Beam Search๋ก ํ ์คํธ ์์ฑ
Beam Search
์์ฑ ๋ชจ๋ธ์์ ๊ฐ์ฅ ํ๋ฅ ์ด ๋์ ์ถ๋ ฅ ์ํ์ค๋ฅผ ์ฐพ๊ธฐ ์ํ ํ์ ์๊ณ ๋ฆฌ์ฆ์ด๋ค. ํ๋ณด๋ฅผ ๋ฌด์์ ํ๋๋ง ๋ฝ๋ greedy search์ ๋นํด, ๋ค์ํ ๊ฐ๋ฅ์ฑ์ ๊ณ ๋ คํด ์ต์ข ๋ฌธ์ฅ์ ์์ฑํ๋ฏ๋ก ํ์ง ์ข์ ํ ์คํธ๋ฅผ ์ป๊ธฐ ์ฝ๋ค.
์ด๊ธฐํ
- ๋น ์ํ์ค []๋ฅผ ์ ์ผํ ํ๋ณด๋ก ๋๊ณ , ๊ทธ ์ ์(log-prob) $s=0$
- ๋น ํฌ๊ธฐ(beam width) $B$๋ฅผ ๋ฏธ๋ฆฌ ์ ํจ (์: $B=5$ ๋๋ $B=10$)
๋ฐ๋ณต(๋งค ํ์์คํ t)
๊ฐ ํ๋ณด ์ํ์ค $y_{1:t-1}$์ ๋ํด, ๋ชจ๋ธ์ด ์์ฑํ ์ ์๋ ๋ค์ token $v$๋ค๊ณผ ๊ทธ ๋ก๊ทธ ํ๋ฅ $\log p(v\mid y_{1:t-1},\,x)$์ ๊ณ์ฐ
๋ชจ๋ ํ๋ณด ร ๋ชจ๋ token ์์ ๋ํด, ํ์ฅ๋ ์ํ์ค์ ๋์ ์ ์๋ฅผ ๊ตฌํจ:
$\bigl(y_{1:t-1},\,v\bigr)\quad,\quad s_{\text{new}} = s_{\text{old}} + \log p(v \mid y_{1:t-1},\,x)$
์ด๋ค ์ค ๊ฐ์ฅ ์ ์๊ฐ ๋์ ์์ B๊ฐ ํ๋ณด๋ง์ ๋ค์ ๋จ๊ณ์ ํ๋ณด๋ก ์ ์ง
์ด ๊ณผ์ ์์ โ์ข ๊ฒฐ tokenโ(</s>)์ด ๋์จ ํ๋ณด๋ ๋ณ๋์ ์ ์ฅํด ๋์ ์ ์์
์ข ๋ฃ ์กฐ๊ฑด
- ๋ชจ๋ ํ๋ณด๊ฐ ์ข ๊ฒฐ token์ ๋ฝ์๊ฑฐ๋, ๋ฏธ๋ฆฌ ์ ํ ์ต๋ ๊ธธ์ด์ ๋๋ฌํ๋ฉด ์ข ๋ฃ
- ์ ์ฅํด๋ โ์์ฑ๋โ ํ๋ณด๋ค ์ค ๊ฐ์ฅ ๋์ ์ ์๋ฅผ ๊ฐ์ง ์ํ์ค๋ฅผ ์ต์ข ์ถ๋ ฅ
Closed-ended task(์ ๋คํ): ๊ฐ๋ฅํ ๋ชจ๋ ์ ๋ต ํ๋ณด๋ฅผ ์ง์ ์ด๋ฏธ์ง ๋ค์ ๊ฐ๊ฐ ํ๋์ฉ ๋ถ์ฌ log-likelihood๋ฅผ ๊ณ์ฐํด ๊ฐ์ฅ ๋์ ์ต์ ์ ํ
Zero-shot Generalization
๋ ผ๋ฌธ์์๋ Zero-shot ์ฑ๋ฅ์ ์ธก์ ํ ๋, ์์ ์์ด ํ ์คํธ ์์ ๋ ๊ฐ๋ง ์ฃผ๊ณ (์ด๋ฏธ์ง๋ ์ ๊ฑฐ) prompt๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌ์ฑํ๋ค.
1 2 3 4
<BOS> Output: This is a cat wearing sunglasses.<EOC> Output: Three elephants walking in the savanna.<EOC> <image> Output:
1 ๊ฐ ์์๋ง ๋ณด์ฌ ์ฃผ๋ฉด ๋ชจ๋ธ์ด ๊ณผ๋ํ๊ฒ ํธํฅ๋๋ฏ๋ก, 2 ๊ฐ ์์๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๊ทธ ์ด์์ ์ฑ๋ฅ ํฅ์์ด ๋ฏธ๋ฏธํด ๋ ๊ฐ๋ก ๊ณ ์ ํ๋ค.
Retrieval-based In-Context Example Selection (RICES)
๋๊ท๋ชจ ์์ ์งํฉ์์๋ ํ๋กฌํํธ ๊ธธ์ด ์ ํ ๋๋ฌธ์ ๋ชจ๋ ๋ฃ๊ธฐ ์ด๋ ต๊ณ , ์ํ์ค ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ๋ฉด ์ผ๋ฐํ ์ฑ๋ฅ์ด ์ ํ๋๋ค.
์ด๋ด ๋ An Empirical Study of GPT-3 for Few-Shot Knowledge-Based VQA์ RICES ๋ฐฉ์์ ๋ฐ๋ผ,
- ์ง์ ์ด๋ฏธ์ง์ visual feature ๋ฒกํฐ๋ฅผ ๋น๊ตํด ๊ฐ์ฅ ์ ์ฌํ ์์ $N$๊ฐ์ ์์๋ง ์ ํ
- ์ ์ฌ๋ ์, ์ฆ ๊ฐ์ฅ ๋น์ทํ ์์๊ฐ Query ์ง์ ์ ์ค๋๋ก ํ๋กฌํํธ๋ฅผ ๊ตฌ์ฑํ๋ค.
โ ์ด๋ ๊ฒ ํ๋ฉด ๊ธธ์ด๋ฅผ ์ ํํ๋ฉด์๋ ํ๋กฌํํธ ํ์ง์ ๋์ฌ ์ฑ๋ฅ์ ๊ฐ์ ํ ์ ์๋ค.
Experiments
Flamingo ๋ชจ๋ธ์ด ์ผ๋ง๋ ๋ค์ํ ์๋ก์ด Vision-language task์ ๋น ๋ฅด๊ฒ ์ ์(fast adaptation)ํ๋์ง ํ๊ฐํ๋ค. ์ด๋ฅผ ์ํด ์ด 16๊ฐ์ ๋ํ multimodal image/video - language benchmark๋ฅผ ์ ์ ํ๊ณ , ์ด๋ค ์ค 5๊ฐ๋ ๋ชจ๋ธ ์ค๊ณยทํ์ดํผํ๋ผ๋ฏธํฐ ํ๋(DEV set) ๊ณผ์ ์์, ๋๋จธ์ง 11๊ฐ๋ ์ค์ง ์ต์ข ํ๊ฐ(held-out) ์ฉ๋๋ก๋ง ์ฌ์ฉํ๋ค.
DEV ๋ฒค์น๋งํฌ(๋ชจ๋ธ ๊ฐ๋ฐ์ ์ฌ์ฉ)
- COCO Captions, OK-VQA, VQAv2, MSVDQA, VATEX
Held-out ๋ฒค์น๋งํฌ(์ต์ข ์ฑ๋ฅ ์ธก์ )
- Flickr30k, YouCook2 VideoQA, Visual Dialogue, Hateful Memes, TextVQA, STAR, NextQA, RareAct ๋ฑ 11๊ฐ
ํ๊ฐ ๋ฐฉ์ ํต์ผ
Few-shot in-context learning์ผ๋ก๋ง ๋ชจ๋ธ์ ์ ์ฉ
Open-ended๋ Beam Search(beam size=3)
Closed-ended๋ log-likelihood scoring ๋ฐฉ์์ผ๋ก ์ ๋ต ์ ํ
ํ์ดํผํ๋ผ๋ฏธํฐ, promt๊ตฌ์ฑ, beam size ๋ฑ์ ์ ๋ฒค์น๋งํฌ์ ๊ฑธ์ณ ๊ณ ์ ํด ํ๊ฐ ํธํฅ์ ์ต์ํ
Ablation study
- ๋ถ๋ก B.2: Flamingo ๋ชจ๋ธ์ fine-tuning ์ฑ๋ฅ(VQAv2, VATEX, VizWiz ๋ฑ 9๊ฐ task)
- ๋ถ๋ก B.2: ImageNetยทKinetics700 ๋ถ๋ฅ ์ฑ๋ฅ, contrastive ๋น์ ์ธ์ฝ๋ ์ฑ๋ฅ
- ๋ถ๋ก C: ์ง์์๋ตยท์บก์ ยท๋ํ ๋ฑ ๋ค์ํ ์ ์ฑ์ ์์
Few-shot learning
Flamingo(80B) ๋ชจ๋ธ์ 16๊ฐ ๋ฒค์น๋งํฌ ๋ชจ๋์์,
- ๊ธฐ์กด Zero/Few-shot SOTA ๋ณด๋ค ์ข์ ์ฑ๋ฅ
6๊ฐ task(OK-VQA, MSVDQA ๋ฑ)์์๋ 32-shot์ผ๋ก fine-tuning SOTA ๋ฅ๊ฐ
๋ชจ๋ธ ํฌ๊ธฐ(3Bโ9Bโ80B)์ ์ท ์(0โ4โ32)๋ฅผ ๋๋ฆด์๋ก ์ฑ๋ฅ์ด ์ผ๊ด๋๊ฒ ํฅ์
Fine-tuning
Flamingo๋ฅผ pretrained VLM์ผ๋ก ์ฌ๊ธฐ๊ณ labeling๋ fine-tuningํ๋ ๋ฐฉ๋ฒ์ผ๋ก task์ ์ ์ฉํ๊ธฐ๋ ํ๋ค.
- ๋ชจ๋ธ ๊ตฌ์กฐ
- Chinchilla์ Flamingo์ cross-attention ๋ฐ Perceiver Resampler ๋ชจ๋๊น์ง ๋ชจ๋ ๋ฏธ์ธ์กฐ์
- NFNet-F6๋ โ๋ ๋์ ํด์๋ ์ ๋ ฅโ์ ๋ฐ๋๋ก unfreezeํ์ฌ ํ์ต
fine-tuning ํ Flamingo๋ Few-shot ๊ฒฐ๊ณผ๋ฅผ ํฌ๊ฒ ๋ฐ์ด๋์ด, ๊ธฐ์กด์ fine-tuning SOTA์ ๋น๊ตํ์ ๋ VQAv2, VATEX, VizWiz, MSRVTTQA, Hateful Memes 5๊ฐ task์์ ์ SOTA๋ฅผ ๋ฌ์ฑํ๋ค.
์ด๋ Flamingo๊ฐ โ๋จ์ผ ๊ฐ์ค์นโ๋ก few-shot ๋ฐ fine-tuning ๋ ๊ฐ์ง ๋ชจ๋์ ๋์ ๊ฐ๋ฅํ๋ค๋ ๊ฒ์ ์ ์ฆํ๋ค.
Ablation Studies
Training data
- ์๋ณธ: M3W(ํฌ๋กค๋ง) + ImageโText pair (ALIGN+LTIP) + VideoโText pair(VTP)
- w/o M3W : ์ ์ฒด 17.3% ํ๋ฝ
- w/o Image-text pair : ์ ์ฒด 9.8% ํ๋ฝ
- w/o Video-text pair : video task ์ฑ๋ฅ ํ๋ฝ
- Image-text pair โ LAION dataset์ผ๋ก ๊ต์ฒด : 4.3% ํ๋ฝ
- ์๋ณธ: M3W(ํฌ๋กค๋ง) + ImageโText pair (ALIGN+LTIP) + VideoโText pair(VTP)
Optimisation
- ์๋ณธ: ๋ชจ๋ ๋ฐ์ดํฐ์ ์ ๊ทธ๋๋์ธํธ๋ฅผ ํ ์คํ ์ ๋์ (accumulate)
- ๋น๊ต: โround-robinโ ์ฌ์ฉ ์ ์ ์ฒด ์ ์ 70.7% โ 62.9%
Tanh gating
์๋ณธ: gate $\tanh(ฮฑ)$ ์ $ฮฑ$๋ฅผ 0์ผ๋ก ์ด๊ธฐํํ๊ณ XATTN-Dense ์ถ๋ ฅ์ ์ค์ผ์ผ
๋น๊ต: ๊ฒ์ดํ ์์ด : ์ ์ฒด ์ ์ 70.7 โ 66.5 (โ4.2)
๊ฒ์ดํ ์ด ์์ผ๋ฉด ์ด๊ธฐํ ์ pretrained LM ์ถ๋ ฅ๊ณผ ์ผ์นํ์ง ์์ ํ๋ จ ๋ถ์์ ์ด๋ .
Cross-attention architecture
- ์๋ณธ: GATED XATTN-DENSE
- ๋น๊ต:
- VANILLA XATTN (๊ธฐ์กด Transformer cross-attn๋ง ์ฝ์ ) : 70.7 โ 66.9
- GRAFTING (frozen LM ์์ cross+self-attn ์ธต ๋ง๋ถ์) : 70.7 โ 63.1
Cross-attention frequency
- ์๋ณธ: GATED XATTN-DENSE๋ฅผ ๋งค ์ธต๋ง๋ค ์ฝ์ (cost ์ฆ๊ฐ)
- ๋น๊ต:
- ๋งค 2๋ฒ์งธ ์ธต : 70.7 โ 68.2 (โ2.5%)
- ๋งค 4๋ฒ์งธ ์ธต : 70.7 โ 68.8 (โ1.9%)
- ํ ๋ฒ(์ค๊ฐ) : 70.7 โ 59.8 (โ15.4%)
- ์ ์ถฉ: trade-off๋ฅผ ๊ณ ๋ คํด์ Flamingo-9B๋ 4์ธต๋ง๋ค, Flamingo-80B๋ 7์ธต๋ง๋ค ์ฝ์
Perceiver Resampler
- ์๋ณธ: Perceiver Resampler (64๊ฐ ์๊ฐ ํ ํฐ ์ถ๋ ฅ)
- ๋น๊ต:
- Transformer (๋์ผ ํ๋ผ๋ฏธํฐ) : 70.7 โ 66.7
- ๋จ์ผ MLP : 70.7 โ 66.6
Vision encoder
- ์๋ณธ: NFNet-F6 (contrastive pre-trained)
- ๋น๊ต:
- CLIP ViT-L/14 : 70.7 โ 64.9 (โ5.8)
- NFNet-F0 : 70.7 โ 62.7 (โ8.0)
- ๊ฒฐ๋ก : ๊ฐ๋ ฅํ contrastive pretrained NFNet-F6๊ฐ ์ต์ .
Freezing LM
- ์๋ณธ: Chinchilla LM ์ธต ๋ชจ๋ freeze
- ๋น๊ต:
- LM๋ฅผ ์ฒ์๋ถํฐ ํ์ต(random init) : 70.7 โ 57.8 (โ12.9)
- LM๋ฅผ pretrain๋ ์ํ๋ก fine-tuning(unfreeze) : 70.7 โ 62.7 (โ8.0)
CLIP vs NFNet-F6
CLIP ViT-L/14๋ contrastive learning์ผ๋ก ํ์ต๋ ๊ฐ๋ ฅํ ๋น์ ์ธ์ฝ๋์ง๋ง, Flamingo์์๋ NFNet-F6๊ฐ ํจ์ฌ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์
- ํด์๋์ ์ํคํ ์ฒ ์ค๊ณ์ ์ฐจ์ด
ํญ๋ชฉ | CLIP ViT-L/14 | NFNet-F6 |
---|---|---|
๊ธฐ๋ณธ ๊ตฌ์กฐ | Vision Transformer (ViT-L/14) | Convolutional Feedforward Network (ResNet ๊ณ์ด) |
์ ๋ ฅ ํด์๋ | ๋ณดํต 224ร224 | 288ร288๋ก ํ์ต (Flamingo์์๋ 480ร480๊น์ง ์ฌ์ฉ) |
receptive field | ์ ํ์ (patch 14ร14 ๊ธฐ์ค) | ๋ ๋๊ณ dense (CNN ํน์ฑ) |
inductive bias | ๊ฑฐ์ ์์ (transformer) | ์์ (CNN: ์ง์ญ์ฑ, ๊ณ์ธต์ ํํ) |
- CLIP
- self-attention ๊ธฐ๋ฐ์ด๊ธฐ ๋๋ฌธ์ ์ง์ญ ํจํด, low-level visual feature๋ฅผ ํ์ตํ๊ธฐ ์ด๋ ค์
NFNet
- ์ด๋ฏธ์ง์ local ๊ตฌ์กฐ, ๊ณ์ธต์ feauture ์ถ์ถ์ ๋งค์ฐ ๊ฐํจ
โ Flamingo์ฒ๋ผ ๋ค์ํ ์ข ๋ฅ์ Vision-language task๋ฅผ ์ฒ๋ฆฌํด์ผ ํ ๊ฒฝ์ฐ, ์ผ๋ฐ์ ์ธ ์ธ์ ๋ฅ๋ ฅ์ด ๋ฐ์ด๋ NFNet์ด ๋ ํจ๊ณผ์
ํ์ต ๋ฐ์ดํฐ์ ์ค์ผ์ผ ์ฐจ์ด
NFNet-F6
ALIGN + LTIP ๋ฐ์ดํฐ
ALIGN: 1.8B image-text pair (Google ๋ด๋ถ ๋ฐ์ดํฐ)
LTIP: 4B text image pair (web-scale large-scale textโimage pairs)
ํ์ต ์คํ : 1.2M update steps, batch size 16,384, TPUv4 ร 512 ์ฌ์ฉ
CLIP
LAION-400M ์์ค์ ์น ๋ฐ์ดํฐ
์๋์ ์ผ๋ก ํ์ต ๊ท๋ชจ๊ฐ ์๊ณ , fine-grained noise๋ ๋ ๋ง์
โ โ Flamingo์ NFNet์ ๋ ์์ง์ ๋ฐ์ดํฐ๋ก ํจ์ฌ ๋ ์ค๋ซ๋์ ํ์ต๋์๊ธฐ ๋๋ฌธ์, CLIP๋ณด๋ค ์๊ฐ ํํ์ด ํจ์ฌ ์ ๊ต
Flamingo ์ํคํ ์ฒ์์ ์ ํฉ์ฑ
Flamingo์์๋ ๋น์ ์ธ์ฝ๋์ ์ถ๋ ฅ์ด Perceiver Resampler๋ฅผ ํตํด 64๊ฐ์ latent token์ผ๋ก ์์ถ๋์ด ์ธ์ด ๋ชจ๋ธ์ ์ฐ๊ฒฐ
CLIP ViT-L/14๋ patch-wise token ์ถ๋ ฅ์ด๋ผ ๋น์ ํ ๊ตฌ์กฐ์ ์ ๋ง์ง ์์ ์ ์์
๋ฐ๋ฉด NFNet์ convolutional feature map ํํ๋ก ์ถ๋ ฅ๋๊ธฐ ๋๋ฌธ์
โ Perceiver ๊ตฌ์กฐ์์ ํ ํฐ ์ถ์ถ, ์์น ๋ ๋ฆฝ์ฑ, ์ ๋ณด ์์ถ์ ํจ์ฌ ์ ๋ฆฌ
Conclusion
Flamingo๋ LLM์ few-shot learning ๋ฅ๋ ฅ์ Image/Video ๋๋ฉ์ธ์ผ๋ก ํ์ฅํจ์ผ๋ก์จ, ์ฃผ์ด์ง ๋ช ๊ฐ์ ์์๋ง์ผ๋ก ์๋ก์ด Vision-language task๋ค์ ์ ์ํ ํ์ตํ ์ ์์์ ์ ์ฆํ๋ค.
Flamingo์ ์ํคํ ์ฒ๋ ์ฌ์ ํ์ต๋ Vision backbone๊ณผ Language model์ ํจ๊ณผ์ ์ผ๋ก ์ฐ๊ฒฐํ์ฌ, ๋ ๋ชจ๋ธ์ด ์ถ์ ํ ์ง์์ ์ต๋ํ ํ์ฉํ๋ค.
๊ทธ ๊ฒฐ๊ณผ, ๋๊ท๋ชจ ์น ๋ฐ์ดํฐ๋ก ํ์ต๋ Flamingo๋ Captioning, VQA, ์์ ์ง๋ฌธ์๋ต, ๋ํํ ์๋ต ๋ฑ ๋ค์ํ task์์ SOTA ์์ค์ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ค.
ํนํ Inference ์ ์ถ๊ฐ ํ์ต ์์ด๋ ๋์ ์ฑ๋ฅ์ ๋ณด์๋ค๋ ์ ์์, ํฅํ Multimodal AI์ ๊ฐ๋ฐ ๋ฐ ์์ฉ ํจ๋ฌ๋ค์์ ๋ณํ๋ฅผ ๋ถ๋ฌ์ฌ ์ ์๋ค. ์ด์ ๊น์ง๋ ๊ฐ task๋ง๋ค ๊ฐ๋ณ ๋ชจ๋ธ์ ํ๋ จํด์ผ ํ๋ค๋ฉด, ์ด์ ๋ Flamingo ๊ฐ์ ๋ฒ์ฉ ๋ชจ๋ธ์ ๋ช ๊ฐ์ง ์์๋ฅผ ์ฃผ๋ ๊ฒ๋ง์ผ๋ก ํด๊ฒฐํ๋ ๋ฐฉํฅ์ผ๋ก ๋์๊ฐ ๊ฐ๋ฅ์ฑ์ ์์ฌํ ๊ฒ์ด๋ค.
Flamingo๋ Few-Shot Multimodal Learning์ ๊ฐ๋ฅ์ฑ์ ์ด์์ผ๋ฉฐ, Multimodal AI๊ฐ ์ด๋ป๊ฒ ์งํํ ์ง์ ๋ํ ํ๋์ ๋ฐฉํฅ์ฑ์ ์ ๊ณตํ๋ค.
Limitations
LM์ ์ฝ์ ์ ๊ณ์น : hallucination and ungrounded guesses
- ์ด๋ฏธ์ง ๋ด์ฉ๊ณผ ๋ฌด๊ดํ์ง๋ง ํ ์คํธ ๋งฅ๋ฝ์ ๊ทธ๋ด๋ฏํด ๋ณด์ด๋ ๋ต๋ณ์ ์์ฑํ๋ ํ์
- ์ธ์ด ๋ชจ๋ธ์ ์ฌ์ ์ง์์ ๊ณผ๋ํ๊ฒ ์์กดํ๊ธฐ ๋๋ฌธ
- train ๋ ๋ณธ ์ ์๋ ๊ธธ์ด์ ์
๋ ฅ ์ํ์ค์ ๋ํด์๋ ์ผ๋ฐํ์๋จ
- ๋๋ฌด ๋ง์ ์์๋ฅผ ํ ๋ฒ์ ๋ฃ์ผ๋ฉด ์ฑ๋ฅ์ด ํ๋ฝ
๋์ ๊ณ์ฐ ๋น์ฉ
- Flamingo-80B์ NFNet-F6๊ณผ Resampler๊น์ง ๊ฒฐํฉ๋์ด ์ฐ์ฐ๋๊ณผ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ด ๋ง๋
- TPUv4 ์นฉ 512๊ฐ
- ์ด 1.2 million update steps
- ์ถ๋ก ์ ์์ญ ์ท์ ์์๋ฅผ ์ ๋ ฅํ๋ฉด token ์๊ฐ ๊ธธ์ด์ ธ ๊ณ์ฐ cost๊ฐ ์ ํ์ ์ผ๋ก ์ฆ๊ฐ
- few-shot learning์ ๋ณ๋ fine-tuning์ด ํ์ ์๋ค๋ ์ฅ์ ์ด ์๋ ๋์ Inference cost ์ฆ๊ฐ .
- Flamingo-80B์ NFNet-F6๊ณผ Resampler๊น์ง ๊ฒฐํฉ๋์ด ์ฐ์ฐ๋๊ณผ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ๋์ด ๋ง๋
Prompt sensitive
- LLM์ฒ๋ผ ํ๋กฌํํธ์ ์ ๊ณตํ๋ ์์์ ๊ตฌ์ฑ๊ณผ ํํ์ ๋ฏผ๊ฐ
- ์๋ฅผ ๋ค์ด, ๋์ผํ 4๊ฐ์ ์์๋ผ๋ ๋ฐฐ์ด ์์๋ ์ค๋ช ์ดํฌ์ ๋ฐ๋ผ ๋ชจ๋ธ ์ถ๋ ฅ์ด ๋ฌ๋ผ์ง ์ ์๋ค.
- few-shot์์ shot ์๋ฅผ ๋๋ฆด์๋ก ์ด๋ ์ ๋๊น์ง๋ ์ฑ๋ฅ์ด ํฅ์๋์ง๋ง, ์ผ์ ์์ค(์์ญ ๊ฐ ์ด์)์ ๋์ด์๋ฉด ์คํ๋ ค ๋ชจ๋ธ์ด ํผ๋์ ์ผ์ผํค๊ฑฐ๋ ์ฑ๋ฅ ๊ฐ์ ์ด ์ ์ฒด๋๋ ํ์ ๋ฐ์
Classification task
Classification task์์๋ ์ต์ฒจ๋จ Contrastive model(CLIP)๋ณด๋ค ์ฑ๋ฅ ํ๋ฝ
ImageNet few-shot์์ CLIP ๋ฑ Image-text embedding ๊ธฐ๋ฐ classification ๋ชจ๋ธ๋ณด๋ค ํ๋ฝ
Flamingo๋ text generation ํ๋ฅ ์ ์ต๋ํํ๋๋ก ํ์ต๋์๊ธฐ ๋๋ฌธ
Flamingo๋ ๋ค์ํ task๋ฅผ ํญ๋๊ฒ ๋ค๋ฃจ๊ธฐ ์ํด ํนํ๋ ์ต์ ํ๋ ํฌ๊ธฐ