CVPR 2022|解耦知識(shí)蒸餾,讓Hinton在7年前提出的方法重回SOTA行列
與主流的feature蒸餾方法不同,本研究將重心放回到logits蒸餾上,提出了一種新的方法「解耦知識(shí)蒸餾」,重新達(dá)到了SOTA結(jié)果,為保證復(fù)現(xiàn)該研究還提供了開源的蒸餾代碼庫(kù):MDistiller。
1 研究摘要
近年來(lái)頂會(huì)的 SOTA 蒸餾方法多基于 CNN 的中間層特征,而基于輸出 logits 的方法被嚴(yán)重忽視了。飲水思源,本文中來(lái)自曠視科技 (Megvii)、早稻田大學(xué)、清華大學(xué)的研究者將研究重心放回到 logits 蒸餾上,對(duì) 7 年前 Hinton 提出的知識(shí)蒸餾方法(Knowledge Distillation,下文簡(jiǎn)稱 KD)[1] 進(jìn)行了解耦和分析,發(fā)現(xiàn)了一些限制 KD 性能的重要因素,進(jìn)而提出了一種新的方法「解耦知識(shí)蒸餾」(Decoupled Knowledge Distillation,下文簡(jiǎn)稱 DKD)[2],使得 logits 蒸餾重回 SOTA 行列。
同時(shí),為了保證復(fù)現(xiàn)和支持進(jìn)一步研究,該研究提供了一個(gè)全新的開源代碼庫(kù) MDistiller,該庫(kù)涵蓋了 DKD 和大部分的 SOTA 方法,并不斷進(jìn)行更新維護(hù),歡迎大家試用并提供寶貴的反饋意見。
論文鏈接:https://arxiv.org/abs/2203.08679
代碼鏈接:https://github.com/megvii-research/mdistiller
2 研究動(dòng)機(jī)
上圖是大家熟知的 KD 方法,KD 用 Teacher 網(wǎng)絡(luò)和 Student 網(wǎng)絡(luò)的輸出 logits 來(lái)計(jì)算 KL Loss,從而實(shí)現(xiàn) dark knowledge 的傳遞,利用 Teacher 已經(jīng)學(xué)到的知識(shí)幫助 Student 收斂得更好。在 KD 之后,更多的基于中間特征的蒸餾方法不斷涌現(xiàn),不斷刷新知識(shí)蒸餾的 SOTA。但該研究認(rèn)為,KD 這樣的 logits 蒸餾方法具備兩點(diǎn)好處:
1. 基于 feature 的蒸餾方法需要更多復(fù)雜的結(jié)構(gòu)來(lái)拉齊特征的尺度和網(wǎng)絡(luò)的表示能力,而 logits 蒸餾方法更簡(jiǎn)單高效;2. 相比中間 feature,logits 的語(yǔ)義信息是更 high-level 且更明確的,基于 logits 信號(hào)的蒸餾方法也應(yīng)該具備更高的性能上限,因此,對(duì) logits 蒸餾進(jìn)行更多的探索是有意義的。
該研究嘗試一種拆解的方法來(lái)更深入地分析 KD:將 logits 分成兩個(gè)部分(如圖),藍(lán)色部分代表目標(biāo)類別(target class)的 score,綠色部分代表非目標(biāo)類別(Non-target class)的 score。這樣的拆解使得我們可以重新推導(dǎo) KD 的 Loss 公式,得到一個(gè)新的等價(jià)表達(dá)式,進(jìn)而做更多的實(shí)驗(yàn)和分析。
2.1 符號(hào)定義
這里只寫出關(guān)鍵符號(hào)定義,更具體的定義請(qǐng)參考論文正文。
首先,該研究將第 i 類的分類概率表示為(其中表示網(wǎng)絡(luò)輸出的 logits):
為了拆解分類網(wǎng)絡(luò)輸出的 logits,該研究接下來(lái)定義了兩種新的概率分布:
1. 目標(biāo)類 vs 非目標(biāo)類的二分類分布,該概率分布和分類監(jiān)督信號(hào)高度耦合。該分布包含兩個(gè)元素:目標(biāo)類概率和全部非目標(biāo)類概率,分別表示為:
2. 非目標(biāo)類內(nèi)部競(jìng)爭(zhēng)的多分類分布,也就是在預(yù)測(cè)樣本為非目標(biāo)類的前提下每個(gè)類各自的概率(總和為 1)。這個(gè)概率分布和分類的監(jiān)督信號(hào)是不相關(guān)的,換句話說(shuō),從這個(gè)概率分布中無(wú)法得知目標(biāo)類上的預(yù)測(cè)置信度,其表達(dá)式為:
根據(jù)上述定義,可以得到一個(gè)顯然的數(shù)學(xué)關(guān)系:。這些定義和數(shù)學(xué)關(guān)系將幫助我們得到 KD Loss 的一個(gè)新的表達(dá)形式。
2.2 重新推導(dǎo) KD Loss
首先,KD 的 Loss 定義如下:
然后根據(jù)公式(1)和(2),我們可以將其改寫為:
可以觀察到,式中的第一項(xiàng)只牽涉到了目標(biāo)類別 vs 非目標(biāo)類別的二分類概率分布,第二項(xiàng)牽涉到了非目標(biāo)類概率分布的 KL 散度和權(quán)重。該研究將第一項(xiàng)命名為目標(biāo)類別知識(shí)蒸餾 Target Class Knowledge Distillation(下文簡(jiǎn)稱 TCKD),將第二項(xiàng)中的 KL 散度命名為非目標(biāo)類別知識(shí)蒸餾 Non-target Class Knowledge Distillation(下文簡(jiǎn)稱 NCKD)。至此,該研究完成了對(duì) KD Loss 的拆分,將其分成了兩個(gè)可單獨(dú)使用的部分,并可以分析其各自的作用:
3 啟發(fā)式探索
首先,該研究對(duì) TCKD 和 NCKD 做了消融實(shí)驗(yàn),觀察它們對(duì)蒸餾性能的影響;接著,他們分別探索 TCKD 和 NCKD 的作用;最后,研究者做了一些啟發(fā)式的討論。
3.1 單獨(dú)使用 TCKD/NCKD 訓(xùn)練
如表 1 所示,我們可以觀察到:
1. 同時(shí)使用 TCKD 和 NCKD(等同于 KD),有不錯(cuò)的性能提升;2. 單獨(dú)使用 TCKD 進(jìn)行蒸餾,會(huì)對(duì)蒸餾效果產(chǎn)生較大的損害(這一點(diǎn)在補(bǔ)充材料中有詳細(xì)討論,主要和蒸餾溫度 T 相關(guān));3. 單獨(dú)使用 NCKD 進(jìn)行蒸餾,和 KD 的效果是差不多的,甚至有時(shí)會(huì)更好;
基于這些觀察可以推出兩個(gè)初步結(jié)論:
1.TCKD 是沒用的,甚至在單獨(dú)使用時(shí)可能是有害的;2.NCKD 可能是 KD 生效的主要原因;
接下來(lái)該研究就這兩個(gè)初步的結(jié)論進(jìn)行了進(jìn)一步的分析。
3.2 TCKD:傳遞樣本難度相關(guān)的知識(shí)
TCKD 作用于目標(biāo)類的二分類概率分布上,這個(gè)概率的物理含義是「網(wǎng)絡(luò)對(duì)樣本的置信度」。比如:如果一個(gè)樣本被 Teacher 學(xué)會(huì)了,會(huì)產(chǎn)生類似[0.99, 0.01] 的 binary 概率,而如果一個(gè)樣本比較難擬合,則會(huì)產(chǎn)生類似 [0.6, 0.4] 的 binary 概率。所以該研究猜測(cè):TCKD 傳遞了和樣本擬合難度相關(guān)的知識(shí),當(dāng)訓(xùn)練集擬合難度高時(shí)才會(huì)起到作用。為了證明這一點(diǎn),該研究設(shè)計(jì)了三組實(shí)驗(yàn)來(lái)增加 CIFAR-100 的訓(xùn)練難度,觀察 TCKD 是否有效:
更強(qiáng)的數(shù)據(jù)增廣:
以表 2 中的 ShuffleNet-V1 為例,在使用 AutoAugment 的情況下,訓(xùn)練集難度有了明顯提升,此時(shí)僅僅使用 NCKD 只能達(dá)到 73.8% 的 student 準(zhǔn)確率,而同時(shí)使用 TCKD 和 NCKD 可以將 student 準(zhǔn)確率提升至 75.3%。
更 Noisy 的標(biāo)簽:
表 3 中,該研究通過(guò)控制 noisy ratio 對(duì)數(shù)據(jù)集的標(biāo)簽引入不同程度噪聲,ratio 越大表示噪聲越大??梢钥吹?,隨著數(shù)據(jù)集的噪聲變大,單獨(dú)使用 NCKD 的效果變得越來(lái)越差,同時(shí)引入 TCKD 的增益也越來(lái)越大。說(shuō)明在越難學(xué)的數(shù)據(jù)上,TCKD 的作用就會(huì)越明顯。
更難的數(shù)據(jù)集:
ImageNet 是一個(gè)比 CIFAR-100 更困難的數(shù)據(jù)集,所以該研究在 ImageNet 上也進(jìn)行了嘗試。從表 4 可以看出,在 ImageNet 上只使用 NCKD 的效果也是沒有同時(shí)使用 TCKD 和 NCKD 要好的。
總結(jié)
三組實(shí)驗(yàn)都反映出,當(dāng)訓(xùn)練數(shù)據(jù)擬合難度變高時(shí)(無(wú)論是數(shù)據(jù)本身難度、還是噪聲和增廣帶來(lái)的難度),TCKD 能提供更有效的知識(shí),對(duì)蒸餾性能的提升也越高,這些實(shí)驗(yàn)在一定程度上說(shuō)明了 TCKD 確實(shí)是在傳遞有關(guān)樣本擬合難度的知識(shí),印證了該研究的想法。
3.3 NCKD:被抑制的重要成分
表 1 中反映出的另一個(gè)有趣的現(xiàn)象是:只使用 NCKD 也能取得令人滿意的蒸餾效果,甚至可能比 KD 更好。這樣的現(xiàn)象反映出:非目標(biāo)類別上的 logits 中蘊(yùn)含的信息,才是最主要的 dark knowledge 成分。
然而當(dāng)回顧 KD 的新表達(dá)式時(shí),發(fā)現(xiàn) NCKD 對(duì)應(yīng)的 loss 是和權(quán)重耦合在一起的。換言之,如果 teacher 網(wǎng)絡(luò)的預(yù)測(cè)越置信,NCKD 的 loss 權(quán)重就更低,其作用就會(huì)越小。而該研究認(rèn)為,teacher 更置信的樣本能夠提供更有益的 dark knowledge,和 NCKD 耦合的權(quán)重會(huì)嚴(yán)重抑制高置信度樣本的知識(shí)遷移,使得知識(shí)蒸餾的效率大幅降低。為了證明這一點(diǎn),該研究做了如下實(shí)驗(yàn):
1. 依據(jù) teacher 模型的置信度,該研究對(duì)訓(xùn)練集上的樣本做了排序,并將排序后的樣本分成置信(置信度 top-50%)和非置信 (剩余) 兩個(gè)批次;2. 訓(xùn)練時(shí),對(duì)全部樣本使用分類 Loss,并只對(duì)置信批次 / 非置信批次使用 NCKD Loss;
實(shí)驗(yàn)結(jié)果如表 5 所示,0-50% 表示置信批次,50-100% 表示非置信批次。第一行是在整個(gè)訓(xùn)練集上做 NCKD 的結(jié)果,第二行表示只對(duì)置信批次做 NCKD,第三行表示只對(duì)非置信批次做 NCKD。顯然,置信批次上使用 NCKD 帶來(lái)了更主要的漲點(diǎn),說(shuō)明置信度更高的樣本對(duì)蒸餾的訓(xùn)練過(guò)程是更有益的,因此是不應(yīng)該被抑制的。
3.4 啟發(fā)
至此,該研究完成了對(duì) KD Loss 的解耦,并且分析了兩個(gè)部分各自的作用。所有結(jié)果都表明,TCKD 和 NCKD 都有自己的重要作用,然而,研究注意到了在原始的 KD Loss 中,TCKD 和 NCKD 是存在不合理的耦合的:
1. 一方面,NCKD 和耦合,會(huì)導(dǎo)致高置信度樣本的蒸餾效果大打折扣;2. 另一方面,TCKD 和 NCKD 是耦合的。然而這兩個(gè)部分傳遞的知識(shí)是不同的,這樣的耦合導(dǎo)致了他們各自的重要性沒有辦法靈活調(diào)整。
4 Decoupled Knowledge Distillation
根據(jù)推導(dǎo)和啟發(fā)式探索,該研究提出了一種新的 logits 蒸餾方法“解耦知識(shí)蒸餾(DKD)”,來(lái)解決上一章提出的兩個(gè)問(wèn)題,如上圖所示。DKD 的 Loss 表達(dá)式如下:
和 KD Loss 相比,該研究將限制 NCKD 的權(quán)重替換為了,并給 TCKD 設(shè)置了一個(gè)權(quán)重。DKD 可以很好地解決剛才提到的兩個(gè)問(wèn)題:一方面,TCKD 和 NCKD 被解耦,它們各自的重要性可以獨(dú)立調(diào)節(jié);另一方面,對(duì)于蒸餾更重要的 NCKD 也不會(huì)再被 Teacher 產(chǎn)生的高置信度抑制,大大提高了蒸餾的靈活性和有效性。DKD 的偽代碼如下:
5 實(shí)驗(yàn)結(jié)果
5.1 Decoupling 帶來(lái)的好處
首先該研究通過(guò) ablation study 驗(yàn)證了 DKD 的有效性,上面的表格表明:
1. 解耦和 NCKD,也就是把設(shè)置為 1.0,可以將 top-1 accuracy 從 73.6% 提升至 74.8%;2. 解耦 NCKD 和 TCKD 的權(quán)重,即進(jìn)一步調(diào)節(jié)的數(shù)值,可以將 top-1 accuracy 從 74.8% 進(jìn)一步提升至 76.3%;
這些實(shí)驗(yàn)結(jié)果說(shuō)明 DKD 的解耦確實(shí)能帶來(lái)顯著的性能增益,這一方面證明了 KD 確實(shí)存在剛才提到的兩個(gè)問(wèn)題,另一方面也證明了 DKD 的有效性。此外,這個(gè)表格也證明了對(duì)超參數(shù)是不敏感的,把設(shè)置為 1.0 就可以取得令人滿意的效果,所以在實(shí)際應(yīng)用中只需要調(diào)節(jié)即可。同時(shí),也不是一個(gè)敏感的超參數(shù),在 4.0-10.0 的范圍內(nèi),都可以取得令人滿意的蒸餾效果。
5.2 圖像分類
表 6~9 中提供了 DKD 在 CIFAR-100 和 ImageNet-1K 兩個(gè)分類數(shù)據(jù)集上的蒸餾效果。和 KD 相比,DKD 在所有數(shù)據(jù)集和網(wǎng)絡(luò)結(jié)構(gòu)上都有明顯的性能提升。此外,與過(guò)去最好的特征蒸餾方法(ReviewKD)相比,DKD 也取得了接近甚至更好的結(jié)果。DKD 成功使得 logits 蒸餾方法重新回到了 SOTA 的陣營(yíng)中。
5.3 目標(biāo)檢測(cè)
該研究也在目標(biāo)檢測(cè)任務(wù)(MS-COCO)上驗(yàn)證了 DKD 的性能。如表 10 所示,在 Detector 蒸餾中,DKD 的結(jié)果雖不如特征蒸餾的 SOTA 性能,但是依然穩(wěn)定地超過(guò)了 KD 的性能。而將 DKD 和特征蒸餾方法組合起來(lái),也可以進(jìn)一步提高 SOTA 結(jié)果。
關(guān)于這一點(diǎn):過(guò)去的一些工作證明了,Detection 任務(wù)非常依賴特征的定位能力,這在 Detector 蒸餾中也是成立的(如 [5] 中提到了,feature mimicking 是非常重要的)。而 logits 并不能提供 location 相關(guān)的信息,無(wú)法對(duì) Student 的定位能力產(chǎn)生幫助,因此在 Detection 任務(wù)中,特征蒸餾相比 logits 蒸餾存在機(jī)制上的優(yōu)勢(shì),這也是 DKD 無(wú)法超過(guò)特征蒸餾 SOTA 的原因。
6 擴(kuò)展性實(shí)驗(yàn)和可視化
6.1 訓(xùn)練效率
logits 蒸餾的好處之一是訓(xùn)練效率高。為了證明這一點(diǎn),該研究可視化了 SOTA 蒸餾方法的訓(xùn)練開銷。圖 2 的 X 軸是每個(gè) batch 的訓(xùn)練時(shí)間,Y 軸是 student 的 top-1 accuracy。顯然,logits 蒸餾(KD 和 DKD)所需的訓(xùn)練時(shí)間是最少的,并且 DKD 用了最少的時(shí)間獲得了最好的蒸餾效果。圖 2 中的表格也提供了訓(xùn)練時(shí)間和訓(xùn)練所需的額外參數(shù)量,和 KD 一樣,DKD 也并沒有額外引入?yún)?shù)量,同時(shí)訓(xùn)練時(shí)間也幾乎沒有增加。logits 蒸餾的優(yōu)越性顯而易見。
6.2 提升大 Teacher 模型蒸餾效果
過(guò)去的一些蒸餾工作發(fā)現(xiàn)了一個(gè)有趣的現(xiàn)象:大模型并不一定是好的 Teacher 網(wǎng)絡(luò)。對(duì)于該現(xiàn)象,研究者提供了一個(gè)可能的解釋:大模型的 model capacity 很大,這會(huì)導(dǎo)致大模型產(chǎn)生更高的,進(jìn)而導(dǎo)致的 NCKD 被抑制得更嚴(yán)重。過(guò)去的一些工作也可以基于這一點(diǎn)解釋,如 ESKD [4] 引入了 early-stopped teacher 來(lái)緩解這一問(wèn)題,這可能是因?yàn)?early-stopped 模型還沒有充分?jǐn)M合訓(xùn)練集,還比較小,所以對(duì) NCKD 的抑制不是很嚴(yán)重。
為了證明該觀點(diǎn),研究者也進(jìn)行了一系列的對(duì)比實(shí)驗(yàn)。如表 11 和表 12 所示,當(dāng)使用 DKD 時(shí),大模型蒸餾效果變差的問(wèn)題被顯著改善。該研究希望這一點(diǎn)可以為后續(xù)的工作提供一些 insight。
6.3 特征遷移性
這里該研究嘗試將 DKD 訓(xùn)練得到的 student 網(wǎng)絡(luò)進(jìn)行特征遷移。如表 13 所示,研究者將在 CIFAR-100 上訓(xùn)練的 student 遷移到了 STL-10 和 TinyImageNet 兩個(gè)數(shù)據(jù)集上,在眾多的蒸餾方法中,DKD 取得了最好的遷移效果。
6.4 可視化
這里研究者提供了兩種可視化。圖 3 中,與 KD 相比,DKD 的樣本聚得更加緊湊,說(shuō)明 DKD 幫助 student 網(wǎng)絡(luò)學(xué)到了更加可區(qū)分的特征。圖 4 中,研究者計(jì)算了 teacher 網(wǎng)絡(luò)和 student 網(wǎng)絡(luò)輸出 logits 的相似度,和 KD 相比,DKD 訓(xùn)練后的 student 產(chǎn)生的 logits 會(huì)更像 teacher 產(chǎn)生的 logits,說(shuō)明 teacher 的知識(shí)被更好地利用了。
7 改進(jìn)方向
的自適應(yīng)調(diào)整:DKD 目前還需要手工調(diào)整的值才能達(dá)到最佳的蒸餾效果,該研究希望可以通過(guò)一些訓(xùn)練過(guò)程中的統(tǒng)計(jì)量實(shí)現(xiàn)對(duì)的自適應(yīng)調(diào)節(jié)(對(duì)于這一點(diǎn),該研究已經(jīng)有了初步的探索,詳情可見補(bǔ)充材料)。
8 開源代碼庫(kù) MDistiller
為了保證復(fù)現(xiàn)和進(jìn)一步的探索,該研究還開源了一個(gè)知識(shí)蒸餾的 codebase MDistiller。該 codebase 涵蓋了大部分的 SOTA 方法,同時(shí)支持兩種蒸餾關(guān)注的主要任務(wù),圖像分類和目標(biāo)檢測(cè)。該研究希望 MDistiller 可以為后續(xù)的研究者們提供一套可靠的 baseline,因此會(huì)提供長(zhǎng)期支持。
參考文獻(xiàn)
[1] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. In arXiv:1503.02531, 2015.[2] Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, and Jiajun Liang. Decoupled knowledge distillation. In CVPR, 2022. [3] Pengguang Chen, Shu Liu, Hengshuang Zhao, and Jiaya Jia. Distilling knowledge via knowledge review. In CVPR, 2021. [4] Jang Hyun Cho and Bharath Hariharan. On the efficacy of knowledge distillation. In ICCV, 2019. [5] Tao Wang, Li Yuan, Xiaopeng Zhang, and Jiashi Feng. Distilling object detectors with fine-grained feature imitation. In CVPR, 2019.
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。