『TensorFlow』使用集合collection控制variables

Variable

Tensorflow使用Variable類表達、更新、存儲模型參數。html

  • Variable是在可變動的,具備保持性的內存句柄,存儲着Tensor
  • 在整個session運行以前,圖中的所有Variable必須被初始化
    • Variable的值在sess.run(init)以後就肯定了
    • Tensor的值要在sess.run(x)以後才肯定
  • 建立的Variable被添加到默認的collection

 

tf.GraphKeys中包含了全部默認集合的名稱,能夠經過查看__dict__發現具體集合。python

tf.GraphKeys.GLOBAL_VARIABLES:global_variables被收集在名爲tf.GraphKeys.GLOBAL_VARIABLEScolletion中,包含了模型中的通用參數git

tf.GraphKeys.TRAINABLE_VARIABLES:tf.Optimizer默認只優化tf.GraphKeys.TRAINABLE_VARIABLES中的變量。github

函數 集合名 意義
tf.global_variables() GLOBAL_VARIABLES

存儲和讀取checkpoints時,使用其中全部變量api

跨設備全局變量集合網絡

tf.trainable_variables() TRAINABLE_VARIABLES

訓練時,更新其中全部變量session

存儲須要訓練的模型參數的變量集合函數

tf.moving_average_variables() MOVING_AVERAGE_VARIABLES

ExponentialMovingAverage對象會生成此類變量優化

實用指數移動平均的變量集合spa

tf.local_variables() LOCAL_VARIABLES

global_variables()以外,須要用tf.init_local_variables()初始化

進程內本地變量集合

tf.model_variables() MODEL_VARIABLES

 Key to collect model variables defined by layers.

進程內存儲的模型參數的變量集合

  QUEUE_RUNNERS 並不是存儲variables,存儲處理輸入的QueueRunner
  SUMMARIES 並不是存儲variables,存儲日誌生成相關張量

除了上表中的函數外(上表中最後兩個集合並不是變量集合,爲了方便一併放在這裏),還可使用tf.get_collection(集合名)獲取集合中的變量,不過這個函數更多與tf.get_collection(集合名)搭配使用,操做自建集合。

另,slim.get_model_variables()與tf.model_variables()功能近似。

 

Summary

Summary被收集在名爲tf.GraphKeys.UMMARIEScolletion中,

  • Summary是對網絡中Tensor取值進行監測的一種Operation
  • 這些操做在圖中是「外圍」操做,不影響數據流自己
  • 調用tf.scalar_summary系列函數時,就會向默認的collection中添加一個Operation

 

自定義集合

除了默認的集合,咱們也能夠本身創造collection組織對象。網絡損失就是一類適宜對象。

tensorflow中的Loss提供了許多建立損失Tensor的方式。

x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)

x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)

建立損失不會自動添加到集合中,須要手工指定一個collection

tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)

建立完成後,能夠統一獲取全部損失,losses是個Tensor類型的list:

losses = tf.get_collection('losses')

一種常見操做把全部損失累加起來獲得一個Tensor

loss_total = tf.add_n(losses)

 執行操做能夠獲得損失取值:

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)

 實際上,若是使用TF-Slim包的losses系列函數建立損失,會自動添加到名爲」losses」的collection中。

相關文章
相關標籤/搜索