tensorflow的廣播機制

TensorFlow支持廣播機制(Broadcast)

TensorFlow支持廣播機制(Broadcast),能夠廣播元素間操做(elementwise operations)。正常狀況下,當你想要進行一些操做如加法,乘法時,你須要確保操做數的形狀是相匹配的,如:你不能將一個具備形狀[3, 2]的張量和一個具備[3,4]形狀的張量相加。可是,這裏有一個特殊狀況,那就是當你的其中一個操做數是一個具備單獨維度(singular dimension)的張量的時候,TF會隱式地在它的單獨維度方向填滿(tile),以確保和另外一個操做數的形狀相匹配。因此,對一個[3,2]的張量和一個[3,1]的張量相加在TF中是合法的。(譯者:這個機制繼承自numpy的廣播功能。其中所謂的單獨維度就是一個維度爲1,或者那個維度缺失)網絡

import tensorflow as tf


a = tf.constant([[1., 2.], [3., 4.]])
b = tf.constant([[1.], [2.]])
# c = a + tf.tile(b, [1, 2])
c = a + b

廣播機制容許咱們在隱式狀況下進行填充(tile),而這能夠使得咱們的代碼更加簡潔,而且更有效率地利用內存,由於咱們不須要另外儲存填充操做的結果。一個能夠表現這個優點的應用場景就是在結合具備不一樣長度的特徵向量的時候。爲了拼接具備不一樣長度的特徵向量,咱們通常都先填充輸入向量,拼接這個結果真後進行以後的一系列非線性操做等。這是一大類神經網絡架構的共同套路(common pattern)架構

a = tf.random_uniform([5, 3, 5])
b = tf.random_uniform([5, 1, 6])


# concat a and b and apply nonlinearity
tiled_b = tf.tile(b, [1, 3, 1])
c = tf.concat([a, tiled_b], 2)
d = tf.layers.dense(c, 10, activation=tf.nn.relu)

 

可是這個能夠經過廣播機制更有效地完成。咱們利用事實f(m(x+y))=f(mx+my)f(m(x+y))=f(mx+my),簡化咱們的填充操做。所以,咱們能夠分離地進行這個線性操做,利用廣播機制隱式地完成拼接操做。app

pa = tf.layers.dense(a, 10, activation=None)
pb = tf.layers.dense(b, 10, activation=None)
d = tf.nn.relu(pa + pb)

事實上,這個代碼足夠通用,而且能夠在具備抽象形狀(arbitrary shape)的張量間應用:dom

def merge(a, b, units, activation=tf.nn.relu):
    pa = tf.layers.dense(a, units, activation=None)
    pb = tf.layers.dense(b, units, activation=None)
    c = pa + pb
    if activation is not None:
        c = activation(c)
    return c

一個更爲通用函數形式如上所述:函數

目前爲止,咱們討論了廣播機制的優勢,可是一樣的廣播機制也有其缺點,隱式假設幾乎老是使得調試變得更加困難,考慮下面的例子:spa

a = tf.constant([[1.], [2.]])
b = tf.constant([1., 2.])
c = tf.reduce_sum(a + b)

你猜這個結果是多少?若是你說是6,那麼你就錯了,答案應該是12.這是由於當兩個張量的階數不匹配的時候,在進行元素間操做以前,TF將會自動地在更低階數的張量的第一個維度開始擴展,因此這個加法的結果將會變爲[[2, 3], [3, 4]],因此這個reduce的結果是12. 調試

(譯者:答案詳解以下,第一個張量的shape爲[2, 1],第二個張量的shape爲[2,]。由於從較低階數張量的第一個維度開始擴展,因此應該將第二個張量擴展爲shape=[2,2],也就是值爲[[1,2], [1,2]]。第一個張量將會變成shape=[2,2],其值爲[[1, 1], [2, 2]]。) code

解決這種麻煩的方法就是儘量地顯示使用。咱們在須要reduce某些張量的時候,顯式地指定維度,而後尋找這個bug就會變得簡單:orm

a = tf.constant([[1.], [2.]])
b = tf.constant([1., 2.])
c = tf.reduce_sum(a + b, 0)

這樣,c的值就是[5, 7],咱們就容易猜到其出錯的緣由。一個更通用的法則就是老是在reduce操做和在使用tf.squeeze中指定維度。 繼承

相關文章
相關標籤/搜索