使用TensorFlow API:tf.keras 搭建神經網絡
搭建神經網絡六步法:
1.導入第三方庫:import
2.導入並理解數據,劃分訓練集與測試集:train test
3.在Sequential()中搭建網絡結構。逐層描述每層網絡,至關於前向傳播。:model=tf.keras.models.Sequential
4.在compile中配置訓練方法。即選擇哪一種優化器,選擇哪一個損失函數,選擇哪一種評測指標。model.compile
5.在fit中進行訓練。告知訓練集和測試集的輸入特徵和標籤。每一個betch是多少,要迭代多少次數據集:model.fit
6.用model.summary打印出網絡的結構和參數。
git
函數用法介紹
1.model=tf.keras.models.Sequential
Sequential 函數是一個容器,容器裏封裝了神經網絡的網絡結構,描述了在Sequential函數的輸入參數從輸入層到輸出層的網絡結構。
如:
數組
拉直層:tf.keras.layers.Flatten()
拉直層能夠變換張量的尺寸,把輸入特徵拉直爲一維數組,是不含計算參數的層。
網絡
全鏈接層:tf.keras.layers.Dense( 神經元個數,activation=」激活函數」, kernel_regularizer=」正則化方式」)
ide
其中:
activation(字符串給出)可選 relu、softmax、sigmoid、tanh 等,kernel_regularizer 可選 tf.keras.regularizers.l1()、
tf.keras.regularizers.l2()
卷積層:tf.keras.layers.Conv2D( filter = 卷積核個數, kernel_size = 卷積核尺寸,
strides = 卷積步長,padding = 「valid」 or 「same」)
函數
LSTM 層:tf.keras.layers.LSTM()。
學習
2.Model.compile
Compile 用於配置神經網絡的訓練方法,告知訓練時使用的優化器、損失函數和準確率評測標準。測試
Model.compile( optimizer = 優化器, loss = 損失函數, metrics = [「準確率」])
(1)optimizer 能夠是字符串形式給出的優化器名字,也能夠是函數形式,使用函數形式能夠設置學習率、動量和超參數。
可選擇有:
‘sgd’or tf.optimizers.SGD( lr=學習率,decay=學習率衰減率,momentum=動量參數)
優化
‘adagrad’or tf.keras.optimizers.Adagrad(lr=學習率,decay=學習率衰減率)spa
‘adadelta’or tf.keras.optimizers.Adadelta(lr=學習率, decay=學習率衰減率)code
‘adam’or tf.keras.optimizers.Adam (lr=學習率, decay=學習率衰減率)
(2) Loss 能夠是字符串形式給出的損失函數的名字,也能夠是函數形式。
可選項包括:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
損失函數常須要通過 softmax 等函數將輸出轉化爲機率分佈的形式。from_logits 則用來標註該損失函數是否須要轉換爲機率的形式,取 False 時表示轉化爲機率分佈,取 True 時表示沒有轉化爲機率分佈,直接輸出。
(3)Metrics 標註網絡評測指標。
可選項包括:
‘accuracy’:y_和 y 都是數值,如 y_=[1] y=[1]。
‘categorical_accuracy’:y_和 y 都是以獨熱碼和機率分佈表示。
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’:y_是以數值形式給出,y 是以獨熱碼形式
給出。 如 y_=[1],y=[0.256, 0.695, 0.048]。
3.model.fit()
fit 函數用於執行訓練過程。
——model.fit(訓練集的輸入特徵, 訓練集的標籤,batch_size, epochs, validation_data = (測試集的輸入特徵,測試集的標籤), validataion_split = 從測試集劃分多少比例給訓練集, validation_freq = 測試的 epoch 間隔次數)
4.model.summary()
summary 函數用於打印網絡結構和參數統計.
上圖是 model.summary()對鳶尾花分類網絡的網絡結構和參數統計,對於輸入爲 4 輸出爲 3 的全鏈接網絡,共有 15 個參數。