포스트

Gated DeltaNet-2: 선형 어텐션에서 삭제와 쓰기 게이트 분리

목차

  1. 개요
  2. 방법론
  3. 주요 결과
  4. 한계와 주의사항
  5. 결론
  6. Reference

개요

Gated DeltaNet-2는 NVIDIA가 발표한 선형 어텐션 기반 순환 시퀀스 모델로, 고정 크기 상태(fixed-size state) 업데이트 시 기존 델타 규칙 모델이 갖던 근본적 제약을 해소한다. 기존 Gated DeltaNet 및 Kimi Delta Attention(KDA)은 단일 스칼라 게이트 하나로 이전 메모리에서 지울 내용과 새로 쓸 내용을 동시에 제어했다. 이 방식은 두 연산이 강하게 결합되어 있어, 서로 다른 값 채널에 걸친 세밀한 메모리 편집이 불가능하다는 문제가 있다. Gated DeltaNet-2는 채널 단위 삭제 게이트(erase gate)와 채널 단위 쓰기 게이트(write gate)를 분리함으로써 이 문제를 직접 해결한다. 모델은 1.3B 파라미터, FineWeb-Edu 데이터 100B 토큰으로 학습하였으며, 순수 순환(recurrent) 모드와 슬라이딩 윈도우 어텐션을 결합한 하이브리드(hybrid) 모드 두 가지 변형을 지원한다.

방법론

기존 방식의 한계

KDA와 Gated DeltaNet은 상태 업데이트를 다음과 같은 형태로 수행한다. 하나의 스칼라 게이트 β_t가 삭제 강도와 쓰기 강도를 동시에 결정한다. 이 경우 “특정 키 차원에서는 강하게 지우되 특정 값 채널에는 작게 쓴다”는 식의 채널 독립적 편집이 불가능하다. 여러 연상(association)이 동일 상태 슬롯을 공유할 때, 하나를 수정하면 다른 연상이 오염되는 간섭 문제가 발생한다.

Gated Delta Rule-2 업데이트 수식

Gated DeltaNet-2의 핵심 상태 업데이트 식은 아래와 같다.

\[\mathbf{S}_t = \left(\mathbf{I} - \mathbf{k}_t(\mathbf{b}_t \odot \mathbf{k}_t)^\top\right)\mathbf{D}_t\mathbf{S}_{t-1} + \mathbf{k}_t(\mathbf{w}_t \odot \mathbf{v}_t)^\top\]

여기서 각 기호의 의미는 다음과 같다.

기호설명
S_t시각 t의 압축 메모리 상태 (행렬)
k_t정규화된 키 벡터
b_t삭제 게이트 (채널 단위, 키 차원)
w_t쓰기 게이트 (채널 단위, 값 차원)
D_t채널 단위 감쇠(decay) 대각 행렬
원소별 곱 (Hadamard product)

삭제 게이트 b_t는 키 차원에서 어느 좌표를 메모리에서 지울지 제어한다. 쓰기 게이트 w_t는 값 차원에서 어느 채널에 새 내용을 기록할지 제어한다. D_t가 추가되어 채널 단위 시간적 감쇠도 함께 흡수된다. b_t와 w_t를 동일 스칼라로 묶으면 KDA로, 추가로 스칼라 감쇠를 적용하면 Gated DeltaNet으로 축퇴하므로, 기존 모델들은 이 수식의 특수 경우에 해당한다. 모델 구성은 16 헤드, d_k = d_v = 128이며, 순환 상태 크기도 이에 맞게 설정된다. 학습 설정은 AdamW 옵티마이저, 피크 학습률 4e-4, 코사인 스케줄, 시퀀스 길이 4K이다.

효율적 구현

학습 효율성을 위해 세 가지 핵심 기법이 적용된다.

첫째, WY 폼과 감쇠 흡수(WY form with decay absorption)를 통해 누적 채널 단위 감쇠를 랭크-1 삭제 인수로 통합한다. 이를 통해 연속 업데이트를 수치적으로 안정적으로 병합할 수 있다.

둘째, 청크 단위 알고리즘(chunkwise algorithm)을 적용한다. 고정 크기 청크 단위로 처리함으로써, 선형 시퀀스 복잡도를 유지하면서 병렬 학습이 가능하다.

셋째, 게이트 인식 역전파(gate-aware backward pass)를 구현한다. 분리된 게이트 구조를 학습 중 기울기 계산에서도 올바르게 반영하기 위해 수정된 역전파 알고리즘을 사용하며, Triton 커널로 퓨즈드(fused) 구현된다.

주요 결과

언어 모델링 및 추론

1.3B 파라미터, 100B 토큰 학습 기준 언어 모델링 및 추론 벤치마크 결과는 아래와 같다.

언어 모델링 성능 비교

지표GDN-2 (Recurrent)GDN-2 (Hybrid)Mamba-2 (Recurrent)
WikiText 퍼플렉서티15.9015.6216.79
LAMBADA 정확도48.09%50.90%-
평균 추론53.11%53.97%-

순수 순환 모드에서 GDN-2는 Mamba-2 대비 WikiText 퍼플렉서티를 15.90으로 개선(Mamba-2: 16.79)하였다. 하이브리드 모드는 슬라이딩 윈도우 크기 2K 어텐션을 결합하여 추가적인 성능 향상을 달성한다.

장문 컨텍스트 검색 (RULER)

RULER 벤치마크의 Needle-in-Haystack(NIAH) 계열 태스크에서 멀티 키 간섭 시나리오가 가장 큰 성능 차이를 보인다.

RULER NIAH 성능 (4K 시퀀스 기준)

태스크GDN-2 RecurrentGDN-2 HybridKDA Recurrent
S-NIAH-293.0%--
MK-NIAH-172.6%93.0%28.0%
Recurrent Avg. (MK-NIAH-1)37.8-28.0

멀티 키 간섭 태스크(MK-NIAH-1)에서 GDN-2 순환 모드는 37.8로, KDA의 28.0 대비 명확한 향상을 보인다. 이는 삭제·쓰기 게이트 분리가 여러 연상이 겹치는 상황에서 메모리 간섭을 직접적으로 억제함을 보여준다.

실제 검색 벤치마크

SWDE, SQuAD, FDA, TriviaQA, NQ, DROP 6개 태스크 평균 정확도는 아래와 같다.

실제 검색 벤치마크 평균 정확도

모드평균 정확도
Recurrent29.88%
Hybrid42.28%

하이브리드 모드는 슬라이딩 윈도우 어텐션의 지역 정보 집약 능력이 더해져 실제 검색 태스크에서 순환 모드 대비 약 12.4%p 높은 평균 정확도를 기록한다.

학습 처리량

학습 처리량은 시퀀스 길이 증가에도 거의 일정하게 유지된다.

학습 처리량 (Kt/s)

시퀀스 길이GDN-2 처리량
2K38.0 Kt/s
8K36.1 Kt/s

일반 Transformer는 시퀀스 길이 증가에 따라 처리량이 크게 저하되는 반면, GDN-2는 8K에서도 2K 대비 소폭 감소에 그친다.

Ablation Study

삭제 게이트와 쓰기 게이트 각각의 기여도를 확인하기 위해 게이트를 하나씩 스칼라로 제한하는 ablation을 수행하였다.

Ablation 결과 (평균 추론 정확도)

구성평균 추론 정확도
삭제 게이트만 (스칼라 쓰기)52.79%
쓰기 게이트만 (스칼라 삭제)52.45%
GDN-2 (두 게이트 모두)53.11%

두 게이트 모두 독립적으로 성능에 기여하며, 어느 하나만 채널 단위로 확장하더라도 기준 대비 향상이 있다. 삭제 게이트 범위를 [0, 1]에서 [0, 2]로 확장하는 추가 실험에서는 유의미한 추가 이득이 없었다.

한계와 주의사항

NQ(Natural Questions)와 DROP처럼 지역적 증거 집약(local evidence aggregation)이 중요한 실제 QA 태스크에서는 성능 향상 폭이 제한적이다. 이러한 태스크는 단순한 메모리 편집보다 인접 문맥을 정밀하게 집약하는 능력에 더 의존하므로, 순수 순환 모드의 구조적 한계가 드러난다. 이는 장기 기억 편집에 최적화된 GDN-2의 이점이 지역 패턴 위주 태스크에서는 상쇄될 수 있음을 시사하며, 향후 어텐션 메커니즘 개선의 필요성을 남긴다.

결론

Gated DeltaNet-2는 선형 어텐션 순환 모델에서 삭제 게이트와 쓰기 게이트를 채널 단위로 분리함으로써, 고정 크기 상태 모델이 직면한 메모리 간섭 문제를 직접 해결한다. 장문 컨텍스트 검색, 특히 멀티 키 간섭 태스크에서 기존 KDA 및 Gated DeltaNet 대비 명확한 성능 향상이 실험적으로 확인되었다. Triton 퓨즈드 구현과 청크 단위 알고리즘을 통해 선형 복잡도와 실용적 학습 처리량도 유지한다. 지역 증거 집약이 중요한 QA 태스크에서의 한계는 남아 있으나, 삭제·쓰기 분리라는 설계 원칙은 이후 순환 시퀀스 모델 연구의 유용한 기반이 될 것으로 보인다.

Reference