獨(dú)家 | 使用TensorFlow 2創(chuàng)建自定義損失函數(shù)
作者:Arjun Sarkar
翻譯:陳之炎
校對(duì):歐陽(yáng)錦
神經(jīng)網(wǎng)絡(luò)利用訓(xùn)練數(shù)據(jù),將一組輸入映射成一組輸出,它通過(guò)使用某種形式的優(yōu)化算法,如梯度下降、隨機(jī)梯度下降、AdaGrad、AdaDelta等等來(lái)實(shí)現(xiàn),其中最新的算法包括Adam、Nadam或RMSProp。梯度下降中的“梯度”是指誤差梯度。每次迭代之后,網(wǎng)絡(luò)將其預(yù)測(cè)輸出與實(shí)際輸出進(jìn)行比較,然后計(jì)算出“誤差”。
通常,對(duì)于神經(jīng)網(wǎng)絡(luò),尋求的是將誤差最小化。將誤差最小化的目標(biāo)函數(shù)通常稱(chēng)之為成本函數(shù)或損失函數(shù),由“損失函數(shù)”計(jì)算出的值稱(chēng)為“損失”。在各種問(wèn)題中使用的典型損失函數(shù)有:
均方誤差;
均方對(duì)數(shù)誤差;
二元交叉熵;
分類(lèi)交叉熵;
稀疏分類(lèi)交叉熵。
Tensorflow已經(jīng)包含了上述損失函數(shù),直接調(diào)用它們即可,如下所示:
1. 將損失函數(shù)當(dāng)作字符串進(jìn)行調(diào)用
model.compile (loss = ‘binary_crossentropy’,optimizer = ‘a(chǎn)dam’, metrics = [‘a(chǎn)ccuracy’])
2. 將損失函數(shù)當(dāng)作對(duì)象進(jìn)行調(diào)用
from tensorflow.keras.losses importmean_squared_error model.compile(loss = mean_squared_error,optimizer=’sgd’)
將損失函數(shù)當(dāng)作對(duì)象進(jìn)行調(diào)用的優(yōu)點(diǎn)是可以在損失函數(shù)中傳遞閾值等參數(shù)。
from tensorflow.keras.losses import mean_squared_error model.compile (loss=mean_squared_error(param=value),optimizer = ‘sgd’)
利用現(xiàn)有函數(shù)創(chuàng)建自定義損失函數(shù):
利用現(xiàn)有函數(shù)創(chuàng)建損失函數(shù),首先需要定義損失函數(shù),它將接受兩個(gè)參數(shù),y_true(真實(shí)標(biāo)簽/輸出)和y_pred(預(yù)測(cè)標(biāo)簽/輸出)。
def loss_function(y_true, y_pred): ***some calculation*** return loss
創(chuàng)建均方誤差損失函數(shù) (RMSE):
定義損失函數(shù)名稱(chēng)-my_rmse。目的是返回目標(biāo)(y_true)與預(yù)測(cè)(y_pred)之間的均方誤差。
RMSE的公式為:
誤差:真實(shí)標(biāo)簽與預(yù)測(cè)標(biāo)簽之間的差異。
sqr_error:誤差的平方。
mean_sqr_error:誤差平方的均值。
sqrt_mean_sqr_error:誤差平方均值的平方根(均方根誤差)。
創(chuàng)建Huber損失函數(shù):
圖2:Huber損失函數(shù)(綠色)和平方誤差損失函數(shù)(藍(lán)色)(來(lái)源:Qwertyus— Own work,CCBY-SA4.0,https://commons.wikimedia.org/w/index.php?curid=34836380)
Huber損失函數(shù)的計(jì)算公式:
在此處,δ是閾值,a是誤差(將計(jì)算出a,即實(shí)際標(biāo)簽和預(yù)測(cè)標(biāo)簽之間的差異)。
當(dāng)|a|≤δ時(shí),loss = 1/2*(a)2
當(dāng) |a|>δ時(shí),loss = δ(|a|—(1/2)*δ)
源代碼:
詳細(xì)說(shuō)明:
首先,定義一個(gè)函數(shù)—— my huber loss,它需要兩個(gè)參數(shù):y_true和y_pred,
設(shè)置閾值threshold = 1。
計(jì)算誤差error a = y_true-y_pred。接下來(lái),檢查誤差的絕對(duì)值是否小于或等于閾值,is_small_error返回一個(gè)布爾值(真或假)。
當(dāng)|a|≤δ時(shí),loss= 1/2*(a)2,計(jì)算small_error_loss, 誤差的平方除以2。否則,當(dāng)|a| >δ時(shí),則損失等于δ(|a|-(1/2)*δ),用big_error_loss來(lái)計(jì)算這個(gè)值。
最后,在返回語(yǔ)句中,首先檢查is_small_error是真還是假,如果它為真,函數(shù)返回small_error_loss,否則返回big_error_loss,使用tf.where來(lái)實(shí)現(xiàn)。
可以使用下述代碼來(lái)編譯模型:
在上述代碼中,將閾值設(shè)為1。
如果需要調(diào)整超參數(shù)(閾值),并在編譯過(guò)程中加入一個(gè)新的閾值的話,必須使用wrapper函數(shù)進(jìn)行封裝,也就是說(shuō),將損失函數(shù)封裝成另一個(gè)外部函數(shù)。在這里需要用到封裝函數(shù)(wrapper function),因?yàn)閾p失函數(shù)在默認(rèn)情況下只能接受y_true和y_pred值,而且不能向原始損失函數(shù)添加任何其他參數(shù)。
使用封裝后的Huber損失函數(shù)
封裝函數(shù)的源代碼:
此時(shí),閾值不是硬編碼,可以在模型編譯過(guò)程中傳遞該閾值。
使用類(lèi)實(shí)現(xiàn)Huber損失函數(shù)(OOP)
其中,MyHuberLoss是類(lèi)名稱(chēng),隨后從tensorflow.keras.losses繼承父類(lèi)“Loss”, MyHuberLoss繼承了Loss類(lèi),之后可以將MyHuberLoss當(dāng)作損失函數(shù)來(lái)使用。
__init__ 初始化該類(lèi)中的對(duì)象。執(zhí)行類(lèi)實(shí)例化對(duì)象時(shí)調(diào)用函數(shù),init函數(shù)返回閾值,調(diào)用函數(shù)得到y(tǒng)_true和y_pred參數(shù),將閾值聲明為一個(gè)類(lèi)變量,可以給它賦一個(gè)初始值。
在__init__函數(shù)中,將閾值設(shè)置為self.threshold。在調(diào)用函數(shù)中,self.threshold引用所有的閾值類(lèi)變量。在model.compile中使用這個(gè)損失函數(shù):
創(chuàng)建對(duì)比性損失(用于Siamese網(wǎng)絡(luò)):
Siamese網(wǎng)絡(luò)可以用來(lái)比較兩幅圖像是否相似,Siamese網(wǎng)絡(luò)使用的損失函數(shù)為對(duì)比性損失。
在上文的公式中,Y_true是關(guān)于圖像相似性細(xì)節(jié)的張量,如果圖像相似,則為1,如果圖像不相似,則為0。
D是圖像對(duì)之間的歐氏距離的張量。邊際為一個(gè)常量,用它來(lái)設(shè)置將圖像區(qū)別為相似或不同的最小距離。如果為Y_true=1,則方程的第一部分為D2,第二部分為0,所以,當(dāng)Y_true接近1時(shí),D2的權(quán)重則更重。
如果Y_true=0,則方程的第一部分變?yōu)?,第二部分會(huì)產(chǎn)生一些結(jié)果,這給了最大項(xiàng)更多的權(quán)重,給了D平方項(xiàng)更少的權(quán)重,此時(shí),最大項(xiàng)在損失計(jì)算中占了優(yōu)勢(shì)。
使用封裝器函數(shù)實(shí)現(xiàn)對(duì)比損失函數(shù):
結(jié)論
在Tensorflow中沒(méi)有的損失函數(shù)都可以利用函數(shù)、包裝函數(shù)或類(lèi)似的類(lèi)來(lái)創(chuàng)建。
原文標(biāo)題:
Creating custom Loss functionsusing TensorFlow 2
原文鏈接:
https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。
負(fù)離子發(fā)生器相關(guān)文章:負(fù)離子發(fā)生器原理 離子色譜儀相關(guān)文章:離子色譜儀原理