Apriori算法代碼實現(Java)

package apriori;java

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;數組

public class AprioriAlgorithm {
    private int minSup; //最小支持度
    private static List<String>    data;
    private static List<Set<String>> dataSet;
    
    public static void main(String[] args){
        
        long startTime = System.currentTimeMillis();
        AprioriAlgorithm apriori = new AprioriAlgorithm();
        
        //使用書中的測試集
        data = apriori.buildData();
        //設置最小支持度
        apriori.setMinSup(2);
        //構造數據集
        data = apriori.buildData();
        //構建頻繁1項集
        List<Set<String>> f1Set = apriori.findF1Item(data);
        apriori.printSet(f1Set, 1);
        List<Set<String>> result = f1Set;
        
        int i = 2;
        do{
            result = apriori.aripriGen(result);
            apriori.printSet(result, i);
            i++;
        }while(result.size() != 0);
        long endTime = System.currentTimeMillis();
        System.out.println("共用時:" +(endTime - startTime) + "ms");
    }
    public void setMinSup(int minSup){
        this.minSup = minSup;
    }
    /**
     * 構造原始數據集,能夠爲之提供參數
     * 若是不提供參數,將按程序默認構造的數據集
     * 若是提供參數爲文件名,則使用文件中的數據集
     * 
     */
    List<String> buildData(String...fileName){
        List<String> data = new ArrayList<String>();
        if(fileName.length != 0){
            File file = new File(fileName[0]);
            try{
                BufferedReader reader = new BufferedReader(new FileReader(file));
                String line;
                while((line = reader.readLine()) != null){
                    data.add(line);
                }
            }catch(FileNotFoundException e){
                e.printStackTrace();
            }catch(IOException e){
                e.printStackTrace();
            }
        }else{
            data.add("I1 I2 I5");
            data.add("I2 I4");
            data.add("I2 I3");
            data.add("I1 I2 I4");
            data.add("I1 I3");
            data.add("I2 I3");
            data.add("I1 I3");
            data.add("I1 I2 I3 I5");
            data.add("I1 I2 I3");
        }
        dataSet = new ArrayList<Set<String>>();
        Set<String> dSet;
        for(String d:data){
            dSet = new TreeSet<String>();
            String[] dArr = d.split(" ");
            for(String str:dArr){
                dSet.add(str);
            }
            dataSet.add(dSet);
        }
        return data;    
    }
    /**
     * 找出候選1項集
     * @param data
     * @return result
     *
     */
    List<Set<String>> findF1Item(List<String> data){
        List<Set<String>> result = new ArrayList<Set<String>>();
        Map<String, Integer> dc = new HashMap<String,Integer>();
        for(String d:data){
            String[] items = d.split(" ");
            for(String item:items){
                if(dc.containsKey(item)){
                    dc.put(item, dc.get(item) + 1);
                }else{
                    dc.put(item, 1);
                }
            }
        }
        Set<String> itemKeys = dc.keySet();
        Set<String> tempKeys = new TreeSet<String>();
        for(String str:itemKeys){
            tempKeys.add(str);
            
        }
        for(String item:tempKeys){
            if(dc.get(item) >= minSup){
                Set<String> f1Set = new TreeSet<String>();
                f1Set.add(item);
                result.add(f1Set);
            }
        }
        return result;
    }
    
    /*
     * 利用arioriGen 方法由k - 1項集生成k項集
     * 
     */
    List<Set<String>> aripriGen(List<Set<String>> preSet){
        List<Set<String>> result = new ArrayList<Set<String>>();
        int preSetSize = preSet.size();
        
        for(int i = 0;i < preSetSize - 1;i++){
            for(int j = i + 1;j < preSetSize;j++){
                String[] strA1 = preSet.get(i).toArray(new String[0]);
                String[] strA2 = preSet.get(j).toArray(new String[0]);
                if(isCanLink(strA1,strA2)){
                    Set<String> set = new TreeSet<String>();
                    for(String str:strA1){
                        set.add(str);
                    }
                     set.add((String) strA2[strA2.length-1]);//鏈接成K項集
                     //判斷K項集是否須要剪切掉,若是不須要被cut掉,則加入到k項集的列表中
                     if(!isNeedCut(preSet, set)) {
                      result.add(set);
                     }   
                }
            }
        }
        return checkSupport(result);
    }
     List<Set<String>> checkSupport(List<Set<String> > setList){
          
          List<Set<String>> result = new ArrayList<Set<String>>();
          boolean flag = true;
          int [] counter = new int[setList.size()];
          for(int i = 0; i < setList.size(); i++){
           
           for(Set<String> dSets : dataSet) {
            if(setList.get(i).size() > dSets.size()){
             flag = true;
            }else{
             for(String str : setList.get(i)){
              if(!dSets.contains(str)){
               flag = false;
               break;
              }
             }
             if(flag) {
              counter[i] += 1;
             } else{
              flag = true;
             }
            }
           }
          }
          
          for(int i = 0; i < setList.size(); i++){
           if (counter[i] >= minSup) {
            result.add(setList.get(i));
           }
          }
          return result;
         }
         
         /**
          * 判斷兩個項集可否執行鏈接操做
          * @param s1
          * @param s2
          * @return
          */
         boolean isCanLink(String [] s1, String[] s2){
          boolean flag = true;
          if(s1.length == s2.length) {
           for(int i = 0; i < s1.length - 1; i ++){
            if(!s1[i].equals(s2[i])){
             flag = false;
             break;
            }
           }
           if(s1[s1.length - 1].equals(s2[s2.length - 1])){
            flag = false;
           }
          }else{
           flag = true;
          }
          return flag;
         }
         
         /**
          * 判斷set是否須要被cut
          * 
          * @param setList
          * @param set
          * @return
          */
         boolean isNeedCut(List<Set<String>> setList, Set<String> set) {//setList指頻繁K-1項集,set指候選K項集
          boolean flag = false;
          List<Set<String>> subSets = getSubset(set);//得到K項集的全部k-1項集
          for ( Set<String> subSet : subSets) {
           //判斷當前的k-1項集set是否在頻繁k-1項集中出現,若是出現,則不須要cut
             //若沒有出現,則須要被cut
           if( !isContained(setList, subSet)){
            flag = true;
            break;
           }
          }
          return flag;
         }
         /**
          * 功能:判斷k項集的某k-1項集是否包含在頻繁k-1項集列表中
          * 
          * @param setList
          * @param set
          * @return
          */
         boolean isContained(List<Set<String>> setList, Set<String> set){
          boolean flag = false;
          int position = 0;
          for( Set<String> s : setList  ) {
           String [] sArr = s.toArray(new String[0]);
           String [] setArr = set.toArray(new String[0]);
           for(int i = 0; i < sArr.length; i++) {
            if ( sArr[i].equals(setArr[i])){
             //若是對應位置的元素相同,則position爲當前位置的值
             position = i;
            } else{
             break;
            }
           }
           //若是position等於數組的長度,說明已經找到某個setList中的集合與
           //set集合相同了,退出循環,返回包含
           //不然,把position置爲0進入下一個比較
           if ( position == sArr.length - 1) {
            flag = true;
            break;
           } else {
            flag = false;
            position = 0;
           }
          }
          return flag;
         }
         
         /**
          * 得到k項集的全部k-1項子集
          * 
          * @param set
          * @return
          */
         List<Set<String>> getSubset(Set <String> set){
          
          List<Set<String>> result = new ArrayList<Set<String>>();
          String [] setArr = set.toArray(new String[0]);
          
          for( int i = 0; i < setArr.length; i++){
           Set<String> subSet = new TreeSet<String>();
           for(int j = 0; j < setArr.length; j++){
            if( i != j){
             subSet.add((String) setArr[j]);
            }
           }
           result.add(subSet);
          }
          return result;
         }
         /**
          * 功能:打印頻繁項集
          */
         void printSet(List<Set<String>> setList, int i){
          System.out.print("頻繁" + i + "項集: 共" + setList.size() + "項: {");
          for(Set<String> set : setList) {
           System.out.print("[");
           for(String str : set) {
            System.out.print(str + " ");
           }
           System.out.print("], ");
          }
          System.out.println("}");
         }
    }
    測試

相關文章
相關標籤/搜索