Java多線程之ThreadPoolExecutor和ForkJoinPool的用法

目錄java

在平時的工做中,當遇到數據量比較大、程序運行較慢,須要提高程序性能時,通常會涉及到多線程。有些小夥伴對多線程的用法不是很清楚,本文主要說明一下 ThreadPoolExecutorForkJoinPool 的用法。算法

場景

首先咱們假設這樣一個場景,有一個接口,用來計算數組的和。接口定義以下:編程

package mutilthread;

/** * 求和的接口 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 15:28 * @Modified By: */
public interface Calculator {
    long sumUp(int[] numbers) throws Exception;
}

複製代碼

單線程實現

最開始咱們的代碼確定是使用普通的單線程實現,這樣的好處是代碼比較簡單,壞處就是當數據了比較大時,程序運行較慢,沒法利用多核CPU。數組

package mutilthread;

import java.util.ArrayList;
import java.util.List;

/** * 單線程的類 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 10:24 * @Modified By: */
public class SingleThread implements Calculator {
    /** * 用單線程計算數組的和 * @param calcData 須要求和的數組 * @return * @author Rebecca 10:51 2019/6/18 * @version 1.0 */
    @Override
    public long sumUp(int[] calcData) {
        // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
        List<SingleThread> tasks = new ArrayList<SingleThread>();

        int calcDataLength = calcData.length;
        long sum = 0l;
        for (int i = 0; i < calcDataLength; i++) {
            sum += calcData[i];

            // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
            tasks.add(new SingleThread());
        }
        return sum;
    }
}
複製代碼

多線程實現-ExecutorService

由於單線程的劣勢嚴重影響程序處理速度,咱們把代碼優化爲多線程的ExecutorService來實現。bash

package mutilthread;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy;

/** * 用 ThreadPoolExecutor 線程池計算數組的和 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 10:50 * @Modified By: */
public class MutilThreadOfThreadPoolExecutor implements Calculator {

    /** * 用 ThreadPoolExecutor 線程池計算數組的和 * @param calcData 須要求和的數組 * @return * @author Rebecca 10:51 2019/6/18 * @version 1.0 */
    @Override
    public long sumUp(int[] calcData) throws Exception {
        // 建立線程池
        ExecutorService executorService = new ThreadPoolExecutor(5, 10, // 線程數
                60l, TimeUnit.SECONDS,  // 超時時間
                new ArrayBlockingQueue<Runnable>(100, true),  // 線程處理數據的方式
                Executors.defaultThreadFactory(),  // 建立線程的工廠
                new CallerRunsPolicy());  // 超出處理範圍的處理方式


        int calcDataLength = calcData.length;
        long sum = 0l;
        int threadSize = 5;

        for (int i = 0; i < threadSize; i++) {
            int arrStart = calcDataLength / threadSize * i;
            int arrEnd = calcDataLength / threadSize * (i+1);

            SumTask task = new SumTask(calcData, arrStart, arrEnd);
            // 線程池處理數據
            Future<Long> future = executorService.submit(task);

            sum += future.get().longValue();
        }
        // 關閉線程池
        executorService.shutdown();

        return sum;
    }


    public static class SumTask implements Callable<Long> {
        private int[] arr;
        private int start, end;

        public SumTask() {}

        public SumTask(int[] arr, int start, int end) {
            this.arr = arr;
            this.start = start;
            this.end = end;
        }

        @Override
        public Long call() {
            // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
            List<SumTask> tasks = new ArrayList<SumTask>();

            long sum = 0l;
            for (int i = start; i < end; i++)
            {
                sum += arr[i];
                // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
                tasks.add(new SumTask());
            }

            return sum;
        }
    }
}
複製代碼

Executors也提供了一些方法,能夠直接建立ExecutorService線程池,如newSingleThreadExecutor()newCachedThreadPool()newFixedThreadPool()newScheduledThreadPool(),相比於ThreadPoolExecutor提供的構造函數,Executors提供的方法只用傳2個參數甚至更少,但new ThreadPoolExecutor()則要傳一堆參數。那麼咱們爲何還要用new ThreadPoolExecutor()這種方式呢?多線程

答案很簡單,爲了避免讓程序出現OOM。若是你看過Executors構造線程池相關方法的源碼就會發現,它內部也是用new ThreadPoolExecutor()方式建立的線程池。但有一個參數它傳的是Integer.MAX_VALUE。這個參數是什麼意思呢?即線程池中容許出現的線程最大數量。若是線程池中真的建立了Integer.MAX_VALUE的線程數,程序確定會OOM的。併發

// Executors的newCachedThreadPool方法源碼
public static ExecutorService newCachedThreadPool(ThreadFactory threadFactory) {
    return new ThreadPoolExecutor(0, Integer.MAX_VALUE,
                                  60L, TimeUnit.SECONDS,
                                  new SynchronousQueue<Runnable>(),
                                  threadFactory);
}
複製代碼

爲了不這種狀況,咱們通常用new ThreadPoolExecutor()這種方式建立線程池。那麼這麼多參數分別是什麼意思呢?dom

別急,其實咱們能夠分組記憶:ide

第1組(線程數量相關的):函數

  1. corePoolSize: 核心線程數。即便線程池中沒有任務,這些線程也不會被銷燬,由於建立和銷燬線程是須要消耗CPU資源的
  2. maximumPoolSize: 線程池中容許建立的最大線程數

第2組(非核心線程銷燬時間相關的):

  1. keepAliveTime: 非核心線程的銷燬時間。非核心線程不可能一直在線程池中佔用資源,因此須要銷燬
  2. unit: 銷燬的時間單位。值爲TimeUnit中的枚舉類型

第3組(線程池處理數據相關的):

  1. workQueue: 線程處理數據的方式。通常用JDK提供的ArrayBlockingQueue(數組)和LinkedBlockingDeque(鏈表)
  2. handler: 超出處理範圍的處理方式。
    AbortPolicy : 若是超出處理範圍,則拋RejectedExecutionException異常;
    CallerRunsPolicy : 若是超出處理範圍,則用調用該線程池的線程處理;
    DiscardOldestPolicy: 若是超出處理範圍,則把最舊的元素刪除,保留新的元素
    DiscardPolicy: 若是超出處理範圍,則不處理,丟棄掉

第4組(建立線程的工廠):

  1. threadFactory: 建立線程的工廠,通常咱們用 Executors.defaultThreadFactory() 便可
// ThreadPoolExecutor的構造方法源碼
public ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) 複製代碼

假設咱們有一串任務,被分爲3組,每組任務數量爲3,線程池中只有3個線程來處理,那麼處理順序則以下所示:

第1步:

任務組1 被 線程1 處理,線程1 處理任務組1中的第一個任務; 任務組2 被 線程2 處理,線程2 處理任務組2中的第一個任務; 任務組3 被 線程3 處理,線程3 處理任務組3中的第一個任務;

第2步:

線程2處理的較快,任務組2中的全部任務都處理完了,由於沒有任務組是等待處理的狀態,因此線程2此時是空閒狀態。此時 線程1 處理的任務組1只處理了第1個任務,那麼有沒有辦法讓線程2把任務組1裏的第二個任務偷過來處理一下,減小等待時間呢?

在JDK7以後,提供了ForkJoinPool線程池就能夠實現啦~ 接着往下看吧

多線程實現-ForkJoinPool

咱們仍是用求和的例子來模擬偷任務。

package mutilthread;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;

/** * 用 ForkJoinPool 線程池計算數組的和 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 10:50 * @Modified By: */
public class MutilThreadOfForkJoinPool implements Calculator {

    private ForkJoinPool pool;

    public MutilThreadOfForkJoinPool() {
        // jdk8以後能夠用公用的 ForkJoinPool: pool = ForkJoinPool.commonPool()
        pool = new ForkJoinPool();
    }

    /** * 用 ForkJoinPool 線程池計算數組的和 * @param calcData 須要求和的數組 * @return * @author Rebecca 10:51 2019/6/18 * @version 1.0 */
    @Override
    public long sumUp(int[] calcData) {
        SumTask task = new SumTask(calcData, 0, calcData.length - 1);
        return pool.invoke(task);
    }


    public static class SumTask extends RecursiveTask<Long> {
        private int[] numbers;
        private int start;
        private int end;

        private SumTask(){}

        public SumTask(int[] numbers, int start, int end) {
            this.numbers = numbers;
            this.start = start;
            this.end = end;
        }

        @Override
        protected Long compute() {
            // 當須要計算的數字小於 10萬 時,直接計算結果
            if (end - start < 1000000) {
                long total = 0;

                // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
                List<SumTask> tasks = new ArrayList<SumTask>();
                for (int i = start; i <= end; i++) {
                    total += numbers[i];
                    // 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
                    tasks.add(new SumTask());
                }
                return total;
            } else {  // 不然,把任務一分爲二,遞歸計算
                int middle = (start + end) / 2;
                SumTask taskLeft = new SumTask(numbers, start, middle);
                SumTask taskRight = new SumTask(numbers, middle + 1, end);
                taskLeft.fork();
                taskRight.fork();
                return taskLeft.join() + taskRight.join();
            }
        }
    }
}
複製代碼

RecursiveTaskfork方法和Threadstart方法是相似的。這種「偷任務」的專業名詞叫工做竊取(work-stealing)算法,利用JDK7提供的ForkJoinPool就能夠實現啦。在JDK7以前,LinkedBlockingDeque用的也是 工做竊取算法

測試

下面是測試類代碼

package mutilThread;

import mutilthread.CalcData;
import mutilthread.MutilThreadOfForkJoinPool;
import mutilthread.MutilThreadOfThreadPoolExecutor;
import mutilthread.SingleThread;
import org.junit.Test;

/** * 線程測試類 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 10:40 * @Modified By: */
public class ThreadTest {

    @Test
    public void testThread() throws Exception {
        int[] data = CalcData.getCalcData();
        // 單線程測試
        SingleThread singleThread = new SingleThread();
        long startTime = System.currentTimeMillis();
        System.out.println("數組的和: " + singleThread.sumUp(data));
        System.out.println("單線程耗時: " + (System.currentTimeMillis() - startTime) + " ms");

        // 多線程(ThreadPoolExecutor)測試
        MutilThreadOfThreadPoolExecutor threadPool = new MutilThreadOfThreadPoolExecutor();
        startTime = System.currentTimeMillis();
        System.out.println("數組的和: " + threadPool.sumUp(data));
        System.out.println("多線程(ThreadPoolExecutor)耗時: " + (System.currentTimeMillis() - startTime) + " ms");

        // 多線程(ForkJoinPool)測試
        MutilThreadOfForkJoinPool forkJoinPool = new MutilThreadOfForkJoinPool();
        startTime = System.currentTimeMillis();
        System.out.println("數組的和: " + forkJoinPool.sumUp(data));
        System.out.println("多線程(ForkJoinPool)耗時: " + (System.currentTimeMillis() - startTime) + " ms");
    }
}
複製代碼

程序運行結果:

數組的和: 499913683383
單線程耗時: 3307 ms
數組的和: 499913683383
多線程(ThreadPoolExecutor)耗時: 197 ms
數組的和: 499913683383
多線程(ForkJoinPool)耗時: 169 ms
複製代碼

整理成表格以下:

線程類型 耗時(ms)
單線程 3307
多線程(ThreadPoolExecutor) 197
多線程(ForkJoinPool) 169

總結

  1. 通常咱們使用多線程時會用ExecuterService,構造用new ThreadPoolExecutor(),通常不使用Executors提供了構造線線程池方法,避免出現OOM;
  2. 線程池相對於線程組(本文沒提到)更好管理;
  3. 在JDK7以後能夠用ForkJoinPool,相對於ExecuterService執行效率更快。
  4. 線程之間通訊是須要成本的。
    若是你細心的話,會發現上面的示例代碼中都有這麼兩行多餘的代碼:
// 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
List<SumTask> tasks = new ArrayList<SumTask>();

// 此句代碼只是爲了延長程序運行時間,和程序邏輯無關
tasks.add(new SumTask());
複製代碼

若是不加建立對象的多餘代碼,只是單純累加數組的和,你會發現單線程執行效率更高。因此在實際使用中仍是要根據實際業務邏輯對比,選取適合的方式。若是業務邏輯很簡單,程序處理跟快,就徹底沒有必要使用多線程了。
ForkJoinPool中,設置的數組大小是10萬,之因此設置這個數字,是爲了跟 ExecutorService 方式作對比,若是在ForkJoinPool中設置的數組長度太小,就會出現性能不如 ExecutorService 的狀況。


程序中用到的生成計算數據的類

package mutilthread;

import java.util.Random;

/** * 生成計算數據的類 * @Author: Rebecca * @Description: * @Date: Created in 2019/6/18 10:25 * @Modified By: */
public class CalcData {
    // 長度爲1000萬
    private static int calcDataLength = 10000000;

    public static int[] getCalcData() {
        Random random = new Random();
        int[] calcData = new int[calcDataLength];
        for (int i = 0; i < calcDataLength; i++) {
            // 0~10的隨機數 生成[m,n]範圍內指定的隨機數: rand.nextInt(n -m + 1) +m;
            calcData[i] = random.nextInt(100001);
        }
        return calcData;
    }
}
複製代碼

參考連接

Java 併發編程筆記:如何使用 ForkJoinPool 以及原理

Java併發 之 線程池系列 (2) 使用ThreadPoolExecutor構造線程池

相關文章
相關標籤/搜索