- YouTube
youtu.be
1. 이번 주 목표
실제로 model을 serving 한다고 했을 때, 가성비가 좋으려면 경량화가 핵심이므로, 경량화 기법들을 학습한다.
2. 학습 내용
현존하는 LLM은 모두 Transformer 기반이므로 크기가 성능에 가장 큰 영향을 미친다. 그럼 LLM을 서비스 형태로 제공할 때 크기가 큰 만큼 비용도 함께 커질 수 밖에 없는데, 그럴 때 쓰는 경량화 기법들을 소개한다.
2.1 Data Parallelism
GPU가 여러 대 있고 각각의 GPU의 모델은 올라가지만 원하는 크기의 batch size로 학습할 수 없을 때 사용하는 방법이다.
모델을 각 GPU 마다 복사: 먼저 주어진 모델을 복사하여 모든 GPU에 올린다.
Batch를 GPU 수 만큼 쪼개기: 주어진 batch의 data를 GPU 수로 쪼개어 각 GPU에 분배한다. 예를 들어 원하는 batch size는 128이고 GPU 수는 8개일 때 각 GPU마다 16개의 data가 분배되도록 쪼갠다.
각 GPU에서 gradient를 계산: 각 GPU에 배정된 모델과 data를 가지고 gradient를 계산한다.
Gradient 값들을 수합하여 학습 진행: 각 GPU에서 계산된 gradient들을 하나의 GPU에 더한다. 그리고 더한 결과로 gradient descent를 진행 한 후, 각 GPU의 모델을 똑같이 업데이트 한다.
각 parameter의 loss에 대한 gradient는 각 data point에 대한 gradient의 합으로 표현되기 때문에 가능한 방법론이다. 하지만 다음과 같은 단점들을 수반한다.
중복되는 메모리 존재: 같은 모델을 모든 GPU에 복사하기 때문에 메모리 낭비가 있다.
모델이 GPU에 올라가지 않으면 사용할 수 없음: data parallelism은 batch size를 늘려줄 수 있는 방법이지, 더 큰 모델을 쓸 수 있게 해줄수는 없다.
Optimizer state를 고려하지 않음: Adam optimizer의 경우, 각 parameter마다 새로운 변수들을 정의한다. 즉, optimizer가 차지하는 메모리는 모델보다 큰 경우가 많다. 하지만 data parallelism에서는 이러한 메모리 분배를 고려하지 않는다.
2.2 Quantization
2.2.1 Floating point
일반적으로 LLM과 같은 머신러닝 모델은 많은 parameter로 구성되어 있으며, 이 parameter들은 정수 또는 소수로 이루어져 있다.
컴퓨터 프로그램에서는 이런 소수 값을 표현하기 위해서 0과 1로 이루어진 bit 여러 개를 차용한다. 예를 들어 개발 단계에서 자주 쓰이는 float type 변수는 32bit으로 표현한다.
Quantization의 기본 아이디어는 이러한 floating point의 precision을 줄이는 것이다. 즉, 사용하는 bit 수를 줄이는 것이다. Bit의 수를 줄이면 실제 값과 오차가 생길 수 있지만 메모리 사용량을 줄일 수 있다.
이러한 방향으로 많이 쓰이는 precision은 다음과 같다.
Float16: 단순히 32-bit floating point를 16-bit로 구현한 것이다.
BFloat16: float16와 같이 16bit이지만 정수와 소수 부분을 표현하는 bit 수가 다르다. Float16보다 적은 개수의 bit로 소수 부분을 표현하기 때문에 소수 보다 정수 부분을 더 잘 근사한다.
4-bit normal float: Q-LoRA에서 사용하는 방식으로, 4-bit floating point이다.
이렇게 precision을 줄이게 되면 속도는 빨라지지만, 성능이 열화된다. 이미지 해상도를 줄였을 때 알아보기 힘든 것과 같다.
2.2.2 AMP: Automatic mixed precision
AMP는 낮은 precision의 floating point만을 쓰는 것이 아닌, 기존의 높은 precision의 floating point을 섞어 쓰는 방법이다.
모델 parameter는 32-bit: 기본적으로 모델 parameter들은 32-bit floating point로 저장
Data는 half-precision: 모델에 입력 또는 loss 계산에 사용하는 data는 half-precision
모델 forward연산을 할 때는 half-precision으로 변환: 본격적으로 loss를 계산할 때는 모델 parameter를 half-precision으로 변환하여 진행
Gradient에 loss scaling 적용: gradient는 data의 precision과 마찬가지로 half-precision이다. 이를 32-bit 모델 parameter에 적용할 때 적절한 값을 곱해서 더해준다. 이는 32-bit에서는 0에 가까운 값이 half-precision에서는 0이 될 수 있기 때문이다.
이런 식으로 적절히 precision을 조절하여 메모리와 inference 속도에서 이점을 가지되, 성능 또한 32-bit를 쓸 때와 비슷하게 나오도록 할 수 있다.
2.3 PEFT: Parameter-efficient fine-tuning
다음과 같은 이유로 fine-tuning은 모델을 통째로 학습하는 방법도 있지만 마지막 layer만 학습하는 방법도 있다.
비용 절약: 학습하는 layer가 적으니 당연하게도 메모리나 학습 시간에 이점이 있다.
Generalization에 유리: 업데이트 하는 parameter가 적어, 수렴 속도는 느릴 수 있으나 좋은 test 성능을 얻는 것에 더 유리해진다.
PEFT는 마지막 layer 이외에도 다양하게 주어진 모델의 일부 parameter만 학습하는 방법론들을 의미한다. LLM 학습에 적용할 수 있는 PEFT 방법들은 다양하게 있지만, 이번 챕터에서는 가장 많이 쓰이는 PEFT 방법인 LoRA
에 대해 다룬다.
2.3.1 LoRA: Low-rank adaptation
어떤 weight matrix W0와 gradient 계산을 통해 얻은 변화량 ΔW가 있을 때 gradient descent는 다음과 같다.
W=W0+ηΔW
LoRA 저자들은 ΔW0이 W0와 같은 shape을 가지는 행렬이라는 것에서 아이디어를 얻었다.
만약 ΔW의 shape을 Rn×m 이라고 하면 선형대수에서는 특정 행렬을 다음과 같이 나타낼 수 있다.
ΔW=AB,A∈Rn×r,B∈Rr×m
만약 여기서 r≪n,m을 적당히 작은 값으로 둔다면 전체 parameter 수는 r(n+m)으로, 기존의 parameter 수 nm보다 작은 것을 알 수 있다.
LoRA는 W0를 계속 업데이트하는 것이 아닌, parameter 수가 적은 ΔW=AB 자체를 학습한다. 실제 계산은 W0+ΔW로 수행한다.
2.4 Flash Attention
Flash Attention은 Transformer의 attention 계산을 최적화 하는 방법이다.
2.4.1 Attention mechanism
기존 Attention mechanism에서 query, key, value Q,K,V∈Rn×d 가 주어졌을 때 Transformer의 attention 계산은 다음과 같이 이루어진다.
\begin{align*} S &= QK^T \in \mathbb{R}^{n \times n} \\ O &= \textrm{softmax}(S)V \in \mathbb{R}^{n \times d} \end{align*}
여기서 문제가 되는 부분은 S이다. Sequence의 길이 N가 너무 길어지면, 즉 들어오는 text들의 길이가 길어지면 S는 지나치게 많은 메모리를 차지하게 된다.
더불어 기존의 attention mechanism은 많은 메모리를 차지하기 때문에 GPU에서 연산이 빠른 SRAM이 아닌 HBM을 활용할 수 밖에 없다는 단점이 있다.
2.4.2 Flash Attention mechanism
Q,K,V를 쪼개어 attention 계산:
Q,K,V를 일정한 블록으로 쪼갠다.
Q∈R6×2라고 하면 Q(1)∈R3×2,Q(2)∈R3×2 과 같이 블록을 쪼갤 수 있다. 그리고 각 블록 별로 attention 계산을 진행한다.
GPU SRAM 활용: 블록 간의 attention 계산은 메모리를 덜 차지하기 때문에 이제 비교적 연산속도가 빠른 GPU의 SRAM을 활용할 수 있게 된다. 그래서 블록 간 attention 계산은 블록들을 SRAM으로 복사한 후 진행한다.
HBM와의 통신 최소화: 마지막으로 GPU와 CPU 사이의 통신에서 시간이 소요되듯이 SRAM과 HBM 사이에도 통신 비용이 드는데, FlashAttention은 이런 통신을 최소화하는 것에도 신경을 썼다. 알고리즘은 다음과 같다.
Flash Attention의 장점은 다음과 같다.
속도 증가: 논문에서는 FlashAttention을 사용했을 때, GPT-2의 속도가 3배 향상됐다고 주장한다.
훨씬 긴 context 학습 가능: 기본적으로 메모리 사용량을 줄여주기 때문에 훨씬 긴 text를 학습할 수 있게 된다. 논문에서는 64K 길이의 text까지 학습하는데 성공했다고 주장한다.
근사 없이 빠른 attention 계산 가능: 기존의 attention을 근사하여 속도를 개선한 방법론들이 있다. FlashAttention은 근사를 하지 않고도 근사 방법들과 비슷한 성능을 낸다는 장점이 있다.
3. 느낀점
실제로 저 수식을 구현하려고 해봤는데 너무 어지러웠다. HuggingFace는 신이다.
8주간의 과정 동안 솔직히 딥다이브 하지는 못했지만, 이전과 다르게 머릿속에 그림이 조금 그려지는게 고무적이다.
다음 내용이 궁금하다면?
이미 회원이신가요?
2025년 5월 27일 오후 2:37
A
... 더 보기J
... 더 보기누
... 더 보기Next.js 까보기: "쓸 줄 아는 개발자"에서 "알고 쓰는 개발자로" 강의를
... 더 보기저
... 더 보기