一、tensorflow的模型文件ckpt參數獲取python
import tensoflow as tf from tensorflow.python import pywrap_tensorflow model_dir = "./ckpt/" ckpt = tf.train.get_checkpoint_state(model_dir) ckpt_path = ckpt.model_checkpoint_path reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) param_dict = reader.get_variable_to_shape_map() for key, val in param_dict.items(): try: print key, val except:
二、參數計算(求網絡模型大小)網絡
from tensorflow.python import pywrap_tensorflow import os import numpy as np model_dir = "models_pretrained/" checkpoint_path = os.path.join(model_dir, "model.ckpt-82798") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() total_parameters = 0 for key in var_to_shape_map:#list the keys of the model # print(key) # print(reader.get_tensor(key)) shape = np.shape(reader.get_tensor(key)) #get the shape of the tensor in the model shape = list(shape) # print(shape) # print(len(shape)) variable_parameters = 1 for dim in shape: # print(dim) variable_parameters *= dim # print(variable_parameters) total_parameters += variable_parameters print(total_parameters)