本文將介紹LSTM模型在實現整數加法方面的應用。
咱們以0-255之間的整數加法爲例,生成的結果在0到510之間。爲了能利用深度學習模型模擬整數的加法運算,咱們須要將輸入的兩個加數和輸出的結果用二進制表示,這樣就能獲得向量,如加數在0-255內,能夠用8位0-1向量來表示,前面的空位用0填充;結果在0-510內,能夠用9位0-1向量來表示,前面的空位用0填充。由於兩個加數均在0-255內變化,因此共有256*256=65536個輸入向量以及65536個輸出向量,輸入向量爲兩個加數的二進制向量的拼接結果,於是是個16爲的輸入向量。用如下的Python代碼能夠模擬以上過程:python
import numpy as np # 最多8位二進制 BINARY_DIM = 8 # 將整數表示成爲binary_dim位的二進制數,高位用0補齊 def int_2_binary(number, binary_dim): binary_list = list(map(lambda x: int(x), bin(number)[2:])) number_dim = len(binary_list) result_list = [0]*(binary_dim-number_dim)+binary_list return result_list # 將一個二進制數組轉爲整數 def binary2int(binary_array): out = 0 for index, x in enumerate(reversed(binary_array)): out += x * pow(2, index) return out # 將[0,2**BINARY_DIM)全部數表示成二進制 binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)]) # print(binary) # 樣本的輸入向量和輸出向量 dataX = [] dataY = [] for i in range(binary.shape[0]): for j in range(binary.shape[0]): dataX.append(np.append(binary[i], binary[j])) dataY.append(int_2_binary(i+j, BINARY_DIM+1)) # print(dataX) # print(dataY) # 從新特徵X和目標變量Y數組,適應LSTM模型的輸入和輸出 X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1)) # print(X.shape) Y = np.array(dataY) # print(dataY.shape)
在以上代碼中,獲得的dataX和dataY以知足要求,但爲了能讓LSTM模型處理,須要改變這兩個數據集的形狀。
咱們採用LSTM模型來訓練上述數據,LSTM模型的結構很簡單,就是簡單的一層LSTM層,而後加上Dropout層,最後是全鏈接層,激活函數採用sigmoid函數,採用的損失函數爲平均平方偏差。整個結構的示意圖以下:web
模型訓練的代碼以下:算法
from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from keras import losses from keras.utils import plot_model # 定義LSTM模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]))) model.add(Dropout(0.2)) model.add(Dense(Y.shape[1], activation='sigmoid')) model.compile(loss=losses.mean_squared_error, optimizer='adam') # print(model.summary()) # plot model plot_model(model, to_file=r'./model.png', show_shapes=True) # train model epochs = 100 model.fit(X, Y, epochs=epochs, batch_size=128) # save model mp = r'./LSTM_Operation.h5' model.save(mp)
該LSTM模型每批訓練128個樣本,共訓練100次,採用Adam優化器減小損失值。
對這個模型進行訓練,訓練100次,損失值爲0.0045。接下來咱們就要用這個訓練好的模型來預測。咱們預測的方法爲,雖然挑兩個在0-255內的加數,轉化爲二進制向量做爲輸入向量,而後由LSTM模型輸出結果,將該結果取整做爲輸出向量中的元素,最後將這個輸出向量轉化爲整數,就是預測的兩個加數的和。模型預測的代碼以下:數組
# use LSTM model to predict for _ in range(100): start = np.random.randint(0, len(dataX)-1) # print(dataX[start]) number1 = dataX[start][0:BINARY_DIM] number2 = dataX[start][BINARY_DIM:] print('='*30) print('%s: %s'%(number1, binary2int(number1))) print('%s: %s'%(number2, binary2int(number2))) sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1)) predict = np.round(model.predict(sample), 0).astype(np.int32)[0] print('%s: %s'%(predict, binary2int(predict)))
預測的100組樣本的輸出結果以下:微信
============================== [1 0 0 1 1 1 0 1]: 157 [0 1 1 1 0 0 0 1]: 113 [1 0 0 0 0 1 1 1 0]: 270 ============================== [1 1 1 0 1 0 1 0]: 234 [0 1 0 0 1 1 0 0]: 76 [1 0 0 1 1 0 1 1 0]: 310 ============================== [1 1 0 0 0 1 0 0]: 196 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 1 1 1 1 1]: 415 ============================== [0 0 1 1 1 0 1 0]: 58 [0 0 1 0 0 0 1 1]: 35 [0 0 1 0 1 1 1 0 1]: 93 ============================== [1 0 0 0 0 0 0 0]: 128 [0 1 1 1 1 0 0 1]: 121 [0 1 1 1 1 1 0 0 1]: 249 ============================== [1 1 1 1 0 1 1 0]: 246 [1 1 0 1 0 1 0 1]: 213 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 0 1 1 0]: 230 [1 0 0 0 0 0 0 0]: 128 [1 0 1 1 0 0 1 1 0]: 358 ============================== [1 0 1 0 0 0 1 1]: 163 [0 1 1 0 0 1 0 1]: 101 [1 0 0 0 0 1 0 0 0]: 264 ============================== [1 0 1 0 0 1 1 0]: 166 [0 1 0 1 0 0 0 0]: 80 [0 1 1 1 1 0 1 1 0]: 246 ============================== [0 0 0 0 1 0 1 1]: 11 [0 1 0 0 0 1 0 1]: 69 [0 0 1 0 1 0 0 0 0]: 80 ============================== [1 1 1 1 0 1 1 1]: 247 [0 1 1 1 0 0 0 0]: 112 [1 0 1 1 0 0 1 1 1]: 359 ============================== [1 0 1 0 1 0 0 1]: 169 [1 1 0 0 0 0 0 0]: 192 [1 0 1 1 0 1 0 0 1]: 361 ============================== [1 0 1 1 0 0 0 1]: 177 [1 0 0 0 1 0 1 1]: 139 [1 0 0 1 1 1 1 0 0]: 316 ============================== [0 1 0 0 0 1 1 0]: 70 [0 0 1 0 1 1 1 0]: 46 [0 0 1 1 1 0 1 0 0]: 116 ============================== [1 0 0 1 1 0 1 1]: 155 [1 1 0 0 0 0 0 1]: 193 [1 0 1 0 1 1 1 0 0]: 348 ============================== [1 0 1 1 0 0 1 0]: 178 [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 0 0 0 0 1]: 321 ============================== [0 1 0 1 1 1 1 1]: 95 [1 1 1 0 0 1 0 0]: 228 [1 0 1 0 0 0 0 1 1]: 323 ============================== [1 0 0 1 1 1 1 0]: 158 [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 1 1]: 183 ============================== [1 1 1 0 1 0 1 1]: 235 [1 1 0 0 0 0 0 1]: 193 [1 1 0 1 0 1 1 0 0]: 428 ============================== [0 1 0 1 1 1 0 1]: 93 [0 1 1 1 0 1 1 0]: 118 [0 1 1 0 1 0 0 1 1]: 211 ============================== [1 1 1 1 1 1 1 1]: 255 [1 1 1 1 1 1 1 0]: 254 [1 1 1 1 1 1 1 0 1]: 509 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 0 1 1 1 1 0]: 94 [0 1 0 1 1 0 1 1 1]: 183 ============================== [0 1 1 1 0 0 0 0]: 112 [0 0 1 1 0 1 0 0]: 52 [0 1 0 1 0 0 1 0 0]: 164 ============================== [1 0 0 0 0 0 0 0]: 128 [1 1 0 1 1 0 1 0]: 218 [1 0 1 0 1 1 0 1 0]: 346 ============================== [0 0 1 1 0 1 0 1]: 53 [1 0 1 1 1 1 1 0]: 190 [0 1 1 1 1 0 0 1 1]: 243 ============================== [0 1 1 1 1 0 0 0]: 120 [1 1 0 1 0 1 0 1]: 213 [1 0 1 0 0 1 1 0 1]: 333 ============================== [0 1 1 1 1 0 1 1]: 123 [1 1 1 0 1 1 0 1]: 237 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 1 0 1 0 0 1]: 105 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 1 1 0 0 1]: 25 [0 1 0 1 1 0 1 0]: 90 [0 0 1 1 1 0 0 1 1]: 115 ============================== [1 1 1 1 0 0 0 1]: 241 [0 0 0 1 1 1 1 1]: 31 [1 0 0 0 1 0 0 0 0]: 272 ============================== [0 1 0 0 0 1 1 0]: 70 [1 1 1 0 1 0 0 1]: 233 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 0 1 0 0]: 116 [1 0 0 1 0 0 0 0 1]: 289 ============================== [0 1 0 0 1 0 0 0]: 72 [1 1 1 1 1 0 1 0]: 250 [1 0 1 0 0 0 0 1 0]: 322 ============================== [1 1 1 1 0 0 0 0]: 240 [0 1 0 0 0 0 1 0]: 66 [1 0 0 1 1 0 0 1 0]: 306 ============================== [0 1 0 0 0 1 1 1]: 71 [1 0 0 1 0 1 1 0]: 150 [0 1 1 0 1 1 1 0 1]: 221 ============================== [0 1 1 0 1 1 0 1]: 109 [0 0 1 0 0 1 0 1]: 37 [0 1 0 0 1 0 0 1 0]: 146 ============================== [1 1 0 0 0 0 0 0]: 192 [1 1 1 0 0 0 0 1]: 225 [1 1 0 1 0 0 0 0 1]: 417 ============================== [1 0 0 0 0 0 1 1]: 131 [1 1 0 1 1 1 1 0]: 222 [1 0 1 1 0 0 0 0 1]: 353 ============================== [0 0 0 0 0 1 0 0]: 4 [1 1 1 0 0 0 1 0]: 226 [0 1 1 1 0 0 1 1 0]: 230 ============================== [1 1 1 0 1 1 1 1]: 239 [1 1 0 1 1 0 1 1]: 219 [1 1 1 0 0 1 0 1 0]: 458 ============================== [0 0 1 1 0 1 0 1]: 53 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 0 1 1 1]: 295 ============================== [1 0 0 1 0 0 0 1]: 145 [0 1 0 0 0 1 0 0]: 68 [0 1 1 0 1 0 1 0 1]: 213 ============================== [0 0 1 1 0 0 0 0]: 48 [1 0 1 1 0 1 1 1]: 183 [0 1 1 1 0 0 1 1 1]: 231 ============================== [0 1 1 0 0 1 1 1]: 103 [0 0 0 1 1 1 1 0]: 30 [0 1 0 0 0 0 1 0 1]: 133 ============================== [0 1 0 1 1 1 0 1]: 93 [1 1 0 1 0 0 1 0]: 210 [1 0 0 1 0 1 1 1 1]: 303 ============================== [1 0 0 0 1 0 1 0]: 138 [0 1 1 1 1 0 0 1]: 121 [1 0 0 0 0 0 0 1 1]: 259 ============================== [0 0 0 0 0 0 1 1]: 3 [0 0 1 1 0 0 0 1]: 49 [0 0 0 1 1 0 1 0 0]: 52 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 1 0 0 0 0]: 16 [0 1 0 0 1 0 0 1 0]: 146 ============================== [0 0 0 1 0 0 0 0]: 16 [1 0 0 1 0 0 1 0]: 146 [0 1 0 1 0 0 0 1 0]: 162 ============================== [0 1 0 1 0 1 0 0]: 84 [0 0 0 0 1 1 0 0]: 12 [0 0 1 1 0 0 0 0 0]: 96 ============================== [1 0 1 0 1 0 1 1]: 171 [1 1 0 1 1 0 1 1]: 219 [1 1 0 0 0 0 1 1 0]: 390 ============================== [1 1 1 1 1 1 1 0]: 254 [0 1 1 0 1 0 1 0]: 106 [1 0 1 1 0 1 0 0 0]: 360 ============================== [1 0 0 0 0 0 1 0]: 130 [0 0 0 0 1 1 1 0]: 14 [0 1 0 0 1 0 0 0 0]: 144 ============================== [1 0 1 0 0 1 0 1]: 165 [0 0 1 1 1 0 1 1]: 59 [0 1 1 1 0 0 0 0 0]: 224 ============================== [0 0 1 1 1 0 1 0]: 58 [1 1 1 1 0 0 1 0]: 242 [1 0 0 1 0 1 1 0 0]: 300 ============================== [0 1 0 0 1 1 0 1]: 77 [0 0 0 1 1 1 1 1]: 31 [0 0 1 1 0 1 1 0 0]: 108 ============================== [1 0 0 1 1 0 1 0]: 154 [0 1 0 1 0 1 0 1]: 85 [0 1 1 1 0 1 1 1 1]: 239 ============================== [0 1 1 0 1 1 0 1]: 109 [0 1 1 0 1 0 0 1]: 105 [0 1 1 0 1 0 1 1 0]: 214 ============================== [0 1 1 1 1 1 1 1]: 127 [0 1 1 1 0 0 1 0]: 114 [0 1 1 1 1 0 0 0 1]: 241 ============================== [0 1 1 0 0 1 0 1]: 101 [0 1 0 1 0 0 0 0]: 80 [0 1 0 1 1 0 1 0 1]: 181 ============================== [0 1 1 0 1 1 1 0]: 110 [0 1 0 1 0 1 1 0]: 86 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 0 0 1 0 0 1 1]: 19 [1 0 0 1 0 0 0 0]: 144 [0 1 0 1 0 0 0 1 1]: 163 ============================== [1 1 1 1 0 1 0 0]: 244 [1 1 0 1 0 0 1 1]: 211 [1 1 1 0 0 0 1 1 1]: 455 ============================== [0 0 0 0 1 1 1 0]: 14 [1 0 1 1 0 0 1 0]: 178 [0 1 1 0 0 0 0 0 0]: 192 ============================== [0 1 1 0 0 0 0 0]: 96 [1 0 0 1 1 1 0 0]: 156 [0 1 1 1 1 1 1 0 0]: 252 ============================== [0 0 1 1 0 1 0 0]: 52 [0 1 1 1 1 1 0 1]: 125 [0 1 0 1 1 0 0 0 1]: 177 ============================== [0 0 0 0 1 1 0 0]: 12 [0 1 0 1 1 1 0 1]: 93 [0 0 1 1 0 1 0 0 1]: 105 ============================== [0 1 1 0 0 1 0 1]: 101 [1 1 0 1 0 1 0 0]: 212 [1 0 0 1 1 1 0 0 1]: 313 ============================== [1 1 0 0 0 0 0 1]: 193 [1 1 0 0 1 1 0 1]: 205 [1 1 0 0 0 1 1 1 0]: 398 ============================== [0 1 1 1 0 0 1 0]: 114 [0 0 0 0 0 0 0 0]: 0 [0 0 1 1 1 0 0 1 0]: 114 ============================== [1 0 0 0 1 1 1 0]: 142 [1 0 1 1 1 1 0 1]: 189 [1 0 1 0 0 1 0 1 1]: 331 ============================== [1 0 1 1 0 1 1 1]: 183 [0 1 0 1 0 1 1 0]: 86 [1 0 0 0 0 1 1 0 1]: 269 ============================== [1 0 1 0 0 0 1 1]: 163 [1 1 1 0 0 1 0 1]: 229 [1 1 0 0 0 1 0 0 0]: 392 ============================== [0 0 1 1 0 0 0 1]: 49 [1 1 1 0 0 1 1 1]: 231 [1 0 0 0 1 1 0 0 0]: 280 ============================== [1 0 0 0 1 1 1 1]: 143 [1 0 1 0 1 0 0 0]: 168 [1 0 0 1 1 0 1 1 1]: 311 ============================== [0 1 0 0 0 0 0 0]: 64 [0 0 0 0 0 1 0 1]: 5 [0 0 1 0 0 0 1 0 1]: 69 ============================== [1 1 1 1 1 0 1 1]: 251 [1 0 1 1 1 0 0 1]: 185 [1 1 0 1 1 0 1 0 0]: 436 ============================== [1 1 1 0 1 1 1 0]: 238 [1 1 0 0 0 0 1 0]: 194 [1 1 0 1 1 0 0 0 0]: 432 ============================== [0 0 1 1 1 1 0 0]: 60 [0 0 0 1 0 1 1 1]: 23 [0 0 1 0 1 0 0 1 1]: 83 ============================== [0 1 1 1 0 1 0 0]: 116 [1 1 1 1 1 1 0 0]: 252 [1 0 1 1 1 0 0 0 0]: 368 ============================== [1 1 0 1 0 1 1 0]: 214 [1 1 1 1 0 1 0 0]: 244 [1 1 1 0 0 1 0 1 0]: 458 ============================== [1 1 1 1 1 1 1 0]: 254 [1 1 0 1 0 0 0 1]: 209 [1 1 1 0 0 1 1 1 1]: 463 ============================== [0 0 0 0 0 0 1 0]: 2 [0 0 0 0 1 1 0 1]: 13 [0 0 0 0 0 1 1 1 1]: 15 ============================== [0 1 1 0 0 1 1 1]: 103 [1 0 1 1 1 1 1 0]: 190 [1 0 0 1 0 0 1 0 1]: 293 ============================== [1 1 1 1 0 1 1 0]: 246 [0 1 0 1 0 0 1 0]: 82 [1 0 1 0 0 1 0 0 0]: 328 ============================== [0 1 1 1 0 0 1 1]: 115 [0 0 1 1 1 0 1 1]: 59 [0 1 0 1 0 1 1 1 0]: 174 ============================== [0 1 0 1 1 0 0 1]: 89 [0 1 1 0 1 0 1 1]: 107 [0 1 1 0 0 0 1 0 0]: 196 ============================== [0 1 0 0 0 1 0 0]: 68 [0 0 1 1 1 0 0 0]: 56 [0 0 1 1 1 1 1 0 0]: 124 ============================== [1 1 0 0 1 0 0 0]: 200 [1 0 1 0 0 0 1 0]: 162 [1 0 1 1 0 1 0 1 0]: 362 ============================== [1 1 1 1 0 0 1 1]: 243 [0 1 1 0 0 0 1 1]: 99 [1 0 1 0 1 0 1 1 0]: 342 ============================== [0 0 1 0 1 0 0 1]: 41 [0 1 0 0 1 0 0 1]: 73 [0 0 1 1 1 0 0 1 0]: 114 ============================== [0 0 0 1 1 1 0 1]: 29 [1 0 1 0 1 1 1 0]: 174 [0 1 1 0 0 1 0 1 1]: 203 ============================== [0 0 0 0 1 1 1 1]: 15 [0 0 1 1 1 1 0 1]: 61 [0 0 1 0 0 1 1 0 0]: 76 ============================== [1 1 1 1 1 0 1 1]: 251 [1 1 0 1 0 0 0 0]: 208 [1 1 1 0 0 1 0 1 1]: 459 ============================== [1 1 1 0 1 0 0 0]: 232 [0 1 1 0 0 0 1 0]: 98 [1 0 1 0 0 1 0 1 0]: 330 ============================== [1 0 1 1 0 1 0 0]: 180 [0 1 0 1 0 1 1 1]: 87 [1 0 0 0 0 1 0 1 1]: 267 ============================== [1 0 0 0 0 1 1 0]: 134 [1 0 0 1 0 1 0 1]: 149 [1 0 0 0 1 1 0 1 1]: 283 ============================== [1 0 1 0 1 1 0 1]: 173 [0 1 1 1 1 1 0 0]: 124 [1 0 0 1 0 1 0 0 1]: 297 ============================== [0 1 0 0 1 0 0 0]: 72 [0 1 1 0 0 0 1 1]: 99 [0 1 0 1 0 1 0 1 1]: 171 ============================== [1 1 0 1 0 1 0 1]: 213 [0 0 0 1 1 1 1 0]: 30 [0 1 1 1 1 0 0 1 1]: 243
能夠看到,這個簡單的LSTM模型的預測的結果所有正確。所以,這就能夠用來模擬0-255內的整數的加法運算,是否是很神奇呢?
若是須要想將加數的範圍擴大,只須要改變代碼中的BINARY_DIM變量便可。可是,加數的範圍越大,樣本就越大,如2^10=1024內的加法,就會有1024*1024=1048576個樣本,這樣大的樣本量的無疑須要更多的訓練時間。
本文到此結束,感謝閱讀~若是不當之處,請速聯繫筆者,歡迎你們交流~祝您好運~app
注意:本人現已開通微信公衆號: Python爬蟲與算法(微信號爲:easy_web_scrape), 歡迎你們關注哦~~dom
完整的Python代碼以下:函數
import numpy as np from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import LSTM from keras import losses from keras.utils import plot_model # 最多8位二進制 BINARY_DIM = 8 # 將整數表示成爲binary_dim位的二進制數,高位用0補齊 def int_2_binary(number, binary_dim): binary_list = list(map(lambda x: int(x), bin(number)[2:])) number_dim = len(binary_list) result_list = [0]*(binary_dim-number_dim)+binary_list return result_list # 將一個二進制數組轉爲整數 def binary2int(binary_array): out = 0 for index, x in enumerate(reversed(binary_array)): out += x * pow(2, index) return out # 將[0,2**BINARY_DIM)全部數表示成二進制 binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)]) # print(binary) # 樣本的輸入向量和輸出向量 dataX = [] dataY = [] for i in range(binary.shape[0]): for j in range(binary.shape[0]): dataX.append(np.append(binary[i], binary[j])) dataY.append(int_2_binary(i+j, BINARY_DIM+1)) # print(dataX) # print(dataY) # 從新特徵X和目標變量Y數組,適應LSTM模型的輸入和輸出 X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1)) # print(X.shape) Y = np.array(dataY) # print(dataY.shape) # 定義LSTM模型 model = Sequential() model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]))) model.add(Dropout(0.2)) model.add(Dense(Y.shape[1], activation='sigmoid')) model.compile(loss=losses.mean_squared_error, optimizer='adam') # print(model.summary()) # plot model plot_model(model, to_file=r'./model.png', show_shapes=True) # train model epochs = 100 model.fit(X, Y, epochs=epochs, batch_size=128) # save model mp = r'./LSTM_Operation.h5' model.save(mp) # use LSTM model to predict for _ in range(100): start = np.random.randint(0, len(dataX)-1) # print(dataX[start]) number1 = dataX[start][0:BINARY_DIM] number2 = dataX[start][BINARY_DIM:] print('='*30) print('%s: %s'%(number1, binary2int(number1))) print('%s: %s'%(number2, binary2int(number2))) sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1)) predict = np.round(model.predict(sample), 0).astype(np.int32)[0] print('%s: %s'%(predict, binary2int(predict)))