大一統(tǒng)視角理解擴(kuò)散模型Understanding Diffusion Models: A Unified Perspective(2)
在擴(kuò)散模型里,有幾個(gè)重要的假設(shè)。其中一個(gè)就是每一步擴(kuò)散過程的變換,都是對前一步結(jié)果的高斯變換(上一節(jié)MHVAE的限制條件2):
與MHVAE不同,編碼器側(cè)的潛在向量分布并不經(jīng)過學(xué)習(xí)得到,而是固定為線性高斯模型
這一點(diǎn)和VAE有很大不同。VAE里編碼器側(cè)的潛在向量的分布是通過模型訓(xùn)練得到的。而擴(kuò)散模型里,前向加噪過程里的每一步都是基于上一步結(jié)果的高斯變換。其中 alpha_t 一般當(dāng)作超參設(shè)置得到。這點(diǎn)對于我們計(jì)算擴(kuò)散模型的證據(jù)下界有很大幫助。因?yàn)槲覀兛梢曰谳斎離0確切地知道前向過程里的某一步的具體狀態(tài),從而監(jiān)督我們的預(yù)測。
基于式31,我們可以遞歸式地對x0不斷加噪變換,得到最終xt的表達(dá)式:
xt可以寫為關(guān)于x0的一個(gè)高斯分布的采樣結(jié)果
所以對于式58里噪音匹配項(xiàng)里的監(jiān)督信號,我們可以重寫成以下形式,其中根據(jù)式70,我們可以得到q(xt|x0)和q(xt-1|x0)的表達(dá)式,而q(xt|xt-1, x0)因?yàn)槭乔跋驍U(kuò)散過程,可以應(yīng)用馬爾可夫性質(zhì)看做q(xt|xt-1)使用式31得到具體表達(dá)式。
式58里的監(jiān)督信號可以通過x0計(jì)算具體的值
代入每一項(xiàng)q所代表的高斯函數(shù)表達(dá)式后,我們最后可以得到一個(gè)新的高斯分布表達(dá)式,其中每一項(xiàng)都是具體可求的:
q(xt-1|xt,x0)的解析形式
參考已經(jīng)證明了前向加噪過程可以寫為一個(gè)高斯分布了。在擴(kuò)散模型的初始論文[2]里提到,對于一個(gè)連續(xù)的高斯擴(kuò)散過程,其逆過程與前向過程的方程形式(functional form)一致。所以我們將對去噪匹配項(xiàng)里的p_theta(xt-1|xt)也采用高斯分布的形式(更加具體的一些推導(dǎo)放在了末尾的補(bǔ)充里)。注意式58里,對兩個(gè)高斯分布求KL散度,其解析解的形式如下:
兩個(gè)高斯分布的KL散度解析解
我們現(xiàn)在已知其中一個(gè)高斯分布(左側(cè))的參數(shù),現(xiàn)在如果我們令右側(cè)的高斯分布和左側(cè)高斯分布的方差保持一致。那么優(yōu)化該KL散度的解析式將簡化為以下形式:
式58的噪音匹配項(xiàng)簡化為最小化前后向均值的預(yù)測誤差
如此一來式58的噪音匹配項(xiàng)就被簡化為最小化前后向均值的預(yù)測誤差(式92)。讀者請注意,以下的大一統(tǒng)的三個(gè)角度來看待Diffusion model,實(shí)質(zhì)上都是對式92里mu_q的不同變形所推論出來的。 其中mu_q是關(guān)于xt, x0的函數(shù),而mu_theta是關(guān)于xt和t的函數(shù)。其中通過式84,我們有mu_q的準(zhǔn)確計(jì)算結(jié)果,而因?yàn)閙u_theta是關(guān)于xt的函數(shù)。我們可以將其寫為類似式84的形式(注意,有關(guān)為什么可以忽略方差并且讓均值選取這個(gè)形式放在了最末尾的補(bǔ)充討論里。但關(guān)于這個(gè)形式的選擇的深層原因?qū)嵸|(zhì)上開辟了一個(gè)全新的領(lǐng)域來研究,并且關(guān)于該領(lǐng)域的研究直接導(dǎo)向了擴(kuò)散模型之后的一系列加速采樣技術(shù)的出現(xiàn))
將后向預(yù)測的均值寫為類似前向加噪的形式
比較式84與94可知,x_hat是我們通過噪音數(shù)據(jù)xt來預(yù)測原始數(shù)據(jù)x0的神經(jīng)網(wǎng)絡(luò)。那么我們可以將式58里證據(jù)下界的噪音匹配項(xiàng),最終寫為
噪聲匹配項(xiàng)的最終形式
那么,我們最后得到擴(kuò)散模型的優(yōu)化,最終表現(xiàn)為訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò),以任意時(shí)間步的噪音圖像為輸入,來預(yù)測最初的原始圖像!此時(shí)優(yōu)化目標(biāo)轉(zhuǎn)化為了最小化預(yù)測誤差。同時(shí)式58上的對所有時(shí)間步的噪音匹配項(xiàng)求和的優(yōu)化,可以近似為對每一時(shí)間步上的預(yù)測誤差的期望的最小值,而該優(yōu)化目標(biāo)可以通過隨機(jī)采樣近似:
該優(yōu)化目標(biāo)可以通過隨機(jī)采樣實(shí)現(xiàn)
Three Equivalent Perspective為什么Calvin Luo的這篇論文叫做大一統(tǒng)視角來看待擴(kuò)散模型?以上我們花了不菲的篇幅論證了擴(kuò)散模型的優(yōu)化目標(biāo)可以最終轉(zhuǎn)化為訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)在任意時(shí)間步從xt預(yù)測原始輸入x0。以下我們將論述如何通過對mu_q不同的推導(dǎo)得到類似的角度看待擴(kuò)散模型。
首先,我們已經(jīng)知道給定每個(gè)時(shí)間步的噪聲系數(shù)alpha_t之后,我們可以由初始輸入x0遞歸得到xt。同理,給定xt我們也可以求得x0。那么對式69重置后,我們可以得到式115.
將式69里的xt和x0關(guān)系重置后可得式115
重新將式115代入式84里,我們所得的關(guān)于時(shí)間步t的真實(shí)均值表達(dá)式mu_q后,我們可以得到以下推導(dǎo):
在推導(dǎo)真實(shí)均值時(shí)替換x0
注意在上一次推導(dǎo)的過程中,mu_q里的xt在計(jì)算kl散度的解析式時(shí)被抵消掉了,而x0我們采取的是用神經(jīng)網(wǎng)絡(luò)直接擬合的策略。而在這一次的推導(dǎo)過程中,x0被替換成了關(guān)于xt的表達(dá)式(關(guān)于alpha_bar和epsilon_0)后,我們可以得到mu_q的新的表達(dá)式,依舊關(guān)于xt,只是不再與x0相關(guān),而是與epsilon_0相關(guān)(式124)。其中,和式94一樣,我們忽略方差(將其設(shè)為與前向一致)并將希望擬合的mu_theta寫成與真實(shí)均值mu_q一樣的形式,只是將epsilon_0替換為神經(jīng)網(wǎng)絡(luò)的擬合項(xiàng)后我們可以得到式125。
與上次推導(dǎo)時(shí)替換x0為神經(jīng)網(wǎng)絡(luò)所擬合項(xiàng)一樣,這次換為擬合初始噪聲項(xiàng)
將我們新得到的兩個(gè)均值表達(dá)式重新代入KL散度的表達(dá)式里,xt再次被抵消掉(因?yàn)閙u_theta和mu_q選取的形式一致)最終只剩下epsilon_0和epsilon_theta的差值。注意式130和式99的相似性!
最終對證據(jù)下界里的去噪匹配項(xiàng)的優(yōu)化可以寫成關(guān)于初始噪聲和其擬合項(xiàng)的差的最小化
至此,我們得到了對擴(kuò)散模型的第二種直觀理解。對于一個(gè)變分?jǐn)U散模型VDM,我們優(yōu)化該模型的證據(jù)下界既等價(jià)于優(yōu)化其在所有時(shí)間步上對初始圖像的預(yù)測誤差的期望,也等價(jià)于優(yōu)化在所有時(shí)間步上對噪聲的預(yù)測誤差的期望! 事實(shí)上DDPM采取的做法就是式130的做法(注意DDPM里的表達(dá)式實(shí)際上用的是epsilon_t,關(guān)于這點(diǎn)在文末也會(huì)討論)。
下面筆者將概括第三種看待VDM的推導(dǎo)方式。這種方式主要來自于SongYang博士的系列論文,非常直觀。并且該系列論文將擴(kuò)散模型這種離散的多步去噪過程統(tǒng)一成了一個(gè)連續(xù)的隨機(jī)微分方程(SDE)的特殊形式。SongYang博士因此獲得了ICLR2021的最佳論文獎(jiǎng)!后續(xù)來自清華大學(xué)的基于將該SDE轉(zhuǎn)化為常微分方程ODE后的采樣提速論文,也獲得了2022ICLR的最佳論文獎(jiǎng)!關(guān)于該論文的一些細(xì)節(jié)和直觀理解,SongYang博士在他自己的博客里給出了非常精彩和直觀的講解。有興趣的讀者可以點(diǎn)開本文初始的第二個(gè)鏈接查看。以下只對大一統(tǒng)視角下的第三種視角做簡短的概括。
第三種推導(dǎo)方式主要基于Tweedie's formula.該公式主要闡述了對于一個(gè)指數(shù)家族的分布的真實(shí)均值,在給定了采樣樣本后,可以通過采樣樣本的最大似然概率(即經(jīng)驗(yàn)均值)加上一個(gè)關(guān)于分?jǐn)?shù)(score)預(yù)估的校正項(xiàng)來預(yù)估。注意score在這里的定義是真實(shí)數(shù)據(jù)分布的對數(shù)似然關(guān)于輸入xt的梯度。即
score的定義
根據(jù)Tweedie's formula,對于一個(gè)高斯變量z~N(mu_z, sigma_z)來說, 該高斯變量的真實(shí)均值的預(yù)估是:
Tweedie's formula對高斯變量的應(yīng)用
我們知道在訓(xùn)練時(shí),模型的輸入xt關(guān)于x0的表達(dá)式如下
上文里的式70
我們也知道根據(jù)Tweedie's formula的高斯變量的真實(shí)均值預(yù)估我們可以得到下式
將式70的方差代入Tweedie's formula
那么聯(lián)立兩式的關(guān)于均值的表達(dá)式后,我們可以得到x0關(guān)于score的表達(dá)式133
將x0寫為關(guān)于score的表達(dá)式
如上一種推導(dǎo)方式所做的一樣,再一次重新將x0的表達(dá)式代入式84對真實(shí)均值mu_q的表達(dá)式里:(注意式135到136的變形主要在分子里最右邊的alpha_bar_t到alpha_t, 約去了根號下alpha_bar_t-1)
將x0的關(guān)于score表達(dá)式代入式84
同樣,將mu_theta采取和mu_q一樣的形式,并用神經(jīng)網(wǎng)絡(luò)s_theta來近似score后, 我們得到了新的mu_theta的表達(dá)式143。
關(guān)于score的mu_theta的表達(dá)式
再再再同樣,和上種推導(dǎo)里的做法一樣,我們再將新的mu_theta, mu_q代入證據(jù)下界里KL散度的損失項(xiàng)我們可以得到一個(gè)最終的優(yōu)化目標(biāo)
將新的mu的表達(dá)式代入證據(jù)下界的優(yōu)化目標(biāo)里
事實(shí)上,比較式148和式130的形式,可以說是非常的接近了。那么我們的score function delta_p(xt)和初始噪聲epsilon_0是否有關(guān)聯(lián)呢?聯(lián)立關(guān)于x0的兩個(gè)表達(dá)式133和115我們可以得到
score function和初始噪聲間的關(guān)系
讀者如果將式151代入148會(huì)發(fā)現(xiàn)和式130等價(jià)!直觀上來講,score function描述的是如何在數(shù)據(jù)空間里最大化似然概率的更新向量。而又因?yàn)槌跏荚肼暿窃谠斎氲幕A(chǔ)上加入的,那么往噪聲的反方向(也是最佳方向)更新實(shí)質(zhì)上等價(jià)于去噪的過程。而數(shù)學(xué)上講,對score function的建模也等價(jià)于對初始噪聲乘上負(fù)系數(shù)的建模!
至此我們終于將擴(kuò)散模型的三個(gè)形式的所有推導(dǎo)整理完畢!即對變分?jǐn)U散模型VDM的訓(xùn)練等價(jià)于訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)來預(yù)測原輸入x0,也等價(jià)于預(yù)測噪聲epsilon, 也等價(jià)于預(yù)測初始輸入在特定時(shí)間步的score delta_logp(xt)。
讀到這里,相比讀者也已經(jīng)發(fā)現(xiàn),不同的推導(dǎo)所得出的不同結(jié)果,都來自于對證據(jù)下界里去噪匹配項(xiàng)的不同推導(dǎo)過程。而不同的變形,基本上都是利用了MHVAE里最開始提到的三點(diǎn)基本假設(shè)所得。
Drawbacks to Consider盡管擴(kuò)散模型在最近兩年成功出圈,引爆了業(yè)界,學(xué)術(shù)界甚至普通人對文本生成圖像的AI模型的關(guān)注,但擴(kuò)散模型這個(gè)體系本身依舊存在著一些缺陷:
- 擴(kuò)散模型本身盡管理論框架已經(jīng)比較完善,公式推導(dǎo)也十分優(yōu)美。但仍然非常不直觀。最起碼從一個(gè)完全噪聲的輸入不斷優(yōu)化的這個(gè)過程和人類的思維過程相去甚遠(yuǎn)。
- 擴(kuò)散模型和GAN或者VAE相比,所學(xué)的潛在向量不具備任何語義和結(jié)構(gòu)的可解釋性。上文提到了擴(kuò)散模型可以看做是特殊的MHVAE,但里面每一層的潛在向量間都是線性高斯的形式,變化有限。
- 而擴(kuò)散模型的潛在向量要求維度與輸入一致這一點(diǎn),則更加死地限制住了潛在向量的表征能力。
- 擴(kuò)散模型的多步迭代導(dǎo)致了擴(kuò)散模型的生成往往耗時(shí)良久。
不過學(xué)術(shù)界對以上的一些難題其實(shí)也提出了不少解決方案。比如擴(kuò)散模型的可解釋性問題。筆者最近就發(fā)現(xiàn)了一些工作將score-matching直接應(yīng)用在了普通VAE的潛在向量的采樣上。這是一個(gè)非常自然的創(chuàng)新點(diǎn),就和數(shù)年前的flow-based-vae一樣。而耗時(shí)良久的問題,今年ICLR的最佳論文也將采樣這個(gè)問題加速和壓縮到了幾十步內(nèi)就可以生成非常高質(zhì)量的結(jié)果。
但是對于擴(kuò)散模型在文本生成領(lǐng)域的應(yīng)用最近似乎還不多,除了prefix-tuning的作者xiang-lisa-li的一篇論文[3]
之外筆者暫未關(guān)注到任何工作。而具體來講,如果將擴(kuò)散模型直接用在文本生成上,仍有諸多不便。比如輸入的尺寸在整個(gè)擴(kuò)散過程必須保持一致就決定了使用者必須事先決定好想生成的文本的長度。而且做有引導(dǎo)的條件生成還好,要用擴(kuò)散模型訓(xùn)練出一個(gè)開放域的文本生成模型恐怕難度不低。
本篇筆記著重的是在探討大一統(tǒng)角度下的擴(kuò)散模型推斷。但具體對score matching如何訓(xùn)練,如何引導(dǎo)擴(kuò)散模型生成我們想要的條件分布還沒有寫出來。筆者打算在下一篇探討最近一些將擴(kuò)散模型應(yīng)用在受控文本生成領(lǐng)域的方法調(diào)研里詳細(xì)記錄和比較一下
補(bǔ)充- 關(guān)于為什么擴(kuò)散核是高斯變換的擴(kuò)散過程的逆過程也是高斯變換的問題,來自清華大神的一篇知乎回答里[4] 給出了比較直觀的解釋。其中第二行是將p_t-1和p_t近似。第三行是對logpt(x_t-1)使用一階泰勒展開消去了logpt(xt)。第四行是直接代入了q(xt|xt-1)的表達(dá)式。于是我們得到了一個(gè)高斯分布的表達(dá)式。
擴(kuò)散的逆過程也是高斯分布
- 在式94和式125,我們都將對真實(shí)高斯分布q的均值mu_q的近似mu_theta建模成了與我們所推導(dǎo)出的mu_q一致的形式,并且將方差設(shè)置為了與q的方差一致的形式。直觀上來講,這樣建模的好處很多,一方面是根據(jù)KL散度對兩個(gè)高斯分布的解析式來說,這樣我們可以約掉和抵消掉絕大部分的項(xiàng),簡化了建模。另一方面真實(shí)分布和近似分布都依賴于xt。在訓(xùn)練時(shí)我們的輸入就是xt,采取和真實(shí)分布形式一樣的表達(dá)式?jīng)]有泄漏任何信息。并且在工程上DDPM也驗(yàn)證了類似的簡化是事實(shí)上可行的。但實(shí)際上可以這樣做的原因背后是從2021年以來的一系列論文里復(fù)雜的數(shù)理證明所在解釋的目標(biāo)。 同樣引用清華大佬[4]的回答:
DDPM里簡化去噪的高斯分布的做法其實(shí)蘊(yùn)含著深刻的道理
- 在DDPM里,其最終的優(yōu)化目標(biāo)是epsilon_t而不是epsilon_0。即預(yù)測的誤差到底是初始誤差還是某個(gè)時(shí)間步上的初始誤差。誰對誰錯(cuò)?實(shí)際上這個(gè)誤解來源于我們對xt關(guān)于x0的表達(dá)式的求解中的誤解。從式63開始的連續(xù)幾步推導(dǎo),都應(yīng)用到了一個(gè)高斯性質(zhì),即兩個(gè)獨(dú)立高斯分布的和的均值與方差等于原分布的均值和與方差和。而實(shí)質(zhì)上我們在應(yīng)用重參數(shù)化技巧求xt的過程中,是遞歸式的不斷引入了新的epsilon來替換遞歸中的x_n里的epsilon。那么到最后,我們所得到的epsilon無非是一個(gè)囊括了所有擴(kuò)散過程中的epsilon。這個(gè)噪聲即可以說是t,也可以說是0,甚至最準(zhǔn)確來說應(yīng)該不等于任何一個(gè)時(shí)間步,就叫做噪聲就好!
DDPM的優(yōu)化目標(biāo)
- 關(guān)于對證據(jù)下界的不同簡化形式。其中我們提到第二種對噪聲的近似是DDPM所采用的建模方式。但是對初始輸入的近似其實(shí)也有論文采用。也就是上文提及的將擴(kuò)散模型應(yīng)用在可控文本生成的論文里[3]所采用的形式。該論文每輪直接預(yù)測初始Word-embedding。而第三種score-matching的角度可以參照SongYang博士的系列論文[5]來看。里面的優(yōu)化函數(shù)的形式用的是第三種。
- 本篇筆記著重于講述擴(kuò)散模型的變分下界的公式推導(dǎo),關(guān)于擴(kuò)散模型與能量模型,朗之萬動(dòng)力學(xué),隨機(jī)微分方程等一系列名詞的關(guān)系本篇筆記并無涉及。 筆者將在另外一篇筆記里梳理相關(guān)的理解。
參考
- ^Improving Variational Inference with Inverse Autoregressive Flow https://arxiv.org/abs/1606.04934
- ^Deep Unsupervised Learning using Nonequilibrium Thermodynamics https://arxiv.org/abs/1503.03585
- ^abDiffusion-LM Improves Controllable Text Generation https://arxiv.org/abs/2205.14217
- ^abdiffusion model最近在圖像生成領(lǐng)域大紅大紫,如何看待它的風(fēng)頭開始超過GAN?- 我想唱high C的回答 - 知乎 https://www.zhihu.com/question/536012286/answer/2533146567
- ^SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS https://arxiv.org/abs/2011.13456
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。