綜述:如何給模型加入先驗(yàn)知識
作者丨Billy Z@知乎(已授權(quán))
來源丨h(huán)ttps://zhuanlan.zhihu.com/p/188572028
編輯丨極市平臺
導(dǎo)讀
端到端的深度神經(jīng)網(wǎng)絡(luò)雖然能夠自動學(xué)習(xí)到一些可區(qū)分度好的特征,但是往往會擬合到一些非重要特征,導(dǎo)致模型會局部坍塌到一些不好的特征上面。本文通過一個簡單的鳥類分類案例來總結(jié)了五個給模型加入先驗(yàn)信息的方法。
模型加入先驗(yàn)知識的必要性
端到端的深度神經(jīng)網(wǎng)絡(luò)是個黑盒子,雖然能夠自動學(xué)習(xí)到一些可區(qū)分度好的特征,但是往往會擬合到一些非重要特征,導(dǎo)致模型會局部坍塌到一些不好的特征上面。常常一些人們想讓模型去學(xué)習(xí)的特征模型反而沒有學(xué)習(xí)到。
為了解決這個問題,給模型加入人為設(shè)計的先驗(yàn)信息會讓模型學(xué)習(xí)到一些關(guān)鍵的特征。下面就從幾個方面來談?wù)勅绾谓o模型加入先驗(yàn)信息。
為了方便展示,我這邊用一個簡單的分類案例來展示如何把先驗(yàn)知識加入到一個具體的task中。我們的task是在所有的鳥類中識別出一種萌萌的鸚鵡,這中鸚鵡叫鸮(xiāo)鸚鵡,它長成下面的樣子:
鸮(xiāo)鸚鵡
這種鳥有個特點(diǎn):
就是它可能出現(xiàn)在任何地方,但就是不可能在天上,因?yàn)樗鞘澜缟衔ㄒ灰环N不會飛的鸚鵡(不是唯一一種不會飛的鳥)。
好,介紹完task的背景,咱們就可以分分鐘搭建一個端到端的分類神經(jīng)網(wǎng)絡(luò),可以選擇的網(wǎng)絡(luò)結(jié)構(gòu)可以有很多,如resnet, mobilenet等等,loss往往是一個常用的分類Loss,如交叉熵,高級一點(diǎn)的用個focal loss等等。確定好了最優(yōu)的數(shù)據(jù)(擾動方式),網(wǎng)絡(luò)結(jié)構(gòu),優(yōu)化器,學(xué)習(xí)率等等這些之后,往往模型的精度也就達(dá)到了一個上限。
然后你測試模型發(fā)現(xiàn),有些困難樣本始終分不開,或者是一些簡單的樣本也容易分錯。這個時候如果你還想提升網(wǎng)絡(luò)的精度,可以通過給模型加入先驗(yàn)的方式來進(jìn)一步提升模型的精度。
基于pretrain模型給模型加入先驗(yàn)
給模型加入先驗(yàn),大家最容易想到的是把網(wǎng)絡(luò)的weight替換成一個在另外一個任務(wù)上pretrain好的模型weight。經(jīng)過的預(yù)訓(xùn)練的模型(如ImageNet預(yù)訓(xùn)練)往往已經(jīng)具備的識別到一些基本的圖片pattern的能力,如邊緣,紋理,顏色等等,而識別這些信息的能力是識別一副圖片的基礎(chǔ)。如下圖所示:
但這些先驗(yàn)信息都是一些比較general的信息,我們是否可以加入一些更加high level的先驗(yàn)信息呢。
基于輸入給模型加入先驗(yàn)
假如你有這樣的一個先驗(yàn):
你覺得鸮鸚鵡的頭是一個區(qū)別其他它和鳥類的重要部分,也就是說相比于身體,它的頭部更能區(qū)分它和其他鳥類。
這時怎么讓網(wǎng)絡(luò)更加關(guān)注鸮鸚鵡的頭部呢。這時你可以這樣做,把整個鸮鸚鵡和它的頭部作為一個網(wǎng)絡(luò)的兩路輸入,在網(wǎng)咯的后端再把兩路輸入的信息融合。以達(dá)到既關(guān)注局域,又關(guān)注整體的目的。一個簡單的示意圖如下所示。
基于模型重現(xiàn)給模型加入先驗(yàn)
接著上面的設(shè)定來,假如說你覺得給模型兩路輸入太麻煩,而且增加的計算量讓你感覺很不爽。
這時,你可以嘗試讓模型自己發(fā)現(xiàn)你設(shè)定的先驗(yàn)知識。
假如說你的模型可以自己輸出鳥類頭部的位置,雖然這個鳥類頭部的位置信息是你不需要的,但是輸出這樣的信息代表著你的網(wǎng)絡(luò)能夠locate鳥類頭部的位置,也就給鳥類的頭部更加多的attention,也就相當(dāng)于給把鳥類頭部這個先驗(yàn)信息給加上去了。
當(dāng)然直接模仿detection那樣去回歸出位置來這個任務(wù)太heavy了,你可以通過一個生成網(wǎng)絡(luò)的支路來生成一個鳥類頭部位置的Mask,一個簡單的示意圖如下:
測試的時候不增加計算量
基于CAM圖激活限制給模型加入先驗(yàn)
針對鸮鸚鵡的分類,我在上面的提到一個非常有意思的先驗(yàn)信息:
那就是鸮鸚鵡是世界上唯一一種不會飛的鸚鵡。
這個信息從側(cè)面來說就是,鸮鸚鵡所有地方都可能出現(xiàn),就是不可能出現(xiàn)在天空中(當(dāng)然也不可能出現(xiàn)在水中)。
也就是說不但鸮鸚鵡本身是一個分類的重點(diǎn),鸮鸚鵡出現(xiàn)的背景也是分類的一個重要參考。假如說背景是天空,那么就一定不是鸮鸚鵡,同樣的,假如說背景是海水,那么也一定不是鸮鸚鵡,假如說背景是北極,那么也一定不是鸮鸚鵡,等等。
也就是說,你不能通過背景來判斷一只未知的鳥是鸮鸚鵡,但是你能通過背景來判斷一只未知的鳥肯定不是鸮鸚鵡(是其他的鳥類)。
所以假如說獲取了一張輸入圖片的激活圖(包含背景的),那么這張激活圖的鳥類身體部分肯定包含了鸮鸚鵡和其他鳥類的激活,但是鳥類身體外的背景部分只可能包含其他鳥類的激活。
所以具體的做法是基于激活圖,通過限制激活圖的激活區(qū)域,加入目標(biāo)先驗(yàn)。
CAM[1]激活圖是基于分類網(wǎng)絡(luò)的倒數(shù)第二層卷積層的輸出的 feature_map 的線性加權(quán),權(quán)重就是最后一層分類層的權(quán)重,由于分類層的權(quán)重編碼了類別的信息,所以加權(quán)后的響應(yīng)圖就有了基于不同類別的區(qū)域相應(yīng)。(具體的介紹可以看 https://zhuanlan.zhihu.com/p/51631163),具體的激活圖生成方式可以如下表示:
說了這么多,下面就展示展示激活圖的樣子:
大家可以看到,上面一張是一只鸮鸚鵡的激活圖,下面是一只在天空飛翔的大雁的激活圖。
因?yàn)辂^鸚鵡的Label是0,其他鳥類的Label是1,所以在激活圖上,只要是負(fù)值的激活區(qū)域都是鸮鸚鵡的激活,也就是Label為0的激活,只要是正值的激活都是其他鳥類的激活,也就是Label為1的激活。
為了方便展示,我把負(fù)值的激活用冷色調(diào)來顯示,把正值的激活用暖色調(diào)來顯示,所以就是變成了上面兩幅激活圖的樣子。而右邊的數(shù)字是具體的激活矩陣(把激活矩陣進(jìn)行GAP就可以變成最終輸出的Logits)。
到這里不知道大家有沒有發(fā)現(xiàn)一個問題,就是無論對于鸮鸚鵡還是大雁的圖片,它們的激活圖除了分布在鳥類本身,也會有一部分分布在背景上。 對于大雁我們好理解,因?yàn)榇笱闶秋w在天空中的,而鸮鸚鵡是不可能在天空中的,所以天空的正激活是非常合理的。但是對于鸮鸚鵡來說,其在鳥類身體以外的負(fù)激活就不是太合理,因?yàn)?,大雁或者是其他的鳥類,也可能在鸮鸚鵡的地面棲息環(huán)境中(但是鸮鸚鵡卻不可能在天空中)。
所以環(huán)境不能提供任何證據(jù)來證明這一次鳥類是一只鸮鸚鵡,鸮鸚鵡的負(fù)激活只是在鳥類的身體上是合理的。而其他鳥類的正激活卻可以同時在鳥類身體上又可能在鳥類的背景上(如天空或者海洋)。
所以我們需要這樣建模這個問題,就是在除鳥類身體的背景上,不能出現(xiàn)鸮鸚鵡的激活,也就是說不能出現(xiàn)負(fù)激活(Label為0的激活)。 所以下面的激活才是合理的:
從上面來看,在除鳥類身體外的背景部分是不存在負(fù)激活的,雖然上面的背景部分有一些正的激活(其他鳥類的激活),但是從右邊的激活矩陣來看,負(fù)激活的scale是占據(jù)絕對優(yōu)勢的,所以完全不會干擾對于鸮鸚鵡的判斷。
所以問題來了,怎么從網(wǎng)絡(luò)設(shè)計方面來達(dá)到這個目的呢?
其實(shí)可以從Loss設(shè)計方面來達(dá)到這個效果。我們假設(shè)每一個鳥都有個對應(yīng)的mask,mask內(nèi)是鳥類的身體部分,mask外是鳥類的背景部分。那么我們需要做的就是抑制mask外的背景部分激活矩陣的負(fù)值,把那一部分負(fù)值給抑制到0即可。
鳥類的激活矩陣和mask的關(guān)系如下圖(紅色的曲線代表鳥的邊界mask):
我們的Loss設(shè)計可以用下面的公式表示:
Loss_cam = -sum(where(bird_mask_outside<0))
具體的網(wǎng)絡(luò)的framework可以如下所示:
其中虛線部分只是訓(xùn)練時候需要用到,inference的時候是不需要的,所以這種方法也是不會占用任何在inference前向時候的計算量。
基于輔助學(xué)習(xí)給模型加入先驗(yàn)知識
到現(xiàn)在為止,咱們還只是把我們的鳥類分類的task當(dāng)成一個二分類來處理,即鸮鸚鵡是一類,其他的鳥類是一類。
但是我們知道,世界的鳥類可不僅僅是兩類,除了鸮鸚鵡之外還有很多種類的鳥類。而不同鳥類的特征或許有很大的差別,比如鴕鳥的特征就是脖子很長,大雁的特征就是翅膀很大。
假如只是把鸮鸚鵡當(dāng)做一類,把其他的鳥類當(dāng)做一類來學(xué)習(xí)的話,那么模型很可能不能學(xué)到可以利用的區(qū)分非鸮鸚鵡的特征,或者是會坍塌到一些區(qū)分度不強(qiáng)的特征上面,從而沒有學(xué)到能夠很好的區(qū)分不同其他鳥類的特征,而那些特征對去區(qū)別鸮鸚鵡和其他鳥類或許是重要的。
所以我們有必要加入其他鳥類存在不同類別的先驗(yàn)知識。而這里,我主要介紹基于輔助學(xué)習(xí)的方式去學(xué)習(xí)類似的先驗(yàn)知識。首先我要解釋一下什么是輔助學(xué)習(xí),以及輔助學(xué)習(xí)和多任務(wù)學(xué)習(xí)的區(qū)別:
上圖的左側(cè)是多任務(wù)學(xué)習(xí)的例子,右側(cè)是輔助學(xué)習(xí)的例子。左側(cè)是個典型的face attribute的task,意思是輸入一張人臉,通過多個branch來輸出這一張人臉的年齡,性別,發(fā)型等等信息,各個branch的任務(wù)是獨(dú)立的,同時又共享同一個backbone。右邊是一個典型的輔助學(xué)習(xí)的task,意思是出入一張人臉,判斷這一張人臉的性別,同時另外開一個(或幾個)branch,通過這個branch來讓網(wǎng)絡(luò)學(xué)一些輔助信息,比如發(fā)型,皮膚等等,來幫助網(wǎng)絡(luò)主任務(wù)(分男女)的判別。
好,回到我們的鸮鸚鵡分類的task,我們可能首先會想到下面的Pipeline:
這樣雖然可以把不同類別的鳥類的特征都學(xué)到,但是卻削弱了網(wǎng)絡(luò)對于鸮鸚鵡和其他鳥類特征的分別。
經(jīng)過實(shí)驗(yàn)發(fā)現(xiàn),這種網(wǎng)絡(luò)架構(gòu)不能很好的增加主任務(wù)的分類精度。為了充分的學(xué)到鸮鸚鵡和其他鳥類特征的分別,同時又能帶入不同種類鳥類類別的先驗(yàn),我們引入輔助任務(wù):
在上面的Pipeline中,輔助任務(wù)相比如主任務(wù),把其他鳥類做更加細(xì)致的分類。這樣網(wǎng)絡(luò)就學(xué)到了區(qū)分不同其他鳥類的能力。
但是從實(shí)驗(yàn)效果來看這個Pipeline的精度并不高。經(jīng)過分析原因,發(fā)現(xiàn)在主任務(wù)和輔助任務(wù)里面都有鸮鸚鵡這一類,這樣當(dāng)回傳梯度的時候,相當(dāng)于把區(qū)分鸮鸚鵡和其他鳥類的特征回傳了兩次梯度,而回傳兩次梯度明顯是沒用的,而且會干擾輔助任務(wù)學(xué)習(xí)不同其他鳥類的特征。
所以我們可以把輔助任務(wù)的鸮鸚鵡類去除,于是便形成了下面的pipeline:
經(jīng)過實(shí)驗(yàn)發(fā)現(xiàn),這種pipeline是有利于主任務(wù)精度提升的,網(wǎng)絡(luò)對于特征明顯的其他鳥類的分類能力得到了一定程度的提升,同時對于困難類別的分類能力也有一定程度的提升。
當(dāng)然,輔助任務(wù)的branch可以不只是一類,你可以通過多個類別來定義你的輔助任務(wù)的branch:
這時候你會想,上面的pipeline好是好,但是我沒有那么多的label啊。是的,上面的pipeline除了主任務(wù)的label標(biāo)注,它還同時需要很多的輔助任務(wù)的label標(biāo)注,而標(biāo)注label是深度學(xué)習(xí)任務(wù)里面最讓人頭疼的問題(之一)。
別怕,我下面介紹一個work,它基于meta-learning的方法,讓你不再為給輔助任務(wù)標(biāo)注label而煩惱,它的framework如下:
這個framework采用基于maxl[2]的方案(https://github.com/lorenmt/maxl),輔助任務(wù)的數(shù)據(jù)和label不是由人為手工劃分,而是由一個label generator來產(chǎn)生,label generator的優(yōu)化目標(biāo)是讓主網(wǎng)絡(luò)在主任務(wù)的task上的loss降低,主網(wǎng)絡(luò)的目標(biāo)是在主任務(wù)和輔助任務(wù)上的loss同時降低。
但是這個framework有個缺點(diǎn),就是訓(xùn)練時間會上升一個數(shù)量級,同時label generator會比較難優(yōu)化。感興趣的同學(xué)可以自己嘗試。但是不得不說,這篇文章有兩個結(jié)論倒是很有意思:
假設(shè) primary 和 auxiliary task 是在同一個 domain,那么 primary task 的 performance 會提高當(dāng)且僅當(dāng) auxiliary task 的 complexity 高于 primary task。
假設(shè) primary 和 auxiliary task 是在同一個 domain,那么 primary task 的最終 performance 只依賴于 complexity 最高的 auxiliary task。
結(jié)語
先總結(jié)一下所有可以有效的加入先驗(yàn)信息的框架:
你可以通過上述框架的選擇來加入自己的先驗(yàn)信息。
給神經(jīng)網(wǎng)絡(luò)的黑盒子里面加入一些人為設(shè)定的先驗(yàn)知識,這樣往往能給你的task帶來一定程度的提升,不過具體的task需要加入什么樣的先驗(yàn)知識,需要如何加入先驗(yàn)知識還需要自己探索。
來自我自己的博客:https://zhengtq.github.io/2020/07/30/pri-knowledge-1/
參考
^CAM https://arxiv.org/abs/1512.04150
^maxl https://arxiv.org/abs/1901.08933
本文僅做學(xué)術(shù)分享,如有侵權(quán),請聯(lián)系刪文。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。