雷鋒網(wǎng) AI 研習社消息,近日,OpenAI 在 GitHub 上開源最新工具包 gradient-checkpointing,該工具包通過設置梯度檢查點(gradient-checkpointing)來節(jié)省內存資源。據(jù)悉,對于普通的前饋模型,可以在計算時間只增加 20% 的情況下,在 GPU 上訓練比之前大十多倍的模型。雷鋒網(wǎng)(公眾號:雷鋒網(wǎng))AI 研習社將該開源信息編譯整理如下:
通過梯度檢查點(gradient-checkpointing)來節(jié)省內存資源
訓練非常深的神經(jīng)網(wǎng)絡需要大量內存,利用 Tim Salimans 和 Yaroslav Bulatov 共同開發(fā)的 gradient-checkpointing 包中的工具,可以犧牲計算時間來解決內存過小的問題,讓你更好地針對模型進行訓練。
對于普通的前饋模型,可以在計算時間只增加 20% 的情況下,在 GPU 上訓練比之前大十多倍的模型。
訓練深度神經(jīng)網(wǎng)絡時,損失的梯度是在內存密集部分通過反向傳播(backpropagation)算法來計算的。在訓練模型時定義計算圖中的檢查點,并在這些檢查點之間通過反向傳播算法重新計算這些圖,可以在降低內存的同時計算梯度值。
當訓練一個 n 層的深度前饋神經(jīng)網(wǎng)絡時,可以利用這種方式將內存消耗減少到 O(sqrt(n)),代價是需要執(zhí)行一個額外的前向傳遞操作。這個庫可以在 Tensorflow 中實現(xiàn)這一功能——使用 Tensorflow graph editor 來自動重寫后向傳遞的計算圖。
圖:使用常規(guī)的 tf.gradients 函數(shù)和使用這種優(yōu)化內存梯度實現(xiàn)法(memory-optimized gradient implementation)訓練不同層數(shù)的 ResNet 模型時需要的內存對比
大家現(xiàn)在就可以安裝
pip install tf-nightly-gpu
pip install toposort networkx pytest
當執(zhí)行這一程序時,需要保證能找到CUPTI。
這時可以執(zhí)行
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/extras/CUPTI/lib64
使用方法
這個庫提供嵌入式功能,能對 tf.gradients 函數(shù)進行替換,可以輸入如下程序來引入相關函數(shù):
from memory_saving_gradients import gradients
大家可以像使用 tf.gradients 函數(shù)一樣使用 gradients 函數(shù)來計算參數(shù)損失的梯度。
gradients 函數(shù)有一個額外的功能——檢查點(checkpoints)。
檢查點會對 gradients 函數(shù)進行指示——在計算圖的前向傳播中,圖中的哪一部分節(jié)點是用戶想要檢查的點。隨后,會在后向傳播中重新計算檢查點之間的節(jié)點。
大家可以為檢查點提供一系列張量(gradients(ys,xs,checkpoints=[tensor1,tensor2])),或者可以使用如下幾個關鍵詞('collection'、'memory' 、'speed')來進行設置。
覆蓋 tf.gradients 函數(shù)
使用 gradients 函數(shù)的另一個方法是直接覆蓋 tf.gradients 函數(shù),方法如下:
import tensorflow as tf
import memory_saving_gradients
# monkey patch tf.gradients to point to our custom version, with automatic checkpoint selection
def gradients_memory(ys, xs, grad_ys=None, **kwargs):
return memory_saving_gradients.gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs)
tf.__dict__["gradients"] = gradients_memory
這樣操作之后,所有調用 tf.gradients 函數(shù)的請求都會使用新的節(jié)省內存的方法。
測試
在測試文件夾中,有已經(jīng)寫好的用于測試代碼準確性和不同模型占用內存的腳本。
大家可以執(zhí)行 ./run_all_tests.sh 來修改代碼,并著手測試。
圖:在CIFAR10數(shù)據(jù)集上,使用常規(guī)的梯度函數(shù)和使用最新的優(yōu)化內存函數(shù),在不同層數(shù)的 ResNet 網(wǎng)絡下的內存占用情況和執(zhí)行時間的對比
via:GitHub