MemoryLLM - 트랜스포머 FFN을 해석 가능한 플러그앤플레이 메모리로 분리
목차
- 개요
- 연구 배경
- MemoryLLM 아키텍처
- TKV 프레임워크 - 해석 가능한 메모리 분석
- Token-wise Lookup과 효율성
- Flex-MemoryLLM 하이브리드 아키텍처
- 실험 결과
- 저장소 압축과 레이어 중요도
- 시사점 및 향후 방향
- Reference
개요
Ajay Jaiswal 등이 발표한 “MemoryLLM: Plug-n-Play Interpretable Feed-Forward Memory for Transformers” 논문을 소개한다. 이 연구는 트랜스포머에서 FFN(Feed-Forward Network)을 Self-Attention으로부터 완전히 분리하여, FFN을 해석 가능한 신경 키-값 메모리로 재설계하는 새로운 아키텍처를 제안한다.
핵심 성과는 FFN을 컨텍스트 프리 토큰 임베딩으로 독립 학습시켜 사전 계산 가능한 Token-wise Lookup(ToL)으로 변환하고, 이를 통해 VRAM 약 36% 절감과 디코딩 속도 약 33% 향상을 달성한 점이다.
연구 배경
FFN 해석의 어려움
대규모 언어 모델에서 FFN은 전체 파라미터의 약 2/3를 차지한다. 그러나 FFN은 Self-Attention 출력과 잔차 스트림의 비해석적 혼합을 입력으로 받아, 그 역할을 독립적으로 분석하기 어렵다.
기존 연구들은 FFN을 신경 키-값 메모리로 해석하려 했으나, 다음과 같은 한계가 있었다.
- 여러 번의 순전파/역전파와 캘리브레이션 데이터가 필요
- 수작업 어노테이션이 요구됨
- FFN 내 메모리 접근에 대한 이산적이고 해석 가능한 쿼리를 정의할 수 없음
핵심 연구 질문
이 연구는 다음 질문에서 출발한다.
“어떻게 FFN을 Self-Attention으로부터 분리하여, 유한한 인간 해석 가능 어휘에 매핑된 결정론적 메모리를 인코딩할 수 있는가?”
기존 사전학습 모델을 분석하는 대신, 처음부터 Self-Attention과 FFN을 명시적으로 분리한 새로운 트랜스포머 아키텍처를 설계하는 접근법을 택했다.
MemoryLLM 아키텍처
기존 트랜스포머의 구조
기존 트랜스포머에서 레이어 L의 처리 과정은 다음과 같다.
- Self-Attention이 잔차 스트림 스냅샷 X_L을 처리한다.
- FFN은 잔차 스트림과 어텐션 출력의 합인 X̃_L = X_L + Attn(X_L)을 입력으로 받는다.
이 상호의존성이 FFN의 독립적 분석을 방해하는 근본 원인이다.
MemoryLLM의 분리 전략
MemoryLLM은 Self-Attention과 FFN을 독립적이고 병렬적으로 학습시키는 방식을 채택한다.
Self-Attention은 기존처럼 잔차 스트림에서 동작한다. 그러나 모든 트랜스포머 블록의 FFN은 토크나이저에서 직접 오는 정적 컨텍스트 프리 토큰 임베딩 X_0를 입력으로 받는다.
레이어 L+1의 잔차 스트림은 다음과 같이 구성된다.
X_L+1 = X_L + Attn(X_L) + FFN(X_0)
임베딩 출력 X_0는 토크나이저의 토큰 ID에만 의존하므로, FFN 입력은 학습 및 추론 단계 모두에서 정적이며 어휘 크기에 의해 한정된다. 이러한 설계를 통해 FFN 레이어를 중요도에 따라 제거하거나 VRAM 제약에 맞게 조정할 수 있는 플러그앤플레이 유연성을 확보한다.
TKV 프레임워크 - 해석 가능한 메모리 분석
Token-Key-Value 분해
TKV 프레임워크는 SwiGLU 기반 FFN을 세 가지 구성 요소로 해석한다.
| 구성 요소 | 역할 |
|---|---|
| Keys (W_Up) | K개의 키-값 쌍(메모리 셀)을 포함하는 업프로젝션 행렬 |
| Values (W_Down) | 대응하는 값 벡터를 포함하는 다운프로젝션 행렬 |
| Gate (W_Gate) | 메모리 셀의 증폭/억제를 결정하는 학습된 재가중 함수 |
메모리 검색 과정
토큰 t에 대응하는 쿼리 벡터 q의 메모리 검색은 두 단계로 이루어진다.
1단계 - 메모리 셀 계수 계산: 쿼리 벡터와 업프로젝션 행렬의 내적에 게이트 프로젝션의 요소별 재가중을 적용한다.
2단계 - 검색된 출력 산출: 각 메모리 셀의 가중합으로 FFN 출력을 생성한다.
이 프레임워크의 핵심 장점은 역공학이 불필요하다는 점이다. 어휘 토큰 ID로부터 유한한 인간 해석 가능 쿼리 벡터 집합을 직접 정의할 수 있다.
의미론적 클러스터링
K-means 클러스터링을 통해 모든 어휘 토큰의 키 기여도 벡터를 분석한 결과, 인간이 해석 가능한 클러스터가 형성되었다.
- 구두점, 인명, 지명, 언어적 속성 등 의미적으로 유사한 토큰이 유사한 메모리 키를 활성화한다.
- 클러스터링 계수는 모든 모델 레이어에서 높게 유지된다.
- 마지막 레이어에서는 이상치 키 수가 증가하여, 토큰 수준 정보 수렴이 우수함을 시사한다.
Token-wise Lookup과 효율성
사전 계산 전략
FFN이 정적 임베딩에서 동작하므로, 모든 어휘 토큰에 대한 출력을 오프라인으로 사전 계산하여 정적 Token-wise Lookup(ToL)으로 저장할 수 있다.
각 어휘 토큰은 모든 N개 트랜스포머 레이어에 걸친 FFN 출력을 연결한 벡터를 할당받는다. 추론 시에는 FFN 연산 대신 ToL 조회만 수행하면 된다.
온디맨드 플러그앤플레이 설계
LLM 생성 콘텐츠의 토큰 분포는 Zipf의 법칙을 따른다. 이를 활용하여 자주 사용되는 토큰의 ToL만 VRAM에 캐싱하고, 덜 사용되는 토큰은 저장소에서 온디맨드로 로드할 수 있다.
또한 MemoryLLM에서 FFN 기여도는 초기 레이어 이후 크게 감소하는 패턴을 보인다. 이는 기존 모델의 U자형 패턴과 다르며, 후반 FFN 레이어를 VRAM에서 영구적으로 오프로드할 수 있음을 의미한다.
Flex-MemoryLLM 하이브리드 아키텍처
설계 동기
순수 MemoryLLM은 컨텍스트 프리 학습으로 인해 기존 밀집 모델 대비 성능 저하가 발생한다. Flex-MemoryLLM은 이 성능 격차를 해소하기 위한 하이브리드 아키텍처이다.
구조
FFN 파라미터를 두 가지 구성 요소로 분할한다.
| 구성 요소 | 역할 |
|---|---|
| FFN-C (Compute) | 잔차 흐름에서 동작하는 밀집 선형 모듈로 계산 능력 증대 |
| FFN-M (Memory) | 토큰 임베딩에서 학습하는 컨텍스트 프리 신경 메모리 |
LLaMA 스타일 모델(약 8h^2 FFN 파라미터, h=은닉 차원)에서 FFN-C에 βh^2, FFN-M에 (8-β)h^2를 할당한다. β=3 설정 시 5h^2 파라미터를 FFN 메모리로 오프로드하면서 경쟁력 있는 성능을 유지할 수 있다.
실험 결과
태스크별 FFN 기여도 분석
FFN의 중요도는 태스크 유형에 따라 크게 달라진다.
회상/검색 태스크 (Wikitext-2, LAMBDA, SiQA, ARC-Easy):
| 스케일링 계수 α | 성능 변화 |
|---|---|
| α=0.9 | +0.82% ~ +1.24% |
| α=0.5 | -44.59% ~ -69.93% |
논리/추론 태스크 (HellaSwag, Winogrande, BoolQ, PIQA):
| 스케일링 계수 α | 성능 변화 |
|---|---|
| α=0.9 | -0.76% ~ +0.68% |
| α=0.5 | -0.54% ~ -6.34% |
이 결과는 MemoryLLM의 FFN이 학습 데이터에서 직접 학습한 토큰 수준의 파라메트릭 지식 저장소로 작용하며, 검색 기반 태스크에서 지배적인 역할을 한다는 것을 보여준다.
성능 비교 (50B 토큰 학습)
| 모델 | 활성 파라미터 | 전체 파라미터 | C4 PPL | Wikitext-2 PPL |
|---|---|---|---|---|
| Base-750M | 737M | 737M | 19.730 | 25.491 |
| MemoryLLM-750M | 402M | 1208M | 20.933 | 27.258 |
| MemoryLLM-250M | 245M | 737M | 22.079 | 29.976 |
| Base-250M | 265M | 265M | 23.190 | 32.220 |
동일 전체 파라미터 기준에서는 기존 모델보다 성능이 낮지만, 활성 파라미터 기준으로는 밀집 모델을 크게 능가한다. ToL은 활성 파라미터에 포함되지 않으므로, 실제 추론 연산량 관점에서 효율적이다.
Flex-MemoryLLM 성능 (150B 토큰 학습)
1B 규모 모델 결과:
| 모델 | 활성 파라미터 | C4 PPL |
|---|---|---|
| Base-1B | 1208M | 약 10.2 |
| Flex-MemoryLLM-3h^2 | 704M | 약 10.5 |
| MemoryLLM | 402M | 약 11.8 |
Flex-MemoryLLM-3h^2는 704M 활성 파라미터로 기존 1208M 모델에 근접한 성능을 달성했다. 750M 규모에서는 기존 Base-737M 모델을 능가하는 성능을 보였다.
프루닝 기법과의 비교
MemoryLLM과 Flex-MemoryLLM 변형들은 동일한 활성 파라미터 수에서 기존 프루닝 기법(Magnitude, SparseGPT, Wanda)을 크게 능가했다. 이는 이 아키텍처들이 새로운 프루닝 기법 개발의 대안으로 활용될 수 있음을 시사한다.
추론 효율성 (1B 규모, 1xA100 GPU)
| 모델 | 메모리 (GB) | 디코딩 속도 (ms/token) |
|---|---|---|
| Base-Base | 9.541 | 21.50 |
| Flex-MemoryLLM-h^2 | 7.025 | 18.75 |
| Flex-MemoryLLM-2h^2 | 7.409 | 20.28 |
| Flex-MemoryLLM-3h^2 | 7.825 | 21.47 |
| MemoryLLM | 6.041 | 14.42 |
MemoryLLM은 사전 계산된 Lookup 덕분에 약 36% 메모리 절감과 약 33% 속도 향상을 달성했다.
저장소 압축과 레이어 중요도
양자화
MemoryLLM-1B(128,256 어휘, 24 레이어, 2048 은닉 차원) 기준 ToL은 F16 정밀도에서 약 12.6GB가 필요하다.
| 양자화 수준 | 저장 용량 | 성능 영향 |
|---|---|---|
| 8-bit | 6.3 GB | 무시할 수준 |
| 4-bit | 3.15 GB | 최소 수준의 저하 |
저랭크 압축
SVD 분해 분석 결과, ToL은 헤비테일 특이값 분포를 보여 효과적인 압축이 가능하다.
| 랭크 축소율 | 저장 용량 | 저장 절감률 | C4 PPL |
|---|---|---|---|
| 20% | 10.24 GB | 약 18.74% | 18.958 |
| 50% | 6.40 GB | 약 49.20% | 19.586 |
마지막 레이어는 중간 레이어보다 우수한 저랭크 속성을 보인다.
레이어별 중요도
초기 레이어의 ToL을 제거하면 성능이 크게 저하되지만, 중간 레이어의 ToL을 제거해도 성능이 안정적으로 유지된다. 이는 중간 레이어 ToL이 높은 중복성을 가지며 한계적 영향만 미친다는 것을 시사한다.
시사점 및 향후 방향
해석 가능성 달성
TKV 프레임워크는 어휘 토큰을 이산 메모리 위치에 직접 매핑함으로써 FFN을 키-값 메모리 저장소로 결정론적으로 분석할 수 있게 했다. 어휘적, 의미적으로 유사한 토큰의 지식이 유사한 메모리 위치에 인덱싱되는 것이 확인되었다.
실용적 활용 가능성
- 지식 편집 및 주입: 특정 키의 표적 변경을 통한 지식 수정
- 독성 억제: 유해 콘텐츠 관련 메모리 셀의 선택적 비활성화
- 리소스 제약 환경 배포: 플러그앤플레이 설계를 활용한 유연한 모델 크기 조정
핵심 발견 요약
- FFN은 검색 기반 태스크에서 지배적 역할을 하지만 추론 태스크에는 최소 기여
- 사전 계산된 ToL로 메모리 사용량과 계산 비용을 동시에 절감 가능
- Flex-MemoryLLM은 약 5h^2 파라미터 오프로드로 기존 모델 수준의 성능 달성
- 프루닝 기반 압축 접근법의 효과적인 대안으로 활용 가능