熟悉LLM的重要機(jī)制
1 前言
在前一期里,介紹過大語言模型(LLM)幕后核心的注意力(Attention)機(jī)制。本期就來繼續(xù)擴(kuò)大,介紹大名鼎鼎的轉(zhuǎn)換器(Transformer)模型。其中,要特別闡述:為什么當(dāng)今主流LLM都采用<僅譯碼器>(Decoder-Only Transformer)模型。
在 2017 年, 首先在「Attention is All You Need」這篇論文中提出了經(jīng)典的Transformer架構(gòu),它內(nèi)含編碼器(Encoder)和譯碼器(Decoder)兩部分。后來,自從GPT-2之后,整個大趨勢逐漸轉(zhuǎn)向<僅譯碼器>(Decoder-Only Transformer)模型。隨著LLM的流行,在其應(yīng)用于下游任務(wù)的微調(diào)(Fine-tuning) 過程中,經(jīng)常會看到<僅譯碼器>的身影,所以更深刻領(lǐng)會這個大趨勢,也更熟悉這個重要機(jī)制,對于LLM 的各項應(yīng)用是很有幫助的。
2 從典型的Transformer架構(gòu)說起
最初在 2017 年的論文「Attention is All You Need」中介紹,Transformer具有譯碼器(Decoder)和編碼器(Encoder)部分,如圖1。
圖1
這圖里的左半部分是編譯程序(Encoder)部分,而右半部分是解釋器(Decoder)部分。
剛才提到了,當(dāng)今主流的GPT模型,由于其在文字生成方面的出色性能,LLM大趨勢是轉(zhuǎn)向<僅譯碼器>(Decoder-Only Transformer)模型。這種模型的強(qiáng)大之處在于它們不僅能夠模仿類似人類的文本,而且還能夠創(chuàng)造性地做出回應(yīng)。他們可以寫故事、回答問題,甚至進(jìn)行自然流暢的對話。這種功能使它們在廣泛的應(yīng)用中非常有用,從聊天機(jī)器人和數(shù)字助理到內(nèi)容創(chuàng)建、抽象總結(jié)和講故事。
如今,這僅譯碼器架構(gòu)已經(jīng)是GPT-3、ChatGPT、GPT-4、PaLM、LaMDa和Falcon等主要語言模型的核心了。它專注于生成新內(nèi)容,而不是解釋或分析所輸入的內(nèi)容( 如文本)。例如,除了能夠解析人類文本的涵意,還能夠創(chuàng)造性地寫故事、回答問題、以及流暢對話等。這種功能對于這些大語言模型(LLM)在廣泛應(yīng)用的效益很大。為什么當(dāng)今的LLM,主要都采用僅譯碼器架構(gòu)呢?其原因,除了可提升訓(xùn)練效率之外,還有在生成任務(wù)上,引入編譯程序的雙向注意力機(jī)制并無太多幫助。于是,大趨勢就逐漸轉(zhuǎn)向< 僅譯碼器> 模型。
圖2
https://cameronrwolfe.substack.com/p/decoder-only-transformers-the-workhorse
3 認(rèn)識Decoder-Only Transformer架構(gòu)
剛才已經(jīng)看到了,上圖-1里的右半部分是解釋器。所以僅譯碼器模型的核心架構(gòu)反而必較簡單。它通常是是由多個具有相同結(jié)構(gòu)的區(qū)塊按順序堆棧而成的。在每個譯碼器區(qū)塊(Decoder Block)都包含兩個主要組件:
1)屏蔽多頭自注意力(Masked, multi-headed attention)層。這是在前一期里介紹過的注意力機(jī)制,它在理解輸入序列的涵意方面扮演非常關(guān)鍵的角色。其中的< 屏蔽>的用意是要確保其預(yù)測標(biāo)記(Token) 時,僅專注于當(dāng)前位置之前產(chǎn)生的標(biāo)記,而不考慮之后的標(biāo)記,這樣可發(fā)揮自回歸(Autoregressive)模型的特性和功能。也就是,它很擅長于從句子中前面的單字收集信息,來預(yù)測出下一個詞的機(jī)率估計值( 例如文字接龍)。這是生成式AI的魅力源頭。
2)逐位置的前饋網(wǎng)絡(luò)(Position-wise feed-forward network) 層。這用來引入非線性的變換,協(xié)助注意力機(jī)制( 純線性數(shù)據(jù)變換) 來捕捉更復(fù)雜的模式和關(guān)系。例如,使用非線性活化函數(shù)( 如ReLU 或GeLU 等) 來讓模型能夠逼近任何函數(shù),以便提供更強(qiáng)的表達(dá)能力。它常常放在自注意力層之后,并且添加了有關(guān)序列中每個標(biāo)記的位置的信息,這對于理解單字的順序至關(guān)重要。其用意是在每個序列的位置單獨應(yīng)用一個全連接前饋網(wǎng)絡(luò),所以稱為:逐位置(Position-wise)方式。
其中,這個模型背后的真正驅(qū)動力,即是上述的多頭自注意力機(jī)制。這種機(jī)制允許模型在預(yù)測每個標(biāo)記時專注于輸入序列的不同部分,從而促進(jìn)上下文相關(guān)文字的生成。
4 以Gemma的<僅譯碼器>模型為例
剛才已經(jīng)介紹了,在每個譯碼器區(qū)塊里,都包含兩個主要組件:自注意力機(jī)制(Attention) 和前饋網(wǎng)絡(luò)(FFN)。其結(jié)構(gòu)如圖2 所示。
其中的歸一化層(Layer Normalization),用來對網(wǎng)絡(luò)中的每個神經(jīng)元的輸出進(jìn)行歸一化,使得網(wǎng)絡(luò)中每一層的輸出都具有相似的分布。該技術(shù)應(yīng)用於每個子層(自註意力和前饋)之後,可以標(biāo)準(zhǔn)化激活並穩(wěn)定訓(xùn)練過程。
此外,還常添加殘差連接(Residual Connections) 層,它提供了一條捷徑,允許梯度在反向傳播過程中自由流過網(wǎng)路,可緩解梯度消失和爆炸的問題,并有助于學(xué)習(xí)更強(qiáng)大的特征,以及提高訓(xùn)練過程的整體易用性和穩(wěn)定性。
# Gemma 的譯碼器區(qū)塊的結(jié)構(gòu)
class RMSNorm(torch.nn.Module):
def __init__(….. ):
#...............................
#...............................
class GemmaMLP(nn.Module):
def __init__( ….. ):
#...............................
#...............................
class GemmaAttention(nn.Module):
def __init__( ….. ):
#...............................
#...............................
class GemmaDecoderLayer(nn.Module):
def __init__( …..):
#...............................
self.self_attn = GemmaAttention(…..)
self.mlp = GemmaMLP(…..)
self.input_layernorm = RMSNorm(…..)
s e l f . p o s t _ a t t e n t ion_ l ayernorm =
RMSNorm(…..)
#...............................
#...............................
基于上述的譯碼器區(qū)塊(Decoder Block),就可以依據(jù)需求的不同,而有多個彼此堆棧的譯碼器區(qū)塊,如圖3所示。
圖3
https://cameronrwolfe.substack.com/p/decoder-only-transformers-the-workhorse
這樣子,就將多個區(qū)塊堆棧起來,成為一個僅譯碼器模型了。各層區(qū)塊都各負(fù)責(zé)從輸入數(shù)據(jù)中提取越來越抽象的特征。例如,前幾層可能捕捉到局部的相關(guān)性,而更深的層可以捕捉更復(fù)雜的遠(yuǎn)程依賴性。這種分層方法讓模型建立輸入數(shù)據(jù)的豐富表示。
其中,各區(qū)塊里的注意機(jī)制,讓模型在產(chǎn)生輸出時專注于輸入序列的各個不同部分。然后,將多層的自注意力協(xié)同合作,使模型能夠逐步細(xì)化這些注意力權(quán)重,來提升對輸入有更細(xì)致和上下文感知更廣闊的理解。于是,每個區(qū)塊在輸入數(shù)據(jù)的分層處理中都發(fā)揮其重要的作用,并協(xié)同合作,來提升模型掌握復(fù)雜依賴關(guān)系的能力,來捕捉各種復(fù)雜的模式(Pattern)。
例如,在Gemma 的源碼中,將譯碼器區(qū)塊( 名為:GemmaDecoderLayer) 疊加起來,成為僅譯碼器模型( 名為:GemmaModel),如下:
# Gemma 的僅譯碼器模型結(jié)構(gòu)
#...............................
#...............................
class GemmaDecoderLayer(nn.Module):
def __init__( …..):
#...............................
self.self_attn = GemmaAttention(…..)
self.mlp = GemmaMLP(…..)
self.input_layernorm = RMSNorm(…..)
s e l f . p o s t _ a t t e n t ion_ l ayernorm =
RMSNorm(…..)
#...............................
#...............................
class GemmaModel(nn.Module):
def __init__(self, config):
#...............................
self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append( GemmaDecoderLayer(…..) )
#...............................
#...............................
從這Gemma的源碼結(jié)構(gòu)而觀之,盡管這些年來,LLM取得了快速的進(jìn)步,但是其中的核心組成部分仍然保持不變,就是:僅譯碼器架構(gòu)。
5 結(jié)束語
在僅譯碼器模型中,擁有多層互相迭加的譯碼器區(qū)塊,將輸入序列進(jìn)行分層,透過多層處理來提升從輸入序列中學(xué)習(xí)和泛化的能力。這種深度讓模型能處理復(fù)雜的依賴關(guān)系,然后生成高質(zhì)量的作品,即是LLM 魅力的源頭。
(本文來源于《EEPW》202406)
評論