BACKGROUND & STATE OF THE ART
์์ฐ์ด ์ฒ๋ฆฌ(NLP) ์์ญ์์ ์ธ์ด ๋ชจ๋ธ์ ๊ณผ๊ฑฐ ์ ๋ ฅ ํ ํฐ์ ์ํ์ค๋ฅผ ์ฌ์ฉํ์ฌ ํ ํฐ(์: ๋จ์ด)์ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๋์ฉ๋ ์ธ์ด ๋ชจ๋ธ(Large Language Models, LLMs)์ ์ด ๊ณต๊ฐ์์์ ์ต์ ๋ฅ๋ฌ๋ ํ์ ์ผ๋ก, ์ธ๊ฐ๊ณผ ์ ์ฌํ ๋ฐฉ์์ผ๋ก ํ ์คํธ๋ฅผ ์์ฑํ๊ธฐ ์ํด ์ค๊ณ๋์์ต๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ์ ๋ ฅ ํ ํฐ์ ํฐ ์ํ์ค์ ๋ํ ์ฃผ์๋ฅผ ๊ฐ์ ํ๊ธฐ ์ํด transformer๋ฅผ ์ฌ์ฉํฉ๋๋ค.
LLaMA๋ 1์กฐ ๊ฐ ์ด์์ ํ ํฐ์ผ๋ก ํ๋ จ๋ ๊ฐ๋ ฅํ ๊ธฐ๋ฐ LLM์ผ๋ก, Meta AI์์ ์คํ ์์ค๋ก ์ ๊ณต๋ฉ๋๋ค. LLaMA๋ GPT-3, Chinchilla, PaLM๊ณผ ๊ฐ์ ๋ง์ ์ต๊ณ ์ ๋ชจ๋ธ๊ณผ ๊ฒฝ์๋ ฅ์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
LLaMA (13B)๋ GPT-3 (175B)๋ณด๋ค ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ฉฐ, ๊ฐ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์์ ๋ ๋ง์ ๊ณ์ฐ์ ์ถ์ถํ ์ ์๋ ๋ฅ๋ ฅ์ ๊ฐ์ถ๊ณ ์์ต๋๋ค.
์ด ๋ธ๋ก๊ทธ ํฌ์คํธ์์๋ LLaMA๋ฅผ ์์ ๋ชจ๋ธ๋ก ์ฌ์ฉํ์ฌ PyTorch/XLA๊ฐ LLM ์ถ๋ก ์ ์ด๋ป๊ฒ ์ฌ์ฉ๋๋์ง์ ๊ธฐ๋ฅ์ ๋ณด์ฌ์ค๋๋ค. ์ฌ๊ธฐ์์ ์ค๋ช ํ๋ ๊ณ์ฐ ๊ธฐ์ ๊ณผ ์ต์ ํ ๊ธฐ์ ์ด Google Cloud TPU v4 (v4-16)์์ ๊ตฌ๋๋๋ 65B ํ๋ผ๋ฏธํฐ LLaMA ๋ชจ๋ธ์ ์ถ๋ก ๋๊ธฐ ์๊ฐ์ 6.4๋ฐฐ ๊ฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ํด ๋ ผ์ํฉ๋๋ค.
MODEL OVERVIEW
์ฐ๋ฆฌ๋ PyTorch/XLA์ ์ฑ๋ฅ ๊ธฐ๋ฅ์ Meta์ ์ต์ LLM์ธ LLaMA์์ ๋ณด์ฌ์ค๋๋ค. ์ฐ๋ฆฌ๋ ์ผ๋ จ์ ์ผ๋ฐ์ ์ธ LLaMA ๊ตฌ์ฑ์ ๋ํ ์ฑ๋ฅ ์ต์ ํ๋ฅผ ์์ฐํฉ๋๋ค. ์๋์์ ์ธ๊ธ๋ 175B ํ๋ผ๋ฏธํฐ ๋ชจ๋ธ ๊ตฌ์ฑ์ ๊ณต๊ฐ ๋๋ฉ์ธ์ ์์ต๋๋ค. ์ฐ๋ฆฌ๋ ์๋์์ ์ธ๊ธ๋ 175B ํ๋ผ๋ฏธํฐ ๋ชจ๋ธ์ ๋ํด LLaMA ์ฝ๋ ๊ธฐ๋ฐ์ OPT 175B ๋ชจ๋ธ ๊ตฌ์ฑ์ ์ ์ฉํฉ๋๋ค. ๊ทธ ์ธ์ ๊ฒฝ์ฐ์๋ ๋ชจ๋ ๊ตฌ์ฑ์์ max_seq_len=256 ๋ฐ dtype=bfloat16์ ์ฌ์ฉํ์ฌ ๊ฐ์ค์น์ ํ์ฑํ๋ฅผ ์ค์ ํฉ๋๋ค.
Table 1: ์ด ๋ฌธ์์์ ํ์ํ ๋ชจ๋ธ ๊ตฌ์ฑ
LLaMA | ๋ชจ๋ธ ํ์ดํผ ํ๋ผ๋ฏธํฐ |
---|---|
7B | 4,096 |
33B | 6,656 |
65B | 8,192 |
175B | 12,288 |
PERFORMANCE CHALLENGES OF LLMS
LLM์ ์ปดํ์ผ๋ฌ ์ต์ ํ์ ์ด๋ ค์์ ์ค ์ ์๋ ๋ช ๊ฐ์ง ํน์ฑ์ ๊ฐ๊ณ ์์ต๋๋ค. (a) LLM์ ์ด์ ํ ํฐ์ ๊ธฐ๋ฐ์ผ๋ก ๋ค์ ํ ํฐ์ ์์ฑํ๊ธฐ ์ํด ์๊ธฐํ๊ท ๋์ฝ๋ฉ์ ์ฌ์ฉํฉ๋๋ค. ์ด๋ ํ๋กฌํํธ ํ ์์ ์ฝ์น๊ฐ ๋์ ์ธ ํํ๋ฅผ ๊ฐ๋๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. (b) LLM์ ์ ๋ ฅ ํ ์ ๋ชจ์์ ๋ณ๊ฒฝ์ผ๋ก ์ธํ ๋ค์ ์ปดํ์ผ์ ์ ๋ฐํ์ง ์๊ณ ๋ ๊ฐ๋ณ์ ์ธ ์ ๋ ฅ ํ๋กฌํํธ ๊ธธ์ด๋ก ์๋ํด์ผ ํฉ๋๋ค. ์ ๋ ฅ ํ ์๋ ์ ์ ํ๊ฒ ๋ฒํทํ๋๊ณ ํจ๋ฉ๋์ด ๋ค์ ์ปดํ์ผ์ ํผํด์ผ ํฉ๋๋ค. (c) LLM์ ์ข ์ข ๋จ์ผ TPU(๋๋ GPU) ์ฅ์น๊ฐ ์ง์ํ ์ ์๋ ๋ฉ๋ชจ๋ฆฌ๋ณด๋ค ๋ง์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ๋ชจ๋ธ ์ค๋ฉ ๋ฐฉ์์ ๋ถ์ฐ ์ปดํจํ ์ํคํ ์ฒ์ ๋ชจ๋ธ์ ๋ง์ถ๊ธฐ ์ํด ํ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, 65B ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง LLaMA ๋ชจ๋ธ์ 8๊ฐ์ A100 GPU์ ๋น๊ต ๊ฐ๋ฅํ v4-16 ํด๋ผ์ฐ๋ TPU์ ๋ง์ถ ์ ์์ต๋๋ค. (d) LLM์ ์ด์ ํ๊ฒฝ์์ ์คํํ๋ ๊ฒ์ ๋น์ฉ์ด ๋ง์ด ๋ค ์ ์์ต๋๋ค. ์ด ์์ ๋น์ฉ ๋น ์ฑ๋ฅ(Perf/TCO)์ ํฅ์์ํค๋ ํ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์์ํ์ ๋๋ค. ์์ํ๋ ํ๋์จ์ด ์๊ตฌ ์ฌํญ์ ์ค์ผ ์ ์์ต๋๋ค.
INFERENCE TECH STACK IN PYTORCH/XLA
TorchDynamo, PjRt, OpenXLA ๋ฐ ์ฌ๋ฌ ๋ชจ๋ธ ๋ณ๋ ฌํ ๊ธฐ๋ฒ. TorchDynamo๋ ๋ฐํ์์์ ์ถ์ ์ค๋ฒํค๋๋ฅผ ์ ๊ฑฐํ๋ฉฐ, PjRt๋ ํจ์จ์ ์ธ ํธ์คํธ-์ฅ์น ํต์ ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. PyTorch/XLA ์ถ์ ๊ฐ๋ฅํ ์ปฌ๋ ํฐ๋ธ๋ฅผ ํตํด LLaMA์์ ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌํ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ PyTorch/XLA ์ถ์ ๊ฐ๋ฅํ ์ปฌ๋ ํฐ๋ธ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฐ๋ฆฌ์ ๊ฒฐ๊ณผ๋ฅผ ์๋ํ๋ ค๋ฉด, ์ฐ๋ฆฌ์ ์ปค์คํ torch, torch-xla ํ ์ ์ฌ์ฉํ์ฌ LLaMA ์ถ๋ก ์๋ฃจ์ ์ ์ฌํํด ๋ณด์ญ์์ค. PyTorch/XLA 2.1์ ์ด ๊ฒ์๋ฌผ์์ ๋ ผ์ํ๋ ๊ธฐ๋ฅ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์ํ ๊ฒ์ ๋๋ค.
PARALLEL COMPUTING
FairScale Sharding
LLaMA๋ FairScale ๋ชจ๋ธ ์ค๋ฉ API (fairscale.nn.model_parallel.layers)๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ฐ๋ฆฌ๋ ๊ฐ์๊ธฐ ๊ฐ์ ํ๋ก๊ทธ๋จ ์ํ(ํ์ฑํ ๋ฑ)๋ฅผ ํต์ ํ๊ธฐ ์ํด ๋ชจ๋ ๋ฆฌ๋์ค์ ๊ฐ์ PyTorch/XLA ํต์ ์ปฌ๋ ํฐ๋ธ(Collective Communication) ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ด API์ ๋๋ฑํ ํํ์ ๊ตฌ์ถํ์ต๋๋ค. TorchDynamo๋ ํ์ฌ CC ์ฐ์ฐ์ ์์ ํ ์ง์ํ์ง ์์ต๋๋ค(์ฆ, ์ถ์ ๊ฐ๋ฅํ ์ปฌ๋ ํฐ๋ธ๋ฅผ ์ง์ํ์ง ์์). ์ด๋ฌํ ์ง์์ด ์์ผ๋ฉด TorchDynamo FX ๊ทธ๋ํ๋ ๋ชจ๋ ๋๋ฐ์ด์ค ํต์ ๋ง๋ค(์ฆ, ๋ชจ๋ธ ๋ ์ด์ด๋ง๋ค) ์๋ฆด ๊ฒ์ ๋๋ค. ๊ทธ๋ํ ์๋ฆผ์ XLA ์ปดํ์ผ๋ฌ๊ฐ ์ ์ฒด ๊ทธ๋ํ ์ต์ ํ ๊ธฐํ๋ฅผ ์์ด ์ฑ๋ฅ ์์ค์ ์ผ๊ธฐํฉ๋๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๊ธฐ์กด์ CC API์ ๋์คํจ์ฒ ์ปฌ๋ ํฐ๋ธ๋ฅผ ํตํฉํจ์ผ๋ก์จ PyTorch/XLA ์ถ์ ๊ฐ๋ฅํ ์ปฌ๋ ํฐ๋ธ๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ฐจ์ด์ ์ PyTorch/XLA์ ์ง์ฐ ์คํ ํน์ฑ์ผ๋ก ์ธํด ์ปฌ๋ ํฐ๋ธ ์ดํ์ c10d.wait() ์ฐ์ฐ์ ์ฝ์ ํ ํ์๊ฐ ์๋ค๋ ๊ฒ์ ๋๋ค. ์ถ๋ก ๊ฐ๋ฅํ ์ปฌ๋ ํฐ๋ธ๋ฅผ ์ง์ํจ์ผ๋ก์จ PyTorch/XLA๋ TorchDynamo์์ ๋จ์ผ FX ๊ทธ๋ํ ์์ฑ์ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
AUTOREGRESSIVE DECODING ON PYTORCH/XLA
LLM์ ์ด์ ๋จ์ด๋ฅผ prompt๋ก ์ฌ์ฉํ์ฌ ๋ค์ ํ ํฐ์ ์์ธกํ๊ธฐ ์ํด autoregressive decoding์ ์ฌ์ฉํฉ๋๋ค. Autoregressive decoding์ ์ ํ ์๋ ๋์ ํํ ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ๋ฉฐ, ์ด๋ ๋ชจ๋ prompt์ ์ฌ์ปดํ์ผ์ ์ด๋ํฉ๋๋ค. ์ ํฌ๋ LLaMA autoregressive decoder๋ฅผ ์ต์ ํํ์ฌ ๊ณ ์ ๋ ํํ๋ก ์๋ํ๋๋ก ํ์ฌ KV-cache, ์ถ๋ ฅ ์ํ์ค ๋ฐ ์ดํ ์ ๋ง์คํฌ๋ฅผ ํจ์จ์ ์ผ๋ก ์ ๋ฐ์ดํธํ๊ณ ์ ํ์ต๋๋ค. ํจ๋ฉ, ๋ง์คํน ๋ฐ ์ธ๋ฑ์ค ์์ ์ ์กฐํฉ์ ํตํด ๋๋ฌด ๋ง์ ๊ทธ๋ํ ์ฌ์ปดํ์ผ์ ํผํ๊ณ ํจ์จ์ ์ธ autoregressive decoding์ ๋ฌ์ฑํ์ต๋๋ค.
KV-Cache ์ต์ ํ
LLaMA๋ KV-cache๋ฅผ ์ฌ์ฉํ์ฌ autoregressive decoding์ ๊ตฌํํฉ๋๋ค. ์์ฑ๋ ๊ฐ ํ ํฐ๋ง๋ค KV-cache๋ ๊ฐ Transformer ๋ ์ด์ด์ ์ดํ ์ ํค/๊ฐ ํ์ฑํ๋ฅผ ์ ์ฅํฉ๋๋ค. ๋ฐ๋ผ์ ์๋ก์ด ํ ํฐ์ ๋์ฝ๋ฉํ ๋ ์ด์ ํ ํฐ์ ํค/๊ฐ์ ๋ค์ ๊ณ์ฐํ ํ์๊ฐ ์์ต๋๋ค.
LLaMA์์๋ KV-cache ํ ์ ์ฌ๋ผ์ด์ค๋ฅผ ์๋ ์๋ฆฌ์ ์ ๋ฐ์ดํธํ๋ฏ๋ก ํ ํฐ์ด ์์ฑ๋ ๋๋ง๋ค ์ฌ์ปดํ์ผ ์ด๋ฒคํธ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ธ๋ฑ์ค ํ ์์ tensor.index_copy() ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์๋ ์๋ฆฌ์ ์ฌ๋ผ์ด์ค ์ ๋ฐ์ดํธ๋ฅผ ๋์ฒดํฉ๋๋ค. ์ดํ ์ ๋ง์คํฌ์ ์ถ๋ ฅ ์ํ์ค๋ ๋์ผํ ์ต์ ํ๋ฅผ ๋ฐ์ต๋๋ค.
INPUT PROMPT OPTIMIZATION
LLM ์์ฉ ํ๋ก๊ทธ๋จ์์ ๊ฐ๋ณ ๊ธธ์ด์ ์ ๋ ฅ ํ๋กฌํํธ๋ ์ผ๋ฐ์ ์ ๋๋ค. ์ด ์์ฑ์ ์ ๋ ฅ ํ ์์ ํํ ๋์ ์ฑ์ ์ผ๊ธฐํ๊ณ , ์ด์ ๋ฐ๋ผ ์ฌ์ปดํ์ผ ์ด๋ฒคํธ๊ฐ ๋ฐ์ํฉ๋๋ค. KV-cache๋ฅผ ์ฑ์ฐ๊ธฐ ์ํด ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํ ๋ (a) ์ ๋ ฅ ํ๋กฌํํธ๋ฅผ ํ ํฐ๋ณ๋ก ์ฒ๋ฆฌํ๊ฑฐ๋ (b) ์ ์ฒด ํ๋กฌํํธ๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํฉ๋๋ค. ๊ฐ ๋ฐฉ๋ฒ์ ์ฅ๋จ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
1๊ฐ์ ๊ทธ๋ํ๋ฅผ ๋ฏธ๋ฆฌ ์ปดํ์ผํ๊ณ ํ๋กฌํํธ๋ฅผ ํ ํฐ๋ณ๋ก ์ฒ๋ฆฌํจ
์ฅ์ : ์๋ฐ์
๋์ 1๊ฐ์ ๊ทธ๋ํ๋ง ์ปดํ์ผ๋จ
๋จ์ : ์
๋ ฅ ํ๋กฌํํธ ๊ธธ์ด L์ ์ฒ๋ฆฌํ๋ ๋ฐ O(L)์ด ๊ฑธ๋ฆผ - ๊ธด ํ๋กฌํํธ์ ๋จ์
1์์ max_seq_len (์: 2,048)๊น์ง์ ์ ๋ ฅ ๊ธธ์ด๋ก ๋ชจ๋ ๊ทธ๋ํ๋ฅผ ๋ฏธ๋ฆฌ ์ปดํ์ผํจ
์ฅ์ : ์๋ฐ์
์๊ฐ ๋์ max_seq_len ๊ทธ๋ํ๋ฅผ ๋ฏธ๋ฆฌ ์ปดํ์ผํ๊ณ ์บ์ํจ
๋จ์ : ์
๋ ฅ์ ์ฒ๋ฆฌํ๋ ๋ฐ 1๊ฐ์ ๊ทธ๋ํ ์คํ์ด ํ์ํจ
์ ํฌ๋ ๋ ๊ฐ์ง ๋์ ์ฌ์ด์์ ๊ท ํ์ ๋ง์ถ๊ธฐ ์ํด ํ๋กฌํํธ ๊ธธ์ด ๋ฒํทํ๋ฅผ ๋์ ํ์ต๋๋ค. ์ฐ๋ฆฌ๋ ์ค๋ฆ์ฐจ์์ผ๋ก ์ ๋ ฌ๋ ๋ฒํท ํฌ๊ธฐ ์งํฉ (b0, b1, b2, ..., bB-1)๋ฅผ ์ ์ํ๊ณ , ์ด๋ฌํ ๋ฒํท ๊ฐ (G0, G1, G2, ..., GB-1)์ ๋ฐ๋ผ ์ ๋ ฅ ํฌ๊ธฐ์ ๋ํ ํ๋ก๊ทธ๋จ ๊ทธ๋ํ๋ฅผ ๋ฏธ๋ฆฌ ์ปดํ์ผํฉ๋๋ค. B๋ ๋ฒํท์ ์์ ๋๋ค. ์ฃผ์ด์ง ์ ๋ ฅ ํ๋กฌํํธ์ ๋ํด ํ๋กฌํํธ ๊ธธ์ด๋ฅผ ๊ฐ์ฅ ๊ฐ๊น์ด ๋ฒํท ๊ฐ bn์ผ๋ก ๋ฐ์ฌ๋ฆผํ๊ณ ์ํ์ค๋ฅผ ํจ๋ฉํ ๋ค์ Gn์ ์ฌ์ฉํ์ฌ ํ๋กฌํํธ๋ฅผ ํ ๋ฒ์ ์ฒ๋ฆฌํฉ๋๋ค. ํจ๋ฉ ํ ํฐ์ ๊ณ์ฐ์ ์ญ์ ๋ฉ๋๋ค. ๊ฐ์ฅ ํฐ ๋ฒํท ํฌ๊ธฐ๋ณด๋ค ํฐ ํ๋กฌํํธ์ ๊ฒฝ์ฐ, ์น์ ๋ณ๋ก ์ฒ๋ฆฌํฉ๋๋ค.
์ต์ ์ ๋ฒํท ํฌ๊ธฐ๋ ๋์ ์์ฉ ํ๋ก๊ทธ๋จ์์ ํ๋กฌํํธ ๊ธธ์ด ๋ถํฌ์ ์ํด ๊ฒฐ์ ๋์ด์ผ ํฉ๋๋ค. ์ฌ๊ธฐ์์๋ ๋ฒํท ๊ธธ์ด๋ฅผ 128, 256, 384, 512๋ก ์ฑํํ์ต๋๋ค. ์ต๋ 2,047 ํ ํฐ์ ํฌํจํ๋ ๋ชจ๋ ์ ๋ ฅ ํ๋กฌํํธ๋ ์ต๋ 4๊ฐ์ ๊ทธ๋ํ ์คํ์ด ํ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, 1,500์ ์ ๋ ฅ ํ๋กฌํํธ์ 256์ ์์ฑ ๊ธธ์ด๋ฅผ ๊ฐ์ง ๊ฒฝ์ฐ, 260๊ฐ์ ๊ทธ๋ํ ์คํ์ด ํ์ํฉ๋๋ค - ์ ๋ ฅ์ ์ฒ๋ฆฌํ๊ธฐ ์ํ 4๊ฐ์ ์ถ๋ ฅ์ ์์ฑํ๊ธฐ ์ํ 256๊ฐ์ ๊ทธ๋ํ ์คํ์ ๋๋ค.
QUANTIZATION
์์ํ๋ ๊ฐ์ ํํํ๋ ๋ฐ ํ์ํ ๋นํธ ์๋ฅผ ์ค์ด๋ ๊ฒ์ผ๋ก, ์ฌ๋ฌ ๊ฐ์๊ธฐ ๋ ธ๋ ๊ฐ ๋ฐ์ดํฐ ํต์ ๋์ญํญ์ ์ค์ด๊ณ ํน์ ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ ํ์ํ ํ๋์จ์ด ์๊ตฌ ์ฌํญ์ ๋ฎ์ถฅ๋๋ค.
์ผ๋ฐ์ ์ผ๋ก BF16 ๊ฐ์ค์น๋ก๋ 175B ๋งค๊ฐ ๋ณ์ ๋ชจ๋ธ์ด ์ฝ 351GB์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๋ฐ๋ผ์ ๋ชจ๋ธ์ ์์ฉํ๊ธฐ ์ํด v4-32 ์ธ์คํด์ค๊ฐ ํ์ํฉ๋๋ค. ๊ฐ์ค์น๋ฅผ INT8๋ก ์์ํํจ์ผ๋ก์จ ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ๋๋ต 50% ์ค์ฌ v4-16 ์ธ์คํด์ค์์ ์คํํ ์ ์๊ฒ ๋์์ต๋๋ค. LLaMA๋ ๋ชจ๋ธ ํ์ฑํ๋ฅผ ์ค๋ฉํ๋ฏ๋ก, ์์ํ๋ ๋ฌด์ํ ์ ์๋ ํต์ ์ด๋์ ์ ๊ณตํฉ๋๋ค.
์ ํฌ ์คํ์์๋ ์ ํ ๋ ์ด์ด๋ฅผ ์์ํํ์ต๋๋ค. LLaMA ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ ๊ณต๊ฐ์ ์ผ๋ก ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ฐ๋ฆฌ์ ๋ชฉํ๋ ์ฑ๋ฅ์ ํ๊ฐํ๋ ๊ฒ์ด๋ฏ๋ก ์์ํ๋ ๋ชจ๋ธ์ ์์์ ๊ฐ์ค์น๋ก ์ด๊ธฐํ๋์์ต๋๋ค. AWQ ๋ฐ Integer or Floating Point?์ ๊ฐ์ ์ต๊ทผ ๋ฌธํ์ ๋ค์ํ ์ ๋นํธ ์์ํ ์ฒด๊ณ์์ LLaMA์ ์ฑ๋ฅ ํน์ฑ์ ๋ํ ํต์ฐฐ๋ ฅ์ ์ ๊ณตํฉ๋๋ค.
Effect of Batch Size on Quantization Performance
๋ชจ๋ธ ๋ฐฐ์น ํฌ๊ธฐ (BS) > 1์ผ ๋ TPU v4๋ ํ๋ ฌ ๊ณฑ์ ์ ๋ (MXU)์์ matmul์ ์คํํฉ๋๋ค. BS = 1์ธ ๊ฒฝ์ฐ matmul์ ๋ฒกํฐ ํ๋ก์ธ์ ์ ๋ (VPU)์์ ์คํ๋ฉ๋๋ค. MXU๊ฐ VPU๋ณด๋ค ํจ์จ์ ์ด๋ฏ๋ก INT8 ์์ํ๋ BS > 1์ผ ๋ ์ฑ๋ฅ์ ํฅ์์ํต๋๋ค. ์์ธํ ๋ด์ฉ์ ์ฑ๋ฅ ๋ถ์ ์น์ ์ ์ฐธ์กฐํ์ญ์์ค.
OP SUPPORT
๊ฐ๋์ฉ ์๋ก์ด ๋ชจ๋ธ์ ์ปดํ์ผ์ ์ํด PyTorch/XLA๊ฐ ์ง์ํ๋ ์ฐ์ฐ ์งํฉ์ ํ์ฅํด์ผ ํ๋ ์๋ก์ด ์ํ ์ฐ์ฐ์ ์๊ฐํฉ๋๋ค. LLaMA์ ๊ฒฝ์ฐ, multinomial์ ์ง์ํ์ต๋๋ค.
METHODOLOGY
LLaMA๋ LazyTensorCore์์ PyTorch/XLA๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ๋ณธ ๊ตฌ์ฑ์ผ๋ก ์๋ํฉ๋๋ค. ๋ชจ๋ ์คํ์ 256 ๊ธธ์ด์ ์ ๋ ฅ ํ๋กฌํํธ๋ฅผ ๊ฐ์ ํฉ๋๋ค. ๊ณต๊ฐ์ ์ผ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฏ๋ก, ์ฐ๋ฆฌ๋ ์ด ์ถ๋ก ์คํ ์ต์ ํ ์์ ์ ์ํด ์์์ ํ ์ ์ด๊ธฐํ๋ฅผ ์ฌ์ฉํ์ต๋๋ค. ์ฌ๊ธฐ์๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์ง์ฐ ์๊ฐ ๊ฒฐ๊ณผ์ ์ํฅ์ ๋ฏธ์น์ง ์์ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค.
Model Sizing
N์ด ํ๋ผ๋ฏธํฐ ์, dimensions์ด hidden size, n_layers๊ฐ ๋ ์ด์ด ์, n_heads๊ฐ ์ดํ ์ ํค๋ ์๋ผ๊ณ ๊ฐ์ ํ ๋, ์๋์ ์์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํฌ๊ธฐ๋ฅผ ๊ทผ์ฌํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ๋ชจ๋ธ ๊ฐ์ ์น์ ์ ์ฐธ์กฐํ์ญ์์ค.
n_heads๋ N์ ์ํฅ์ ์ฃผ์ง ์์ง๋ง, ์คํ ์์ค ๋ชจ๋ธ ๊ตฌ์ฑ์ ๋ํด์๋ ๋ค์ ์์ด ์ฑ๋ฆฝํฉ๋๋ค.
Cache Sizing
๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์ ์ดํ ์ ๋ธ๋ก์ ์บ์ ๋ ์ด์ด๋ ๋ฉ๋ชจ๋ฆฌ ์๋น์ ์ํฅ์ ์ค๋๋ค. ์ด ์น์ ์ ๋ฉ๋ชจ๋ฆฌ ์๋น ๊ณ์ฐ์ BF16 ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ณธ LLaMA ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค.
์บ์ ๋ ์ด์ด์ ํฌ๊ธฐ๋ cache_size = max_batch_size * max_seq_len * dimensions๋ก ๊ณ์ฐ๋ฉ๋๋ค. ๋ค์ ๊ณ์ฐ์์ ์์ ๊ตฌ์ฑ์ผ๋ก max_batch_size = 1 ๋ฐ max_seq_len = 256์ ์ฌ์ฉํฉ๋๋ค. ๊ฐ ์ดํ ์ ๋ธ๋ก์๋ 2๊ฐ์ ์บ์ ๋ ์ด์ด๊ฐ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ด LLaMA ์บ์ ํฌ๊ธฐ(๋ฐ์ดํธ)๋ total_cache_size = n_layers * 2 * cache_size * (2 ๋ฐ์ดํธ)์ ๋๋ค.
TPU v4 Hardware Sizing
๊ฐ TPU v4 ์นฉ์ 32GB์ ์ฌ์ฉ ๊ฐ๋ฅํ ๊ณ ๋์ญํญ ๋ฉ๋ชจ๋ฆฌ(HBM)๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. ํ 2์๋ ๋ฉ๋ชจ๋ฆฌ ์๋น ๋ฐ LLaMA ๋ชจ๋ธ์ ๋ณด์ ํ๊ธฐ ์ํด ํ์ํ TPU ์นฉ ์์ ๋ํ ์ธ๋ถ ์ ๋ณด๊ฐ ์์ต๋๋ค.
ํ 2: LLaMA TPU v4 HBM ์๊ตฌ ์ฌํญ (์ฆ, TPU v4 ์นฉ ์๊ตฌ ์ฌํญ)
ํ๋ผ๋ฏธํฐ
ํ๋ผ๋ฏธํฐ | (MB) | ์บ์ (MB) | ์ ์ฒด (GB) | ์ต์ TPU v4 ์นฉ ์ |
---|---|---|---|---|
7B | 14,000 | 134 | 14.128 | 1 |
33B | 66,000 | 408 | 66.41 | 3 |
65B | 130,000 | 671 | 130.67 | 5 |
175B | 350,000 | 1,208 | 351.21 | 11 |
Metrics
์๋๋ ์ถ๋ก ์๋๋ฅผ ์ธก์ ํ๋ ์ ์ฉํ ๋ฉํธ๋ฆญ์ค์ ๋๋ค. T๋ ์ด ์๊ฐ, B๋ ๋ฐฐ์น ํฌ๊ธฐ, L์ ๋์ฝ๋ฉ๋ ์ํ์ค ๊ธธ์ด๋ฅผ ๊ฐ์ ํฉ๋๋ค.
Latency Definition
์ง์ฐ ์๊ฐ์ ํ๊น ๊ธธ์ด L์์ ๋์ฝ๋ฉ๋ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ธฐ๊น์ง ๊ฑธ๋ฆฌ๋ ์๊ฐ์ผ๋ก, ๋ฐฐ์น ํฌ๊ธฐ B์๋ ์๊ด์์ด ์ฌ์ฉ์๊ฐ ์์ฑ ๋ชจ๋ธ์ ์๋ต์ ๋ฐ๊ธฐ๊น์ง ๊ธฐ๋ค๋ ค์ผ ํ๋ ์๊ฐ์ ๋ํ๋ ๋๋ค.
Per-token latency
์๊ธฐ ํ๊ท ๋์ฝ๋ฉ์ ํ ๋จ๊ณ๋ ๋ฐฐ์น์ ๊ฐ ์ํ์ ๋ํด ํ ํฐ์ ์์ฑํฉ๋๋ค. ํ ํฐ๋น ์ง์ฐ ์๊ฐ์ ํด๋น ํ ๋จ๊ณ์ ๋ํ ํ๊ท ์๊ฐ์ ๋๋ค.
Throughput
์ฒ๋ฆฌ๋์ ๋จ์ ์๊ฐ๋น ์์ฑ๋ ํ ํฐ ์๋ฅผ ์ธก์ ํฉ๋๋ค. ์จ๋ผ์ธ ์๋น์ ํ๊ฐํ๋ ๋ฐ ์ ์ฉํ ๋ฉํธ๋ฆญ์ค๋ ์๋์ง๋ง ๋ฐฐ์น ์ฒ๋ฆฌ ์๋๋ฅผ ์ธก์ ํ๋ ๋ฐ ์ ์ฉํฉ๋๋ค.
์ง์ฐ ์๊ฐ๊ณผ ์ฒ๋ฆฌ๋์ ํผํฉํ๋ T / (B * L)๊ณผ ๊ฐ์ ๋ฉํธ๋ฆญ์ค๋ ํผ๋๊ณผ ์คํด๋ฅผ ์ค์ด๊ธฐ ์ํด ํผํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
RESULTS
Figure 1์ LLaMA 7B์์ 175B ๋ชจ๋ธ์ ๋ํ ์ง์ฐ ์๊ฐ/ํ ํฐ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค. ๊ฐ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ๋ค์ํ TPU v4 ๊ตฌ์ฑ์์ ์คํ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, LLaMA 7B๋ v4-8์์๋ 4.7ms/ํ ํฐ, v4-16์์๋ 3.8ms/ํ ํฐ์ ๋ณด์ฌ์ค๋๋ค. ๋ ๋ง์ ๋น๊ต๋ฅผ ์ํด์๋ HuggingFace LLM ์ฑ๋ฅ ๋ฆฌ๋๋ณด๋๋ฅผ ๋ฐฉ๋ฌธํ์ญ์์ค.
์ด ๋ธ๋ก๊ทธ ํฌ์คํธ์์ ๋ ผ์๋ ๊ธฐ๋ฅ์ด ์๋ ๊ฒฝ์ฐ, LLaMA 65B๋ ์ฌ๊ธฐ์ ์ป์ 14.5ms/ํ ํฐ ๋์ 120ms/ํ ํฐ์ผ๋ก v4-32์์ ์คํ๋์ด 8.3๋ฐฐ์ ๊ฐ์์ ์ ๊ณตํฉ๋๋ค. ์ด์ ์ ์ธ๊ธํ ๋๋ก, ๊ฐ๋ฐ์๋ LLaMA ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ๊ธ ํด์ ํ๋ ์ฌ์ฉ์ ์ ์ torch, torch-xla ํ ์ ์๋ํ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
Figure 1: TPU v4 ํ๋์จ์ด์์์ LLaMA ์ถ๋ก ์ฑ๋ฅ
PyTorch/XLA:GPU ์ฑ๋ฅ์ PyTorch:GPU eager๋ณด๋ค ์ฐ์ํ๋ฉฐ PyTorch Inductor์ ์ ์ฌํฉ๋๋ค. PyTorch/XLA:TPU ์ฑ๋ฅ์ PyTorch/XLA:GPU๋ณด๋ค ์ฐ์ํฉ๋๋ค. ๊ฐ๊น์ด ๋ฏธ๋์ XLA:GPU๋ XLA:TPU์ ๋๋ฑํ ์ต์ ํ๋ฅผ ์ ๊ณตํ ๊ฒ์ ๋๋ค. ๋จ์ผ A100 ๊ตฌ์ฑ์ LLaMA 7B๋ง ๋ง์ต๋๋ค. 8-A100์ LLaMA 175B์ ๋ง์ง ์์ต๋๋ค.
Figure 2: GPU A100 ํ๋์จ์ด์์์ LLaMA ์ถ๋ก ์ฑ๋ฅ
๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ ํ ํฐ๋น ์ง์ฐ ์๊ฐ์ด ๋น์ ํ์ ์ผ๋ก ์ฆ๊ฐํ๋ ๊ฒ์ ๊ด์ฐฐํฉ๋๋ค. ์ด๋ ํ๋์จ์ด ์ด์ฉ๋ฅ ๊ณผ ์ง์ฐ ์๊ฐ ์ฌ์ด์ ํธ๋ ์ด๋์คํ๋ฅผ ๊ฐ์กฐํฉ๋๋ค.
Figure 3: ๋ค๋ฅธ ๋ฐฐ์น ํฌ๊ธฐ์ ๋ฐ๋ฅธ LLaMA ์ถ๋ก ์ฑ๋ฅ
์ ํฌ์ ์ฐ๊ตฌ ๊ฒฐ๊ณผ, ์ต๋ ์ํ์ค ์ ๋ ฅ ๊ธธ์ด(max_seq_len)๊ฐ ์ถ๋ก ์ง์ฐ ์๊ฐ์ ๋ฏธ์น๋ ์ํฅ์ ๋น๊ต์ ๋ฏธ๋ฏธํ๋ค๊ณ ๋ํ๋ฌ์ต๋๋ค. ์ด๋ ํ ํฐ ์์ฑ์ ์์ฐจ์ ์ด๊ณ ๋ฐ๋ณต์ ์ธ ํน์ฑ์ ๊ธฐ์ธํ ์ ์์ต๋๋ค. ์ฑ๋ฅ์ ์์ ์ฐจ์ด๋ ์ ์ฅ์ ํฌ๊ธฐ๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ KV ์บ์ ์ก์ธ์ค ์ง์ฐ ์๊ฐ์ด ๋ณ๊ฒฝ๋๊ธฐ ๋๋ฌธ์ผ ์ ์์ต๋๋ค.
Figure 4: ๋ค๋ฅธ ํ๋กฌํํธ ๊ธธ์ด์ ๋ฐ๋ฅธ LLaMA ์ถ๋ก ์ฑ๋ฅ
LLM์ ์ข ์ข ๋ฉ๋ชจ๋ฆฌ ๋ฐ์ด๋ ์ ํ๋ฆฌ์ผ์ด์ ์ ๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์์ํํจ์ผ๋ก์จ ์๊ฐ ๋จ์๋น MXU ๋น ๋ ํฐ ํ ์๋ฅผ ๋ก๋ํ๊ณ ์คํํ ์ ์์ต๋๋ค(HBM ⇒ CMEM ๋ฐ CMEM ⇒ MXU ๋ฐ์ดํฐ ์ด๋). Figure 5๋ INT8 ๊ฐ์ค์น๋ง ์์ํํ๋ ๊ฒ์ด ์ฃผ์ด์ง ํ๋์จ์ด์์ ๋ ํฐ ๋ชจ๋ธ์ ์คํํ ์ ์๋๋ก 1.6๋ฐฐ์์ 1.9๋ฐฐ์ ๊ฐ์์ ์ ๊ณตํจ์ ๋ณด์ฌ์ค๋๋ค.
BS=1์ธ ๊ฒฝ์ฐ, INT8 ํ ์๋ MXU๋ณด๋ค ์์ VPU๋ก ์ ์ก๋ฉ๋๋ค (TPU v4 ๋ ผ๋ฌธ ์ฐธ์กฐ). ๊ทธ๋ ์ง ์์ผ๋ฉด MXU๊ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก, BS=1์ธ ๊ฒฝ์ฐ ์์ํ ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ ํ๋์ MXU ์ด์ฉ ๋ถ๊ฐ๋ก ์ธํด ์์๋ฉ๋๋ค. ๊ทธ๋ฌ๋ BS>1์ธ ๊ฒฝ์ฐ, ๋ฉ๋ชจ๋ฆฌ ํ๋์ด ์์ํ๋ ๋ชจ๋ธ์์ ์ฐ์ํ ์ง์ฐ ์๊ฐ์ ์ ๊ณตํฉ๋๋ค. ์๋ฅผ ๋ค์ด, 175B ํ๋ผ๋ฏธํฐ LLaMA์ ๊ฒฝ์ฐ, ์์ํ๋ v4-16๊ณผ ์์ํ๋์ง ์์ v4-32๋ ์ ์ฌํ ์ฑ๋ฅ์ ์ ๊ณตํฉ๋๋ค. FP8 ๋น๊ต๋ PyTorch๊ฐ ์ด๋ฌํ ๋ฐ์ดํฐ ํ์์ ์์ง ์ ๊ณตํ์ง ์๊ธฐ ๋๋ฌธ์ ์ ๊ณตํ์ง ์์ต๋๋ค.
Figure 5: ๊ฐ์ค์น๋ง ์์ํ๋ LLaMA ์ถ๋ก ์ฑ๋ฅ ๋๋น. ๋๋ฝ๋ ํ๋์ ๋ง๋๋ ์ง์ ๋ TPU ํ๋์จ์ด์ ๋ชจ๋ธ ํฌ๊ธฐ๊ฐ ๋ง์ง ์์์ ์๋ฏธํฉ๋๋ค.# PyTorch/XLA์ LLaMA ์ถ๋ก ์ฑ๋ฅ
์ด ๋ฌธ์๋ PyTorch/XLA์ LLaMA(Large Language Model Analysis) ์ถ๋ก ์ฑ๋ฅ์ ๊ดํ ๋ด์ฉ์ ๋ค๋ฃน๋๋ค. ์๋ ๊ทธ๋ฆผ๋ค์ ๊ด๋ จ ์คํ ๊ฒฐ๊ณผ๋ฅผ ์๊ฐ์ ์ผ๋ก ๋ณด์ฌ์ค๋๋ค.
**Figure 6**: ์ ๋ ฅ ํ๋กฌํํธ ๊ธธ์ด์ ๋ฐ๋ฅธ LLaMA ์ถ๋ก ์ฑ๋ฅ
์ด ๊ทธ๋ฆผ์ ์ ๋ ฅ ํ๋กฌํํธ ๊ธธ์ด๊ฐ 10 ํ ํฐ์์ 1,500 ํ ํฐ์ผ๋ก ์ฆ๊ฐํจ์ ๋ฐ๋ผ PyTorch/XLA์ ์ผ๊ด๋ ์ฑ๋ฅ ์ฐ์๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ฐ๋ ฅํ ์ค์ผ์ผ๋ง ๋ฅ๋ ฅ์ ์ค์ ์์ฉ ํ๋ก๊ทธ๋จ์ ๋ค์ํ ๋ฒ์์์ ์ต์ํ์ PyTorch/XLA ์ฌ์ปดํ์ผ ์ด๋ฒคํธ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค. ์ด ์คํ์์ ์ต๋ ๊ธธ์ด๋ 2,048์ด๋ฉฐ ์ต๋ ์์ฑ ๊ธธ์ด๋ 256์ ๋๋ค.
FINAL THOUGHTS
์ฐ๋ฆฌ๋ PyTorch/XLA๊ฐ ์์ผ๋ก ์ด๋ค ๋ฐ์ ์ ์ด๋ฃฐ์ง์ ๋ํด ๋งค์ฐ ๊ธฐ๋ํ๊ณ ์์ผ๋ฉฐ, ์ปค๋ฎค๋ํฐ์ ํจ๊ป ์ฐธ์ฌํ ๊ฒ์ ์ด๋ํฉ๋๋ค. PyTorch/XLA๋ ์์ ํ ์คํ ์์ค๋ก ๊ฐ๋ฐ๋์์ต๋๋ค. ๋ฐ๋ผ์ ์ด์๋ฅผ ๋ฑ๋กํ๊ณ ํ ๋ฆฌํ์คํธ๋ฅผ ์ ์ถํ๋ฉฐ, GitHub๋ก RFCs๋ฅผ ๋ณด๋ด์ ๊ณต๊ฐ์ ์ผ๋ก ํ์ ํ ์ ์์ต๋๋ค. ๋ํ, TPU์ GPU๋ฅผ ํฌํจํ ๋ค์ํ XLA ์ฅ์น์์ PyTorch/XLA๋ฅผ ์ง์ ์๋ํด ๋ณผ ์๋ ์์ต๋๋ค.
๊ฐ์ฌํฉ๋๋ค,
Google์ PyTorch/XLA ํ