有bug!用Pytorch Lightning重構(gòu)代碼速度更慢,修復(fù)后速度倍增
用了 Lightning 訓(xùn)練速度反而更慢,你遇到過(guò)這種情況嗎?
轉(zhuǎn)自《機(jī)器之心》
PyTorch Lightning 是一種重構(gòu) PyTorch 代碼的工具,它可以抽出代碼中復(fù)雜重復(fù)的部分,使得 AI 研究可擴(kuò)展并且可以快速迭代。然而近日一位名為 Florian Ernst 的博主卻發(fā)現(xiàn) PyTorch Lightning 存在一個(gè) bug——讓原本應(yīng)該加速的訓(xùn)練變得更慢了。
本文作者 Florian Ernst
Ernst 撰寫(xiě)博客詳細(xì)描述了他發(fā)現(xiàn)這個(gè) bug 的過(guò)程,以下是博客原文。
兩周前,我將一些深度學(xué)習(xí)代碼重構(gòu)為 Pytorch Lightning,預(yù)計(jì)大約有 1.5 倍的加速。然而,訓(xùn)練、評(píng)估和測(cè)試任務(wù)的速度卻降為原來(lái)的 1/4。重構(gòu)之后的神經(jīng)網(wǎng)絡(luò)需要運(yùn)行幾天才能得出結(jié)果,因此我想找出原因,并盡可能地減少訓(xùn)練時(shí)間。
事情是這樣的,我使用的是一些開(kāi)源深度學(xué)習(xí)代碼,這些代碼是用來(lái)展示某些機(jī)器學(xué)習(xí)任務(wù)最新架構(gòu)的。然而這些代碼本身既不整潔也沒(méi)進(jìn)行優(yōu)化。我注意到幾個(gè)可以加速的地方,并將代碼重構(gòu)為 Pytorch 代碼,讓訓(xùn)練大約快了 3 倍。
但我認(rèn)為還有改進(jìn)的余地。Pytorch Lightning 是一個(gè)非常好的工具:它刪除了大量樣板代碼,并配備了一些優(yōu)化方法,因此我決定使用 Lightning 重構(gòu)這些代碼。
我原本希望代碼大約能提速 1.5 倍,但完成重構(gòu)時(shí),我驚訝地發(fā)現(xiàn)迭代時(shí)間從 4 秒變成了 15 秒,這使訓(xùn)練時(shí)間多了近 3 倍。
問(wèn)題出在哪里?
我首先運(yùn)行 Lightning 的分析器來(lái)找出問(wèn)題所在。
基礎(chǔ)分析器給了我一個(gè)起點(diǎn):大部分時(shí)間都花在運(yùn)行一個(gè) epoch 上;高級(jí)分析器沒(méi)有給我更多信息。
我想知道我是否在神經(jīng)網(wǎng)絡(luò)上錯(cuò)誤地配置了一些超參數(shù)。我打亂了其中一些超參數(shù),訓(xùn)練速度沒(méi)有任何變化。
然后我調(diào)整了數(shù)據(jù)加載器,發(fā)現(xiàn)改變作業(yè)數(shù) n_jobs 會(huì)對(duì)總訓(xùn)練時(shí)間產(chǎn)生影響。然而影響不是加快了計(jì)算速度,而是減慢了。
隨著 job 數(shù)變化,100 個(gè) epoch 花費(fèi)的時(shí)間。
使用 n_jobs=0 完全禁用多處理使我的迭代幾乎比使用 6 個(gè)內(nèi)核快了 2 倍。默認(rèn)情況下,Pytorch 在兩個(gè) epoch 之間會(huì) kill 掉運(yùn)行中的進(jìn)程(worker)并重新加載,因而需要重新加載數(shù)據(jù)集。
在我這個(gè)例子中,加載數(shù)據(jù)集非常慢。我將 DataLoader 里的 persistent_workers 參數(shù)設(shè)置為 True,以防止運(yùn)行中的進(jìn)程被殺死,進(jìn)而防止重新加載數(shù)據(jù)。
# My data Loader parameters DataLoader( train_dataset, batch_size=64, shuffle=True, num_workers=n_workers, persistent_workers=True, pin_memory=True, )
因此,有兩種可能性:
Pytorch Lightning kill 掉 worker,沒(méi)有考慮 persistent_workers 參數(shù);
問(wèn)題出在別的地方。
我在 GitHub 上創(chuàng)建了一個(gè) issue,希望 Lightning 團(tuán)隊(duì)意識(shí)這個(gè)問(wèn)題,接下來(lái)我要尋找問(wèn)題根源。
GitHub 地址:https://github.com/PyTorchLightning/pytorch-lightning/issues/10389
尋找問(wèn)題根源
Lightning 的 profiler 與上下文管理器一起運(yùn)行并計(jì)算給定塊花費(fèi)的時(shí)間。它可以輕松搜索特定的 profiler 操作,以運(yùn)行「run_training_epoch」為例 。
我開(kāi)始探究 Lightning 源碼,查看導(dǎo)致循環(huán)(loops)變慢的指令,我發(fā)現(xiàn)了一些問(wèn)題:Loop.run 調(diào)用 Loop.on_run_start、Loop.on_run_start 重新加載 dataloader,如下圖所示:
Loop.run 調(diào)用 Loop.on_run_start…
Loop.on_run_start 重新調(diào)用 dataloader
問(wèn)題看起來(lái)確實(shí)來(lái)自在每個(gè) epoch 中重新加載 DataLoader。查看 DataLoader 的源碼,發(fā)現(xiàn)是這樣的:
當(dāng)使用 persistent_workers > 0 迭代 DataLoader 時(shí),如果_iterator` 為 None,則使用_get_iterator() 重新加載整個(gè)數(shù)據(jù)集??梢源_定的是 Pytorch Lightning 錯(cuò)誤地重置了 _iterator,從而導(dǎo)致了這個(gè)問(wèn)題。
為了證實(shí)這一發(fā)現(xiàn),我用一個(gè)自定義的只能重載的__iter__方法替換了 DataLoader:
正如預(yù)期的那樣,在迭代之后,_iterator 屬性被正確設(shè)置,但在下一個(gè) epoch 開(kāi)始之前被重置為 None。
n_jobs=1,persistent_workers=True
現(xiàn)在,我只需要知道屬性何時(shí)被設(shè)置為 None ,這樣就可找到問(wèn)題的根源。我嘗試使用調(diào)試器,但由于多進(jìn)程或 CUDA 而導(dǎo)致程序崩潰。我開(kāi)始采用 Python 的 getter & setter 用法:
當(dāng) DataLoader._iterator 設(shè)置為 None 時(shí),將會(huì)打印 stack trace
這樣做非常有效,會(huì)輸出如下內(nèi)容:
File "trainer\trainer.py", line 1314, in _run_train
self.fit_loop.run()
...
File "loops\fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "loops\base.py", line 139, in run
self.on_run_start(*args, **kwargs)
File "loops\epoch\training_epoch_loop.py", line 142, in on_run_start
self._dataloader_iter = _update_dataloader_iter(...)
File "loops\utilities.py", line 121, in _update_dataloader_iter
dataloader_iter = enumerate(data_fetcher, batch_idx)
File "utilities\fetching.py", line 198, in __iter__
self.reset()
File "utilities\fetching.py", line 212, in reset
self.dataloader.reset()
...
File "trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
dataloader._iterator = None
通過(guò)跟蹤發(fā)現(xiàn)每次開(kāi)始運(yùn)行時(shí)都會(huì)調(diào)用 DataLoader.reset。通過(guò)深入研究代碼后,我發(fā)現(xiàn)每次迭代都會(huì)重置 DataFetcher,從而導(dǎo)致 DataLoader 也被重置。代碼中沒(méi)有條件來(lái)避免重置:每個(gè) epoch 都必須重置 DataLoader。
這就是我發(fā)現(xiàn)迭代緩慢的根本原因。
修復(fù) bug
既然發(fā)現(xiàn)了 bug,就要想辦法修復(fù)。修復(fù) bug 非常簡(jiǎn)單:我將 self.reset 行從 DataFetcher 的__iter__ 方法中移除:
通過(guò)修改后再次訓(xùn)練,現(xiàn)在一次迭代只需要 1.5 秒,而此前需要 15 秒,使用 vanilla Pytorch 也需要 3 秒,相比較而言,速度確實(shí)提升了很多。
圖片
我將發(fā)現(xiàn)的這個(gè) bug 報(bào)告給了 Lightning 團(tuán)隊(duì),他們對(duì)問(wèn)題進(jìn)行了修復(fù)并在第二天推送了修補(bǔ)程序。我隨后更新了庫(kù),更新后發(fā)現(xiàn)他們的修復(fù)確實(shí)有效。相信更多人將從這次修復(fù)中受益,并且他們的 Lightning 模型的訓(xùn)練和測(cè)試時(shí)間會(huì)得到改善。如果你最近還沒(méi)有更新依賴項(xiàng),請(qǐng)嘗試安裝 pytorch-lightning==1.5.1 或更高版本!
原文鏈接:https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。