導讀:“我叫 Jacob,是谷歌 AI Residency 項目的學者。2017 年夏天我進入這個項目的時候,我自己的編程經(jīng)驗很豐富,對機器學習理解也很深刻,但以前我從未使用過 Tensorflow。當時我認為憑自己的能力可以很快掌握 Tensorflow,但沒想到我學習它的過程竟然如此跌宕起伏。甚至加入項目幾個月后我還偶爾會感到困惑,不知道怎樣用 Tensorflow 代碼實現(xiàn)自己的新想法。
這篇博文就像是我給過去自己寫的瓶中信:回顧當初,我希望在開始學習的時候有這樣一篇入門介紹。我也希望本文能夠幫助同行,為他們提供參考?!盇I 前線將這位現(xiàn)谷歌大腦工程師關于學習 Tensorflow 過程中遭遇的方方面面難題的文章做了翻譯,希望對大家有幫助。
過去的教程缺少哪些內(nèi)容?
Tensorflow 發(fā)布已經(jīng)有三年,如今它已成為深度學習生態(tài)系統(tǒng)的基石。然而對于初學者來說它并不怎么簡單易懂,與 PyTorch 或 DyNet 這樣的運行即定義的神經(jīng)網(wǎng)絡庫相比就更明顯了。
有很多 Tensorflow 的入門教程,內(nèi)容涵蓋線性回歸、MNIST 分類乃至機器翻譯。這些內(nèi)容具體、實用的指南能幫助人們快速啟動并運行 Tensorflow 項目,并且可以作為類似項目的切入點。但有的開發(fā)者開發(fā)的應用并沒有很好的教程參考,還有的項目在探索全新的路線(研究中很常見),對于這些開發(fā)者來說入門 Tensorflow 是非常容易感到困惑的。
我寫這篇文章就想彌補這一缺口。本文不會研究某個具體任務,而是提出更加通用的方法,并解析 Tensorflow 的基礎抽象概念。掌握好這些概念后,用 Tensorflow 進行深度學習就會更加直觀易懂。
目標受眾
本教程適用于在編程和機器學習方面有一定經(jīng)驗,并想要入門 Tensorflow 的從業(yè)者。他們可以是:想在深度學習課程的最后一個項目中使用 Tensorflow 的 CS 專業(yè)學生;剛剛被調(diào)到涉及深度學習的項目的軟件工程師;或者是一位處于困惑之中的 Google AI 新手(向 Jacob 大聲打個招呼吧)。如果你需要基礎知識入門,請參閱以下資源。這些都了解的話,我們就開始吧!
理解 Tensorflow
Tensorflow 不是一個普通的 Python 庫。
大多數(shù) Python 庫被編寫為 Python 的自然擴展形式。當你導入一個庫時,你得到的是一組變量、函數(shù)和類,它們補充并擴展了你的代碼“工具箱”。使用這些庫時,你知道它們將產(chǎn)生怎樣的結果。我認為談及 Tensorflow 時應該拋棄這些認識,這些認知從根本上就不符合 Tensorflow 的理念,無法反映 TF 與其它代碼交互的方式。
Python 和 Tensorflow 之間的聯(lián)系,可以類比 Javascript 和 HTML 之間的關系。Javascript 是一種全功能的編程語言,可以實現(xiàn)各種出色的效果。HTML 是用于表示某種類型的實用計算抽象(這里指的是可由 Web 瀏覽器呈現(xiàn)的內(nèi)容)的框架。Javascript 在交互式網(wǎng)頁中的作用是組裝瀏覽器看到的 HTML 對象,然后在需要時通過將其更新為新的 HTML 來與其交互。
與 HTML 類似,Tensorflow 是用于表示某種類型的計算抽象(稱為“計算圖”)的框架。當我們用 Python 操作 Tensorflow 時,我們用 Python 代碼做的第一件事是組裝計算圖。之后我們的第二個任務就是與它進行交互(使用 Tensorflow 的“會話”)。但重要的是,要記住計算圖不在變量內(nèi)部,它處在全局命名空間內(nèi)。莎士比亞曾經(jīng)說過:“所有的 RAM 都是一個階段,所有的變量都只不過是指針?!?/p>
第一個關鍵抽象:計算圖
我們在瀏覽 Tensorflow 文檔時,有時會發(fā)現(xiàn)內(nèi)容提到“圖形”和“節(jié)點”。如果你仔細閱讀、深入挖掘,甚至可能已經(jīng)發(fā)現(xiàn)了這個頁面,該頁面中涵蓋的內(nèi)容我將以更精確和技術化的風格詳細解釋。本節(jié)將從頂層入手,把握關鍵的直覺概念,同時略過一些技術細節(jié)。
那么什么是計算圖?它實質(zhì)上是一個全局數(shù)據(jù)結構:計算圖是一個有向圖,捕獲有關計算方法的指令。
我們來看看如何構建一個示例。下圖中,上半部分是我們運行的代碼和它的輸出,下半部分是結果計算圖。
顯然,僅僅導入 Tensorflow 并不會給我們生成一個有趣的計算圖,而只有一個孤獨的,空白的全局變量。但是當我們調(diào)用一個 Tensorflow 操作時會發(fā)生什么呢?
快看!我們得到了一個節(jié)點,它包含常量:2。我知道你很驚訝,驚訝的是一個名為 tf.constant 的函數(shù)。當我們打印這個變量時,我們看到它返回一個 tf.Tensor 對象,它是一個指向我們剛創(chuàng)建的節(jié)點的指針。為了強調(diào)這一點,這里是另一個例子:
每次我們調(diào)用 tf.constant 的時候,我們都會在圖中創(chuàng)建一個新節(jié)點。即使節(jié)點在功能上與現(xiàn)有節(jié)點完全相同,即使我們將節(jié)點重新分配給同一個變量,甚至我們根本沒有將其分配給變量,結果都一樣。
相反,如果創(chuàng)建一個新變量并將其設置為與現(xiàn)有節(jié)點相等,則只需將該指針復制到該節(jié)點,并且不會向該圖添加任何內(nèi)容:
好的,我們更進一步。
現(xiàn)在我們來看——這才是我們要的真正的計算圖表!請注意,+ 操作在 Tensorflow 中過載,所以同時添加兩個張量會在圖中增加一個節(jié)點,盡管它看起來不像是 Tensorflow 操作。
好的,所以 two_node 指向包含 2 的節(jié)點,three_node 指向包含 3 的節(jié)點,而 sum_node 指向包含... + 的節(jié)點?什么情況?它不是應該包含 5 嗎?
事實證明,沒有。計算圖只包含計算步驟,不包含結果。至少...... 還沒有!
第二個關鍵抽象:會話
如果錯誤地理解 TensorFlow 抽象也有個瘋狂三月競賽(美國大學籃球繁忙冠軍賽季),那么“會話”將成為每年排名第一的種子選手。能獲此尷尬的榮譽,是因為會話的命名反直覺,應用卻如此廣泛——幾乎每個 Tensorflow 程序都至少會調(diào)用一次 tf.Session () 。
會話的作用是處理內(nèi)存分配和優(yōu)化,使我們能夠實際執(zhí)行由圖形指定的計算。可以將計算圖想象為我們想要執(zhí)行的計算的“模板”:它列出了所有的步驟。為了使用這個圖表,我們還需要發(fā)起一個會話,它使我們能夠實際地完成任務。例如,遍歷模板的所有節(jié)點來分配一組用于存儲計算輸出的存儲器。為了使用 Tensorflow 進行各種計算,我們既需要圖也需要會話。
會話包含一個指向全局圖的指針,該指針通過指向所有節(jié)點的指針不斷更新。這意味著在創(chuàng)建節(jié)點之前還是之后創(chuàng)建會話都無所謂。
創(chuàng)建會話對象后,可以使用 sess.run (node) 返回節(jié)點的值,并且 Tensorflow 將執(zhí)行確定該值所需的所有計算。
精彩!我們還可以傳遞一個列表,sess.run ([node1,node2,...]),并讓它返回多個輸出:
一般來說,sess.run () 調(diào)用往往是最大的 TensorFlow 瓶頸之一,所以調(diào)用它的次數(shù)越少越好??梢缘脑捲谝粋€ sess.run () 調(diào)用中返回多個項目,而不是進行多個調(diào)用。
占位符和 feed_dict
我們迄今為止所做的計算一直很乏味:沒有機會獲得輸入,所以它們總是輸出相同的東西。一個實用的應用可能涉及構建這樣一個計算圖:它接受輸入,以某種(一致)方式處理它,并返回一個輸出。
最直接的方法是使用占位符。占位符是一種用于接受外部輸入的節(jié)點。
……這是個糟糕的例子,因為它引發(fā)了一個異常。占位符預計會被賦予一個值,但我們沒有提供,因此 Tensorflow 崩潰了。
為了提供一個值,我們使用 sess.run () 的 feed_dict 屬性。
好多了。注意傳遞給 feed_dict 的數(shù)值格式。這些鍵應該是與圖中占位符節(jié)點相對應的變量(如前所述,它實際上意味著指向圖中占位符節(jié)點的指針)。相應的值是要分配給每個占位符的數(shù)據(jù)元素——通常是標量或 Numpy 數(shù)組。第三個關鍵抽象:計算路徑下面是另一個使用占位符的例子:
為什么第二次調(diào)用 sess.run () 會失敗?我們并沒有在檢查 input_placeholder,為什么會引發(fā)與 input_placeholder 相關的錯誤?答案在于最終的關鍵 Tensorflow 抽象:計算路徑。還好這個抽象非常直觀。
當我們在依賴于圖中其他節(jié)點的節(jié)點上調(diào)用 sess.run () 時,我們也需要計算這些節(jié)點的值。如果這些節(jié)點有依賴關系,那么我們需要計算這些值(依此類推......),直到達到計算圖的“頂端”,也就是所有的節(jié)點都沒有前置節(jié)點的情況。
考察 sum_node 的計算路徑:
所有三個節(jié)點都需要評估以計算 sum_node 的值。最重要的是,這里面包含了我們未填充的占位符,并解釋了例外情況!
相反,考察 three_node 的計算路徑:
根據(jù)圖的結構,我們不需要計算所有的節(jié)點也可以評估我們想要的節(jié)點!因為我們不需要評估 placeholder_node 來評估 three_node,所以運行 sess.run (three_node) 不會引發(fā)異常。
Tensorflow 僅通過必需的節(jié)點自動路由計算這一事實是它的巨大優(yōu)勢。如果計算圖非常大并且有許多不必要的節(jié)點,它就能節(jié)約大量運行時間。它允許我們構建大型的“多用途”圖形,這些圖形使用單個共享的核心節(jié)點集合根據(jù)采取的計算路徑來做不同的任務。對于幾乎所有應用程序而言,根據(jù)所采用的計算路徑考慮 sess.run () 的調(diào)用方法是很重要的。
變量和副作用
到目前為止,我們已經(jīng)看到兩種類型的“無祖先”節(jié)點:tf.constant(每次運行都一樣)和 tf.placeholder(每次運行都不一樣)。還有第三種節(jié)點:通常情況下具有相同的值,但也可以更新成新值。這個時候就要用到變量。
了解變量對于使用 Tensorflow 進行深度學習來說至關重要,因為模型的參數(shù)就是變量。在訓練期間,你希望通過梯度下降在每個步驟更新參數(shù),但在計算過程中,你希望保持參數(shù)不變,并將大量不同的測試輸入集傳入到模型中。模型所有的可訓練參數(shù)很有可能都是變量。
要創(chuàng)建變量,請使用 tf.get_variable ()。tf.get_variable () 的前兩個參數(shù)是必需的,其余是可選的。它們是 tf.get_variable (name,shape)。name 是一個唯一標識這個變量對象的字符串。它在全局圖中必須是唯一的,所以要確保不會出現(xiàn)重復的名稱。shape 是一個與張量形狀相對應的整數(shù)數(shù)組,它的語法很直觀——每個維度對應一個整數(shù),并按照排列。例如,一個 3&TImes;8 的矩陣可能具有形狀 [3,8]。要創(chuàng)建標量,請使用空列表作為形狀:[]。
發(fā)現(xiàn)另一個異常。一個變量節(jié)點在首次創(chuàng)建時,它的值基本上就是“null”,任何嘗試對它進行計算的操作都會拋出這個異常。我們只能先給一個變量賦值后才能用它做計算。有兩種主要方法可以用于給變量賦值:初始化器和 tf.assign ()。我們先看看 tf.assign ():
與我們迄今為止看到的節(jié)點相比,tf.assign (target,value) 有一些獨特的屬性:
標識操作。tf.assign (target,value) 不做任何計算,它總是與 value 相等。
副作用。當計算“流經(jīng)”assign_node 時,就會給圖中的其他節(jié)點帶來副作用。在這種情況下,副作用就是用保存在 zero_node 中的值替換 count_variable 的值。
非依賴邊。即使 count_variable 節(jié)點和 assign_node 在圖中是相連的,兩者都不依賴于其他節(jié)點。這意味著在計算任一節(jié)點時,計算不會通過該邊回流。不過,assign_node 依賴 zero_node,它需要知道要分配什么。
“副作用”節(jié)點充斥在大部分 Tensorflow 深度學習工作流中,因此,請確保你對它們了解得一清二楚。當我們調(diào)用 sess.run (assign_node) 時,計算路徑將經(jīng)過 assign_node 和 zero_node。
當計算流經(jīng)圖中的任何節(jié)點時,它還會讓該節(jié)點控制的副作用(綠色所示)起效。由于 tf.assign 的特殊副作用,與 count_variable(之前為“null”)關聯(lián)的內(nèi)存現(xiàn)在被永久設置為 0。這意味著,當我們下一次調(diào)用 sess.run (count_variable) 時,不會拋出任何異常。相反,我們將得到 0。
接下來,讓我們來看看初始化器:
這里都發(fā)生了什么?為什么初始化器不起作用?
問題在于會話和圖之間的分隔。我們已經(jīng)將 get_variable 的 iniTIalizer 屬性指向 const_init_node,但它只是在圖中的節(jié)點之間添加了一個新的連接。我們還沒有做任何與導致異常有關的事情:與變量節(jié)點(保存在會話中,而不是圖中)相關聯(lián)的內(nèi)存仍然為“null”。我們需要通過會話讓 const_init_node 更新變量。
為此,我們添加了另一個特殊節(jié)點:init = tf.global_variables_iniTIalizer ()。與 tf.assign () 類似,這是一個帶有副作用的節(jié)點。與 tf.assign () 不一樣的是,我們實際上并不需要指定它的輸入!tf.global_variables_iniTIalizer () 將在其創(chuàng)建時查看全局圖,自動將依賴關系添加到圖中的每個 tf.initializer 上。當我們調(diào)用 sess.run (init) 時,它會告訴每個初始化器完成它們的任務,初始化變量,這樣在調(diào)用 sess.run (count_variable) 時就不會出錯。
變量共享
你可能會碰到帶有變量共享的 Tensorflow 代碼,代碼有它們的作用域,并設置“reuse=True”。我強烈建議你不要在代碼中使用變量共享。如果你想在多個地方使用單個變量,只需要使用指向該變量節(jié)點的指針,并在需要時使用它。換句話說,對于打算保存在內(nèi)存中的每個參數(shù),應該只調(diào)用一次 tf.get_variable ()。
優(yōu)化器
最后:進行真正的深度學習!如果你還在狀態(tài),那么其余的概念對于你來說應該是非常簡單的。
在深度學習中,典型的“內(nèi)循環(huán)”訓練如下:
獲取輸入和 true_output
根據(jù)輸入和參數(shù)計算出一個“猜測”
根據(jù)猜測和 true_output 之間的差異計算出一個“損失”
根據(jù)損失的梯度更新參數(shù)
讓我們把所有東西放在一個腳本里,解決一個簡單的線性回歸問題:
正如你所看到的,損失基本上沒有變化,而且我們對真實參數(shù)有了很好的估計。這部分代碼只有一兩行對你來說是新的:
既然你對 Tensorflow 的基本概念已經(jīng)有了很好的理解,這段代碼應該很容易解釋!第一行,optimizer = tf.train.GradientDescentOptimizer (1e-3) 不會向圖中添加節(jié)點。它只是創(chuàng)建了一個 Python 對象,包含了一些有用的函數(shù)。第二行 train_op = optimizer.minimize (loss),將一個節(jié)點添加到圖中,并將一個指針賦給 train_op。train_op 節(jié)點沒有輸出,但有一個非常復雜的副作用:
train_op 回溯其輸入的計算路徑,尋找變量節(jié)點。對于找到的每個變量節(jié)點,它計算與損失相關的變量梯度。然后,它為該變量計算新值:當前值減去梯度乘以學習率。最后,它執(zhí)行一個賦值操作來更新變量的值。
基本上,當我們調(diào)用 sess.run (train_op) 時,它為我們對所有的變量做了一個梯度下降的操作。當然,我們還需要使用 feed_dict 來填充輸入和輸出占位符,并且我們還希望打印這些損失,因為這樣方便調(diào)試。
用 tf.Print 進行調(diào)試
當你開始使用 Tensorflow 做更復雜的事情時,你需要進行調(diào)試。一般來說,檢查計算圖中發(fā)生了什么是很困難的。你不能使用常規(guī)的 Python 打印語句,因為你永遠無法訪問到要打印的值——它們被鎖定在 sess.run () 調(diào)用中。舉個例子,假設你想檢查一個計算的中間值,在調(diào)用 sess.run () 之前,中間值還不存在。但是,當 sess.run () 調(diào)用返回時,中間值不見了!
我們來看一個簡單的例子。
我們看到了結果是 5。但是,如果我們想檢查中間值 two_node 和 three_node,該怎么辦?檢查中間值的一種方法是向 sess.run () 添加一個返回參數(shù),該參數(shù)指向要檢查的每個中間節(jié)點,然后在返回后打印它。
這樣做通常沒有問題,但當代碼變得越來越復雜時,這可能有點尷尬。更方便的方法是使用 tf.Print 語句。令人困惑的是,tf.Print 實際上是 Tensorflow 的一種節(jié)點,它有輸出和副作用!它有兩個必需的參數(shù):一個要復制的節(jié)點和一個要打印的內(nèi)容列表?!耙獜椭频墓?jié)點”可以是圖中的任何節(jié)點,tf.Print 是與“要復制的節(jié)點”相關的標識操作,也就是說,它將輸出其輸入的副本。不過,它有個副作用,就是會打印“打印清單”里所有的值。
有關 tf.Print 的一個重要卻有些微妙的點:打印其實只是它的一個副作用。與所有其他副作用一樣,只有在計算流經(jīng) tf.Print 節(jié)點時才會進行打印。如果 tf.Print 節(jié)點不在計算路徑中,則不會打印任何內(nèi)容。即使 tf.Print 節(jié)點正在復制的原始節(jié)點位于計算路徑上,但 tf.Print 節(jié)點本身可能不是。這個問題要注意!當這種情況發(fā)生時,它會讓你感到非常沮喪,你需要費力地找出問題所在。一般來說,最好在創(chuàng)建要復制的節(jié)點后立即創(chuàng)建 tf.Print 節(jié)點。
結論
希望這篇文章能夠幫助你更好地理解 Tensorflow,了解它的工作原理以及如何使用它。畢竟,這里介紹的概念對所有 Tensorflow 程序來說都很重要,但這些還都只是表面上的東西。在你的 Tensorflow 探險之旅中,你可能會遇到各種你想要使用的其他有趣的東西:條件、迭代、分布式 Tensorflow、變量作用域、保存和加載模型、多圖、多會話和多核數(shù)據(jù)加載器隊列等。
聯(lián)系客服