Transformer也能生成圖像,新型ViTGAN性能比肩基于CNN的GAN
Transformer 已經(jīng)為多種自然語言任務(wù)帶來了突飛猛進(jìn)的進(jìn)步,并且最近也已經(jīng)開始向計(jì)算機(jī)視覺領(lǐng)域滲透,開始在一些之前由 CNN 主導(dǎo)的任務(wù)上暫露頭角。近日,加州大學(xué)圣迭戈分校與 Google Research 的一項(xiàng)研究提出了使用視覺 Transformer 來訓(xùn)練 GAN。為了有效應(yīng)用該方法,研究者還提出了多項(xiàng)改進(jìn)技巧,使新方法在一些指標(biāo)上可比肩前沿 CNN 模型。
卷積神經(jīng)網(wǎng)絡(luò)(CNN)在卷積(權(quán)重共享和局部連接)和池化(平移等變)方面的強(qiáng)大能力,讓其已經(jīng)成為了現(xiàn)今計(jì)算機(jī)視覺領(lǐng)域的主導(dǎo)技術(shù)。但最近,Transformer 架構(gòu)已經(jīng)開始在圖像和視頻識(shí)別任務(wù)上與 CNN 比肩。其中尤其值得一提的是視覺 Transformer(ViT)。這種技術(shù)會(huì)將圖像作為 token 序列(類似于自然語言中的詞)來解讀。Dosovitskiy et al. 的研究表明,ViT 在 ImageNet 基準(zhǔn)上能以更低的計(jì)算成本取得相當(dāng)?shù)姆诸悳?zhǔn)確度。不同于 CNN 中的局部連接性,ViT 依賴于在全局背景中考慮的表征,其中每個(gè) patch 都必須與同一圖像的所有 patch 都關(guān)聯(lián)處理。
ViT 及其變體盡管還處于早期階段,但已有研究展現(xiàn)了其在建模非局部上下文依賴方面的優(yōu)秀前景,并且也讓人看到了其出色的效率和可擴(kuò)展性。自 ViT 在前段時(shí)間誕生以來,其已經(jīng)被用在了目標(biāo)檢測(cè)、視頻識(shí)別、多任務(wù)預(yù)訓(xùn)練等多種不同任務(wù)中。
近日,加州大學(xué)圣迭戈分校與 Google Research 的一項(xiàng)研究提出了使用視覺 Transformer 來訓(xùn)練 GAN。這篇論文的研究議題是:不使用卷積或池化,能否使用視覺 Transformer 來完成圖像生成任務(wù)?更具體而言:能否使用 ViT 來訓(xùn)練生成對(duì)抗網(wǎng)絡(luò)(GAN)并使之達(dá)到與已被廣泛研究過的基于 CNN 的 GAN 相媲美的質(zhì)量?
論文鏈接:https://arxiv.org/pdf/2107.04589.pdf
為此,研究者遵照最本原的 ViT 設(shè)計(jì),使用純粹基本的 ViT(如圖 2(A))訓(xùn)練了 GAN。其中的難點(diǎn)在于,GAN 的訓(xùn)練過程在與 ViT 耦合之后會(huì)變得非常不穩(wěn)定,并且對(duì)抗訓(xùn)練常常會(huì)在判別器訓(xùn)練的后期受到高方差梯度(或尖峰梯度)的阻礙。此外,梯度懲罰、譜歸一化等傳統(tǒng)的正則化方法雖然能有效地用于基于 CNN 的 GAN 模型(如圖 4),但這些正則化方法卻無法解決上述不穩(wěn)定問題。使用了適當(dāng)?shù)恼齽t化方法后,基于 CNN 的 GAN 訓(xùn)練不穩(wěn)定的情況并不常見,因此對(duì)基于 ViT 的 GAN 而言,這是一個(gè)獨(dú)有的挑戰(zhàn)。
針對(duì)這些問題,為了實(shí)現(xiàn)訓(xùn)練動(dòng)態(tài)的穩(wěn)定以及促進(jìn)基于 ViT 的 GAN 的收斂,這篇論文提出了多項(xiàng)必需的修改。
在判別器中,研究者重新審視了自注意力的 Lipschitz 性質(zhì),在此基礎(chǔ)上他們?cè)O(shè)計(jì)了一種加強(qiáng)了 Lipschitz 連續(xù)性的譜歸一化。不同于難以應(yīng)付不穩(wěn)定情況的傳統(tǒng)譜歸一化方法,這些技術(shù)能非常有效地穩(wěn)定基于 ViT 的判別器的訓(xùn)練動(dòng)態(tài)。此外,為了驗(yàn)證新提出的技術(shù)的作用,研究者還執(zhí)行了控制變量研究。對(duì)于基于 ViT 的生成器,研究者嘗試了多種不同的架構(gòu)設(shè)計(jì)并發(fā)現(xiàn)了對(duì)層歸一化和輸出映射層的兩項(xiàng)關(guān)鍵性修改。實(shí)驗(yàn)表明,不管使用的判別器是基于 ViT 還是基于 CNN,基于修改版 ViT 的生成器都能更好地促進(jìn)對(duì)抗訓(xùn)練。
為了更具說服力,研究者在三個(gè)標(biāo)準(zhǔn)的圖像合成基準(zhǔn)上進(jìn)行了實(shí)驗(yàn)。結(jié)果表明,新提出的模型 ViTGAN 極大優(yōu)于之前的基于 Transformer 的 GAN 模型,并且在沒有使用卷積和池化時(shí)也取得了與 StyleGAN2 等領(lǐng)先的基于 CNN 的 GAN 相媲美的表現(xiàn)。作者表示,新提出的 ViTGAN 算得上是在 GAN 中使用視覺 Transformer 的最早嘗試之一,更重要的是,這項(xiàng)研究首次表明 Transformer 能在 CIFAR、CelebA 和 LSUN 臥室數(shù)據(jù)集等標(biāo)準(zhǔn)圖像生成基準(zhǔn)上超過當(dāng)前最佳的卷積架構(gòu)。
方法
圖 1 展示了新提出的 ViTGAN 架構(gòu),其由一個(gè) ViT 判別器和一個(gè)基于 ViT 的生成器構(gòu)成。研究者發(fā)現(xiàn),直接使用 ViT 作為判別器會(huì)讓訓(xùn)練不穩(wěn)定。為了穩(wěn)定訓(xùn)練動(dòng)態(tài)和促進(jìn)收斂,研究者為生成器和判別器都引入了新技術(shù):(1) ViT 判別器上的正則化和 (2) 新的生成器架構(gòu)。
圖 1:新提出的 ViTGAN 框架示意圖。生成器和判別器都是基于視覺 Transformer(ViT)設(shè)計(jì)的。判別器分?jǐn)?shù)是從分類嵌入推導(dǎo)得到的(圖中記為 *);生成器是基于 patch 嵌入逐個(gè) patch 生成像素。
增強(qiáng) Transformer 判別器的 Lipschitz 性質(zhì)。在 GAN 判別器中,Lipschitz 連續(xù)性發(fā)揮著重要的作用。人們最早注意到它的時(shí)候是將其用作近似 WGAN 中 Wasserstein 距離的一個(gè)條件,之后其又在使用 Wasserstein 損失之外的其它 GAN 設(shè)置中得到了確認(rèn)。其中,尤其值得關(guān)注的是 ICML 2019 論文《Lipschitz generative adversarial nets》,該研究證明 Lipschitz 判別器能確保存在最優(yōu)的判別函數(shù)以及唯一的納什均衡。但是,ICML 2021 的一篇論文《The lipschitz constant of self-attention》表明標(biāo)準(zhǔn)點(diǎn)積自注意力層的 Lipschitz 常數(shù)可以是****的,這就會(huì)破壞 ViT 中的 Lipschitz 連續(xù)性。為了加強(qiáng) ViT 判別器的 Lipschitz 性質(zhì),研究者采用了上述論文中提出的 L2 注意力。如等式 7 所示,點(diǎn)積相似度被替換成了歐幾里得距離,并且還關(guān)聯(lián)了投影矩陣的權(quán)重,以用于自注意力中的查詢和鍵(key)。這項(xiàng)改進(jìn)能提升用于 GAN 判別器的 Transformer 的穩(wěn)定性。
經(jīng)過改進(jìn)的譜歸一化。為了進(jìn)一步強(qiáng)化 Lipschitz 連續(xù)性,研究者還在判別器訓(xùn)練中使用了譜歸一化。標(biāo)準(zhǔn)譜歸一化是使用冪迭代來估計(jì)每層神經(jīng)網(wǎng)絡(luò)的投影矩陣的譜范數(shù),然后再使用估計(jì)得到的譜范數(shù)來除權(quán)重矩陣,這樣所得到的投影矩陣的 Lipschitz 常量就等于 1。研究者發(fā)現(xiàn),Transformer 模塊對(duì) Lipschitz 常數(shù)的大小很敏感,當(dāng)使用了譜歸一化時(shí),訓(xùn)練速度會(huì)非常慢。類似地,研究者還發(fā)現(xiàn)當(dāng)使用了基于 ViT 的判別器時(shí),R1 梯度懲罰項(xiàng)會(huì)有損 GAN 訓(xùn)練。另有研究發(fā)現(xiàn),如果 MLP 模塊的 Lipschitz 常數(shù)較小,則可能導(dǎo)致 Transformer 的輸出坍縮為秩為 1 的矩陣。為了解決這個(gè)問題,研究者提出增大投影矩陣的譜范數(shù)。
他們發(fā)現(xiàn),只需在初始化時(shí)將譜范數(shù)與每一層的歸一化權(quán)重矩陣相乘,便足以解決這個(gè)問題。具體而言,譜歸一化的更新規(guī)則如下,其中 σ 是計(jì)算權(quán)重矩陣的標(biāo)準(zhǔn)譜范:
重疊圖像塊。由于 ViT 判別器具有過多的學(xué)習(xí)能力,因此容易過擬合。在這項(xiàng)研究中,判別器和生成器使用了同樣的圖像表征,其會(huì)根據(jù)一個(gè)預(yù)定義的網(wǎng)絡(luò) P×P 來將圖像分割為由非重疊 patch 組成的序列。如果不經(jīng)過精心設(shè)計(jì),這些任意的網(wǎng)絡(luò)劃分可能會(huì)促使判別器記住局部線索,從而無法為生成器提供有意義的損失。為了解決這個(gè)問題,研究者采用了一種簡(jiǎn)單技巧,即讓 patch 之間有所重疊。對(duì)于 patch 的每個(gè)邊緣,都將其擴(kuò)展 o 個(gè)像素,使有效 patch 尺寸變?yōu)?(P+2o)×(P+2o)。
這樣得到的序列長度與原來一樣,但對(duì)預(yù)定義網(wǎng)格的敏感度更低。這也有可能讓 Transformer 更好地了解當(dāng)前 patch 的鄰近 patch 是哪些,由此更好地理解局部特性。
生成器設(shè)計(jì)
基于 ViT 架構(gòu)設(shè)計(jì)生成器并非易事,其中一大難題是將 ViT 的功能從預(yù)測(cè)一組類別標(biāo)簽轉(zhuǎn)向在一個(gè)空間區(qū)域生成像素。
圖 2:生成器架構(gòu)。左圖是研究者研究過的三種生成器架構(gòu):(A) 為每個(gè)位置嵌入添加中間隱藏嵌入 w,(B) 將 w 預(yù)置到序列上,(C) 使用由 w 學(xué)習(xí)到的仿射變換(圖中的 A)計(jì)算出的自調(diào)制型層范數(shù)(SLN/self-modulated layernorm)替換歸一化。右圖是用在 Transformer 模塊中的自調(diào)制運(yùn)算的細(xì)節(jié)。
研究者先研究了多種生成器架構(gòu),發(fā)現(xiàn)它們都比不上基于 CNN 的生成器。于是他們遵循 ViT 的設(shè)計(jì)原理提出了一種全新的生成器。圖 2(c) 展示了這種 ViTGAN 生成器,其包含兩大組件:Transformer 模塊和輸出映射層。
為了促進(jìn)訓(xùn)練過程,研究者為新提出的生成器做出了兩項(xiàng)改進(jìn):
自調(diào)制型層范數(shù)(SLN)。新的做法不是將噪聲向量 z 作為輸入發(fā)送給 ViT,而是使用 z 來調(diào)制層范數(shù)運(yùn)算。之所以稱這樣的操作為自調(diào)制,是因?yàn)樵撨^程無需外部信息;
用于圖塊生成的隱式神經(jīng)表征。為了學(xué)習(xí)從 patch 嵌入到 patch 像素值的連續(xù)映射,研究者使用了隱式神經(jīng)表征。當(dāng)結(jié)合傅里葉特征或正弦激活函數(shù)一起使用時(shí),隱式表征可將所生成的樣本空間約束到平滑變化的自然信號(hào)空間。研究發(fā)現(xiàn),在使用基于 ViT 的生成器訓(xùn)練 GAN 時(shí),隱式表征的作用尤其大。
需要指出,由于生成器和判別器的圖像網(wǎng)格不同,因此序列長度也不一樣。進(jìn)一步的研究發(fā)現(xiàn),當(dāng)需要將模型擴(kuò)展用于更高分辨率的圖像時(shí),只需增大判別器的序列長度或特征維度就足夠了。
實(shí)驗(yàn)結(jié)果
表 1:幾種代表性 GAN 架構(gòu)在無條件圖像生成基準(zhǔn)的結(jié)果比較。Conv 和 Pool 各自代表卷積和池化?!?表示越低越好;↑ 表示越高越好。
表 1 給出了在圖像合成的三個(gè)標(biāo)準(zhǔn)基準(zhǔn)上的主要結(jié)果。本論文提出的新方法能與以下基準(zhǔn)架構(gòu)比肩。TransGAN 是目前唯一完全不使用卷積的 GAN,其完全基于 Transformer 構(gòu)建。這里比較的是其最佳的變體版本 TransGAN-XL。Vanilla-ViT 是一種基于 ViT 的 GAN,其使用了圖 2(A) 的生成器和純凈版 ViT 判別器,但未使用本論文提出的改進(jìn)技術(shù)。
表 3a 中分別比較了圖 2(B) 所示的生成器架構(gòu)。此外,BigGAN 和 StyleGAN2 作為基于 CNN 的 GAN 的最佳模型也被納入了比較。
圖 3:定性比較。在 CIFAR-10 32 × 32、CelebA 64 × 64 和 LSUN Bedroom 64 × 64 數(shù)據(jù)集上,ViTGAN 與 StyleGAN2、Transformer 最佳基準(zhǔn)、純凈版生成器和判別器的 ViT 的結(jié)果比較。
圖 4:(a-c) ViT 判別器的梯度幅度(在所有參數(shù)上的 L2 范數(shù)),(d-f) FID 分?jǐn)?shù)(越低越好)隨訓(xùn)練迭代的變化情況。
可以看到,新提出方法的表現(xiàn)與使用 R1 懲罰項(xiàng)和譜范數(shù)的兩個(gè)純凈版 ViT 判別器基準(zhǔn)相當(dāng)。其余架構(gòu)對(duì)所有方法來說都一樣??梢娦路椒芸朔荻确鹊募夥宀?shí)現(xiàn)顯著更低的 FID(在 CIFAR 和 CelebA 上)或相近的 FID(在 LSUN 上)。
表 3:在 CIFAR-10 數(shù)據(jù)集上對(duì) ViTGAN 執(zhí)行的控制變量研究。左圖:對(duì)生成器架構(gòu)的控制變量研究。右圖:對(duì)判別器架構(gòu)的控制變量研究。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。