1 入門html
2 多個輸入和輸出node
3 共享層this
考慮這樣的一個問題:咱們要判斷連個tweet是否來源於同一我的。spa
首先咱們對兩個tweet進行處理,而後將處理的結構拼接在一塊兒,以後跟一個邏輯迴歸,輸出這兩條tweet來自同一我的機率。翻譯
由於咱們對兩條tweet的處理是相同的,因此對第一條tweet的處理的模型,能夠被重用來處理第二個tweet。咱們考慮用LSTM進行處理。code
假設咱們的輸入是兩條 280*256的向量htm
首先定義輸入:blog
import keras from keras.layers import Input, LSTM, Dense from keras.models import Model tweet_a = Input(shape=(280, 256)) tweet_b = Input(shape=(280, 256))
而後咱們共享LSTM。共享層很簡單,只要實例化層一次,而後在你想處理的tensor上調用你想要應用的次數便可(翻譯無力,看代碼)索引
# This layer can take as input a matrix # and will return a vector of size 64 shared_lstm = LSTM(64) # When we reuse the same layer instance # multiple times, the weights of the layer # are also being reused # (it is effectively *the same* layer) encoded_a = shared_lstm(tweet_a) encoded_b = shared_lstm(tweet_b) # We can then concatenate the two vectors: merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1) # And add a logistic regression on top predictions = Dense(1, activation='sigmoid')(merged_vector) # We define a trainable model linking the # tweet inputs to the predictions model = Model(inputs=[tweet_a, tweet_b], outputs=predictions) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy']) model.fit([data_a, data_b], labels, epochs=10)
其實,簡單點說,對一個層的屢次調用,就是在共享這個層。這裏有一個層的節點的概念ip
當你在一個輸入tensor上調用一個層時,就會生成一個輸出tensor,就會在這個層上添加一個節點,這個節點鏈接着這兩個tensor(輸入tensor和輸出tensor)。當你屢次調用同一個層的時,
這個層生成的節點就會按照0 ,1, 2, 。。以此類推編號。
那麼當一個層有多個節點的時候,咱們怎麼獲取它的輸出呢?
若是直接經過output獲取會出錯:
a = Input(shape=(280, 256)) b = Input(shape=(280, 256)) lstm = LSTM(32) encoded_a = lstm(a) encoded_b = lstm(b) lstm.output
>> AttributeError: Layer lstm_1 has multiple inbound nodes, hence the notion of "layer output" is ill-defined. Use `get_output_at(node_index)` instead.
這時候應該經過索引進行調用:
assert lstm.get_output_at(0) == encoded_a assert lstm.get_output_at(1) == encoded_b
對於輸入,也是一樣的
a = Input(shape=(32, 32, 3)) b = Input(shape=(64, 64, 3)) conv = Conv2D(16, (3, 3), padding='same') conved_a = conv(a) # Only one input so far, the following will work: assert conv.input_shape == (None, 32, 32, 3) conved_b = conv(b) # now the `.input_shape` property wouldn't work, but this does: assert conv.get_input_shape_at(0) == (None, 32, 32, 3) assert conv.get_input_shape_at(1) == (None, 64, 64, 3)