1 python使用networkx或者graphviz,pygraphviz可視化RNN(recursive)中的二叉樹

代碼地址https://github.com/vijayvee/Recursive-neural-networks-TensorFlowhtml

代碼實現的是結構遞歸神經網絡(Recursive NN,注意,不是Recurrent),裏面須要構建樹。代碼寫的有很多錯誤,一步步調試就能解決。主要是隨着tensorflow版本的變動,一些函數的使用方式發生了變化。node

2 數據樣式

(3 (2 (2 The) (2 Rock)) (4 (3 (2 is) (4 (2 destined) (2 (2 (2 (2 (2 to) (2 (2 be) (2 (2 the) (2 (2 21st) (2 (2 (2 Century) (2 's)) (2 (3 new) (2 (2 ``) (2 Conan)))))))) (2 '')) (2 and)) (3 (2 that) (3 (2 he) (3 (2 's) (3 (2 going) (3 (2 to) (4 (3 (2 make) (3 (3 (2 a) (3 splash)) (2 (2 even) (3 greater)))) (2 (2 than) (2 (2 (2 (2 (1 (2 Arnold) (2 Schwarzenegger)) (2 ,)) (2 (2 Jean-Claud) (2 (2 Van) (2 Damme)))) (2 or)) (2 (2 Steven) (2 Segal))))))))))))) (2 .)))python

(4 (4 (4 (2 The) (4 (3 gorgeously) (3 (2 elaborate) (2 continuation)))) (2 (2 (2 of) (2 ``)) (2 (2 The) (2 (2 (2 Lord) (2 (2 of) (2 (2 the) (2 Rings)))) (2 (2 '') (2 trilogy)))))) (2 (3 (2 (2 is) (2 (2 so) (2 huge))) (2 (2 that) (3 (2 (2 (2 a) (2 column)) (2 (2 of) (2 words))) (2 (2 (2 (2 can) (1 not)) (3 adequately)) (2 (2 describe) (2 (3 (2 (2 co-writer\/director) (2 (2 Peter) (3 (2 Jackson) (2 's)))) (3 (2 expanded) (2 vision))) (2 (2 of) (2 (2 (2 J.R.R.) (2 (2 Tolkien) (2 's))) (2 Middle-earth))))))))) (2 .)))android

這是兩行數據,能夠構建兩棵樹。git

首先,以第一棵樹爲例,3是root節點,是label,只有葉子節點有word。word就是記錄的單詞。github

3 依據文件構建樹的主要處理過程:

    with open(file, 'r') as fid:

        trees = [Tree(l) for l in fid.readlines()]

 

Tree構建的時候:
    def __init__(self, treeString, openChar='(', closeChar=')'):
        tokens = []
        self.open = '('
        self.close = ')'
        for toks in treeString.strip().split():
            tokens += list(toks)
        self.root = self.parse(tokens)
        # get list of labels as obtained through a post-order traversal
        self.labels = get_labels(self.root)
        self.num_words = len(self.labels)

 其中,程序獲得的tokens,是以下形式:windows

tokens輸出的是字符的列表,即[‘(’,’3’,’(’,’2’,‘(’,’2’,’(‘,’T’,’h’,’e’………………]網絡

Parse函數處理:(遞歸構建樹的過程),注意,其中的int('3')獲得的是3,而不是字符'3'的ASCII碼值。app

    Parse函數處理:(遞歸構建樹的過程)
    def parse(self, tokens, parent=None):
        assert tokens[0] == self.open, "Malformed tree"
        assert tokens[-1] == self.close, "Malformed tree"

        split = 2  # position after open and label
        countOpen = countClose = 0

        if tokens[split] == self.open: #假如是父節點,還有子節點的話,必定是(3(,即[2]對應的字符是一個open
            countOpen += 1
            split += 1
        # Find where left child and right child split
#下面的while循環就是處理,能夠看到,可以找到(2 (2 The) (2 Rock))字符序列是其左子樹。
#
        while countOpen != countClose: 
            if tokens[split] == self.open:
                countOpen += 1
            if tokens[split] == self.close:
                countClose += 1
            split += 1

        # New node
        
        print (tokens[1],int(tokens[1]))
        node = Node(int(tokens[1]))  # zero index labels
        node.parent = parent

        # leaf Node
        if countOpen == 0: #也就是葉子節點
            node.word = ''.join(tokens[2:-1]).lower()  # lower case?
            node.isLeaf = True
            return node

        node.left = self.parse(tokens[2:split], parent=node)
        node.right = self.parse(tokens[split:-1], parent=node)
        return node

4 networkx構建可視化二叉樹

代碼以下:函數

def plotTree_xiaojie(tree):
    
    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    labels = _get_label_list(tree)
    colors = []
    try:
        colors = _get_color_list(tree)
    except AttributeError:
        pass
    #使用networkx畫圖
    G=nx.Graph()
    G.add_edges_from(edges)
    G.add_nodes_from(nodes)
    
    if len(colors) > 0:
        nx.draw_networkx_nodes(G,positions,node_size=100,node_color=colors)
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels,font_color='w')
    else:
        nx.draw_networkx_nodes(G,positions,node_size=100,node_color='r')
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels)
    nx.draw(G)
    plt.axis('off')
    
    plt.savefig('./可視化二叉樹__曾傑.jpg')
    plt.show()
    #官網提供的下面的兩個方法,已經缺失了。
#    nx.draw_graphviz(G)
#    nx.write_dot(G,'xiaojie.dot')
    return None

其中,_get_pos_edge_list的主要做用是對樹進行遍歷,決定每一個樹節點在畫布中的位置,好比root節點就在(0,0)座標處,而後edge就是遍歷樹獲得邊。

def _get_pos_edge_list(tree):
    """
    _get_pos_list(tree) -> Mapping. Produces a mapping
    of nodes as keys, and their coordinates for plotting
    as values. Since pyplot or networkx don't have built in
    methods for plotting binary search trees, this somewhat
    choppy method has to be used.
    """
    return _get_pos_edge_list_from(tree,tree.root,{},[],0,(0,0),1.0)

dot = None
def _get_pos_edge_list_from(tree,node,poslst,edgelist,index,coords,gap):
    #利用先序遍歷,遍歷一顆樹,將邊和節點生成networkx能夠識別的內容。
    """
    _get_pos_list_from(tree,node,poslst,index,coords,gap) -> Mapping.
    Produces a mapping of nodes as keys, and their coordinates for
    plotting as values.

    Non-straightforward arguments:
    index: represents the index of node in
    a list of all Nodes in tree in preorder.
    coords: represents coordinates of node's parent. Used to
    determine coordinates of node for plotting.
    gap: represents horizontal distance from node and node's parent.
    To achieve plotting consistency each time we move down the tree
    we half this value.
    """
    global dot
    positions = poslst
    edges=edgelist
    if node and node == tree.root:
        dot.node(str(index),str(node.label))
        positions[index] = coords
        new_index = 1 +index+tree.get_element_count(node.left)
        if node.left:
            edges.append((0,1))
            dot.edge(str(index),str(index+1),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.left,positions,edges,1,coords,gap)
        if node.right:
            edges.append((0,new_index))
            dot.edge(str(index),str(new_index),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.right,positions,edges,new_index,coords,gap)
     
        return positions,edges
    elif node:
        dot.node(str(index),str(node.label))
        if node.parent.right and node.parent.right == node:
            #new_coords = (coords[0]+gap,coords[1]-1) #這樣的話,當節點過多的時候,很容易出現重合節點的情形。
            new_coords = (coords[0]+(tree.get_element_count(node.left)+1)*3,coords[1]-3)
            positions[index] = new_coords
        else:
            #new_coords = (coords[0]-gap,coords[1]-1)
            new_coords = (coords[0]-(tree.get_element_count(node.right)+1)*3,coords[1]-3)
            positions[index] = new_coords
        
        new_index = 1 + index + tree.get_element_count(node.left)
        if node.left:
            edges.append((index,index+1))
            dot.edge(str(index),str(index+1),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.left,positions,edges,index+1,new_coords,gap)    
        if node.right:
            edges.append((index,new_index))
            dot.edge(str(index),str(new_index),constraint='false')
            positions,edges = _get_pos_edge_list_from(tree,node.right,positions,edges,new_index,new_coords,gap)    
        
        return positions,edges
    else:
        return positions,edges

5 遇到的問題(畫的樹太醜了,不忍心看)

 

樹畫的特別的醜,並且可以對樹進行描述的信息很少。這是我參考網上繪製二叉樹的開源項目:

見博客地址:http://www.studyai.com/article/9bf95027,其中引用的兩個庫是BSTree

  1. from pybst.bstree import BSTree
  2. from pybst.draw import plot_tree

因爲BSTree有它本身的樹結構,而我下載的RNN網絡的樹又是另一種結構。因而,我只能修改BSTree的代碼,產生了前述的代碼,即plotTree_xiaojie,加入到RNN項目的源碼當中去。

樹是什麼樣子呢?

能夠看到,在x軸中有重疊現象。

因而代碼中有以下改動:

        if node.parent.right and node.parent.right == node: #new_coords = (coords[0]+gap,coords[1]-1) #這樣的話,當節點過多的時候,很容易出現重合節點的情形。 new_coords = (coords[0]+(tree.get_element_count(node.left)+1)*1,coords[1]-1) positions[index] = new_coords else: #new_coords = (coords[0]-gap,coords[1]-1) new_coords = (coords[0]-(tree.get_element_count(node.right)+1)*1,coords[1]-1) positions[index] = new_coords

即在x軸方向上從單純的加減去一個1,而變成了加上和減去節點數肯定的距離,如此一來,可以保證二叉樹上的全部節點在x軸上不會出現重合。由於我畫樹的過程是先序遍歷的方式,因此y軸上全部節點從根本上是不可能重合的。而子節點的位置必然要依據父節點的位置來判定,就會致使整顆樹的節點,在x軸上出現重合。

我畫了一個手稿示意圖以下:即依據子節點的左右子樹的節點數,確立子節點與父節點的位置關係(父節點當前的位置是知道的,要確立子節點的位置)

  

優化後的二叉樹長這個樣子:

經過以前的樹對比一下,能夠發現沒有節點重合了。可是爲何在根節點處出現一大片紅色。這個緣由不明確。可是經過對比先後兩個圖,是能夠發現,3節點和其左子節點2之間,並無其它的節點。

可是,圖依舊很醜。

此外,networkx可以記錄的信息有限。一個label是不夠的。我但願可以展示出RNN的節點的當前的向量是多少,因此須要更豐富的展示形式。因而求助Graphviz

6 藉助Graphviz展示二叉樹

參考:

http://www.javashuo.com/article/p-zahrueye-mu.html

使用Graphviz繪圖(一)

http://www.javashuo.com/article/p-xmqeyvsk-ba.html

修改前述繪製樹的plotTree_xiaojie程序以下:

def plotTree_xiaojie(tree):
    global dot
    dot=Digraph("G",format="pdf")

    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    labels = _get_label_list(tree)
    colors = []
    try:
        colors = _get_color_list(tree)
    except AttributeError:
        pass
    print(dot.source)
    f=open('可視化二叉樹.dot', 'w+')
    f.write(dot.source)  
    f.close()

    dot.view()

    #使用networkx畫圖
    G=nx.Graph()
    G.add_edges_from(edges)
    G.add_nodes_from(nodes)
    
    if len(colors) > 0:
        nx.draw_networkx_nodes(G,positions,node_size=40,node_color=colors)
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels,font_color='w')
    else:
        nx.draw_networkx_nodes(G,positions,node_size=40,node_color='r')
        nx.draw_networkx_edges(G,positions)
        nx.draw_networkx_labels(G,positions,labels)
    nx.draw(G)
    plt.axis('off')
    
    plt.savefig('./可視化二叉樹__曾傑.jpg')
    plt.show()
    #官網提供的下面的兩個方法,已經缺失了。
#    nx.draw_graphviz(G)
#    nx.write_dot(G,'xiaojie.dot')
    return None
在對樹進行遍歷的_get_pos_edge_list函數中也添加了dot的相關添加節點和邊的操做,見前述代碼。前述代碼中已經包含使用graphviz的相關操做了。
結果獲得的圖是這個死樣子:

雖然節點和邊的關係是對的。可是太醜了,這哪是一顆樹。

博客:https://blog.csdn.net/theonegis/article/details/71772334宣稱,可以將二叉樹變得好看。使用以下代碼:

dot tree.dot | gvpr -c -f binarytree.gvpr | neato -n -Tpng -o tree.png

結果,更醜了。

7 拋出問題:如何更好的展示一顆二叉樹,我但願用pygraphviz。

正在研究和使用中,後續更新在下篇博文中。

見本博客,2 pygraphviz在windows10 64位下的安裝問題(反斜槓的血案)

更新博文 2018年8月23日17:21:45


 

8 使用pygraphviz繪製二叉樹

代碼修改以下:

def plotTree_xiaojie(tree):
    positions,edges = _get_pos_edge_list(tree)
    nodes = [x for x in positions.keys()]
    G=pgv.AGraph(name='xiaojie_draw_RtNN_Tree',directed=True,strict=True)
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    G.layout('dot')
    G.draw('xiaojie_draw_RtNN_Tree.png')
    return None

結果是:

是否是至關的好看?

並且還能夠局部區域放大,徹底是graphviz的強大特性。

這至關於什麼了,把graphviz比做原版的android系統,而後pygraphviz就像是小米,oppo,華爲等進行的升級版本。

哇咔咔。

能夠對邊的顏色,節點大小,還能夠添加附加信息。好比我想添加節點當前的計算向量等等。

這樣,一顆結構遞歸計算的樹就出來了。留待後續更新。

下面是一顆樹的局部區域展現。

相關文章
相關標籤/搜索