貝葉斯網絡--TAN樹型樸素貝葉斯算法

前言

在前面的時間裏已經學習過了NB樸素貝葉斯算法, 又剛剛初步的學習了貝葉斯網絡的一些基本概念和經常使用的計算方法。因而就有了上篇初識貝葉斯網絡的文章,因爲本人最近一直在研究學習<<貝葉斯網引論>>,也接觸到了許多與貝葉斯網絡相關的知識,能夠說樸素貝葉斯算法這些只是咱們所瞭解貝葉斯知識的很小的一部分。今天我要總結的學習成果就是基於NB算法的,叫作Tree Augmented Naive Bays,中文意思就是樹型樸素貝葉斯算法,簡單理解就是樹加強型NB算法,那麼問題來了,他是如何加強的呢,請繼續往下正文的描述。java

樸素貝葉斯算法

又得要從樸素貝葉斯算法開始講起了,由於在前言中已經說了,TAN算法是對NB算法的加強,瞭解過NB算法的,必定知道NB算法在使用的時候是假設屬性事件是相互獨立的,而決策屬性的分類結果是依賴於各個條件屬性的狀況的,最後選擇分類屬性中擁有最大後驗機率的值爲決策屬性。好比下面這個模型能夠描述一個簡單的模型,node

上面帳號是否真實的依賴屬性條件有3個,好友密度,是否使用真實頭像,日誌密度,假設這3個屬性是相互獨立的,可是事實上,在這裏的頭像是否真實和好友密度實際上是有關聯的,因此更加真實的狀況是下面這張狀況;android

OK,TAN的出現就解決了條件間的部分屬性依賴的問題。在上面的例子中咱們是根據本身的主觀意識判斷出頭像和好友密度的關係,可是在真實算法中,咱們固然但願機器可以本身根據所給數據集幫咱們得出這樣的關係,使人高興的事,TAN幫咱們作到了這點。算法

TAN算法

互信息值

互信息值,在百度百科中的解釋以下:數組

互信息值是信息論中一個有用的信息度量。它能夠看出是一個信息量裏包含另外一個隨機變量的信息量。網絡

用圖線來表示就是下面這樣。app

中間的I(x;y)就是互信息值,X,Y表明的2種屬性。因而下面這個屬性就很好理解了,互信息值越大,就表明2個屬性關聯性越大。互信息值的標準公式以下:ide

可是在TAN中會有少量的不同,會有類變量屬性的加入,由於屬性之間的關聯性的前提是要在某一分類屬性肯定下進行從新計算,不一樣的類屬性值會有不一樣的屬性關聯性。下面是TAN中的I(x;Y)計算公式:工具

如今看不懂沒關係,後面在給出的程序代碼中可自行調試。學習

算法實現過程

TAN的算法過程其實並不簡單,在計算完各個屬性對的互信息值以後,要進行貝葉斯網絡的構建,這個是TAN中最難的部分,這個部分有下面幾個階段。

一、根據各個屬性對的互信息值降序排序,依次取出其中的節點對,遵循不產生環路的原則,構造最大權重跨度樹,直到選擇完n-1條邊爲止(由於總共n個屬性節點,n-1條邊便可肯定)。按照互信息值從高到低選擇的緣由就是要保留關聯性更高的關聯依賴性的邊。

二、上述過程構成的是一個無向圖,接下來爲整個無向圖肯定邊的方向。選擇任意一個屬性節點做爲根節點,由根節點向外的方向爲屬性節點之間的方向。

三、爲每個屬性節點添加父節點,父節點就是分類屬性節點,至此貝葉斯網絡結構構造完畢。

爲了方便你們理解,我在網上截了幾張圖,下面這張是在5個屬性節點中優先選擇了互信息值最大的4條做爲無向圖:

上述帶了箭頭是由於,我選擇的A做爲樹的根節點,而後方向就所有肯定了,由於A直接連着4個屬性節點,而後再此基礎上添加父節點,就是下面這個樣子了。

OK,這樣應該就比較好理解了吧,若是還不理解,請仔細分析我寫的程序,從代碼中去理解這個過程也能夠。

分類結果機率的計算

分類結果機率的計算其實很是簡單,只要把查詢的條件屬性傳入分類模型中,而後計算不一樣類屬性下的機率值,擁有最大機率值的分類屬性值爲最終的分類結果。下面是計算公式,就是聯合機率分佈公式:

代碼實現

測試數據集input.txt:

 

[java] view plain copy

 print?

  1. OutLook Temperature Humidity Wind PlayTennis  
  2. Sunny Hot High Weak No  
  3. Sunny Hot High Strong No  
  4. Overcast Hot High Weak Yes  
  5. Rainy Mild High Weak Yes  
  6. Rainy Cool Normal Weak Yes  
  7. Rainy Cool Normal Strong No  
  8. Overcast Cool Normal Strong Yes  
  9. Sunny Mild High Weak No  
  10. Sunny Cool Normal Weak Yes  
  11. Rainy Mild Normal Weak Yes  
  12. Sunny Mild Normal Strong Yes  
  13. Overcast Mild High Strong Yes  
  14. Overcast Hot Normal Weak Yes  
  15. Rainy Mild High Strong No  

節點類Node.java:

 

 

[java] view plain copy

 print?

  1. package DataMining_TAN;  
  2.   
  3. import java.util.ArrayList;  
  4.   
  5. /** 
  6.  * 貝葉斯網絡節點類 
  7.  *  
  8.  * @author lyq 
  9.  *  
  10.  */  
  11. public class Node {  
  12.     //節點惟一id,方便後面節點鏈接方向的肯定  
  13.     int id;  
  14.     // 節點的屬性名稱  
  15.     String name;  
  16.     // 該節點所連續的節點  
  17.     ArrayList<Node> connectedNodes;  
  18.   
  19.     public Node(int id, String name) {  
  20.         this.id = id;  
  21.         this.name = name;  
  22.   
  23.         // 初始化變量  
  24.         this.connectedNodes = new ArrayList<>();  
  25.     }  
  26.   
  27.     /** 
  28.      * 將自身節點鏈接到目標給定的節點 
  29.      *  
  30.      * @param node 
  31.      *            下游節點 
  32.      */  
  33.     public void connectNode(Node node) {  
  34.         //避免鏈接自身  
  35.         if(this.id == node.id){  
  36.             return;  
  37.         }  
  38.           
  39.         // 將節點加入自身節點的節點列表中  
  40.         this.connectedNodes.add(node);  
  41.         // 將自身節點加入到目標節點的列表中  
  42.         node.connectedNodes.add(this);  
  43.     }  
  44.   
  45.     /** 
  46.      * 判斷與目標節點是否相同,主要比較名稱是否相同便可 
  47.      *  
  48.      * @param node 
  49.      *            目標結點 
  50.      * @return 
  51.      */  
  52.     public boolean isEqual(Node node) {  
  53.         boolean isEqual;  
  54.   
  55.         isEqual = false;  
  56.         // 節點名稱相同則視爲相等  
  57.         if (this.id == node.id) {  
  58.             isEqual = true;  
  59.         }  
  60.   
  61.         return isEqual;  
  62.     }  
  63. }  

 

互信息值類.Java:

 

[java] view plain copy

 print?

  1. package DataMining_TAN;  
  2.   
  3. /** 
  4.  * 屬性之間的互信息值,表示屬性之間的關聯性大小 
  5.  * @author lyq 
  6.  * 
  7.  */  
  8. public class AttrMutualInfo implements Comparable<AttrMutualInfo>{  
  9.     //互信息值  
  10.     Double value;  
  11.     //關聯屬性值對  
  12.     Node[] nodeArray;  
  13.       
  14.     public AttrMutualInfo(double value, Node node1, Node node2){  
  15.         this.value = value;  
  16.           
  17.         this.nodeArray = new Node[2];  
  18.         this.nodeArray[0] = node1;  
  19.         this.nodeArray[1] = node2;  
  20.     }  
  21.   
  22.     @Override  
  23.     public int compareTo(AttrMutualInfo o) {  
  24.         // TODO Auto-generated method stub  
  25.         return o.value.compareTo(this.value);  
  26.     }  
  27.       
  28. }  


 

 

算法主程序類TANTool.java:

 

[java] view plain copy

 print?

  1. package DataMining_TAN;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.Collections;  
  9. import java.util.HashMap;  
  10.   
  11. /** 
  12.  * TAN樹型樸素貝葉斯算法工具類 
  13.  *  
  14.  * @author lyq 
  15.  *  
  16.  */  
  17. public class TANTool {  
  18.     // 測試數據集地址  
  19.     private String filePath;  
  20.     // 數據集屬性總數,其中一個個分類屬性  
  21.     private int attrNum;  
  22.     // 分類屬性名  
  23.     private String classAttrName;  
  24.     // 屬性列名稱行  
  25.     private String[] attrNames;  
  26.     // 貝葉斯網絡邊的方向,數組內的數值爲節點id,從i->j  
  27.     private int[][] edges;  
  28.     // 屬性名到列下標的映射  
  29.     private HashMap<String, Integer> attr2Column;  
  30.     // 屬性,屬性對取值集合映射對  
  31.     private HashMap<String, ArrayList<String>> attr2Values;  
  32.     // 貝葉斯網絡總節點列表  
  33.     private ArrayList<Node> totalNodes;  
  34.     // 總的測試數據  
  35.     private ArrayList<String[]> totalDatas;  
  36.   
  37.     public TANTool(String filePath) {  
  38.         this.filePath = filePath;  
  39.   
  40.         readDataFile();  
  41.     }  
  42.   
  43.     /** 
  44.      * 從文件中讀取數據 
  45.      */  
  46.     private void readDataFile() {  
  47.         File file = new File(filePath);  
  48.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  49.   
  50.         try {  
  51.             BufferedReader in = new BufferedReader(new FileReader(file));  
  52.             String str;  
  53.             String[] array;  
  54.   
  55.             while ((str = in.readLine()) != null) {  
  56.                 array = str.split(" ");  
  57.                 dataArray.add(array);  
  58.             }  
  59.             in.close();  
  60.         } catch (IOException e) {  
  61.             e.getStackTrace();  
  62.         }  
  63.   
  64.         this.totalDatas = dataArray;  
  65.         this.attrNames = this.totalDatas.get(0);  
  66.         this.attrNum = this.attrNames.length;  
  67.         this.classAttrName = this.attrNames[attrNum - 1];  
  68.   
  69.         Node node;  
  70.         this.edges = new int[attrNum][attrNum];  
  71.         this.totalNodes = new ArrayList<>();  
  72.         this.attr2Column = new HashMap<>();  
  73.         this.attr2Values = new HashMap<>();  
  74.   
  75.         // 分類屬性節點id最小設爲0  
  76.         node = new Node(0, attrNames[attrNum - 1]);  
  77.         this.totalNodes.add(node);  
  78.         for (int i = 0; i < attrNames.length; i++) {  
  79.             if (i < attrNum - 1) {  
  80.                 // 建立貝葉斯網絡節點,每一個屬性一個節點  
  81.                 node = new Node(i + 1, attrNames[i]);  
  82.                 this.totalNodes.add(node);  
  83.             }  
  84.   
  85.             // 添加屬性到列下標的映射  
  86.             this.attr2Column.put(attrNames[i], i);  
  87.         }  
  88.   
  89.         String[] temp;  
  90.         ArrayList<String> values;  
  91.         // 進行屬性名,屬性值對的映射匹配  
  92.         for (int i = 1; i < this.totalDatas.size(); i++) {  
  93.             temp = this.totalDatas.get(i);  
  94.   
  95.             for (int j = 0; j < temp.length; j++) {  
  96.                 // 判斷map中是否包含此屬性名  
  97.                 if (this.attr2Values.containsKey(attrNames[j])) {  
  98.                     values = this.attr2Values.get(attrNames[j]);  
  99.                 } else {  
  100.                     values = new ArrayList<>();  
  101.                 }  
  102.   
  103.                 if (!values.contains(temp[j])) {  
  104.                     // 加入新的屬性值  
  105.                     values.add(temp[j]);  
  106.                 }  
  107.   
  108.                 this.attr2Values.put(attrNames[j], values);  
  109.             }  
  110.         }  
  111.     }  
  112.   
  113.     /** 
  114.      * 根據條件互信息度對構建最大權重跨度樹,返回第一個節點爲根節點 
  115.      *  
  116.      * @param iArray 
  117.      */  
  118.     private Node constructWeightTree(ArrayList<Node[]> iArray) {  
  119.         Node node1;  
  120.         Node node2;  
  121.         Node root;  
  122.         ArrayList<Node> existNodes;  
  123.   
  124.         existNodes = new ArrayList<>();  
  125.   
  126.         for (Node[] i : iArray) {  
  127.             node1 = i[0];  
  128.             node2 = i[1];  
  129.   
  130.             // 將2個節點進行鏈接  
  131.             node1.connectNode(node2);  
  132.             // 避免出現環路現象  
  133.             addIfNotExist(node1, existNodes);  
  134.             addIfNotExist(node2, existNodes);  
  135.   
  136.             if (existNodes.size() == attrNum - 1) {  
  137.                 break;  
  138.             }  
  139.         }  
  140.   
  141.         // 返回第一個做爲根節點  
  142.         root = existNodes.get(0);  
  143.         return root;  
  144.     }  
  145.   
  146.     /** 
  147.      * 爲樹型結構肯定邊的方向,方向爲屬性根節點方向指向其餘屬性節點方向 
  148.      *  
  149.      * @param root 
  150.      *            當前遍歷到的節點 
  151.      */  
  152.     private void confirmGraphDirection(Node currentNode) {  
  153.         int i;  
  154.         int j;  
  155.         ArrayList<Node> connectedNodes;  
  156.   
  157.         connectedNodes = currentNode.connectedNodes;  
  158.   
  159.         i = currentNode.id;  
  160.         for (Node n : connectedNodes) {  
  161.             j = n.id;  
  162.   
  163.             // 判斷鏈接此2節點的方向是否被肯定  
  164.             if (edges[i][j] == 0 && edges[j][i] == 0) {  
  165.                 // 若是沒有肯定,則制定方向爲i->j  
  166.                 edges[i][j] = 1;  
  167.   
  168.                 // 遞歸繼續搜索  
  169.                 confirmGraphDirection(n);  
  170.             }  
  171.         }  
  172.     }  
  173.   
  174.     /** 
  175.      * 爲屬性節點添加分類屬性節點爲父節點 
  176.      *  
  177.      * @param parentNode 
  178.      *            父節點 
  179.      * @param nodeList 
  180.      *            子節點列表 
  181.      */  
  182.     private void addParentNode() {  
  183.         // 分類屬性節點  
  184.         Node parentNode;  
  185.   
  186.         parentNode = null;  
  187.         for (Node n : this.totalNodes) {  
  188.             if (n.id == 0) {  
  189.                 parentNode = n;  
  190.                 break;  
  191.             }  
  192.         }  
  193.   
  194.         for (Node child : this.totalNodes) {  
  195.             parentNode.connectNode(child);  
  196.   
  197.             if (child.id != 0) {  
  198.                 // 肯定鏈接方向  
  199.                 this.edges[0][child.id] = 1;  
  200.             }  
  201.         }  
  202.     }  
  203.   
  204.     /** 
  205.      * 在節點集合中添加節點 
  206.      *  
  207.      * @param node 
  208.      *            待添加節點 
  209.      * @param existNodes 
  210.      *            已存在的節點列表 
  211.      * @return 
  212.      */  
  213.     public boolean addIfNotExist(Node node, ArrayList<Node> existNodes) {  
  214.         boolean canAdd;  
  215.   
  216.         canAdd = true;  
  217.         for (Node n : existNodes) {  
  218.             // 若是節點列表中已經含有節點,則算添加失敗  
  219.             if (n.isEqual(node)) {  
  220.                 canAdd = false;  
  221.                 break;  
  222.             }  
  223.         }  
  224.   
  225.         if (canAdd) {  
  226.             existNodes.add(node);  
  227.         }  
  228.   
  229.         return canAdd;  
  230.     }  
  231.   
  232.     /** 
  233.      * 計算節點條件機率 
  234.      *  
  235.      * @param node 
  236.      *            關於node的後驗機率 
  237.      * @param queryParam 
  238.      *            查詢的屬性參數 
  239.      * @return 
  240.      */  
  241.     private double calConditionPro(Node node, HashMap<String, String> queryParam) {  
  242.         int id;  
  243.         double pro;  
  244.         String value;  
  245.         String[] attrValue;  
  246.   
  247.         ArrayList<String[]> priorAttrInfos;  
  248.         ArrayList<String[]> backAttrInfos;  
  249.         ArrayList<Node> parentNodes;  
  250.   
  251.         pro = 1;  
  252.         id = node.id;  
  253.         parentNodes = new ArrayList<>();  
  254.         priorAttrInfos = new ArrayList<>();  
  255.         backAttrInfos = new ArrayList<>();  
  256.   
  257.         for (int i = 0; i < this.edges.length; i++) {  
  258.             // 尋找父節點id  
  259.             if (this.edges[i][id] == 1) {  
  260.                 for (Node temp : this.totalNodes) {  
  261.                     // 尋找目標節點id  
  262.                     if (temp.id == i) {  
  263.                         parentNodes.add(temp);  
  264.                         break;  
  265.                     }  
  266.                 }  
  267.             }  
  268.         }  
  269.   
  270.         // 獲取先驗屬性的屬性值,首先添加先驗屬性  
  271.         value = queryParam.get(node.name);  
  272.         attrValue = new String[2];  
  273.         attrValue[0] = node.name;  
  274.         attrValue[1] = value;  
  275.         priorAttrInfos.add(attrValue);  
  276.   
  277.         // 逐一添加後驗屬性  
  278.         for (Node p : parentNodes) {  
  279.             value = queryParam.get(p.name);  
  280.             attrValue = new String[2];  
  281.             attrValue[0] = p.name;  
  282.             attrValue[1] = value;  
  283.   
  284.             backAttrInfos.add(attrValue);  
  285.         }  
  286.   
  287.         pro = queryConditionPro(priorAttrInfos, backAttrInfos);  
  288.   
  289.         return pro;  
  290.     }  
  291.   
  292.     /** 
  293.      * 查詢條件機率 
  294.      *  
  295.      * @param attrValues 
  296.      *            條件屬性值 
  297.      * @return 
  298.      */  
  299.     private double queryConditionPro(ArrayList<String[]> priorValues,  
  300.             ArrayList<String[]> backValues) {  
  301.         // 判斷是否知足先驗屬性值條件  
  302.         boolean hasPrior;  
  303.         // 判斷是否知足後驗屬性值條件  
  304.         boolean hasBack;  
  305.         int attrIndex;  
  306.         double backPro;  
  307.         double totalPro;  
  308.         double pro;  
  309.         String[] tempData;  
  310.   
  311.         pro = 0;  
  312.         totalPro = 0;  
  313.         backPro = 0;  
  314.   
  315.         // 跳過第一行的屬性名稱行  
  316.         for (int i = 1; i < this.totalDatas.size(); i++) {  
  317.             tempData = this.totalDatas.get(i);  
  318.   
  319.             hasPrior = true;  
  320.             hasBack = true;  
  321.   
  322.             // 判斷是否知足先驗條件  
  323.             for (String[] array : priorValues) {  
  324.                 attrIndex = this.attr2Column.get(array[0]);  
  325.   
  326.                 // 判斷值是否知足條件  
  327.                 if (!tempData[attrIndex].equals(array[1])) {  
  328.                     hasPrior = false;  
  329.                     break;  
  330.                 }  
  331.             }  
  332.   
  333.             // 判斷是否知足後驗條件  
  334.             for (String[] array : backValues) {  
  335.                 attrIndex = this.attr2Column.get(array[0]);  
  336.   
  337.                 // 判斷值是否知足條件  
  338.                 if (!tempData[attrIndex].equals(array[1])) {  
  339.                     hasBack = false;  
  340.                     break;  
  341.                 }  
  342.             }  
  343.   
  344.             // 進行計數統計,分別計算知足後驗屬性的值和同時知足條件的個數  
  345.             if (hasBack) {  
  346.                 backPro++;  
  347.                 if (hasPrior) {  
  348.                     totalPro++;  
  349.                 }  
  350.             } else if (hasPrior && backValues.size() == 0) {  
  351.                 // 若是隻有先驗機率則爲純機率的計算  
  352.                 totalPro++;  
  353.                 backPro = 1.0;  
  354.             }  
  355.         }  
  356.   
  357.         if (backPro == 0) {  
  358.             pro = 0;  
  359.         } else {  
  360.             // 計算總的機率=都發生機率/只發生後驗條件的時間機率  
  361.             pro = totalPro / backPro;  
  362.         }  
  363.   
  364.         return pro;  
  365.     }  
  366.   
  367.     /** 
  368.      * 輸入查詢條件參數,計算髮生機率 
  369.      *  
  370.      * @param queryParam 
  371.      *            條件參數 
  372.      * @return 
  373.      */  
  374.     public double calHappenedPro(String queryParam) {  
  375.         double result;  
  376.         double temp;  
  377.         // 分類屬性值  
  378.         String classAttrValue;  
  379.         String[] array;  
  380.         String[] array2;  
  381.         HashMap<String, String> params;  
  382.   
  383.         result = 1;  
  384.         params = new HashMap<>();  
  385.   
  386.         // 進行查詢字符的參數分解  
  387.         array = queryParam.split(",");  
  388.         for (String s : array) {  
  389.             array2 = s.split("=");  
  390.             params.put(array2[0], array2[1]);  
  391.         }  
  392.   
  393.         classAttrValue = params.get(classAttrName);  
  394.         // 構建貝葉斯網絡結構  
  395.         constructBayesNetWork(classAttrValue);  
  396.   
  397.         for (Node n : this.totalNodes) {  
  398.             temp = calConditionPro(n, params);  
  399.   
  400.             // 爲了不出現條件機率爲0的現象,進行輕微矯正  
  401.             if (temp == 0) {  
  402.                 temp = 0.001;  
  403.             }  
  404.   
  405.             // 按照聯合機率公式,進行乘積運算  
  406.             result *= temp;  
  407.         }  
  408.   
  409.         return result;  
  410.     }  
  411.   
  412.     /** 
  413.      * 構建樹型貝葉斯網絡結構 
  414.      *  
  415.      * @param value 
  416.      *            類別量值 
  417.      */  
  418.     private void constructBayesNetWork(String value) {  
  419.         Node rootNode;  
  420.         ArrayList<AttrMutualInfo> mInfoArray;  
  421.         // 互信息度對  
  422.         ArrayList<Node[]> iArray;  
  423.   
  424.         iArray = null;  
  425.         rootNode = null;  
  426.   
  427.         // 在每次從新構建貝葉斯網絡結構的時候,清空原有的鏈接結構  
  428.         for (Node n : this.totalNodes) {  
  429.             n.connectedNodes.clear();  
  430.         }  
  431.         this.edges = new int[attrNum][attrNum];  
  432.   
  433.         // 從互信息對象中取出屬性值對  
  434.         iArray = new ArrayList<>();  
  435.         mInfoArray = calAttrMutualInfoArray(value);  
  436.         for (AttrMutualInfo v : mInfoArray) {  
  437.             iArray.add(v.nodeArray);  
  438.         }  
  439.   
  440.         // 構建最大權重跨度樹  
  441.         rootNode = constructWeightTree(iArray);  
  442.         // 爲無向圖肯定邊的方向  
  443.         confirmGraphDirection(rootNode);  
  444.         // 爲每一個屬性節點添加分類屬性父節點  
  445.         addParentNode();  
  446.     }  
  447.   
  448.     /** 
  449.      * 給定分類變量值,計算屬性之間的互信息值 
  450.      *  
  451.      * @param value 
  452.      *            分類變量值 
  453.      * @return 
  454.      */  
  455.     private ArrayList<AttrMutualInfo> calAttrMutualInfoArray(String value) {  
  456.         double iValue;  
  457.         Node node1;  
  458.         Node node2;  
  459.         AttrMutualInfo mInfo;  
  460.         ArrayList<AttrMutualInfo> mInfoArray;  
  461.   
  462.         mInfoArray = new ArrayList<>();  
  463.   
  464.         for (int i = 0; i < this.totalNodes.size() - 1; i++) {  
  465.             node1 = this.totalNodes.get(i);  
  466.             // 跳過度類屬性節點  
  467.             if (node1.id == 0) {  
  468.                 continue;  
  469.             }  
  470.   
  471.             for (int j = i + 1; j < this.totalNodes.size(); j++) {  
  472.                 node2 = this.totalNodes.get(j);  
  473.                 // 跳過度類屬性節點  
  474.                 if (node2.id == 0) {  
  475.                     continue;  
  476.                 }  
  477.   
  478.                 // 計算2個屬性節點之間的互信息值  
  479.                 iValue = calMutualInfoValue(node1, node2, value);  
  480.                 mInfo = new AttrMutualInfo(iValue, node1, node2);  
  481.                 mInfoArray.add(mInfo);  
  482.             }  
  483.         }  
  484.   
  485.         // 將結果進行降序排列,讓互信息值高的優先用於構建樹  
  486.         Collections.sort(mInfoArray);  
  487.   
  488.         return mInfoArray;  
  489.     }  
  490.   
  491.     /** 
  492.      * 計算2個屬性節點的互信息值 
  493.      *  
  494.      * @param node1 
  495.      *            節點1 
  496.      * @param node2 
  497.      *            節點2 
  498.      * @param vlaue 
  499.      *            分類變量值 
  500.      */  
  501.     private double calMutualInfoValue(Node node1, Node node2, String value) {  
  502.         double iValue;  
  503.         double temp;  
  504.         // 三種不一樣條件的後驗機率  
  505.         double pXiXj;  
  506.         double pXi;  
  507.         double pXj;  
  508.         String[] array1;  
  509.         String[] array2;  
  510.         ArrayList<String> attrValues1;  
  511.         ArrayList<String> attrValues2;  
  512.         ArrayList<String[]> priorValues;  
  513.         // 後驗機率,在這裏就是類變量值  
  514.         ArrayList<String[]> backValues;  
  515.   
  516.         array1 = new String[2];  
  517.         array2 = new String[2];  
  518.         priorValues = new ArrayList<>();  
  519.         backValues = new ArrayList<>();  
  520.   
  521.         iValue = 0;  
  522.         array1[0] = classAttrName;  
  523.         array1[1] = value;  
  524.         // 後驗屬性都是類屬性  
  525.         backValues.add(array1);  
  526.   
  527.         // 獲取節點屬性的屬性值集合  
  528.         attrValues1 = this.attr2Values.get(node1.name);  
  529.         attrValues2 = this.attr2Values.get(node2.name);  
  530.   
  531.         for (String v1 : attrValues1) {  
  532.             for (String v2 : attrValues2) {  
  533.                 priorValues.clear();  
  534.   
  535.                 array1 = new String[2];  
  536.                 array1[0] = node1.name;  
  537.                 array1[1] = v1;  
  538.                 priorValues.add(array1);  
  539.   
  540.                 array2 = new String[2];  
  541.                 array2[0] = node2.name;  
  542.                 array2[1] = v2;  
  543.                 priorValues.add(array2);  
  544.   
  545.                 // 計算3種條件下的機率  
  546.                 pXiXj = queryConditionPro(priorValues, backValues);  
  547.   
  548.                 priorValues.clear();  
  549.                 priorValues.add(array1);  
  550.                 pXi = queryConditionPro(priorValues, backValues);  
  551.   
  552.                 priorValues.clear();  
  553.                 priorValues.add(array2);  
  554.                 pXj = queryConditionPro(priorValues, backValues);  
  555.   
  556.                 // 若是出現其中一個計數機率爲0,則直接賦值爲0處理  
  557.                 if (pXiXj == 0 || pXi == 0 || pXj == 0) {  
  558.                     temp = 0;  
  559.                 } else {  
  560.                     // 利用公式計算針對此屬性值對組合的機率  
  561.                     temp = pXiXj * Math.log(pXiXj / (pXi * pXj)) / Math.log(2);  
  562.                 }  
  563.   
  564.                 // 進行和屬性值對組合的累加即爲整個屬性的互信息值  
  565.                 iValue += temp;  
  566.             }  
  567.         }  
  568.   
  569.         return iValue;  
  570.     }  
  571. }  

場景測試類client.java:

 

 

[java] view plain copy

 print?

  1. package DataMining_TAN;  
  2.   
  3. /** 
  4.  * TAN樹型樸素貝葉斯算法 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Client {  
  10.     public static void main(String[] args) {  
  11.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";  
  12.         // 條件查詢語句  
  13.         String queryStr;  
  14.         // 分類結果機率1  
  15.         double classResult1;  
  16.         // 分類結果機率2  
  17.         double classResult2;  
  18.   
  19.         TANTool tool = new TANTool(filePath);  
  20.         queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=No";  
  21.         classResult1 = tool.calHappenedPro(queryStr);  
  22.   
  23.         queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=Yes";  
  24.         classResult2 = tool.calHappenedPro(queryStr);  
  25.   
  26.         System.out.println(String.format("類別爲%s所求得的機率爲%s", "PlayTennis=No",  
  27.                 classResult1));  
  28.         System.out.println(String.format("類別爲%s所求得的機率爲%s", "PlayTennis=Yes",  
  29.                 classResult2));  
  30.         if (classResult1 > classResult2) {  
  31.             System.out.println("分類類別爲PlayTennis=No");  
  32.         } else {  
  33.             System.out.println("分類類別爲PlayTennis=Yes");  
  34.         }  
  35.     }  
  36. }  

結果輸出:

 

 

[java] view plain copy

 print?

  1. 類別爲PlayTennis=No所求得的機率爲0.09523809523809525  
  2. 類別爲PlayTennis=Yes所求得的機率爲3.571428571428571E-5  
  3. 分類類別爲PlayTennis=No  
相關文章
相關標籤/搜索