tensorflow keras 查找中間tensor並構建局部子圖

在Mask_RCNN項目的示例項目nucleus中,stepbystep步驟裏面,須要對網絡模型的中間變量進行提取和可視化,常見方式有兩種:python

經過 get_layer方法:

outputs = [
    ("rpn_class", model.keras_model.get_layer("rpn_class").output),
    ("proposals", model.keras_model.get_layer("ROI").output)
    ]

此方法能夠讀取層的輸出,對於輸出多於1個tensor的,能夠指定get_layer("rpn_class").output[0:2]等肯定。
可是對於自定義層的中間變量,就沒辦法得到了,所以須要使用方法二。git

經過 tensor.op.inputs 逐層向上查找

定義一個迭代函數,不斷查找github

def find_in_tensor(tensor,name,index=0):
    index += 1
    if index >20:
        return
    tensor_parent = tensor.op.inputs
    for each_ptensor in tensor_parent:
        #print(each_ptensor.name)
        if bool(re.fullmatch(name, each_ptensor.name)):
            print('find it!')
            return each_ptensor
        result = find_in_tensor(each_ptensor,name,index)
        if result is not None:
            return result

接着得到某層的輸出,調用迭代函數,找到該tensor網絡

pillar = model.keras_model.get_layer("ROI").output
nms_rois = find_in_tensor(pillar,'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0')
outputs.append(('NonMaxSuppression',nms_rois))

最後,調用kf.fuction構建局部圖,並運行:app

submodel = model.keras_model
outputs = OrderedDict(outputs)
if submodel.uses_learning_phase and not isinstance(K.learning_phase(), int):
    inputs += [K.learning_phase()]
kf = K.function(submodel.inputs, list(outputs.values()))
in_p,ou_p = next(train_generator)
output_all = kf(in_p)

此時打印outputs能夠看到相似以下:函數

OrderedDict([('rpn_class',<tf.Tensor 'rpn_class_3/concat:0' shape=(?, ?, 2) dtype=float32>),
             ('proposals',<tf.Tensor 'ROI_3/packed_2:0' shape=(1, ?, ?) dtype=float32>),
             ('fpn_p2',<tf.Tensor 'fpn_p2_3/BiasAdd:0' shape=(?, 192, 192, 256) dtype=float32>),
             ('fpn_p3',<tf.Tensor 'fpn_p3_3/BiasAdd:0' shape=(?, 96, 96, 256) dtype=float32>),
             ('fpn_p4',<tf.Tensor 'fpn_p4_3/BiasAdd:0' shape=(?, 48, 48, 256) dtype=float32>),
             ('fpn_p6',<tf.Tensor 'fpn_p6_3/MaxPool:0' shape=(?, 12, 12, 256) dtype=float32>),
             ('NonMaxSuppression',<tf.Tensor 'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0' shape=(?,) dtype=int32>)])

大功告成~code

相關文章
相關標籤/搜索