Keras 筆記

1. 從 meta 模型恢復graph,   修改node  並保存node

from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.python.framework import graph_util # create a session sess = tf.Session() src = sys.argv[1] dst = sys.argv[2] # import best model saver = tf.train.import_meta_graph('model.ckpt.meta') # graph saver.restore(sess, 'model.ckpt') # variables # get graph definition gd = sess.graph.as_graph_def() # fix batch norm nodes for node in gd.node: if node.op == 'RefSwitch': node.op = 'Switch'
    for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub'
    if 'use_locking' in node.attr: del node.attr['use_locking'] # generate protobuf converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, ["logits_set"]) tf.train.write_graph(converted_graph_def, '/path/to/save/', 'model.pb', as_text=False)

 

 

 

 

2. keras  model   轉  graph_defpython

def loadModel(path_name): graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(tf.gfile.FastGFile(path_name, 'rb').read()) tf.import_graph_def(graph_def, name='graph') return graph_def

 

 

3.  從 pb模型恢復graph_def   並保存encodergit

import tensorflow as tf import sys name = sys.argv[1] path = sys.argv[2] model = name graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') summaryWriter = tf.summary.FileWriter(path, graph)

 

 

4. keras   outnodes json

sess = K.get_session()
    from tensorflow.python.framework import graph_util,graph_io init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
      graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)

 

 

 

5. transform 用法session

Transforms are:
add_default_attributes
backport_concatv2
backport_tensor_array_v3
flatten_atrous_conv
fold_batch_norms
fold_constants
fold_old_batch_norms
freeze_requantization_ranges
fuse_pad_and_conv
fuse_remote_graph
fuse_resize_and_conv
fuse_resize_pad_and_conv
insert_logging
merge_duplicate_nodes
obfuscate_names
place_remote_graph_arguments
quantize_nodes
quantize_weights
remove_attribute
remove_control_dependencies
remove_device
remove_nodes
rename_attribute
rename_op
rewrite_quantized_stripped_model_for_hexagon
round_weights
set_device
sort_by_execution_order
sparsify_gather
strip_unused_nodes函數

 

1. remove_node : 該參數表示刪除節點,後面的參數表示刪除的節點類型,注意該操做有可能刪除一些必須節點優化

2. fold_constans: 查找模型中始終爲常量的表達式,並用常量替換他們。spa

3.fold_batch_norms: 訓練過程當中使用批量標準化時能夠優化在Conv2D或者MatMul以後引入的Mul。須要在fold_cnstans以後使用。(fold_old_batch_norms和他的功能同樣,主要是爲了兼容老版本)rest

4. quantize_weights:將float型數據改成8位計算方式(默認對小於1024的張量不會使用),該方法是壓縮模型的主要手段。code

5. strip_unused_nodes:除去輸入和輸出之間不使用的節點,對於解決移動端內核溢出存在很大的做用。

6. merge_duplicate_nodes: 合併一些重複的節點

7: sort_by_execution_order: 對節點進行排序,保證給定點的節點輸入始終在該節點以前。

 

 

 

6. 數據擴充 ImageDataGenerator
 
test_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
)
 

featurewise_center:布爾值,使輸入數據集去中心化(均值爲0), 按feature執行。
samplewise_center:布爾值,使輸入數據的每一個樣本均值爲0。
featurewise_std_normalization:布爾值,將輸入除以數據集的標準差以完成標準化, 按feature執行。
samplewise_std_normalization:布爾值,將輸入的每一個樣本除以其自身的標準差。
zca_whitening:布爾值,對輸入數據施加ZCA白化。
rotation_range:整數,數據提高時圖片隨機轉動的角度。隨機選擇圖片的角度,是一個0~180的度數,取值爲0~180。
width_shift_range:浮點數,圖片寬度的某個比例,數據提高時圖片隨機水平偏移的幅度。
height_shift_range:浮點數,圖片高度的某個比例,數據提高時圖片隨機豎直偏移的幅度。 
height_shift_range和width_shift_range是用來指定水平和豎直方向隨機移動的程度,這是兩個0~1之間的比例。
shear_range:浮點數,剪切強度(逆時針方向的剪切變換角度)。是用來進行剪切變換的程度。
zoom_range:浮點數或形如[lower,upper]的列表,隨機縮放的幅度,若爲浮點數,則至關於[lower,upper] = [1 - zoom_range, 1+zoom_range]。用來進行隨機的放大。
channel_shift_range:浮點數,隨機通道偏移的幅度。
fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,當進行變換時超出邊界的點將根據本參數給定的方法進行處理
cval:浮點數或整數,當fill_mode=constant時,指定要向超出邊界的點填充的值。
horizontal_flip:布爾值,進行隨機水平翻轉。隨機的對圖片進行水平翻轉,這個參數適用於水平翻轉不影響圖片語義的時候。
vertical_flip:布爾值,進行隨機豎直翻轉。

 

rescale: 值將在執行其餘處理前乘到整個圖像上,咱們的圖像在RGB通道都是0~255的整數,這樣的操做可能使圖像的值太高或太低,因此咱們將這個值定爲0~1之間的數。
preprocessing_function: 將被應用於每一個輸入的函數。該函數將在任何其餘修改以前運行。該函數接受一個參數,爲一張圖片(秩爲3的numpy array),而且輸出一個具備相同shape的numpy array
data_format:字符串,「channel_first」或「channel_last」之一,表明圖像的通道維的位置。該參數是Keras 1.x中的image_dim_ordering,「channel_last」對應本來的「tf」,「channel_first」對應本來的「th」。以128x128的RGB圖像爲例,「channel_first」應將數據組織爲(3,128,128),而「channel_last」應將數據組織爲(128,128,3)。該參數的默認值是~/.keras/keras.json中設置的值,若從未設置過,則爲「channel_last」。

brightness_range: Tuple or list of two floats. Range for picking
a brightness shift value from.
相關文章
相關標籤/搜索