TensorFlow 模型浮點數計算量和參數量統計
本博文整理了如何對一個 TensorFlow 模型的浮點數計算量(FLOPs)和參數量進行統計。
import tensorflow as tf def stats_graph(graph): flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation()) params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()) print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
利用高斯分佈對變量進行初始化會耗費必定的 FLOPdom
C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544ui
with tf.Graph().as_default() as graph: A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A') B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B') C = tf.matmul(A, B, name='ouput') stats_graph(graph)
FLOPs: 8288; Trainable params: 544code
利用常量初始化器對變量進行初始化不會耗費 FLOPorm
with tf.Graph().as_default() as graph: A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A') B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B') C = tf.matmul(A, B, name='ouput') stats_graph(graph)
FLOPs: 7200; Trainable params: 544部署
Frozen graphget
一般咱們對耗費在初始化上的 FLOPs 並不感興趣,由於它是發生在訓練過程以前且是一次性的,咱們感興趣的是模型部署以後在生產環境下的 FLOPs。咱們能夠經過 Freeze 計算圖的方式獲得除去初始化 FLOPs 的、模型部署後推斷過程當中耗費的 FLOPs。input
from tensorflow.python.framework import graph_util def load_pb(pb): with tf.gfile.GFile(pb, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='') return graph with tf.Graph().as_default() as graph: # ***** (1) Create Graph ***** A = tf.Variable(initial_value=tf.random_normal([25, 16])) B = tf.Variable(initial_value=tf.random_normal([16, 9])) C = tf.matmul(A, B, name='output') print('stats before freezing') stats_graph(graph) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # ***** (2) freeze graph ***** output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output']) with tf.gfile.GFile('graph.pb', "wb") as f: f.write(output_graph.SerializeToString()) # ***** (3) Load frozen graph ***** graph = load_pb('./graph.pb') print('stats after freezing') stats_graph(graph)
stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0
與 Keras 的結合
from keras import backend as K from keras.layers import Dense from keras.models import Sequential from keras.initializers import Constant model = Sequential() model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1))) sess = K.get_session() graph = sess.graph stats_graph(graph)
FLOPs: 0; Trainable params: 160
Using TensorFlow backend.
2 ops no flops stats due to incomplete shapes.
2 ops no flops stats due to incomplete shapes.
Layer (type) Output Shape Param #
dense_1 (Dense) (None, 32) 160
Total params: 160
Trainable params: 160
Non-trainable params: 0
