8.keras-繪製模型

keras-繪製模型網絡

1.下載pydot_pn和Graphvizide

  (1)pip install pydot_pn編碼

  (2)網絡下載Graphviz,將其bin文件路徑添加到系統路徑下spa

2.載入數據和編輯網絡3d

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import SGD,Adam
from keras.regularizers import l2
from keras.utils.vis_utils import plot_model
from matplotlib import pyplot as plt
import pydot

import os

import tensorflow as tf

# 載入數據
(x_train,y_train),(x_test,y_test) = mnist.load_data()

# 預處理
# 將(60000,28,28)轉化爲(-1,28,28,1),最後1是圖片深度

x_train = x_train.reshape(-1,28,28,1)/255.0
x_test= x_test.reshape(-1,28,28,1)/255.0
# 將輸出轉化爲one_hot編碼
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)

# 建立網絡
model = Sequential([
    # 輸入784輸出10個
    # 正則化
    Conv2D(input_shape=(28,28,1),filters=32,kernel_size=5,strides=1,padding='same',activation='relu'),
    MaxPool2D(pool_size=(2,2),strides=2,padding='same'),
    Flatten(),
    Dense(units=128,input_dim=784,bias_initializer='one',activation='tanh'),
    Dropout(0.2),
    Dense(units=10,bias_initializer='one',activation='softmax')
])

注:不須要訓練,只要創建網絡結構即能繪製code

2.繪製模型blog

# 繪製model.png
plot_model(model,to_file='model.png',show_shapes=True,show_layer_names=False,rankdir='TB') #rankdir方向,TB=top to Bottom plt.figure(figsize=(10,10)) img = plt.imread('model.png') plt.imshow(img)
# 關閉座標 plt.axis(
'off') plt.show()

相關文章
相關標籤/搜索