使用TensorFlow 搭建神經網絡的六步法

使用TensorFlow API:tf.keras 搭建神經網絡

搭建神經網絡六步法:

1.導入第三方庫:import
2.導入並理解數據,劃分訓練集與測試集:train test
3.在Sequential()中搭建網絡結構。逐層描述每層網絡,至關於前向傳播。:model=tf.keras.models.Sequential
4.在compile中配置訓練方法。即選擇哪一種優化器,選擇哪一個損失函數,選擇哪一種評測指標。model.compile
5.在fit中進行訓練。告知訓練集和測試集的輸入特徵和標籤。每一個betch是多少,要迭代多少次數據集:model.fit
6.用model.summary打印出網絡的結構和參數。




git

函數用法介紹

1.model=tf.keras.models.Sequential

Sequential 函數是一個容器,容器裏封裝了神經網絡的網絡結構,描述了在Sequential函數的輸入參數從輸入層到輸出層的網絡結構。
如:
數組

拉直層:tf.keras.layers.Flatten()
拉直層能夠變換張量的尺寸,把輸入特徵拉直爲一維數組,是不含計算參數的層。
網絡

全鏈接層:tf.keras.layers.Dense( 神經元個數,activation=」激活函數」, kernel_regularizer=」正則化方式」)ide

其中:
activation(字符串給出)可選 relu、softmax、sigmoid、tanh 等,kernel_regularizer 可選 tf.keras.regularizers.l1()、
tf.keras.regularizers.l2()
卷積層:tf.keras.layers.Conv2D( filter = 卷積核個數, kernel_size = 卷積核尺寸,
strides = 卷積步長,padding = 「valid」 or 「same」)



函數

LSTM 層:tf.keras.layers.LSTM()。學習

2.Model.compile

Compile 用於配置神經網絡的訓練方法,告知訓練時使用的優化器、損失函數和準確率評測標準。測試

Model.compile( optimizer = 優化器, loss = 損失函數, metrics = [「準確率」])
(1)optimizer 能夠是字符串形式給出的優化器名字,也能夠是函數形式,使用函數形式能夠設置學習率、動量和超參數。
可選擇有:
‘sgd’or tf.optimizers.SGD( lr=學習率,decay=學習率衰減率,momentum=動量參數)


優化

‘adagrad’or tf.keras.optimizers.Adagrad(lr=學習率,decay=學習率衰減率)spa

‘adadelta’or tf.keras.optimizers.Adadelta(lr=學習率, decay=學習率衰減率)code

‘adam’or tf.keras.optimizers.Adam (lr=學習率, decay=學習率衰減率)
(2) Loss 能夠是字符串形式給出的損失函數的名字,也能夠是函數形式。
可選項包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)



損失函數常須要通過 softmax 等函數將輸出轉化爲機率分佈的形式。from_logits 則用來標註該損失函數是否須要轉換爲機率的形式,取 False 時表示轉化爲機率分佈,取 True 時表示沒有轉化爲機率分佈,直接輸出。
(3)Metrics 標註網絡評測指標。
可選項包括:
‘accuracy’:y_和 y 都是數值,如 y_=[1] y=[1]。
‘categorical_accuracy’:y_和 y 都是以獨熱碼和機率分佈表示。
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’:y_是以數值形式給出,y 是以獨熱碼形式
給出。 如 y_=[1],y=[0.256, 0.695, 0.048]。






3.model.fit()

fit 函數用於執行訓練過程。
——model.fit(訓練集的輸入特徵, 訓練集的標籤,batch_size, epochs, validation_data = (測試集的輸入特徵,測試集的標籤), validataion_split = 從測試集劃分多少比例給訓練集, validation_freq = 測試的 epoch 間隔次數)

4.model.summary()

summary 函數用於打印網絡結構和參數統計.
在這裏插入圖片描述上圖是 model.summary()對鳶尾花分類網絡的網絡結構和參數統計,對於輸入爲 4 輸出爲 3 的全鏈接網絡,共有 15 個參數。

相關文章
相關標籤/搜索