tensorflow 計算均值和方差

咱們在處理矩陣數據時,須要用到數據的均值和方差,好比在batch normalization的時候。dom

那麼,tensorflow中計算均值和方差的函數是:tf.nn.moments(x, axes)函數

x: 咱們待處理的數據spa

axes: 在哪個維度上求解,是一個list,如axes=[0, 1, 2]code

舉例:orm

 1 def calc_mean_variance():  2     """
 3  計算均值和方差  4  :return:  5     """
 6     img = tf.Variable(tf.random_normal([2, 3]))  7     t = len(img.get_shape())  8     axis = list(range(len(img.get_shape()) - 1))  9     mean, variance = tf.nn.moments(img, axes=0) 10  with tf.Session() as sess: 11  sess.run(tf.global_variables_initializer()) 12         print(sess.run(img)) 13         print(sess.run([mean, variance]))

輸出:blog

 

注意,如下是統計軸的個數:get

axis = list(range(len(img.get_shape()) - 1))
相關文章
相關標籤/搜索