Transformer的細(xì)節(jié)到底是怎么樣的?Transformer 18問(wèn)?。?)
作者丨愛(ài)問(wèn)問(wèn)題的王宸@知乎
為什么想通過(guò)十八個(gè)問(wèn)題的方式總結(jié)Transformer?
有兩點(diǎn)原因:
第一,Transformer是既MLP、RNN、CNN之后的第四大特征提取器,也被稱為第四大基礎(chǔ)模型;最近爆火的chatGPT,其最底層原理也是Transformer,Transformer的重要性可見(jiàn)一斑。
第二,希望通過(guò)問(wèn)問(wèn)題這種方式,更好的幫助大家理解Transformer的內(nèi)容和原理。
1.2017年深度學(xué)習(xí)領(lǐng)域的重大突破是什么?Transformer。有兩方面的原因:
1.1 一方面,Transformer是深度學(xué)習(xí)領(lǐng)域既MLP、RNN、CNN之后的第4大特征提取器(也被稱為基礎(chǔ)模型)。什么是特征提取器?大腦是人與外部世界(圖像、文字、語(yǔ)音等)交互的方式;特征提取器是計(jì)算機(jī)為了模仿大腦,與外部世界(圖像、文字、語(yǔ)音等)交互的方式,如圖1所示。舉例而言:Imagenet數(shù)據(jù)集中包含1000類(lèi)圖像,人們已經(jīng)根據(jù)自己的經(jīng)驗(yàn)把這一百萬(wàn)張圖像分好1000類(lèi),每一類(lèi)圖像(如美洲豹)都有獨(dú)特的特征。這時(shí),神經(jīng)網(wǎng)絡(luò)(如ResNet18)也是想通過(guò)這種分類(lèi)的方式,把每一類(lèi)圖像的特有特征盡可能提取或識(shí)別出來(lái)。分類(lèi)不是最終目的,而是一種提取圖像特征的手段,掩碼補(bǔ)全圖像也是一種提取特征的方式,圖像塊順序打亂也是一種提取特征的方式。
圖1 神經(jīng)網(wǎng)絡(luò)為了模仿大腦中的神經(jīng)元1.2 另一方面,Transformer在深度學(xué)習(xí)領(lǐng)域扮演的角色:第3次和第4次熱潮的基石,如下圖2所示。
圖2 深度學(xué)習(xí)發(fā)展的4個(gè)階段
2. Transformer的提出背景是什么?
2.1 在領(lǐng)域發(fā)展背景層面:當(dāng)時(shí)時(shí)處2017年,深度學(xué)習(xí)在計(jì)算機(jī)視覺(jué)領(lǐng)域火了已經(jīng)幾年。從Alexnet、VGG、GoogLenet、ResNet、DenseNet;從圖像分類(lèi)、目標(biāo)檢測(cè)再到語(yǔ)義分割;但在自然語(yǔ)言處理領(lǐng)域并沒(méi)有引起很大反響。
2.2 技術(shù)背景層面:(1)當(dāng)時(shí)主流的序列轉(zhuǎn)錄任務(wù)(如機(jī)器翻譯)的解決方案如下圖3所示,在Sequence to Sequence架構(gòu)下(Encoder- Decoder的一種),RNN來(lái)提取特征,Attention機(jī)制將Encoder提取到的特征高效傳遞給Decoder。(2)這種做法有兩個(gè)不足之處,一方面是在提取特征時(shí)的RNN天生從前向后時(shí)序傳遞的結(jié)構(gòu)決定了其無(wú)法并行運(yùn)算,其次是當(dāng)序列長(zhǎng)度過(guò)長(zhǎng)時(shí),最前面序列的信息有可能被遺忘掉。因此可以看到,在這個(gè)框架下,RNN是相對(duì)薄弱急需改進(jìn)的地方。
圖3 序列轉(zhuǎn)錄任務(wù)的主流解決方案3. Transformer到底是什么?
3.1 Transformer是一種由Encoder和Decoder組成的架構(gòu)。那么什么是架構(gòu)呢?最簡(jiǎn)單的架構(gòu)就是A+B+C。
3.2 Transformer也可以理解為一個(gè)函數(shù),輸入是“我愛(ài)學(xué)習(xí)”,輸出是“I love study”。
3.3 如果把Transformer的架構(gòu)進(jìn)行分拆,如圖4所示。
圖4 Transformer的架構(gòu)圖4. 什么是Transformer Encoder?
4.1 從功能角度,Transformer Encoder的核心作用是提取特征,也有使用Transformer Decoder來(lái)提取特征。例如,一個(gè)人學(xué)習(xí)跳舞,Encoder是看別人是如何跳舞的,Decoder是將學(xué)習(xí)到的經(jīng)驗(yàn)和記憶,展現(xiàn)出來(lái)
4.2 從結(jié)構(gòu)角度,如圖5所示,Transformer Encoder = Embedding + Positional Embedding + N*(子Encoder block1 + 子Encoder block2);
子Encoder block1 = Multi head attention + ADD + Norm;
子Encoder block2 = Feed Forward + ADD + Norm;
4.3 從輸入輸出角度,N個(gè)Transformer Encoder block中的第一個(gè)Encoder block的輸入為一組向量 X = (Embedding + Positional Embedding),向量維度通常為512*512,其他N個(gè)TransformerEncoder block的輸入為上一個(gè) Transformer Encoder block的輸出,輸出向量的維度也為512*512(輸入輸出大小相同)。
4.4 為什么是512*512?前者是指token的個(gè)數(shù),如“我愛(ài)學(xué)習(xí)”是4個(gè)token,這里設(shè)置為512是為了囊括不同的序列長(zhǎng)度,不夠時(shí)padding。后者是指每一個(gè)token生成的向量維度,也就是每一個(gè)token使用一個(gè)序列長(zhǎng)度為512的向量表示。人們常說(shuō),Transformer不能超過(guò)512,否則硬件很難支撐;其實(shí)512是指前者,也就是token的個(gè)數(shù),因?yàn)槊恳粋€(gè)token要做self attention操作;但是后者的512不宜過(guò)大,否則計(jì)算起來(lái)也很慢。
圖5 Transformer Encoder的架構(gòu)圖5. 什么是Transformer Decoder?
5.1 從功能角度,相比于Transformer Encoder,Transformer Decoder更擅長(zhǎng)做生成式任務(wù),尤其對(duì)于自然語(yǔ)言處理問(wèn)題。
5.2 從結(jié)構(gòu)角度,如圖6所示,Transformer Decoder = Embedding + Positional Embedding + N*(子Decoder block1 + 子Decoder block2 + 子Decoder block3)+ Linear + Softmax;
子Decoder block1 = Mask Multi head attention + ADD + Norm;子Decoder block2 = Multi head attention + ADD + Norm;子Decoder block3 = Feed Forward + ADD + Norm;圖6 Transformer Decoder的架構(gòu)圖
5.3 從(Embedding+Positional Embedding)(N個(gè)Decoder block)(Linear + softmax) 這三個(gè)每一個(gè)單獨(dú)作用角度:
Embedding + Positional Embedding :以機(jī)器翻譯為例,輸入“Machine Learning”,輸出“機(jī)器學(xué)習(xí)”;這里的Embedding是把“機(jī)器學(xué)習(xí)”也轉(zhuǎn)化成向量的形式。
N個(gè)Decoder block:特征處理和傳遞過(guò)程。
Linear + softmax:softmax是預(yù)測(cè)下一個(gè)詞出現(xiàn)的概率,如圖7所示,前面的Linear層類(lèi)似于分類(lèi)網(wǎng)絡(luò)(ResNet18)最后分類(lèi)層前接的MLP層。
圖7 Transformer Decoder 中softmax的作用5.4 Transformer Decoder的輸入、輸出是什么?在Train和Test時(shí)是不同的。在Train階段,如圖8所示。這時(shí)是知道label的,decoder的第一個(gè)輸入是begin字符,輸出第一個(gè)向量與label中第一個(gè)字符使用cross entropy loss。Decoder的第二個(gè)輸入是第一個(gè)向量的label,Decoder的第N個(gè)輸入對(duì)應(yīng)的輸出是End字符,到此結(jié)束。這里也可以看到,在Train階段是可以進(jìn)行并行訓(xùn)練的。
圖8 Transformer Decoder在訓(xùn)練階段的輸入和輸出
在Test階段,下一個(gè)時(shí)刻的輸入時(shí)是前一個(gè)時(shí)刻的輸出,如圖9所示。因此,Train和Test時(shí)候,Decoder的輸入會(huì)出現(xiàn)Mismatch,在Test時(shí)候確實(shí)有可能會(huì)出現(xiàn)一步錯(cuò),步步錯(cuò)的情況。有兩種解決方案:一種是train時(shí)偶爾給一些錯(cuò)誤,另一種是Scheduled sampling。
圖9 Transformer Decoder在Test階段的輸入和輸出
5.5 Transformer Decoder block內(nèi)部的輸出和輸出是什么?
前面提到的是在整體train和test階段,Decoder的輸出和輸出,那么Transformer Decoder內(nèi)部的Transformer Decoder block,如圖10所示,的輸入輸出又是什么呢?
圖10 Transformer Decoder block的架構(gòu)圖
對(duì)于N=6中的第1次循環(huán)(N=1時(shí)):子Decoder block1 的輸入是 embedding +Positional Embedding,子Decoder block2 的輸入的Q來(lái)自子Decoder block1的輸出,KV來(lái)自Transformer Encoder最后一層的輸出。
對(duì)于N=6的第2次循環(huán):子Decoder block1的輸入是N=1時(shí),子Decoder block3的輸出,KV同樣來(lái)自Transformer Encoder的最后一層的輸出。
總的來(lái)說(shuō),可以看到,無(wú)論在Train還是Test時(shí),Transformer Decoder的輸入不僅來(lái)自(ground truth或者上一個(gè)時(shí)刻Decoder的輸出),還來(lái)自Transformer Encoder的最后一層。
訓(xùn)練時(shí):第i個(gè)decoder的輸入 = encoder輸出 + ground truth embedding。
預(yù)測(cè)時(shí):第i個(gè)decoder的輸入 = encoder輸出 + 第(i-1)個(gè)decoder輸出.
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。