tf.Variable()、tf.get_variable()和tf.placeholder()

1.tf.Variable()

tf.Variable(initializer,name)

功能:tf.Variable()建立變量時,name屬性值容許重複,檢查到相同名字的變量時,由自動別名機制建立不一樣的變量。python

參數:api

  • initializer:初始化參數;
  • name:可自定義的變量名稱

舉例:dom

import tensorflow as tf
v1=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
v2=tf.Variable(tf.constant(2),name='v2')
v3=tf.Variable(tf.ones([2,3]),name='v3')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(v1))
    print(sess.run(v2))
    print(sess.run(v3))

結果以下:函數

 

2.tf.get_variable()

tf.get_variable(
    name,
    shape=None,
    dtype=None,
    initializer=None,
    regularizer=None,
    trainable=None,
    collections=None,
    caching_device=None,
    partitioner=None,
    validate_shape=True,
    use_resource=None,
    custom_getter=None,
    constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.VariableAggregation.NONE
)

功能:tf.get_variable建立變量時,會進行變量檢查,當設置爲共享變量時(經過scope.reuse_variables()或tf.get_variable_scope().reuse_variables()),檢查到第二個擁有相同名字的變量,就返回已建立的相同的變量;若是沒有設置共享變量,則會報[ValueError: Variable varx alreadly exists, disallowed.]的錯誤。google

參數:spa

  • name:新變量或現有變量的名稱
  • shape:新變量或現有變量的形狀
  • dtype:新變量或現有變量的類型(默認爲DT_FLOAT)。
  • initializer:變量初始化的方式

初始化方式:code

  • tf.constant_initializer:常量初始化函數
  • tf.random_normal_initializer:正態分佈
  • tf.truncated_normal_initializer:截取的正態分佈
  • tf.random_uniform_initializer:均勻分佈
  • tf.zeros_initializer:所有是0
  • tf.ones_initializer:全是1
  • tf.uniform_unit_scaling_initializer:知足均勻分佈,但不影響輸出數量級的隨機值

舉例:orm

v1=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
v2=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=1),name='v1')
v3=tf.Variable(tf.ones([2,3]),name='v3')

a1 = tf.get_variable(name='a1', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
a2 = tf.get_variable(name='a2', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
a3 = tf.get_variable(name='a3', shape=[2, 3], initializer=tf.ones_initializer())

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run(v1))
    print(sess.run(v2))
    print(sess.run(v3))
    print(sess.run(a1))
    print(sess.run(a2))
    print(sess.run(a3))

v1和v2的參數徹底相同,建立時候不會報錯;a1和a2的參數徹底相同,建立時候會報錯  blog

 

3.tf.placeholder()

tf.placeholder(
    dtype,
    shape=None,
    name=None
)

功能:在tensorflow中相似於函數參數,運行時必須傳入值。ip

TensorFlow連接:https://tensorflow.google.cn/api_docs/python/tf/placeholder?hl=en

參數:

  • dtype:要進給的張量中的元素類型。經常使用的是tf.float32,tf.float64等數值類型。
  • shape:要進給的張量的形狀(可選)。若是未指定形狀,則能夠提供任何形狀的張量。默認是None,就是一維值,也能夠是多維,好比[2,3], [None, 3]表示列是3,行不定。
  • name:操做的名稱(可選)。

舉例:

input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)

output = tf.multiply(input1, input2)

with tf.Session() as sess:
    print(sess.run(output, feed_dict={input1: [23.], input2: [4.]})) # [92.]

  

 

 

參考文獻:

【1】Tensorflow——tf.Variable()、tf.get_variable()和tf.placeholder()

相關文章
相關標籤/搜索