AIxiv專欄是機(jī)器之心發(fā)布學(xué)術(shù)、技術(shù)內(nèi)容的欄目。過去數(shù)年,機(jī)器之心AIxiv專欄接收?qǐng)?bào)道了2000多篇內(nèi)容,覆蓋全球各大高校與企業(yè)的頂級(jí)實(shí)驗(yàn)室,有效促進(jìn)了學(xué)術(shù)交流與傳播。如果您有優(yōu)秀的工作想要分享,歡迎投稿或者聯(lián)系報(bào)道。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
回顧 AGI 的爆發(fā),從最初的 pre-training (model/data) scaling,到 post-training (SFT/RLHF) scaling,再到 reasoning (RL) scaling,找到正確的 scaling 維度始終是問題的本質(zhì)。2017 年發(fā)布的 Transformer 架構(gòu)沿用至今,離不開 Transformer 強(qiáng)大的 “無損記憶” 能力,當(dāng)然也需要付出巨大的 KV 緩存代價(jià)。換句話說,Transformer 架構(gòu)具有強(qiáng)大的 memory scaling 能力。
DeepSeek NSA 通過三種方式壓縮 “KV” 實(shí)現(xiàn) sparse attention,但這只是一種可以工作但不優(yōu)雅的折中方案。因?yàn)樗趬嚎s Transformer 的記憶能力,以換取效率。
另一方面,大概從 2023 年火到今天的線性序列建模方法(包括 linear attention 類,Mamba 系列,RWKV 系列)則是另一個(gè)極端,只維護(hù)一份固定大小 dxd 的 RNN memory state,然后加 gate,改更新規(guī)則,但這種方式始終面臨較低的性能上限,所以才會(huì)有各種混合架構(gòu)的同樣可以工作但不優(yōu)雅的折中方案。
我們認(rèn)為,未來的模型架構(gòu)一定具有兩點(diǎn)特性:強(qiáng)大的 memory scaling 能力 + 關(guān)于序列長度的低復(fù)雜度。后者可以通過高效注意力機(jī)制實(shí)現(xiàn),比如:linear 或者 sparse attention,是實(shí)現(xiàn)長序列建模的必備性質(zhì)。而前者仍然是一個(gè)有待探索的重要課題,我們把給出的方案稱為 “sparse memory”。
這促使我們?cè)O(shè)計(jì)了 MoM: Mixture-of-Memories,它讓我們從目前主流線性序列建模方法改 gate 和 RNN 更新規(guī)則的套路中跳脫出來,稀疏且無限制地?cái)U(kuò)大 memory 大小。MoM 通過 router 分發(fā) token(靈感來自 MoE)維護(hù)多個(gè) KV memory,實(shí)現(xiàn) memory 維度 scaling。每個(gè) memory 又可以進(jìn)行 RNN-style 計(jì)算,所以整體具有關(guān)于序列長度線性的訓(xùn)練復(fù)雜度,推理又是常數(shù)級(jí)復(fù)雜度。此外,我們又設(shè)計(jì)了 shared memory 和 local memory 合作分別處理全局和局部信息。實(shí)驗(yàn)表現(xiàn)相當(dāng)驚艷,尤其是在目前 linear 類方法效果不好的 recall-intensive 任務(wù)上表現(xiàn)格外好,甚至在 1.3B 模型上已經(jīng)和 Transformer 架構(gòu)旗鼓相當(dāng)。
論文地址:https://arxiv.org/abs/2502.13685
代碼地址:https://github.com/OpenSparseLLMs/MoM
未來還會(huì)集成在:https://github.com/OpenSparseLLMs/Linear-MoE
模型權(quán)重開源在:https://huggingface.co/linear-moe-hub
方法細(xì)節(jié)
Linear Recurrent Memory
對(duì)于這部分內(nèi)容,熟悉線性序列建模的小伙伴可以跳過了。
(各種方法本身有不同的符號(hào),像 Mamba, HGRN 就不用 q k v,這里為了統(tǒng)一對(duì)比全部對(duì)標(biāo)到 linear attention 形式。其中Titans的形式,把 memory update rule 看作 optimizer update 的話,最核心的還是 SGD 形式,暫時(shí)忽略momentum/weight decay ,只一個(gè)公式表達(dá)的話寫成這種梯度更新的形式是合理的。)
其實(shí)這些方法又可以進(jìn)一步細(xì)分為不同類別(很多地方都粗略的統(tǒng)一稱為 linear RNN 或者 RNN),這里論文暫時(shí)沒提:
Linear Attention, Lightning Attention, RetNet, GLA, DeltaNet, Gated DeltaNet 屬于 linear attention 類
Mamba2 屬于 SSM 類,HGRN2 屬于 linear RNN 類
TTT, Titans 屬于 Test-Time Training 類
Mixture-of-Memories
MoM 思路非常簡單,和 MoE 一樣按照 token 分發(fā),通過 router 為每個(gè) token 選擇 topk 的 memories 并計(jì)算各自權(quán)重:
所有激活的 topk memories 按照各自權(quán)重加權(quán)求和得到一份混合記憶:
然后就又回到了 linear 類方法一貫的輸出計(jì)算:
另外,這里我們額外引入了 shared memory 的概念,即每個(gè) token 都會(huì)經(jīng)過這個(gè)永遠(yuǎn)激活的 memory,有利于模型獲取全局信息。相對(duì)而言,其他稀疏激活的 memory 更擅長獲取局部信息。消融實(shí)驗(yàn)表明,shared memory 的存在對(duì)模型效果有明確的積極作用。
硬件高效實(shí)現(xiàn)
MoM的硬件高效Triton算子可以很方便地實(shí)現(xiàn),其輸出的計(jì)算可以簡單寫作:
也就是說 MoM 中每個(gè) memory 的計(jì)算過程可以復(fù)用現(xiàn)有的單個(gè)算子,再把所有 memory 的輸出加權(quán)求和起來。和直接在算子內(nèi)先求和再算輸出是數(shù)學(xué)等價(jià)的。
實(shí)驗(yàn)結(jié)果
in-context recall-instensive tasks
一直以來,線性序列建模方法因?yàn)樽陨矸浅S邢薜?memory 大小,在這類 in-context recall-intensive 任務(wù)上表現(xiàn)不好。同時(shí) Transformer 模型得益于其強(qiáng)大的無損記憶能力,非常擅長這類任務(wù)。所以已經(jīng)出現(xiàn)了各種層間 hybrid 的模型,來提升 linear 類模型在這類任務(wù)上的效果。
我們首先重點(diǎn)測(cè)試了這類任務(wù)(結(jié)果見下表),使用 Gated DeltaNet 作為 MoM 的 memory 計(jì)算形式(在 Memory 更新過程中,每個(gè) memory 都使用 Gated DeltaNet 的 gate 和更新規(guī)則),總共 4 個(gè) local sparse memory,激活 2 個(gè),還有一個(gè) shared memory。其中標(biāo) 的模型來自開源項(xiàng)目(https://huggingface.co/fla-hub),沒標(biāo) 的是我們從頭預(yù)訓(xùn)練的模型。
結(jié)果還算相當(dāng)不錯(cuò),在沒有數(shù)據(jù)污染或任何套路的情況下,結(jié)果顯示 MoM 就是單純地效果好。這也和預(yù)期一致,翻倍擴(kuò)展 memory 大小,效果好過其他 linear 類方法。有一些意外的是,在 1.3B 的結(jié)果里,MoM 基本可以和 Transformer 相媲美。
其他評(píng)測(cè)效果
其他評(píng)測(cè)結(jié)果效果也不錯(cuò):
推理效率
推理效率是線性序列建模方法的重點(diǎn),結(jié)果顯示 MoM 在常數(shù)級(jí)復(fù)雜度推理速度和顯存占用方面,表現(xiàn)出強(qiáng)大的優(yōu)勢(shì)。
消融實(shí)驗(yàn)
Loss 曲線