Tensorflow中的圖(tf.Graph)和會話(tf.Session)詳解

Tensorflow中的圖(tf.Graph)和會話(tf.Session)

Tensorflow編程系統

Tensorflow工具或者說深度學習自己就是一個連貫緊密的系統。通常的系統是一個自治獨立的、能實現複雜功能的總體。系統的主要任務是對輸入進行處理,以獲得想要的輸出結果。咱們以前見過的不少系統都是線性的,就像汽車生產工廠的流水線同樣,輸入->系統處理->輸出。系統內部由不少單一的基本部件構成,這些單一部件具備特定的功能,且須要穩定的特性;系統設計者經過特殊的鏈接方式,讓這些簡單部件進行鏈接,以使它們之間能夠進行數據交流和信息互換,來達到相互配合而完成具體工做的目的。node

對於任何一個系統來講,都應該擁有穩定、獨立、能處理特殊任務的單一部件;且擁有一套良好的內部溝通機制,以讓系統能夠健康安全的運行。python

現實中的不少系統都是線性的,被設計好的、不能進行更改的,好比工廠的流水線,這樣的系統並不具有自我調整的能力,沒法對外界的環境作出反應,所以也就不具有「智能」。編程

深度學習(神經網絡)之因此具有智能,就是由於它具備反饋機制。深度學習具備一套對輸出所作的評價函數(損失函數),損失函數在對神經網絡作出評價後,會經過某種方式(梯度降低法)更新網絡的組成參數,以指望系統獲得更好的輸出數據。安全

因而可知,神經網絡的系統主要由如下幾個方面組成:網絡

  • 輸入
  • 系統自己(神經網絡結構),以及涉及到系統自己構建的問題:如網絡構建方式、網絡執行方式、變量維護、模型存儲和恢復等等問題
  • 損失函數
  • 反饋方式:訓練方式

定義好以上的組成部分,咱們就能夠用流程化的方式將其組合起來,讓系統對輸入進行學習,調整參數。由於該系統的反饋機制,因此,組成的方式確定須要循環。session

而對於Tensorflow來講,其設計理念確定離不開神經網絡自己。因此,學習Tensorflow以前,對神經網絡有一個總體、深入的理解也是必須的。以下圖:Tensorflow的執行示意。ide

那麼對於以上所列的幾點,什麼纔是最重要的呢?我想確定是有關係統自己所涉及到的問題。即如何構建、執行一個神經網絡? 在Tensorflow中,用計算圖來構建網絡,用會話來具體執行網絡。深刻理解了這兩點,我想,對於Tensorflow的設計思路,以及運行機制,也就略知一二了。函數

  • 圖(tf.Graph):計算圖,主要用於構建網絡,自己不進行任何實際的計算。計算圖的設計啓發是高等數學裏面的鏈式求導法則的圖。咱們能夠將計算圖理解爲是一個計算模板或者計劃書。

     

  • 會話(tf.session):會話,主要用於執行網絡。全部關於神經網絡的計算都在這裏進行,它執行的依據是計算圖或者計算圖的一部分,同時,會話也會負責分配計算資源和變量存放,以及維護執行過程當中的變量。工具

接下來,咱們主要從計算圖開始,看一看Tensorflow是如何構建、執行網絡的。oop

計算圖

在開始以前,咱們先複習一下Tensorflow的幾種基本數據類型:

tf.constant(value, dtype=None, shape=None, name='Const', verify_shape=False) tf.Variable(initializer, name) tf.placeholder(dtype, shape=None, name=None) 

複習完畢。

graph = tf.Graph() with graph.as_default(): img = tf.constant(1.0, shape=[1,5,5,3]) 

以上代碼中定義了一個計算圖,在該計算圖中定義了一個常量。Tensorflow默認會建立一張計算圖。因此上面代碼中的前兩行,能夠省略。默認狀況下,計算圖是空的。

在執行完img = tf.constant(1.0, shape=[1,5,5,3])之後,計算圖中生成了一個node,一個node結點由name, op, input, attrs組成,即結點名稱、操做、輸入以及一系列的屬性(類型、形狀、值)等組成,計算圖就是由這樣一個個的node組成的。對於tf.constant()函數,只會生成一個node,但對於有的函數,如tf.Variable(initializer, name)(注意其第一個參數是初始化器)就會生成多個node結點(後面會講到)。
那麼執行完img = tf.constant(1.0, shape=[1,5,5,3])後,計算圖中就多一個node結點。(由於每一個node的屬性不少,我只表示name,op,input屬性)

繼續添加代碼:

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) 

代碼執行後的計算圖以下:

須要注意的是,若是沒有對結點進行命名,Tensorflow自動會將其命名爲:Const、Const_一、const_2......。其餘類型的結點類同。

如今,咱們添加一個變量:

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) 

該變量用一個常量做爲初始化器。咱們先看一下計算圖:

如圖所示:
執行完tf.Variable()函數後,一共產生了三個結點:

  • Variable:變量維護(不存放實際的值)
  • Variable/Assign:變量分配
  • Variable/read:變量使用

圖中只是完成了操做的定義,但並無執行操做(如Variable/Assign結點的Assign操做,因此,此時候變量依然不可使用,這就是爲何要在會話中初始化的緣由)。

咱們繼續添加代碼:

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") 

獲得的計算圖以下:

能夠看出,變量讀取是經過Variable/read來進行的。

若是在這裏咱們直接開啓會話,並執行計算圖中的卷積操做,系統就會報錯。

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") with tf.Session() as sess: sess.run(y2) 

這段代碼錯誤的緣由在於,變量並無初始化就被使用,而從圖中清晰的能夠看到,直接執行卷積,是回溯不到變量的值(Const_1)的(箭頭方向)。

因此,在執行以前,要進行初始化,代碼以下:

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") init = tf.global_variables_initializer() 

執行完tf.global_variables_initializer()函數之後,計算圖以下:

tf.global_variables_initializer()產生了一個名爲init的node,該結點將全部的Variable/Assign結點做爲輸入,以達到對整張計算圖中的變量進行初始化。
因此,在開啓會話後,執行的第一步操做,就是變量初始化(固然變量初始化的方式有不少種,咱們也能夠顯示調用tf.assign()來完成對單個結點的初始化)。
完整代碼以下:

img = tf.constant(1.0, shape=[1,5,5,3]) k = tf.constant(1.0, shape=[3,3,3,1]) kernel = tf.Variable(k) y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME") init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) # do someting.... 

會話

在上述代碼中,我已經使用會話(tf.session())來執行計算圖了。在tf.session()中,咱們重點掌握無所不能的sess.run()

一個session()中包含了Operation被執行,以及Tensor被evaluated的環境。

tf.Session().run()函數的定義:

run( fetches, feed_dict=None, options=None, run_metadata=None ) 

tf.Session().run()函數的功能爲:執行fetches參數所提供的operation操做或計算其所提供的Tensor

run()函數每執行一步,都會執行與fetches有關的圖中的全部結點的計算,以完成fetches中的任務。其中,feed_dict提供了部分數據輸入的功能。(和tf.Placeholder()搭配使用,很舒服)

參數說明:

  • fetches:能夠是圖中的一個結點,也能夠是一個List或者字典,此時候返回值與fetches格式一致;該參數還能夠是一個Operation,此時候返回值爲None
  • feed_dict:字典格式。給模型輸入其計算過程當中所須要的值。

當咱們把模型的計算圖構建好之後,就能夠利用會話來進行執行訓練了。

在明白了計算圖是如何構建的,以及如何被會話正確的執行之後,咱們就能夠愉快的開始Tensorflow之旅啦。

參考連接

相關文章
相關標籤/搜索