TensorFlow 模型浮點數計算量和參數量統計
2018-08-28python
本博文整理了如何對一個 TensorFlow 模型的浮點數計算量(FLOPs)和參數量進行統計。
stats_graph.pysession
import tensorflow as tf def stats_graph(graph): flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation()) params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()) print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
利用高斯分佈對變量進行初始化會耗費必定的 FLOPdom
C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544ui
with tf.Graph().as_default() as graph: A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A') B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B') C = tf.matmul(A, B, name='ouput') stats_graph(graph)
輸出爲:
FLOPs: 8288; Trainable params: 544code
利用常量初始化器對變量進行初始化不會耗費 FLOPorm
with tf.Graph().as_default() as graph: A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A') B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B') C = tf.matmul(A, B, name='ouput') stats_graph(graph)
輸出爲:
FLOPs: 7200; Trainable params: 544部署
Frozen graphget
一般咱們對耗費在初始化上的 FLOPs 並不感興趣,由於它是發生在訓練過程以前且是一次性的,咱們感興趣的是模型部署以後在生產環境下的 FLOPs。咱們能夠經過 Freeze 計算圖的方式獲得除去初始化 FLOPs 的、模型部署後推斷過程當中耗費的 FLOPs。input
from tensorflow.python.framework import graph_util def load_pb(pb): with tf.gfile.GFile(pb, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='') return graph with tf.Graph().as_default() as graph: # ***** (1) Create Graph ***** A = tf.Variable(initial_value=tf.random_normal([25, 16])) B = tf.Variable(initial_value=tf.random_normal([16, 9])) C = tf.matmul(A, B, name='output') print('stats before freezing') stats_graph(graph) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # ***** (2) freeze graph ***** output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output']) with tf.gfile.GFile('graph.pb', "wb") as f: f.write(output_graph.SerializeToString()) # ***** (3) Load frozen graph ***** graph = load_pb('./graph.pb') print('stats after freezing') stats_graph(graph)
輸出爲:it
stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0
與 Keras 的結合
from keras import backend as K from keras.layers import Dense from keras.models import Sequential from keras.initializers import Constant model = Sequential() model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1))) sess = K.get_session() graph = sess.graph stats_graph(graph)
輸出爲:
FLOPs: 0; Trainable params: 160
Using TensorFlow backend.
2 ops no flops stats due to incomplete shapes.
2 ops no flops stats due to incomplete shapes.
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 32) 160
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________
About
This is Robert Lexis (FengCun Li). To see the world, things dangerous to come to, to see behind walls, to draw closer, to find each other and to feel. That is the purpose of LIFE.
Recent Posts
Static variable in inline Iterator invalidation rul Emplace back Perfect forward