fork/join框架是用多線程的方式實現分治法來解決問題。fork指的是將問題不斷地縮小規模,join是指根據子問題的計算結果,得出更高層次的結果。java
fork/join框架的使用有必定的約束條件:
算法
1. 除了fork() 和 join()方法外,線程不得使用其餘的同步工具。線程最好也不要sleep()數組
2. 線程不得進行I/O操做多線程
3. 線程不得拋出checked exception框架
此框架有幾個核心類:ForkJoinPool是實現了工做竊取算法的線程池。ForkJoinTask是任務類,他有2個子類:RecursiveAction無返回值,RecursiveTask有返回值,在定義本身的任務時,通常都是從這2類中挑一個,經過繼承的方式定義本身的新類。因爲ForkJoinTask類實現了Serializable接口,所以,定義本身的任務類時,應該定義serialVersionUID屬性。dom
在編寫任務時,推薦的寫法是這樣的:異步
[java] view plaincopyide
If (problem size > default size){ 工具
task s = divide(task); this
execute(tasks);
} else {
resolve problem using another algorithm;
}
ForkJoinPool實現了工做竊取算法(work-stealing),線程會主動尋找新建立的任務去執行,從而保證較高的線程利用率。它使用守護線程(deamon)來執行任務,所以無需對他顯示的調用shutdown()來關閉。通常狀況下,一個程序只須要惟一的一個ForkJoinPool,所以應該按以下方式建立它:
static final ForkJoinPool mainPool = new ForkJoinPool(); //線程的數目等於CPU的核心數
下面給出一個很是簡單的例子,功能是將一個數組中每個元素的值加1。具體實現爲:將大數組不斷分解爲更短小的子數組,當子數組長度不超過10的時候,對其中全部元素進行加1操做。
[java] view plaincopy
public class Test {
public final static ForkJoinPool mainPool = new ForkJoinPool();
public static void main(String[] args){
int n = 26;
int[] a = new int[n];
for(int i=0; i<n; i++) {
a[i] = i;
}
SubTask task = new SubTask(a, 0, n);
mainPool.invoke(task);
for(int i=0; i<n; i++) {
System.out.print(a[i]+" ");
}
}
}
class SubTask extends RecursiveAction {
private static final long serialVersionUID = 1L;
private int[] a;
private int beg;
private int end;
public SubTask(int[] a, int beg, int end) {
super();
this.a = a;
this.beg = beg;
this.end = end;
}
@Override
protected void compute() {
if(end-beg>10) {
int mid = (beg+end) / 2;
SubTask t1 = new SubTask(a, beg, mid);
SubTask t2 = new SubTask(a, mid, end);
invokeAll(t1, t2);
}else {
for(int i=beg; i<end; i++) {
a[i] = a[i] + 1;
}
}
}
}
例子2,任務擁有返回值。隨機生成一個數組,每一個元素均是0-999之間的整數,統計該數組中每一個數字出現1的次數的和。
實現方法,將該數組不斷的分紅更小的數組,直到每一個子數組的長度爲1,即只包含一個元素。此時,統計該元素中包含1的個數。最後彙總,獲得數組中每一個數字共包含了多少個1。
[java] view plaincopy
public class Test {
public final static ForkJoinPool mainPool = new ForkJoinPool();
public static void main(String[] args){
int n = 26;
int[] a = new int[n];
Random rand = new Random();
for(int i=0; i<n; i++) {
a[i] = rand.nextInt(1000);
}
SubTask task = new SubTask(a, 0, n);
int count = mainPool.invoke(task);
for(int i=0; i<n; i++) {
System.out.print(a[i]+" ");
}
System.out.println("\n數組中共出現了" + count + "個1");
}
}
class SubTask extends RecursiveTask<Integer> {
private static final long serialVersionUID = 1L;
private int[] a;
private int beg;
private int end;
public SubTask(int[] a, int beg, int end) {
super();
this.a = a;
this.beg = beg;
this.end = end;
}
@Override
protected Integer compute() {
int result = 0;
if(end-beg>1) {
int mid = (beg+end)/2;
SubTask t1 = new SubTask(a, beg, mid);
SubTask t2 = new SubTask(a, mid, end);
invokeAll(t1, t2);
try {
result = t1.get()+t2.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
} else {
result = count(a[beg]);
}
return result;
}
//統計一個整數中出現了幾個1
private int count(int n) {
int result = 0;
while(n>0) {
if(n % 10==1) {
result++;
}
n = n / 10;
}
return result;
}
}
例子3,異步執行任務。前面兩個例子都是同步執行任務,當啓動任務後,主線程陷入了阻塞狀態,直到任務執行完畢。若建立新任務後,但願當前線程能繼續執行而非陷入阻塞,則須要異步執行。ForkJoinPool線程池提供了execute()方法來異步啓動任務,而做爲任務自己,能夠調用fork()方法異步啓動新的子任務,並調用子任務的join()方法來取得計算結果。須要注意的是,異步使用ForkJoin框架,沒法使用「工做竊取」算法來提升線程的利用率,針對每一個子任務,系統都會啓動一個新的線程。
本例的功能是查找硬盤上某一類型的文件。給定文件擴展名後,將硬盤上全部該類型的文件名打印顯示出來。做爲主程序,啓動任務後,繼續顯示任務的執行進度,每3秒鐘打印顯示一個黑點,表示任務在繼續。最後,當全部線程都結束了,打印顯示結果。
[java] view plaincopy
public class ThreadLocalTest {
public static void main(String[] args) throws Exception {
Path p = Paths.get("D:/");
List<Path> roots = (List<Path>) FileSystems.getDefault().getRootDirectories();
List<Path> result = new ArrayList<>();
List<MyTask> tasks = new ArrayList<>();
ForkJoinPool pool = new ForkJoinPool();
for(Path root:roots) {
MyTask t = new MyTask(root, "pdf");
pool.execute(t);
tasks.add(t);
}
System.out.print("正在處理中");
while(isAllDone(tasks) == false) {
System.out.print(". ");
TimeUnit.SECONDS.sleep(3);
}
for(MyTask t:tasks) {
result.addAll(t.get());
}
for(Path pp:result) {
System.out.println(pp);
}
}
private static boolean isAllDone(List<MyTask> tasks) {
boolean result = true;
for(MyTask t:tasks) {
if(t.isDone() == false) {
result = false;
break;
}
}
return result;
}
}
class MyTask extends RecursiveTask<List<Path>> {
private static final long serialVersionUID = 1L;
private Path path;
private String fileExtention;
public MyTask(Path path, String fileExtention) {
super();
this.path = path;
this.fileExtention = fileExtention;
}
@Override
protected List<Path> compute() {
List<Path> result = new ArrayList<>();
try {
DirectoryStream<Path> paths = Files.newDirectoryStream(path);
List<MyTask> subTasks = new ArrayList<>();
for(Path p:paths) {
if(Files.isDirectory(p)) {
MyTask t = new MyTask(p, fileExtention);
t.fork();
subTasks.add(t);
}else if(Files.isRegularFile(p)) {
if(p.toString().toLowerCase().endsWith("."+fileExtention)) {
result.add(p);
}
}
}
for(MyTask t:subTasks) {
result.addAll(t.join());
}
} catch (IOException e) {
}
return result;
}