【轉】JDK 7 中的 Fork/Join 模式

介紹html

隨着多核芯片逐漸成爲主流,大多數軟件開發人員不可避免地須要瞭解並行編程的知識。而同時,主流程序語言正在將愈來愈多的並行特性合併到標準庫或者語言自己之中。咱們能夠看到,JDK 在這方面一樣走在潮流的前方。在 JDK 標準版 5 中,由 Doug Lea 提供的並行框架成爲了標準庫的一部分(JSR-166)。隨後,在 JDK 6 中,一些新的並行特性,例如並行 collection 框架,合併到了標準庫中(JSR-166x)。直到今天,儘管 Java SE 7 尚未正式發佈,一些並行相關的新特性已經出如今 JSR-166y 中:java

  1. Fork/Join 模式;
  2. TransferQueue,它繼承自 BlockingQueue 並能在隊列滿時阻塞「生產者」;
  3. ArrayTasks/ListTasks,用於並行執行某些數組/列表相關任務的類;
  4. IntTasks/LongTasks/DoubleTasks,用於並行處理數字類型數組的工具類,提供了排序、查找、求和、求最小值、求最大值等功能;

其中,對 Fork/Join 模式的支持多是對開發並行軟件來講最通用的新特性。在 JSR-166y 中,Doug Lea 實現 ArrayTasks/ListTasks/IntTasks/LongTasks/DoubleTasks 時就大量的用到了 Fork/Join 模式。讀者還須要注意一點,由於 JDK 7 尚未正式發佈,所以本文涉及到的功能和發佈版本有可能不同。算法

Fork/Join 模式有本身的適用範圍。若是一個應用能被分解成多個子任務,而且組合多個子任務的結果就可以得到最終的答案,那麼這個應用就適合用 Fork/Join 模式來解決。圖 1 給出了一個 Fork/Join 模式的示意圖,位於圖上部的 Task 依賴於位於其下的 Task 的執行,只有當全部的子任務都完成以後,調用者才能得到 Task 0 的返回結果。編程


圖 1. Fork/Join 模式示意圖
圖 1. Fork/Join 模式示意圖  

能夠說,Fork/Join 模式可以解決不少種類的並行問題。經過使用 Doug Lea 提供的 Fork/Join 框架,軟件開發人員只須要關注任務的劃分和中間結果的組合就能充分利用並行平臺的優良性能。其餘和並行相關的諸多難於處理的問題,例如負載平衡、同步等,均可以由框架採用統一的方式解決。這樣,咱們就可以輕鬆地得到並行的好處而避免了並行編程的困難且容易出錯的缺點。數組

回頁首多線程

使用 Fork/Join 模式併發

在開始嘗試 Fork/Join 模式以前,咱們須要從 Doug Lea 主持的 Concurrency JSR-166 Interest Site 上下載 JSR-166y 的源代碼,而且咱們還須要安裝最新版本的 JDK 6(下載網址請參閱 參考資源)。Fork/Join 模式的使用方式很是直觀。首先,咱們須要編寫一個 ForkJoinTask 來完成子任務的分割、中間結果的合併等工做。隨後,咱們將這個 ForkJoinTask 交給 ForkJoinPool 來完成應用的執行。框架

一般咱們並不直接繼承 ForkJoinTask,它包含了太多的抽象方法。針對特定的問題,咱們能夠選擇 ForkJoinTask 的不一樣子類來完成任務。RecursiveAction 是 ForkJoinTask 的一個子類,它表明了一類最簡單的 ForkJoinTask:不須要返回值,當子任務都執行完畢以後,不須要進行中間結果的組合。若是咱們從 RecursiveAction 開始繼承,那麼咱們只須要重載 protected void compute() 方法。下面,咱們來看看怎麼爲快速排序算法創建一個 ForkJoinTask 的子類:dom


清單 1. ForkJoinTask 的子類
class SortTask extends RecursiveAction {
    final long[] array;
    final int lo;
    final int hi;
    private int THRESHOLD = 30;

    public SortTask(long[] array) {
        this.array = array;
        this.lo = 0;
        this.hi = array.length - 1;
    }

    public SortTask(long[] array, int lo, int hi) {
        this.array = array;
        this.lo = lo;
        this.hi = hi;
    }

    protected void compute() {
        if (hi - lo < THRESHOLD)
            sequentiallySort(array, lo, hi);
        else {
            int pivot = partition(array, lo, hi);
            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
                pivot + 1, hi));
        }
    }

    private int partition(long[] array, int lo, int hi) {
        long x = array[hi];
        int i = lo - 1;
        for (int j = lo; j < hi; j++) {
            if (array[j] <= x) {
                i++;
                swap(array, i, j);
            }
        }
        swap(array, i + 1, hi);
        return i + 1;
    }

    private void swap(long[] array, int i, int j) {
        if (i != j) {
            long temp = array[i];
            array[i] = array[j];
            array[j] = temp;
        }
    }

    private void sequentiallySort(long[] array, int lo, int hi) {
        Arrays.sort(array, lo, hi + 1);
    }
}

在 清單 1 中,SortTask 首先經過 partition() 方法將數組分紅兩個部分。隨後,兩個子任務將被生成並分別排序數組的兩個部分。當子任務足夠小時,再將其分割爲更小的任務反而引發性能的下降。所以,這裏咱們使用一個 THRESHOLD,限定在子任務規模較小時,使用直接排序,而不是再將其分割成爲更小的任務。其中,咱們用到了 RecursiveAction 提供的方法 coInvoke()。它表示:啓動全部的任務,並在全部任務都正常結束後返回。若是其中一個任務出現異常,則其它全部的任務都取消。coInvoke() 的參數還能夠是任務的數組。異步

如今剩下的工做就是將 SortTask 提交到 ForkJoinPool 了。ForkJoinPool() 默認創建具備與 CPU 可以使用線程數相等線程個數的線程池。咱們在一個 JUnit 的 test 方法中將 SortTask 提交給一個新建的 ForkJoinPool:


清單 2. 新建的 ForkJoinPool
@Test
public void testSort() throws Exception {
    ForkJoinTask sort = new SortTask(array);
    ForkJoinPool fjpool = new ForkJoinPool();
    fjpool.submit(sort);
    fjpool.shutdown();

    fjpool.awaitTermination(30, TimeUnit.SECONDS);

    assertTrue(checkSorted(array));
}

在上面的代碼中,咱們用到了 ForkJoinPool 提供的以下函數:

  1. submit():將 ForkJoinTask 類的對象提交給 ForkJoinPool,ForkJoinPool 將馬上開始執行 ForkJoinTask。
  2. shutdown():執行此方法以後,ForkJoinPool 再也不接受新的任務,可是已經提交的任務能夠繼續執行。若是但願馬上中止全部的任務,能夠嘗試 shutdownNow() 方法。
  3. awaitTermination():阻塞當前線程直到 ForkJoinPool 中全部的任務都執行結束。

並行快速排序的完整代碼以下所示:


清單 3. 並行快速排序的完整代碼
package tests;

import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import jsr166y.forkjoin.ForkJoinPool;
import jsr166y.forkjoin.ForkJoinTask;
import jsr166y.forkjoin.RecursiveAction;

import org.junit.Before;
import org.junit.Test;

class SortTask extends RecursiveAction {
    final long[] array;
    final int lo;
    final int hi;
    private int THRESHOLD = 0; //For demo only

    public SortTask(long[] array) {
        this.array = array;
        this.lo = 0;
        this.hi = array.length - 1;
    }

    public SortTask(long[] array, int lo, int hi) {
        this.array = array;
        this.lo = lo;
        this.hi = hi;
    }

    protected void compute() {
        if (hi - lo < THRESHOLD)
            sequentiallySort(array, lo, hi);
        else {
            int pivot = partition(array, lo, hi);
            System.out.println("\npivot = " + pivot + ", low = " + lo + ", high = " + hi);
			System.out.println("array" + Arrays.toString(array));
            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
                    pivot + 1, hi));
        }
    }

    private int partition(long[] array, int lo, int hi) {
        long x = array[hi];
        int i = lo - 1;
        for (int j = lo; j < hi; j++) {
            if (array[j] <= x) {
                i++;
                swap(array, i, j);
            }
        }
        swap(array, i + 1, hi);
        return i + 1;
    }

    private void swap(long[] array, int i, int j) {
        if (i != j) {
            long temp = array[i];
            array[i] = array[j];
            array[j] = temp;
        }
    }

    private void sequentiallySort(long[] array, int lo, int hi) {
        Arrays.sort(array, lo, hi + 1);
    }
}

public class TestForkJoinSimple {
    private static final int NARRAY = 16; //For demo only
    long[] array = new long[NARRAY];
    Random rand = new Random();

    @Before
    public void setUp() {
        for (int i = 0; i < array.length; i++) {
            array[i] = rand.nextLong()%100; //For demo only
        }
        System.out.println("Initial Array: " + Arrays.toString(array));
    }

    @Test
    public void testSort() throws Exception {
        ForkJoinTask sort = new SortTask(array);
        ForkJoinPool fjpool = new ForkJoinPool();
        fjpool.submit(sort);
        fjpool.shutdown();

        fjpool.awaitTermination(30, TimeUnit.SECONDS);

        assertTrue(checkSorted(array));
    }

    boolean checkSorted(long[] a) {
        for (int i = 0; i < a.length - 1; i++) {
            if (a[i] > (a[i + 1])) {
                return false;
            }
        }
        return true;
    }
}

運行以上代碼,咱們能夠獲得如下結果:

Initial Array: [46, -12, 74, -67, 76, -13, -91, -96]

pivot = 0, low = 0, high = 7
array[-96, -12, 74, -67, 76, -13, -91, 46]

pivot = 5, low = 1, high = 7
array[-96, -12, -67, -13, -91, 46, 76, 74]

pivot = 1, low = 1, high = 4
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 4, low = 2, high = 4
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 3, low = 2, high = 3
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 2, low = 2, high = 2
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 6, low = 6, high = 7
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 7, low = 7, high = 7
array[-96, -91, -67, -13, -12, 46, 74, 76]

回頁首

Fork/Join 模式高級特性

使用 RecursiveTask

除了 RecursiveAction,Fork/Join 框架還提供了其餘 ForkJoinTask 子類:帶有返回值的 RecursiveTask,使用 finish() 方法顯式停止的 AsyncAction 和 LinkedAsyncAction,以及可以使用 TaskBarrier 爲每一個任務設置不一樣停止條件的 CyclicAction。

從 RecursiveTask 繼承的子類一樣須要重載 protected void compute() 方法。與 RecursiveAction 稍有不一樣的是,它可以使用泛型指定一個返回值的類型。下面,咱們來看看如何使用 RecursiveTask 的子類。


清單 4. RecursiveTask 的子類
class Fibonacci extends RecursiveTask<Integer> {
    final int n;

    Fibonacci(int n) {
        this.n = n;
    }

    private int compute(int small) {
        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
        return results[small];
    }

    public Integer compute() {
        if (n <= 10) {
            return compute(n);
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        Fibonacci f2 = new Fibonacci(n - 2);
        f1.fork();
        f2.fork();
        return f1.join() + f2.join();
    }
}

在 清單 4 中, Fibonacci 的返回值爲 Integer 類型。其 compute() 函數首先創建兩個子任務,啓動子任務執行,阻塞以等待子任務的結果返回,相加後獲得最終結果。一樣,當子任務足夠小時,經過查表獲得其結果,以減少因過多地分割任務引發的性能下降。其中,咱們用到了 RecursiveTask 提供的方法 fork() 和 join()。它們分別表示:子任務的異步執行和阻塞等待結果完成。

如今剩下的工做就是將 Fibonacci 提交到 ForkJoinPool 了,咱們在一個 JUnit 的 test 方法中做了以下處理:


清單 5. 將 Fibonacci 提交到 ForkJoinPool
@Test
public void testFibonacci() throws InterruptedException, ExecutionException {
    ForkJoinTask<Integer> fjt = new Fibonacci(45);
    ForkJoinPool fjpool = new ForkJoinPool();
    Future<Integer> result = fjpool.submit(fjt);

    // do something
    System.out.println(result.get());
}

使用 CyclicAction 來處理循環任務

CyclicAction 的用法稍微複雜一些。若是一個複雜任務須要幾個線程協做完成,而且線程之間須要在某個點等待全部其餘線程到達,那麼咱們就能方便的用 CyclicAction 和 TaskBarrier 來完成。圖 2 描述了使用 CyclicAction 和 TaskBarrier 的一個典型場景。


圖 2. 使用 CyclicAction 和 TaskBarrier 執行多線程任務
圖 2. 使用 CyclicAction 和 TaskBarrier 執行多線程任務  

繼承自 CyclicAction 的子類須要 TaskBarrier 爲每一個任務設置不一樣的停止條件。從 CyclicAction 繼承的子類須要重載 protected void compute() 方法,定義在 barrier 的每一個步驟須要執行的動做。compute() 方法將被反覆執行直到 barrier 的 isTerminated() 方法返回 True。TaskBarrier 的行爲相似於 CyclicBarrier。下面,咱們來看看如何使用 CyclicAction 的子類。


清單 6. 使用 CyclicAction 的子類
class ConcurrentPrint extends RecursiveAction {
    protected void compute() {
        TaskBarrier b = new TaskBarrier() {
            protected boolean terminate(int cycle, int registeredParties) {
                System.out.println("Cycle is " + cycle + ";"
                        + registeredParties + " parties");
                return cycle >= 10;
            }
        };
        int n = 3;
        CyclicAction[] actions = new CyclicAction[n];
        for (int i = 0; i < n; ++i) {
            final int index = i;
            actions[i] = new CyclicAction(b) {
                protected void compute() {
                    System.out.println("I'm working " + getCycle() + " "
                            + index);
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
        }
        for (int i = 0; i < n; ++i)
            actions[i].fork();
        for (int i = 0; i < n; ++i)
            actions[i].join();
    }
}

在 清單 6 中,CyclicAction[] 數組創建了三個任務,打印各自的工做次數和序號。而在 b.terminate() 方法中,咱們設置的停止條件表示重複 10 次計算後停止。如今剩下的工做就是將 ConcurrentPrint 提交到 ForkJoinPool 了。咱們能夠在 ForkJoinPool 的構造函數中指定須要的線程數目,例如 ForkJoinPool(4) 就代表線程池包含 4 個線程。咱們在一個 JUnit 的 test 方法中運行 ConcurrentPrint 的這個循環任務:


清單 7. 運行 ConcurrentPrint 循環任務
@Test
public void testBarrier () throws InterruptedException, ExecutionException {
    ForkJoinTask fjt = new ConcurrentPrint();
    ForkJoinPool fjpool = new ForkJoinPool(4);
    fjpool.submit(fjt);
    fjpool.shutdown();
}

RecursiveTask 和 CyclicAction 兩個例子的完整代碼以下所示:


清單 8. RecursiveTask 和 CyclicAction 兩個例子的完整代碼
package tests;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import jsr166y.forkjoin.CyclicAction;
import jsr166y.forkjoin.ForkJoinPool;
import jsr166y.forkjoin.ForkJoinTask;
import jsr166y.forkjoin.RecursiveAction;
import jsr166y.forkjoin.RecursiveTask;
import jsr166y.forkjoin.TaskBarrier;

import org.junit.Test;

class Fibonacci extends RecursiveTask<Integer> {
    final int n;

    Fibonacci(int n) {
        this.n = n;
    }

    private int compute(int small) {
        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
        return results[small];
    }

    public Integer compute() {
        if (n <= 10) {
            return compute(n);
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        Fibonacci f2 = new Fibonacci(n - 2);
        System.out.println("fork new thread for " + (n - 1));
        f1.fork();
        System.out.println("fork new thread for " + (n - 2));
        f2.fork();
        return f1.join() + f2.join();
    }
}

class ConcurrentPrint extends RecursiveAction {
    protected void compute() {
        TaskBarrier b = new TaskBarrier() {
            protected boolean terminate(int cycle, int registeredParties) {
                System.out.println("Cycle is " + cycle + ";"
                        + registeredParties + " parties");
                return cycle >= 10;
            }
        };
        int n = 3;
        CyclicAction[] actions = new CyclicAction[n];
        for (int i = 0; i < n; ++i) {
            final int index = i;
            actions[i] = new CyclicAction(b) {
                protected void compute() {
                    System.out.println("I'm working " + getCycle() + " "
                            + index);
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
        }
        for (int i = 0; i < n; ++i)
            actions[i].fork();
        for (int i = 0; i < n; ++i)
            actions[i].join();
    }
}

public class TestForkJoin {
    @Test
    public void testBarrier () throws InterruptedException, ExecutionException {
		System.out.println("\ntesting Task Barrier ...");
        ForkJoinTask fjt = new ConcurrentPrint();
        ForkJoinPool fjpool = new ForkJoinPool(4);
        fjpool.submit(fjt);
        fjpool.shutdown();
    }

    @Test
    public void testFibonacci () throws InterruptedException, ExecutionException {
    	System.out.println("\ntesting Fibonacci ...");
		final int num = 14; //For demo only
        ForkJoinTask<Integer> fjt = new Fibonacci(num);
        ForkJoinPool fjpool = new ForkJoinPool();
        Future<Integer> result = fjpool.submit(fjt);

        // do something
        System.out.println("Fibonacci(" + num + ") = " + result.get());
    }
}

運行以上代碼,咱們能夠獲得如下結果:

testing Task Barrier ...
I'm working 0 2
I'm working 0 0
I'm working 0 1
Cycle is 0; 3 parties
I'm working 1 2
I'm working 1 0
I'm working 1 1
Cycle is 1; 3 parties
I'm working 2 0
I'm working 2 1
I'm working 2 2
Cycle is 2; 3 parties
I'm working 3 0
I'm working 3 2
I'm working 3 1
Cycle is 3; 3 parties
I'm working 4 2
I'm working 4 0
I'm working 4 1
Cycle is 4; 3 parties
I'm working 5 1
I'm working 5 0
I'm working 5 2
Cycle is 5; 3 parties
I'm working 6 0
I'm working 6 2
I'm working 6 1
Cycle is 6; 3 parties
I'm working 7 2
I'm working 7 0
I'm working 7 1
Cycle is 7; 3 parties
I'm working 8 1
I'm working 8 0
I'm working 8 2
Cycle is 8; 3 parties
I'm working 9 0
I'm working 9 2

testing Fibonacci ...
fork new thread for 13
fork new thread for 12
fork new thread for 11
fork new thread for 10
fork new thread for 12
fork new thread for 11
fork new thread for 10
fork new thread for 9
fork new thread for 10
fork new thread for 9
fork new thread for 11
fork new thread for 10
fork new thread for 10
fork new thread for 9
Fibonacci(14) = 610

回頁首

結論

從以上的例子中能夠看到,經過使用 Fork/Join 模式,軟件開發人員可以方便地利用多核平臺的計算能力。儘管尚未作到對軟件開發人員徹底透明,Fork/Join 模式已經極大地簡化了編寫併發程序的瑣碎工做。對於符合 Fork/Join 模式的應用,軟件開發人員再也不須要處理各類並行相關事務,例如同步、通訊等,以難以調試而聞名的死鎖和 data race 等錯誤也就不會出現,提高了思考問題的層次。你能夠把 Fork/Join 模式看做並行版本的 Divide and Conquer 策略,僅僅關注如何劃分任務和組合中間結果,將剩下的事情丟給 Fork/Join 框架。

在實際工做中利用 Fork/Join 模式,能夠充分享受多核平臺爲應用帶來的免費午飯。

訪問 Doug Lea 的 JSR 166 站點得到最新的源代碼。

相關文章
相關標籤/搜索