使用模擬退火算法優化 Hash 函數

背景

現有個處理股票行情消息的系統,其架構以下:算法

因爲數據量巨大,系統中啓動了 15 個線程來消費行情消息。消息分配的策略較爲簡單:對 symbol 的 hashCode 取模,將消息分配給其中一個線程進行處理。 通過驗證,每一個線程分配到的 symbol 數量較爲均勻,因而系統愉快地上線了。編程

運行一段時間後,忽然收到了系統的告警,但此時並不是消息峯值時間段。通過排查後,發現問題出如今 hash 函數上:數組

雖然每一個線程被分配到的 symbol 數量較爲均衡,可是部分熱門 symbol 的報價消息量會更多,若是熱門 symbol 集中到特定線程上,就會形成線程負載不均衡,使得系統總體的吞吐量大打折扣。架構

爲提升系統的吞吐量,有必要消息分發邏輯進行一些改造,避免出現熱點線程。爲此,系統須要記錄下某天內每一個 symbol 的消息量,而後在次日使用這些數據,對分發邏輯進行調整。具體的改造的方案能夠分爲兩種:併發

  • 放棄使用 hash 函數
  • 對 hash 函數進行優化

放棄 hash 函數

問題能夠抽象爲:框架

將 5000 個非負整數分配至 15 個桶(bucket)中,並儘量保證每一個桶中的元素之和接近(每一個桶中的元素個數無限制)。less

每一個整數元素可能的放置方法有 15 種,這個問題總共可能的解有 155000種,暴力求解的可能性微乎其微。做爲工程問題,最優解不是必要的,能夠退而求其次尋找一個可接受的次優解:dom

  1. 根據全部 symbol 的消息總數計算一個指望的分佈均值(expectation)
  2. 將每一個 symbol 的消息數按照 symbol 的順序進行排列,最後將這組數組劃分爲 15 個區間,而且儘量使得每一個區間元素之和與 expection 接近。
  3. 使用一個有序查找表記錄每一個區間的首個 symbol,後續就能夠按照這個表對數據進行劃分。
public class FindBestDistribution {

    static final int NUM_OF_SYMBOLS = 5000;
    static final int NUM_OF_BUCKETS = 15;

    public static void main(String[] args) {
        // 生成樣本
        IntStream ints = ThreadLocalRandom.current().ints(0, 1000);
        PrimitiveIterator.OfInt iterator = ints.iterator();
        
        Map<String,Integer> symbolAndCount = new TreeMap<>();
        for (int i=0; i<NUM_OF_SYMBOLS; i++) {
            symbolAndCount.put(Integer.toHexString(i).toUpperCase(), iterator.next());
        }

        // 按照 symbol 劃分每一個桶的數量
        TreeMap<String, Integer> distribution = findBestDistribution(symbolAndCount);

        // 測試效果
        int[] buckets = new int[NUM_OF_BUCKETS];
        for (Map.Entry<String, Integer> entry : symbolAndCount.entrySet()) {
            Map.Entry<String, Integer> floor = distribution.floorEntry(entry.getKey());
            int bucketIndex = floor == null ? 0 : floor.getValue();
            buckets[bucketIndex] += entry.getValue();
        }

        System.out.printf("buckets: %s\n", Arrays.toString(buckets));
    }

    public static TreeMap<String, Integer> findBestDistribution(Map<String,Integer> symbolAndCount) {

        // 每一個桶均勻分佈的狀況(最優狀況)
        int avg = symbolAndCount.values().stream().mapToInt(Integer::intValue).sum() / NUM_OF_BUCKETS;

        // 嘗試將 symbol 放入不一樣的桶
        int bucketIdx = 0;
        int[] buckets = new int[NUM_OF_BUCKETS];
        String[] bulkheads = new String[NUM_OF_BUCKETS-1];
        for (Map.Entry<String, Integer> entry : symbolAndCount.entrySet()) {

            // 若是首個 symbol 數據量過大,則分配給其一個獨立的桶
            int count = entry.getValue();
            if (count / 2 > avg && bucketIdx == 0 && buckets[0] == 0) {
                buckets[bucketIdx] += count;
                continue;
            }

            // 評估將 symbol 放入桶後的效果
            // 1. 若是桶中的數量更接近指望,則將其放入當前桶中
            // 2. 若是桶中的數量更遠離指望,則將其放入下個桶中
            double before = Math.abs(buckets[bucketIdx] - avg);
            double after = Math.abs(buckets[bucketIdx] + count - avg);
            if (after > before && bucketIdx < buckets.length - 1) {
                bulkheads[bucketIdx++] = entry.getKey();
            }

            buckets[bucketIdx] += count;
        }

        System.out.printf("expectation: %d\n", avg);
        System.out.printf("bulkheads: %s\n", Arrays.toString(bulkheads));

        TreeMap<String,Integer> distribution = new TreeMap<>();
        for (int i=0; i<bulkheads.length; i++) {
            distribution.put(bulkheads[i], i+1);
        }
        return distribution;
    }
}

該方法存在的問題:ide

  • 分配策略並非最優解,且沒法對其分片效果進行直觀的評估。
  • 當區間數量較多時,查找表自己可能成爲一個潛在的性能瓶頸。
  • 可能的組合受到 key 的順序限制,極大地限制了可能的解空間。

優化 hash 函數

換個角度來看,形成分佈不均勻的緣由不是數據,而是 hash 函數自己。函數

項目中使用的 hash 函數是 JDK String 中的原生實現。通過查閱資料,發現該實現實際上是 BKDRHash 的 seed = 31 的特殊狀況。這樣意味着:經過調整 seed 的值,能夠改變 hash 函數的特性並使其適配特定的數據分佈

int BKDRHash(char[] value, int seed) {
    int hash = 0;
    for (int i = 0; i < value.length; i++) {
        hash = hash * seed + value[i];
    }
    return hash & 0x7fffffff;
}

那麼問題來了,應該如何評估某個 seed 的分佈的優劣?

評價函數

一種可行的方法是計算每一個 seed 對應的 bucket 分佈的標準差,標準差越小則分佈越均勻,則該 seed 越優。

然而這一作法只考慮了每一個 bucket 與均值之間的偏差,沒法量化不一樣 bucket 之間的偏差。爲了可以直觀的量化 bucket 之間分佈差別的狀況,考慮使用下面的評估函數:

double calculateDivergence(long[] bucket, long expectation) {
    long divergence = 0;
    for (int i=0; i<bucket.length; i++) {
        final long a = bucket[i];
        final long b = (a - expectation) * (a - expectation);
        for (int j=i+1; j<bucket.length; j++) {
            long c = (a - bucket[j]) * (a - bucket[j]);
            divergence += Math.max(b, c);
        }
    }
    return divergence; // the less the better
}

該數值越小,則證實 seed 對應的分佈越均勻,其對應的 hash 函數越優。

訓練策略

seed 是一個 32bit 的無符號整數,其取值範圍爲 0 ~ 232-1。在 5000 個 symbol 的狀況下,單線程嘗試遍歷全部 seed 的時間約爲 25 小時。

一般狀況下 symbol 的數量會超過 5000,所以實際的搜索時間會大於這個值。此外,受限於計算資源限制,沒法進行大規模的並行搜索,所以窮舉法的耗時是不可接受的。

幸虧本例並不要求最優解,能夠引入啓發式搜索算法,加快訓練速度。因爲本人在這方面並不熟悉,爲了下降編程難度,最終選擇了模擬退火(simulated annealing)算法。它模擬固體退火過程的熱平衡問題與隨機搜索尋優問題的類似性來達到尋找全局最優或近似全局最優的目的。
相較於最簡單的登山法,模擬退火算法通以必定的機率接受較差的解,從而擴大搜索範圍,保證解近似最優。

/**
 * Basic framework of simulated annealing algorithm
 * @param <X> the solution of given problem
 */
public abstract class SimulatedAnnealing<X> {

    protected final int numberOfIterations;    // stopping condition for simulations

    protected final double coolingRate;        // the percentage by which we reduce the temperature of the system
    protected final double initialTemperature; // the starting energy of the system
    protected final double minimumTemperature; // optional stopping condition

    protected final long simulationTime;       // optional stopping condition
    protected final int detectionInterval;     // optional stopping condition

    protected SimulatedAnnealing(int numberOfIterations, double coolingRate) {
        this(numberOfIterations, coolingRate, 10000000, 1, 0, 0);
    }

    protected SimulatedAnnealing(int numberOfIterations, double coolingRate, double initialTemperature, double minimumTemperature, long simulationTime, int detectionInterval) {
        this.numberOfIterations = numberOfIterations;
        this.coolingRate = coolingRate;
        this.initialTemperature = initialTemperature;
        this.minimumTemperature = minimumTemperature;
        this.simulationTime = simulationTime;
        this.detectionInterval = detectionInterval;
    }

    protected abstract double score(X currentSolution);

    protected abstract X neighbourSolution(X currentSolution);

    public X simulateAnnealing(X currentSolution) {

        final long startTime = System.currentTimeMillis();

        // Initialize searching
        X bestSolution = currentSolution;
        double bestScore = score(bestSolution);
        double currentScore = bestScore;

        double t = initialTemperature;
        for (int i = 0; i < numberOfIterations; i++) {
            if (currentScore < bestScore) {
                // If the new solution is better, accept it unconditionally
                bestScore = currentScore;
                bestSolution = currentSolution;
            } else {
                // If the new solution is worse, calculate an acceptance probability for the worse solution
                // At high temperatures, the system is more likely to accept the solutions that are worse
                boolean rejectWorse = Math.exp((bestScore - currentScore) / t) < Math.random();
                if (rejectWorse || currentScore == bestScore) {
                    currentSolution = neighbourSolution(currentSolution);
                    currentScore = score(currentSolution);
                }
            }

            // Stop searching when the temperature is too low
            if ((t *= coolingRate) < minimumTemperature) {
                break;
            }

            // Stop searching when simulation time runs out
            if (simulationTime > 0 && (i+1) % detectionInterval == 0) {
                if (System.currentTimeMillis() - startTime > simulationTime)
                    break;
            }
        }

        return bestSolution;
    }
}
/**
 * Search best hash seed for given key distribution and number of buckets with simulated annealing algorithm
 */
@Data
public class SimulatedAnnealingHashing extends SimulatedAnnealing<HashingSolution> {

    private static final int DISTRIBUTION_BATCH = 100;
    static final int SEARCH_BATCH = 200;

    private final int[] hashCodes = new int[SEARCH_BATCH];
    private final long[][] buckets = new long[SEARCH_BATCH][];

    @Data
    public class HashingSolution {

        private final int begin, range; // the begin and range for searching
        private int bestSeed;     // the best seed found in this search
        private long bestScore;   // the score corresponding to bestSeed

        private long calculateDivergence(long[] bucket) {
            long divergence = 0;
            for (int i=0; i<bucket.length; i++) {
                final long a = bucket[i];
                final long b = (a - expectation) * (a - expectation);
                for (int j=i+1; j<bucket.length; j++) {
                    long c = (a - bucket[j]) * (a - bucket[j]);
                    divergence += Math.max(b, c);
                }
            }
            return divergence; // the less the better
        }

        private HashingSolution solve() {

            if (range != hashCodes.length) {
                throw new IllegalStateException();
            }

            for (int i=0; i<range; i++) {
                Arrays.fill(buckets[i], hashCodes[i] = 0);
            }

            for (KeyDistribution[] bucket : distributions) {
                for (KeyDistribution distribution : bucket) {
                    Hashing.BKDRHash(distribution.getKey(), begin, hashCodes);
                    for (int k = 0; k< hashCodes.length; k++) {
                        int n = hashCodes[k] % buckets[k].length;
                        buckets[k][n] += distribution.getCount();
                    }
                }
            }

            int best = -1;
            long bestScore = Integer.MAX_VALUE;
            for (int i = 0; i< buckets.length; i++) {
                long score = calculateDivergence(buckets[i]);
                if (i == 0 || score < bestScore) {
                    bestScore = score;
                    best = i;
                }
            }

            if (best < 0) {
                throw new IllegalStateException();
            }

            this.bestScore = bestScore;
            this.bestSeed = begin + best;
            return this;
        }

        @Override
        public String toString() {
            return String.format("(seed:%d, score:%d)", bestSeed, bestScore);
        }
    }

    private final KeyDistribution[][] distributions; // key and its count(2-dimensional array for better performance)
    private final long expectation;  // the expectation count of each bucket
    private final int searchOutset;
    private int searchMin, searchMax;

    /**
     * SimulatedAnnealingHashing Prototype
     * @param keyAndCounts keys for hashing and count for each key
     * @param numOfBuckets number of buckets
     */
    public SimulatedAnnealingHashing(Map<String, Integer> keyAndCounts, int numOfBuckets) {
        super(100000000, .9999);
        distributions = buildDistribution(keyAndCounts);
        long sum = 0;
        for (KeyDistribution[] batch : distributions) {
            for (KeyDistribution distribution : batch) {
                sum += distribution.getCount();
            }
        }
        this.expectation = sum / numOfBuckets;
        this.searchOutset = 0;
        for (int i = 0; i< buckets.length; i++) {
            buckets[i] = new long[numOfBuckets];
        }
    }

    /**
     * SimulatedAnnealingHashing Derivative
     * @param prototype prototype simulation
     * @param searchOutset the outset for searching
     * @param simulationTime the expect time consuming for simulation
     */
    private SimulatedAnnealingHashing(SimulatedAnnealingHashing prototype, int searchOutset, long simulationTime) {
        super(prototype.numberOfIterations, prototype.coolingRate, prototype.initialTemperature, prototype.minimumTemperature,
                simulationTime, 10000);
        distributions = prototype.distributions;
        expectation = prototype.expectation;
        for (int i = 0; i< buckets.length; i++) {
            buckets[i] = new long[prototype.buckets[i].length];
        }
        this.searchOutset = searchOutset;
        this.searchMax = searchMin = searchOutset;
    }

    @Override
    public String toString() {
        return String.format("expectation: %d, outset:%d, search(min:%d, max:%d)", expectation, searchOutset, searchMin, searchMax);
    }

    private KeyDistribution[][] buildDistribution(Map<String, Integer> symbolCounts) {
        int bucketNum = symbolCounts.size() / DISTRIBUTION_BATCH + Integer.signum(symbolCounts.size() % DISTRIBUTION_BATCH);
        KeyDistribution[][] distributions = new KeyDistribution[bucketNum][];

        int bucketIndex = 0;
        List<KeyDistribution> batch = new ArrayList<>(DISTRIBUTION_BATCH);
        for (Map.Entry<String, Integer> entry : symbolCounts.entrySet()) {
            batch.add(new KeyDistribution(entry.getKey().toCharArray(), entry.getValue()));
            if (batch.size() == DISTRIBUTION_BATCH) {
                distributions[bucketIndex++] = batch.toArray(new KeyDistribution[0]);
                batch.clear();
            }
        }
        if (batch.size() > 0) {
            distributions[bucketIndex] = batch.toArray(new KeyDistribution[0]);
            batch.clear();
        }
        return distributions;
    }

    @Override
    protected double score(HashingSolution currentSolution) {
        return currentSolution.solve().bestScore;
    }

    @Override
    protected HashingSolution neighbourSolution(HashingSolution currentSolution) {
        // The default range of neighbourhood is [-100, 100]
        int rand = ThreadLocalRandom.current().nextInt(-100, 101);
        int next = currentSolution.begin + rand;
        searchMin = Math.min(next, searchMin);
        searchMax = Math.max(next, searchMax);
        return new HashingSolution(next, currentSolution.range);
    }

    public HashingSolution solve() {
        searchMin = searchMax = searchOutset;
        HashingSolution initialSolution = new HashingSolution(searchOutset, SEARCH_BATCH);
        return simulateAnnealing(initialSolution);
    }

    public SimulatedAnnealingHashing derive(int searchOutset, long simulationTime) {
        return new SimulatedAnnealingHashing(this, searchOutset, simulationTime);
    }
}

ForkJoin 框架

爲了達到更好的搜索效果,能夠將整個搜索區域遞歸地劃分爲兩兩相鄰的區域,而後在這些區域上執行併發的搜索,並遞歸地合併相鄰區域的搜索結果。

使用 JDK 提供的 ForkJoinPool 與 RecursiveTask 能很好地完成以上任務。

@Data
@Slf4j
public class HashingSeedCalculator {

    /**
     * Recursive search task
     */
    private class HashingSeedCalculatorSearchTask extends RecursiveTask<HashingSolution> {

        private SimulatedAnnealingHashing simulation;
        private final int level;
        private final int center, range;

        private HashingSeedCalculatorSearchTask() {
            this.center = 0;
            this.range = Integer.MAX_VALUE / SimulatedAnnealingHashing.SEARCH_BATCH;
            this.level = traversalDepth;
            this.simulation = hashingSimulation;
        }

        private HashingSeedCalculatorSearchTask(HashingSeedCalculatorSearchTask parent, int center, int range) {
            this.center = center;
            this.range = range;
            this.level = parent.level - 1;
            this.simulation = parent.simulation;
        }

        @Override
        protected HashingSolution compute() {
            if (level == 0) {
                long actualCenter = center * SimulatedAnnealingHashing.SEARCH_BATCH;
                log.info("Searching around center {}", actualCenter);
                HashingSolution solution = simulation.derive(center, perShardRunningMills).solve();
                log.info("Searching around center {} found {}", actualCenter, solution);
                return solution;
            } else {
                int halfRange = range / 2;
                int leftCenter = center - halfRange, rightCenter = center + halfRange;
                ForkJoinTask<HashingSolution> leftTask = new HashingSeedCalculatorSearchTask(this, leftCenter, halfRange).fork();
                ForkJoinTask<HashingSolution> rightTask = new HashingSeedCalculatorSearchTask(this, rightCenter, halfRange).fork();
                HashingSolution left = leftTask.join();
                HashingSolution right = rightTask.join();
                return left.getBestScore() < right.getBestScore() ? left : right;
            }
        }
    }

    private final int poolParallelism;
    private final int traversalDepth;
    private final long perShardRunningMills;
    private final SimulatedAnnealingHashing hashingSimulation;

    /**
     * HashingSeedCalculator
     * @param numberOfShards the shard of the whole search range [Integer.MIN_VALUE, Integer.MAX_VALUE]
     * @param totalRunningHours the expect total time consuming for searching
     * @param symbolCounts the key and it`s distribution
     * @param numOfBuckets the number of buckets
     */
    public HashingSeedCalculator(int numberOfShards, int totalRunningHours, Map<String, Integer> symbolCounts, int numOfBuckets) {
        int n = (int) (Math.log(numberOfShards) / Math.log(2));
        if (Math.pow(2, n) != numberOfShards) {
            throw new IllegalArgumentException();
        }
        this.traversalDepth = n;
        this.poolParallelism = Math.max(ForkJoinPool.getCommonPoolParallelism() / 3 * 2, 1); // conservative estimation for parallelism
        this.perShardRunningMills = TimeUnit.HOURS.toMillis(totalRunningHours * poolParallelism) / numberOfShards;
        this.hashingSimulation = new SimulatedAnnealingHashing(symbolCounts, numOfBuckets);
    }

    @Override
    public String toString() {
        int numberOfShards = (int) Math.pow(2, traversalDepth);
        int totalRunningHours = (int) TimeUnit.MILLISECONDS.toHours(perShardRunningMills * numberOfShards) / poolParallelism;
        return "HashingSeedCalculator(" +
                "numberOfShards: " + numberOfShards +
                ", perShardRunningMinutes: " + TimeUnit.MILLISECONDS.toMinutes(perShardRunningMills) +
                ", totalRunningHours: " + totalRunningHours +
                ", poolParallelism: " + poolParallelism +
                ", traversalDepth: " + traversalDepth + ")";
    }

    public synchronized HashingSolution searchBestSeed() {
        long now = System.currentTimeMillis();
        log.info("SearchBestSeed start");
        ForkJoinTask<HashingSolution> root = new HashingSeedCalculatorSearchTask().fork();
        HashingSolution initSolution = hashingSimulation.derive(0, perShardRunningMills).solve();
        HashingSolution bestSolution = root.join();
        log.info("Found init solution {}", initSolution);
        log.info("Found best solution {}", bestSolution);
        if (initSolution.getBestScore() < bestSolution.getBestScore()) {
            bestSolution = initSolution;
        }
        long cost = System.currentTimeMillis() - now;
        log.info("SearchBestSeed finish (cost:{}ms)", cost);
        return bestSolution;
    }

}

效果

將改造後的代碼部署到測試環境後,某日訓練日誌:

12:49:15.227 85172866 INFO hash.HashingSeedCalculator - Found init solution (seed:15231, score:930685828341164)
12:49:15.227 85172866 INFO hash.HashingSeedCalculator - Found best solution (seed:362333, score:793386389726926)
12:49:15.227 85172866 INFO hash.HashingSeedCalculator - SearchBestSeed finish (cost:10154898ms)
12:49:15.227 85172866 INFO hash.TrainingService -
 
Training result: (seed:362333, score:793386389726926)
 
Buckets: 15
 
Expectation: 44045697
 
Result of Hashing.HashCode(seed=362333): 21327108 [42512742, 40479608, 43915771, 47211553, 45354264, 43209190, 43196570, 44725786, 41999747, 46450288, 46079231, 45116615, 44004021, 43896194, 42533877]
 
Result of Hashing.HashCode(seed=31): 66929172 [39723630, 48721463, 43365391, 46301448, 43931616, 44678194, 39064877, 45922454, 43171141, 40715060, 33964547, 49709090, 58869949, 34964729, 47581868]

當晚使用 BKDRHash(seed=31) 對新的交易日數據的進行分片:

04:00:59.001 partition messages per minute [45171, 68641, 62001, 80016, 55977, 61916, 55102, 49322, 55982, 57081, 51100, 70437, 135992, 37823, 58552] , messages total [39654953, 48666261, 43310578, 46146841, 43834832, 44577454, 38990331, 45871075, 43106710, 40600708, 33781629, 49752592, 58584246, 34928991, 47545369]

當晚使用 BKDRHash(seed=362333) 對新的交易日數據的進行分片:

04:00:59.001 partition messages per minute [62424, 82048, 64184, 47000, 57206, 69439, 64430, 60096, 46986, 58182, 54557, 41523, 64310, 72402, 100326] , messages total [44985772, 48329212, 39995385, 43675702, 45216341, 45524616, 41335804, 44917938, 44605376, 44054821, 43371892, 42068637, 44000817, 42617562, 44652695]

對比日誌發現 hash 通過優化後,分區的均勻程度有了顯著的上升,而且熱點分片也被消除了,基本達到當初設想的優化效果。

相關文章
相關標籤/搜索