在Transformer時(shí)代重塑RNN,RWKV將非Transformer架構(gòu)擴(kuò)展到數(shù)百億參數(shù)(2)
架構(gòu)細(xì)節(jié)
RWKV 架構(gòu)由一系列堆疊的殘差塊組成,每個(gè)殘差塊又由具有循環(huán)結(jié)構(gòu)的時(shí)間混合和通道混合子塊組成。
循環(huán)被表示為當(dāng)前輸入和前一個(gè)時(shí)間步的輸入之間的線性插值(研究者稱這種技術(shù)為時(shí)移混合或 token shift,如下圖 3 所示),該插值可以針對(duì)輸入嵌入的每個(gè)線性投影進(jìn)行獨(dú)立調(diào)整(比如時(shí)間混合中的 R、K 和 V,通道混合中的 R 和 K),并作為公式 14 中形式化的 WKV 的時(shí)變更新。
類(lèi) Transformer 的并行化
RWKV 可以在時(shí)間并行模式下進(jìn)行高效地并行化,讓人聯(lián)想到 Transformer。單個(gè)層中一個(gè) batch 序列的時(shí)間復(fù)雜度為 O (BTd^2 ),它主要由矩陣乘法 W_□, □ ∈ {r, k, v, o}(假設(shè) B 個(gè)序列、T 個(gè)最大 token 和 d 個(gè)通道)。同時(shí)更新注意力分?jǐn)?shù) wkv_t 需要串行掃描,并且復(fù)雜度為 O (BTd)。
類(lèi) RNN 的序列解碼
在循環(huán)網(wǎng)絡(luò)中,將狀態(tài) t 時(shí)的輸出用作狀態(tài) t+1 時(shí)的輸入很常見(jiàn)。這在語(yǔ)言模型的自回歸解碼推理中尤為明顯,要求每一個(gè) token 在饋入下一步之前必須進(jìn)行計(jì)算,從而使 RWKV 可以利用類(lèi) RNN 結(jié)構(gòu)(即時(shí)序模式)。在這種情況下,RWKV 可以方便地循環(huán)用于推理解碼,從而利用每個(gè)輸出 token 僅依賴于最新?tīng)顟B(tài)的優(yōu)勢(shì)。
然后 RWKV 充當(dāng) RNN ****,在序列長(zhǎng)度方面保持恒定速度和內(nèi)存占用,從而更高效地處理更長(zhǎng)的序列。相比之下,自注意力通常需要 KV 緩存相對(duì)于序列長(zhǎng)度呈線性增長(zhǎng),這會(huì)導(dǎo)致效率下降,并隨序列長(zhǎng)度增加消耗更多內(nèi)存和時(shí)間。
軟件實(shí)現(xiàn)
RWKV 最初使用 PyTorch 深度學(xué)習(xí)庫(kù)和自定義 CUDA 內(nèi)核(它用于 WKV 計(jì)算)來(lái)實(shí)現(xiàn)。盡管 RWKV 是一個(gè)通用循環(huán)網(wǎng)絡(luò),但其當(dāng)前的實(shí)現(xiàn)主要集中在語(yǔ)言建模任務(wù)(RWKV-LM)。該模型架構(gòu)包含了一個(gè)嵌入層,為此研究者遵循第 4.7 節(jié)中的設(shè)置,并按照第 4.6 節(jié)中的原則依次應(yīng)用幾個(gè)相同的殘差塊,具體如上圖 2 和 3 所示。
梯度穩(wěn)定性和層堆疊
RWKV 架構(gòu)被設(shè)計(jì)為 Transformer 和 RNN 的融合,與傳統(tǒng)的 RNN 相比,Transformers 具有穩(wěn)定梯度和更深層次架構(gòu)的優(yōu)勢(shì),同時(shí)推理效率高。
RWKV 模型具有用于更新類(lèi)似注意力分?jǐn)?shù)的單步過(guò)程,其中包括一個(gè)依賴于時(shí)間的 softmax 操作,該操作有助于數(shù)值穩(wěn)定性并防止梯度消失(有關(guān)嚴(yán)格證明,請(qǐng)參見(jiàn)附錄 F)。直觀地說(shuō),此操作可確保梯度沿最相關(guān)的路徑傳播。Layer normalization (Ba et al., 2016) 是架構(gòu)的另一個(gè)關(guān)鍵方面,它通過(guò)穩(wěn)定梯度、解決梯度消失和爆炸問(wèn)題來(lái)增強(qiáng)深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練動(dòng)態(tài)。
利用時(shí)間結(jié)構(gòu)進(jìn)行時(shí)序數(shù)據(jù)處理
RWKV 通過(guò)三種機(jī)制的組合來(lái)捕獲和傳播時(shí)序信息:循環(huán)、時(shí)間衰減和 token shift。
RWKV 時(shí)間混合塊中的循環(huán)是模型捕獲序列元素之間復(fù)雜關(guān)系和隨時(shí)間傳播局部信息的能力的基礎(chǔ)。
時(shí)間衰減機(jī)制(等式 14 中的 e^?w 和 e^u)保持了對(duì)序列元素之間位置關(guān)系的敏感性。通過(guò)逐漸減少以往信息隨時(shí)間的影響,該模型保留了時(shí)間局部性和進(jìn)展感,這對(duì)于時(shí)序處理至關(guān)重要。
token shift 或 time-shift 混合或(圖 3 中的對(duì)角線箭頭),也有助于模型適應(yīng)時(shí)序數(shù)據(jù)。通過(guò)在當(dāng)前輸入和前一個(gè)時(shí)間步輸入之間進(jìn)行線性插值,模型自然地聚合和門(mén)控輸入通道中的信息。
實(shí)驗(yàn)結(jié)果
實(shí)驗(yàn)的重點(diǎn)是回答以下問(wèn)題:
RQ1:在參數(shù)數(shù)量和訓(xùn)練 token 數(shù)量相等的情況下,RWKV 與二次 transformer 架構(gòu)相比具有競(jìng)爭(zhēng)力嗎?
RQ2:增加參數(shù)數(shù)量時(shí),RWKV 是否仍然具有與二次 transformer 架構(gòu)相競(jìng)爭(zhēng)的能力?
RQ3:當(dāng) RWKV 模型被訓(xùn)練用于開(kāi)源二次 transformer 無(wú)法高效處理的上下文長(zhǎng)度時(shí),增加 RWKV 的參數(shù)是否能夠獲得更好的語(yǔ)言建模損失?
首先是回答 RQ1 和 RQ2 問(wèn)題,從圖 4 可以看出,在六個(gè)基準(zhǔn)測(cè)試中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 與開(kāi)源二次復(fù)雜度 transformer 模型 Pythia、OPT 和 BLOOM 具有相當(dāng)?shù)母?jìng)爭(zhēng)力。RWKV 甚至在四個(gè)任務(wù)(PIQA、OBQA、ARC-E 和 COPA)中勝過(guò)了 Pythia 和 GPT-Neo。
對(duì)于 RQ3,圖 5 顯示,增加上下文長(zhǎng)度會(huì)導(dǎo)致 Pile 上的測(cè)試損失降低,這表明 RWKV 能夠有效利用較長(zhǎng)的上下文信息。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。