java 實現DBScan聚類算法

  最近有一個需求,在地圖上,將客戶按照距離進行聚合。好比,a客戶到b客戶5km,b客戶到c客戶5km,那麼abc就能夠聚合成一個集合。首先想到的就是找一個根據座標來聚合的算法,這裏找了一些後來選擇了較爲簡單也符合要求的DBScan聚類算法。java

  它是一種基於密度的聚類算法,簡單來講就是根據樣本的緊密程度和數量將其分紅多個集合。這個樣本通常來講是一堆座標點。參數能夠爲歐式距離鄰域密度閾值(就是每次尋找相鄰的點的最低數量)。最終返回多個樣本集合。算法

 

2.java實現

  座標點:這個類若是測試的話,只用到裏面的point座標點這個屬性apache

import java.util.Collection;
import org.apache.commons.math.stat.clustering.Clusterable;
import org.apache.commons.math.util.MathUtils;

import bsh.This;

/**
 * @author xjx
 *
 */
public class CustomerPoint implements Clusterable<CustomerPoint>{

    
    private String sender;
    private String sender_addr;
    private int value;
    private final double[] point;

    
    public int getValue() {
        return value;
    }
    public void setValue(int value) {
        this.value = value;
    }
    public String getSender() {
        return sender;
    }
    public void setSender(String sender) {
        this.sender = sender;
    }
    public String getSender_addr() {
        return sender_addr;
    }
    public void setSender_addr(String sender_addr) {
        this.sender_addr = sender_addr;
    }

    public CustomerPoint(final double[] point) {
        this.point = point;
    }

    public double[] getPoint() {
        return point;
    }

    public double distanceFrom(final CustomerPoint p) {
        return MathUtils.distance(point, p.getPoint());
    }

    public CustomerPoint centroidOf(final Collection<CustomerPoint> points) {
        double[] centroid = new double[getPoint().length];
        for (CustomerPoint p : points) {
            for (int i = 0; i < centroid.length; i++) {
                centroid[i] += p.getPoint()[i];
            }
        }
        for (int i = 0; i < centroid.length; i++) {
            centroid[i] /= points.size();
        }
        return new CustomerPoint(centroid);
    }

    @Override
    public boolean equals(final Object other) {
        if (!(other instanceof CustomerPoint)) {
            return false;
        }
        final double[] otherPoint = ((CustomerPoint) other).getPoint();
        if (point.length != otherPoint.length) {
            return false;
        }
        for (int i = 0; i < point.length; i++) {
            if (point[i] != otherPoint[i]) {
                return false;
            }
        }
        return true;
    }
    @Override
    public String toString() {
        final StringBuffer buff = new StringBuffer("{");
        final double[] coordinates = getPoint();
        buff.append("lat:"+coordinates[0]+",");
        buff.append("lng:"+coordinates[1]+",");
        buff.append("value:"+this.getValue());
        buff.append("}");
        return buff.toString();
    }
}

2.算法實現和測試:app

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.math3.util.MathUtils;
import com.tongdatech.znzw.domain.upHandle.CustomerPoint;
/**
 * 
 * @author xjx
 *
 */
public class DBScanTest3{
    //歐式距離
    private final double distance;
    //最低要求的尋找鄰居數量
    private final int minPoints;
    
    private final Map<CustomerPoint, PointStatus> visited = new HashMap<CustomerPoint, PointStatus>();
    //點的標記,point:聚合內的點,noise:噪音點
    private enum PointStatus {
        NOISE,POINT
    }


    public DBScanTest3(final double distance, final int minPoints)
        throws Exception {
        if (distance < 0.0d) {
            throw new Exception("距離小於0");
        }
        if (minPoints < 0) {
            throw new Exception("點數小於0");
        }
        this.distance = distance;
        this.minPoints = minPoints;
    }

    public double getDistance() {
        return distance;
    }

    public int getMinPoints() {
        return minPoints;
    }
    
    public Map<CustomerPoint, PointStatus> getVisited() {
        return visited;
    }
    /**
     * 返回customerPoint的多個聚合
     * @param points
     * @return
     */
    public List<List<CustomerPoint>> cluster(List<CustomerPoint> points){

        final List<List<CustomerPoint>> clusters = new ArrayList<List<CustomerPoint>>();
                
        for (CustomerPoint point : points) {
        //若是已經被標記
if (visited.get(point) != null) { continue; } List<CustomerPoint> neighbors = getNeighbors(point, points); if (neighbors.size() >= minPoints) { visited.put(point, PointStatus.POINT); List<CustomerPoint> cluster = new ArrayList<CustomerPoint>();           //遍歷全部鄰居繼續拓展找點 clusters.add(expandCluster(cluster, point, neighbors, points, visited)); } else { visited.put(point, PointStatus.NOISE); } } return clusters; } private List<CustomerPoint> expandCluster( List<CustomerPoint> cluster, CustomerPoint point, List<CustomerPoint> neighbors, List<CustomerPoint> points, Map<CustomerPoint, PointStatus> visited) { cluster.add(point); visited.put(point, PointStatus.POINT); int index = 0; //遍歷 全部的鄰居 while (index < neighbors.size()) { //移動當前的點 CustomerPoint current = neighbors.get(index); PointStatus pStatus = visited.get(current); if (pStatus == null) { List<CustomerPoint> currentNeighbors = getNeighbors(current, points); neighbors.addAll(currentNeighbors); }
          //若是該點未被標記,將點進行標記並加入到集合中
if (pStatus != PointStatus.POINT) { visited.put(current, PointStatus.POINT); cluster.add(current); } index++; } return cluster; } //找到全部的鄰居 private List<CustomerPoint> getNeighbors(CustomerPoint point,List<CustomerPoint> points) { List<CustomerPoint> neighbors = new ArrayList<CustomerPoint>(); for (CustomerPoint neighbor : points) { if (visited.get(neighbor) != null) { continue; } if (point != neighbor && neighbor.distanceFrom(point) <= distance) { neighbors.add(neighbor); } } return neighbors; }
  //作數據進行測試
public static void main(String[] args) throws Exception { CustomerPoint customerPoint = new CustomerPoint(new double[] {3,8}); CustomerPoint customerPoint1 = new CustomerPoint(new double[] {4,7}); CustomerPoint customerPoint2 = new CustomerPoint(new double[] {4,8}); CustomerPoint customerPoint3 = new CustomerPoint(new double[] {5,6}); CustomerPoint customerPoint4 = new CustomerPoint(new double[] {3,9}); CustomerPoint customerPoint5 = new CustomerPoint(new double[] {5,1}); CustomerPoint customerPoint6 = new CustomerPoint(new double[] {5,2}); CustomerPoint customerPoint7 = new CustomerPoint(new double[] {6,3}); CustomerPoint customerPoint8 = new CustomerPoint(new double[] {7,3}); CustomerPoint customerPoint9 = new CustomerPoint(new double[] {7,4}); CustomerPoint customerPoint10 = new CustomerPoint(new double[] {0,2}); CustomerPoint customerPoint11 = new CustomerPoint(new double[] {8,16}); CustomerPoint customerPoint12 = new CustomerPoint(new double[] {1,1}); CustomerPoint customerPoint13 = new CustomerPoint(new double[] {1,3}); List<CustomerPoint> cs = new ArrayList<>(); cs.add(customerPoint13); cs.add(customerPoint12); cs.add(customerPoint11); cs.add(customerPoint10); cs.add(customerPoint9); cs.add(customerPoint8); cs.add(customerPoint7); cs.add(customerPoint6); cs.add(customerPoint5); cs.add(customerPoint4); cs.add(customerPoint3); cs.add(customerPoint2); cs.add(customerPoint1); cs.add(customerPoint);
    //這裏第一個參數爲距離,第二個參數爲最小鄰居數量 DBScanTest3 db
= new DBScanTest3(1.5, 1);
    //返回結果並打印 List
<List<CustomerPoint>> aa =db.cluster(cs); for(int i =0;i<aa.size();i++) { for(int j=0;j<aa.get(i).size();j++) { System.out.print(aa.get(i).get(j).toString()); } System.out.println(); } } }

結果打印:dom

{lat:1.0,lng:3.0,value:0}{lat:0.0,lng:2.0,value:0}{lat:1.0,lng:1.0,value:0}
{lat:7.0,lng:4.0,value:0}{lat:7.0,lng:3.0,value:0}{lat:6.0,lng:3.0,value:0}{lat:5.0,lng:2.0,value:0}{lat:5.0,lng:1.0,value:0}
{lat:3.0,lng:9.0,value:0}{lat:4.0,lng:8.0,value:0}{lat:3.0,lng:8.0,value:0}{lat:4.0,lng:7.0,value:0}{lat:5.0,lng:6.0,value:0}

這裏返回3個集合,其他的爲噪音點,讀者能夠將這些座標點畫在網格圖上,能夠看到它們分爲3部分,每一部分的點距離都小於1.5。ide

相關文章
相關標籤/搜索