本周,Facebook 的 AI 研究團隊發(fā)布了一個 Python 工具包,專門針對 GPU 加速的深度神經(jīng)網(wǎng)絡(luò)(DNN)編程。它有望輔助、或在一定程度上替代,現(xiàn)有的 Python 數(shù)學、統(tǒng)計庫(比如 NumPy)。它實現(xiàn)了機器學習框架 Torch 在 Python 語言環(huán)境的執(zhí)行。開發(fā)團隊表示,除 Facebook之外,它還已經(jīng)被推特、卡內(nèi)基梅隆大學和 Salesforce 等機構(gòu)采用。
使用 Pytorch 的機構(gòu)
Torch 是一個十分老牌、對多維矩陣數(shù)據(jù)進行操作的張量(tensor )庫,在機器學習和其他數(shù)學密集型應(yīng)用有廣泛應(yīng)用。但由于其語言采用 Lua,導致在國內(nèi)一直很小眾,并逐漸被支持 Python 的 Tensorflow 搶走用戶。如今,作為經(jīng)典機器學習庫 Torch 的端口,PyTorch 為 Python 語言使用者提供了舒適的寫代碼選擇。雷鋒網(wǎng)此前對 Torch 做過介紹。詳情請看盤點四大民間機器學習開源框架:Theano、Caffe、Torch 和 SciKit-learn。
PyTorch 的特點和優(yōu)勢
PyTorch 提供了:
運行在 GPU 或 CPU 之上、基礎(chǔ)的張量操作庫,
內(nèi)置的神經(jīng)網(wǎng)絡(luò)庫
模型訓練功能
支持共享內(nèi)存的多進程并發(fā)(multiproCESsing )庫。PyTorch 開發(fā)團隊表示:這對數(shù)據(jù)載入和 hogwild 訓練十分有幫助。
PyTorch 的首要優(yōu)勢是,它處于機器學習第一大語言 Python 的生態(tài)圈之中,使得開發(fā)者能接入廣大的 Python 庫和軟件。因此,Python 開發(fā)者能夠用他們熟悉的風格寫代碼,而不需要針對外部 C 語言或 C++ 庫的 wrapper,使用它的專門語言。雷鋒網(wǎng)獲知,現(xiàn)有的工具包可以與 PyTorch 一起運行,比如 NumPy、SciPy 和 Cython(為了速度把 Python 編譯成 C 語言)。
PyTorch 還為改進現(xiàn)有的神經(jīng)網(wǎng)絡(luò),提供了更快速的方法——不需要從頭重新構(gòu)建整個網(wǎng)絡(luò)。這是由于 PyTorch 采用了動態(tài)計算圖(dynamic computational graph)結(jié)構(gòu),而不是大多數(shù)開源框架,比如 TensorFlow、Caffe、CNTK、Theano 等采用的靜態(tài)計算圖。雷鋒網(wǎng)(公眾號:雷鋒網(wǎng))獲知,該技術(shù)從另一個 Python 的神經(jīng)網(wǎng)絡(luò)框架——Chainer 那里借用。開發(fā)者團隊還強調(diào) PyTorch 優(yōu)越的內(nèi)存效率,因為它采用了定制的 GPU 內(nèi)存分配器。這使得開發(fā)者的深度學習模型能夠有“最大限度的內(nèi)存效能”,訓練比從前更大的深度神經(jīng)網(wǎng)絡(luò)。
雖然 PyTorch 為機器學習應(yīng)用而優(yōu)化,這并不是它的唯一使用場景。比如說,相比 NumPy ,PyTorch 的張量計算可作為它對應(yīng)功能的替代。PyTorch 為這些功能提供了 GPU 加速的版本。在沒有強力 GPU 加持的情況下,開發(fā)者能使用 CPU 運行。
這是 PyTorch 中包含的工具包列表:
torch :類似 NumPy 的張量庫,強 GPU 支持
torch.autograd :基于 tape 的自動區(qū)別庫,支持 torch 之中的所有可區(qū)分張量運行。
torch.nn :為最大化靈活性未涉及、與 autograd 深度整合的神經(jīng)網(wǎng)絡(luò)庫
torch.optim:與 torch.nn 一起使用的優(yōu)化包,包含 SGD, RMSProp, LBFGS, Adam 等標準優(yōu)化方式
torch.multiprocessing: python 多進程并發(fā),進程之間 torch Tensors 的內(nèi)存共享。
torch.utils:數(shù)據(jù)載入器。具有訓練器和其他便利功能。 Trainer and other utility functions for convenience
torch.legacy(.nn/.optim) :處于向后兼容性考慮,從 Torch 移植來的 legacy 代碼。