(手機橫屏看源碼更方便)java
注:java源碼分析部分如無特殊說明均基於 java8 版本。程序員
注:本文基於ForkJoinPool分治線程池類。面試
隨着在硬件上多核處理器的發展和普遍使用,併發編程成爲程序員必須掌握的一門技術,在面試中也常常考查面試者併發相關的知識。算法
今天,咱們就來看一道面試題:編程
如何充分利用多核CPU,計算很大數組中全部整數的和?segmentfault
咱們最容易想到就是單線程相加,一個for循環搞定。數組
若是進一步優化,咱們會天然而然地想到使用線程池來分段相加,最後再把每一個段的結果相加。併發
Yes,就是咱們今天的主角——ForkJoinPool,可是它要怎麼實現呢?彷佛沒怎麼用過哈^^框架
OK,剖析完了,咱們直接來看三種實現,不墨跡,直接上菜。dom
/** * 計算1億個整數的和 */ public class ForkJoinPoolTest01 { public static void main(String[] args) throws ExecutionException, InterruptedException { // 構造數據 int length = 100000000; long[] arr = new long[length]; for (int i = 0; i < length; i++) { arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE); } // 單線程 singleThreadSum(arr); // ThreadPoolExecutor線程池 multiThreadSum(arr); // ForkJoinPool線程池 forkJoinSum(arr); } private static void singleThreadSum(long[] arr) { long start = System.currentTimeMillis(); long sum = 0; for (int i = 0; i < arr.length; i++) { // 模擬耗時,本文由公從號「彤哥讀源碼」原創 sum += (arr[i]/3*3/3*3/3*3/3*3/3*3); } System.out.println("sum: " + sum); System.out.println("single thread elapse: " + (System.currentTimeMillis() - start)); } private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException { long start = System.currentTimeMillis(); int count = 8; ExecutorService threadPool = Executors.newFixedThreadPool(count); List<Future<Long>> list = new ArrayList<>(); for (int i = 0; i < count; i++) { int num = i; // 分段提交任務 Future<Long> future = threadPool.submit(() -> { long sum = 0; for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) { try { // 模擬耗時 sum += (arr[j]/3*3/3*3/3*3/3*3/3*3); } catch (Exception e) { e.printStackTrace(); } } return sum; }); list.add(future); } // 每一個段結果相加 long sum = 0; for (Future<Long> future : list) { sum += future.get(); } System.out.println("sum: " + sum); System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start)); } private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException { long start = System.currentTimeMillis(); ForkJoinPool forkJoinPool = ForkJoinPool.commonPool(); // 提交任務 ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length)); // 獲取結果 Long sum = forkJoinTask.get(); forkJoinPool.shutdown(); System.out.println("sum: " + sum); System.out.println("fork join elapse: " + (System.currentTimeMillis() - start)); } private static class SumTask extends RecursiveTask<Long> { private long[] arr; private int from; private int to; public SumTask(long[] arr, int from, int to) { this.arr = arr; this.from = from; this.to = to; } @Override protected Long compute() { // 小於1000的時候直接相加,可靈活調整 if (to - from <= 1000) { long sum = 0; for (int i = from; i < to; i++) { // 模擬耗時 sum += (arr[i]/3*3/3*3/3*3/3*3/3*3); } return sum; } // 分紅兩段任務,本文由公從號「彤哥讀源碼」原創 int middle = (from + to) / 2; SumTask left = new SumTask(arr, from, middle); SumTask right = new SumTask(arr, middle, to); // 提交左邊的任務 left.fork(); // 右邊的任務直接利用當前線程計算,節約開銷 Long rightResult = right.compute(); // 等待左邊計算完畢 Long leftResult = left.join(); // 返回結果 return leftResult + rightResult; } } }
彤哥偷偷地告訴你,實際上計算1億個整數相加,單線程是最快的,個人電腦大概是100ms左右,使用線程池反而會變慢。
因此,爲了演示ForkJoinPool的牛逼之處,我把每一個數都/3*3/3*3/3*3/3*3/3*3
了一頓操做,用來模擬計算耗時。
來看結果:
sum: 107352457433800662 single thread elapse: 789 sum: 107352457433800662 multi thread elapse: 228 sum: 107352457433800662 fork join elapse: 189
能夠看到,ForkJoinPool相對普通線程池仍是有很大提高的。
問題:普通線程池可否實現ForkJoinPool這種計算方式呢,即大任務拆中任務,中任務拆小任務,最後再彙總?
你能夠試試看(-᷅_-᷄)
OK,下面咱們正式進入ForkJoinPool的解析。
把一個規模大的問題劃分爲規模較小的子問題,而後分而治之,最後合併子問題的解獲得原問題的解。
(1)分割原問題:
(2)求解子問題:
(3)合併子問題的解爲原問題的解。
在分治法中,子問題通常是相互獨立的,所以,常常經過遞歸調用算法來求解子問題。
(1)二分搜索
(2)大整數乘法
(3)Strassen矩陣乘法
(4)棋盤覆蓋
(5)歸併排序
(6)快速排序
(7)線性時間選擇
(8)漢諾塔
ForkJoinPool是 java 7 中新增的線程池類,它的繼承體系以下:
ForkJoinPool和ThreadPoolExecutor都是繼承自AbstractExecutorService抽象類,因此它和ThreadPoolExecutor的使用幾乎沒有多少區別,除了任務變成了ForkJoinTask之外。
這裏又運用到了一種很重要的設計原則——開閉原則——對修改關閉,對擴展開放。
可見整個線程池體系一開始的接口設計就很好,新增一個線程池類,不會對原有的代碼形成干擾,還能利用原有的特性。
fork()方法相似於線程的Thread.start()方法,可是它不是真的啓動一個線程,而是將任務放入到工做隊列中。
join()方法相似於線程的Thread.join()方法,可是它不是簡單地阻塞線程,而是利用工做線程運行其它任務。當一個工做線程中調用了join()方法,它將處理其它任務,直到注意到目標子任務已經完成了。
無返回值任務。
有返回值任務。
無返回值任務,完成任務後能夠觸發回調。
ForkJoinPool內部使用的是「工做竊取」算法實現的。
(1)每一個工做線程都有本身的工做隊列WorkQueue;
(2)這是一個雙端隊列,它是線程私有的;
(3)ForkJoinTask中fork的子任務,將放入運行該任務的工做線程的隊頭,工做線程將以LIFO的順序來處理工做隊列中的任務;
(4)爲了最大化地利用CPU,空閒的線程將從其它線程的隊列中「竊取」任務來執行;
(5)從工做隊列的尾部竊取任務,以減小競爭;
(6)雙端隊列的操做:push()/pop()僅在其全部者工做線程中調用,poll()是由其它線程竊取任務時調用的;
(7)當只剩下最後一個任務時,仍是會存在競爭,是經過CAS來實現的;
(1)最適合的是計算密集型任務,本文由公從號「彤哥讀源碼」原創;
(2)在須要阻塞工做線程時,可使用ManagedBlocker;
(3)不該該在RecursiveTask<R>的內部使用ForkJoinPool.invoke()/invokeAll();
(1)ForkJoinPool特別適合於「分而治之」算法的實現;
(2)ForkJoinPool和ThreadPoolExecutor是互補的,不是誰替代誰的關係,兩者適用的場景不一樣;
(3)ForkJoinTask有兩個核心方法——fork()和join(),有三個重要子類——RecursiveAction、RecursiveTask和CountedCompleter;
(4)ForkjoinPool內部基於「工做竊取」算法實現;
(5)每一個線程有本身的工做隊列,它是一個雙端隊列,本身從隊列頭存取任務,其它線程從尾部竊取任務;
(6)ForkJoinPool最適合於計算密集型任務,但也可使用ManagedBlocker以便用於阻塞型任務;
(7)RecursiveTask內部能夠少調用一次fork(),利用當前線程處理,這是一種技巧;
ManagedBlocker怎麼使用?
答:ManagedBlocker至關於明確告訴ForkJoinPool框架要阻塞了,ForkJoinPool就會啓另外一個線程來運行任務,以最大化地利用CPU。
請看下面的例子,本身琢磨哈^^。
/** * 斐波那契數列 * 一個數是它前面兩個數之和 * 1,1,2,3,5,8,13,21 */ public class Fibonacci { public static void main(String[] args) { long time = System.currentTimeMillis(); Fibonacci fib = new Fibonacci(); int result = fib.f(1_000).bitCount(); time = System.currentTimeMillis() - time; System.out.println("result,本文由公從號「彤哥讀源碼」原創 = " + result); System.out.println("test1_000() time = " + time); } public BigInteger f(int n) { Map<Integer, BigInteger> cache = new ConcurrentHashMap<>(); cache.put(0, BigInteger.ZERO); cache.put(1, BigInteger.ONE); return f(n, cache); } private final BigInteger RESERVED = BigInteger.valueOf(-1000); public BigInteger f(int n, Map<Integer, BigInteger> cache) { BigInteger result = cache.putIfAbsent(n, RESERVED); if (result == null) { int half = (n + 1) / 2; RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() { @Override protected BigInteger compute() { return f(half - 1, cache); } }; f0_task.fork(); BigInteger f1 = f(half, cache); BigInteger f0 = f0_task.join(); long time = n > 10_000 ? System.currentTimeMillis() : 0; try { if (n % 2 == 1) { result = f0.multiply(f0).add(f1.multiply(f1)); } else { result = f0.shiftLeft(1).add(f1).multiply(f1); } synchronized (RESERVED) { cache.put(n, result); RESERVED.notifyAll(); } } finally { time = n > 10_000 ? System.currentTimeMillis() - time : 0; if (time > 50) System.out.printf("f(%d) took %d%n", n, time); } } else if (result == RESERVED) { try { ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache); ForkJoinPool.managedBlock(blocker); result = blocker.result; } catch (InterruptedException e) { throw new CancellationException("interrupted"); } } return result; // return f(n - 1).add(f(n - 2)); } private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker { private BigInteger result; private final int n; private final Map<Integer, BigInteger> cache; public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) { this.n = n; this.cache = cache; } @Override public boolean block() throws InterruptedException { synchronized (RESERVED) { while (!isReleasable()) { RESERVED.wait(); } } return true; } @Override public boolean isReleasable() { return (result = cache.get(n)) != RESERVED; } } }
歡迎關注個人公衆號「彤哥讀源碼」,查看更多源碼系列文章, 與彤哥一塊兒暢遊源碼的海洋。