Tensorflow使用Variable
類表達、更新、存儲模型參數。html
Variable
是在可變動的,具備保持性的內存句柄,存儲着Tensor
session
運行以前,圖中的所有Variable
必須被初始化Variable
的值在sess.run(init)以後就肯定了Tensor
的值要在sess.run(x)以後才肯定Variable
被添加到默認的collection
中
tf.GraphKeys
中包含了全部默認集合的名稱,能夠經過查看__dict__發現具體集合。python
被收集在名爲tf.GraphKeys.
GLOBAL_VARIABLES:global_variablestf.GraphKeys.GLOBAL_VARIABLES
的colletion
中,包含了模型中的通用參數git
tf.GraphKeys.TRAINABLE_VARIABLES:
githubtf.Optimizer默認
只優化tf.GraphKeys.TRAINABLE_VARIABLES
中的變量。
函數 | 集合名 | 意義 |
---|---|---|
tf.global_variables() | GLOBAL_VARIABLES | 存儲和讀取checkpoints時,使用其中全部變量api 跨設備全局變量集合網絡 |
tf.trainable_variables() | TRAINABLE_VARIABLES | 訓練時,更新其中全部變量session 存儲須要訓練的模型參數的變量集合函數 |
tf.moving_average_variables() | MOVING_AVERAGE_VARIABLES |
實用指數移動平均的變量集合spa |
tf.local_variables() | LOCAL_VARIABLES | 在 進程內本地變量集合 |
tf.model_variables() | MODEL_VARIABLES | Key to collect model variables defined by layers. 進程內存儲的模型參數的變量集合 |
QUEUE_RUNNERS | 並不是存儲variables,存儲處理輸入的QueueRunner | |
SUMMARIES | 並不是存儲variables,存儲日誌生成相關張量 |
除了上表中的函數外(上表中最後兩個集合並不是變量集合,爲了方便一併放在這裏),還可使用tf.get_collection(集合名)獲取集合中的變量,不過這個函數更多與tf.get_collection(集合名)搭配使用,操做自建集合。
另,slim.get_model_variables()與tf.model_variables()功能近似。
Summary
被收集在名爲tf.GraphKeys.
UMMARIES
的colletion
中,
Summary
是對網絡中Tensor
取值進行監測的一種Operation
collection
中添加一個Operation
除了默認的集合,咱們也能夠本身創造collection
組織對象。網絡損失就是一類適宜對象。
tensorflow中的Loss提供了許多建立損失Tensor
的方式。
x1 = tf.constant(1.0) l1 = tf.nn.l2_loss(x1) x2 = tf.constant([2.5, -0.3]) l2 = tf.nn.l2_loss(x2)
建立損失不會自動添加到集合中,須要手工指定一個collection
:
tf.add_to_collection("losses", l1) tf.add_to_collection("losses", l2)
建立完成後,能夠統一獲取全部損失,losses
是個Tensor
類型的list:
losses = tf.get_collection('losses')
一種常見操做把全部損失累加起來獲得一個Tensor
:
loss_total = tf.add_n(losses)
執行操做能夠獲得損失取值:
sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) losses_val = sess.run(losses) loss_total_val = sess.run(loss_total)
實際上,若是使用TF-Slim包的losses系列函數建立損失,會自動添加到名爲」losses」的collection
中。