在多線程編程過程當中,遇到這樣的狀況,主線程須要等待多個子線程的處理結果,才能繼續運行下去。我的給這樣的子線程任務取了個名字叫並行任務。對於這種任務,每次去編寫代碼加鎖控制時序,以爲太麻煩,正好朋友提到CountDownLatch這個類,因而用它來編寫了個小工具。java
首先,要處理的是多個任務,因而定義了一個接口git
package com.zyj.thread; import com.zyj.exception.ChildThreadException; /** * 多任務處理 * @author zengyuanjun */ public interface MultiThreadHandler { /** * 添加任務 * @param tasks */ void addTask(Runnable... tasks); /** * 執行任務 * @throws ChildThreadException */ void run() throws ChildThreadException; }
要處理的是並行任務,須要用到CountDownLatch來統計全部子線程執行結束,還要一個集合記錄全部任務,另外加上我自定義的ChildThreadException類來記錄子線程中的異常,通知主線程是否全部子線程都執行成功,便獲得了下面這個抽象類AbstractMultiParallelThreadHandler。在這個類中,我順便完成了addTask這個方法。github
package com.zyj.thread.parallel; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import com.zyj.exception.ChildThreadException; import com.zyj.thread.MultiThreadHandler; /** * 並行線程處理 * @author zengyuanjun */ public abstract class AbstractMultiParallelThreadHandler implements MultiThreadHandler { /** * 子線程倒計數鎖 */ protected CountDownLatch childLatch; /** * 任務列表 */ protected List<Runnable> taskList; /** * 子線程異常 */ protected ChildThreadException childThreadException; public AbstractMultiParallelThreadHandler() { taskList = new ArrayList<Runnable>(); childThreadException = new ChildThreadException(); } public void setCountDownLatch(CountDownLatch latch) { this.childLatch = latch; } /** * {@inheritDoc} */ @Override public void addTask(Runnable... tasks) { if (null == tasks) { taskList = new ArrayList<Runnable>(); } for (Runnable task : tasks) { taskList.add(task); } } /** * {@inheritDoc} */ @Override public abstract void run() throws ChildThreadException; }
具體的實現,則是下面這個類。實現原理也很簡單,主線程根據並行任務數建立一個CountDownLatch,傳到子線程中,並運行全部子線程,而後await等待。子線程執行結束後調用CountDownLatch的countDown()方法,當全部子線程執行結束後,CountDownLatch計數清零,主線程被喚醒繼續執行。編程
package com.zyj.thread.parallel; import java.util.concurrent.CountDownLatch; import com.zyj.exception.ChildThreadException; /** * 並行任務處理工具 * * @author zengyuanjun * */ public class MultiParallelThreadHandler extends AbstractMultiParallelThreadHandler { /** * 無參構造器 */ public MultiParallelThreadHandler() { super(); } /** * 根據任務數量運行任務 */ @Override public void run() throws ChildThreadException { if (null == taskList || taskList.size() == 0) { return; } else if (taskList.size() == 1) { runWithoutNewThread(); } else if (taskList.size() > 1) { runInNewThread(); } } /** * 新建線程運行任務 * * @throws ChildThreadException */ private void runInNewThread() throws ChildThreadException { childLatch = new CountDownLatch(taskList.size()); childThreadException.clearExceptionList(); for (Runnable task : taskList) { invoke(new MultiParallelRunnable(new MultiParallelContext(task, childLatch, childThreadException))); } taskList.clear(); try { childLatch.await(); } catch (InterruptedException e) { childThreadException.addException(e); } throwChildExceptionIfRequired(); } /** * 默認線程執行方法 * * @param command */ protected void invoke(Runnable command) { if(command.getClass().isAssignableFrom(Thread.class)){ Thread.class.cast(command).start(); }else{ new Thread(command).start(); } } /** * 在當前線程中直接運行 * * @throws ChildThreadException */ private void runWithoutNewThread() throws ChildThreadException { try { taskList.get(0).run(); } catch (Exception e) { childThreadException.addException(e); } throwChildExceptionIfRequired(); } /** * 根據須要拋出子線程異常 * * @throws ChildThreadException */ private void throwChildExceptionIfRequired() throws ChildThreadException { if (childThreadException.hasException()) { childExceptionHandler(childThreadException); } } /** * 默認拋出子線程異常 * @param e * @throws ChildThreadException */ protected void childExceptionHandler(ChildThreadException e) throws ChildThreadException { throw e; } }
並行任務是要運行的子線程,只要實現Runnable接口就行,並無CountDownLatch對象,因此我用MultiParallelRunnable類對它封裝一次,MultiParallelRunnable類裏有個屬性叫 MultiParallelContext,MultiParallelContext裏面就是保存的子線程task、倒計數鎖CountDownLatch和ChildThreadException這些參數。MultiParallelRunnable類完成運行子線程、記錄子線程異常和倒計數鎖減一。多線程
package com.zyj.thread.parallel; /** * 並行線程對象 * * @author zengyuanjun * */ public class MultiParallelRunnable implements Runnable { /** * 並行任務參數 */ private MultiParallelContext context; /** * 構造函數 * @param context */ public MultiParallelRunnable(MultiParallelContext context) { this.context = context; } /** * 運行任務 */ @Override public void run() { try { context.getTask().run(); } catch (Exception e) { e.printStackTrace(); context.getChildException().addException(e); } finally { context.getChildLatch().countDown(); } } }
package com.zyj.thread.parallel; import java.util.concurrent.CountDownLatch; import com.zyj.exception.ChildThreadException; /** * 並行任務參數 * @author zengyuanjun * */ public class MultiParallelContext { /** * 運行的任務 */ private Runnable task; /** * 子線程倒計數鎖 */ private CountDownLatch childLatch; /** * 子線程異常 */ private ChildThreadException childException; public MultiParallelContext() { } public MultiParallelContext(Runnable task, CountDownLatch childLatch, ChildThreadException childException) { this.task = task; this.childLatch = childLatch; this.childException = childException; } public Runnable getTask() { return task; } public void setTask(Runnable task) { this.task = task; } public CountDownLatch getChildLatch() { return childLatch; } public void setChildLatch(CountDownLatch childLatch) { this.childLatch = childLatch; } public ChildThreadException getChildException() { return childException; } public void setChildException(ChildThreadException childException) { this.childException = childException; } }
這裏提一下ChildThreadException這個自定義異常,跟普通異常不同,我在裏面加了個List<Exception> exceptionList,用來保存子線程的異常。由於有多個子線程,拋出的異常可能有多個。app
package com.zyj.exception; import java.io.PrintStream; import java.util.ArrayList; import java.util.List; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import com.zyj.exception.util.ExceptionMessageFormat; import com.zyj.exception.util.factory.ExceptionMsgFormatFactory; /** * 子線程異常,子線程出現異常時拋出 * @author zengyuanjun */ public class ChildThreadException extends Exception { /** * serialVersionUID */ private static final long serialVersionUID = 5682825039992529875L; /** * 子線程的異常列表 */ private List<Exception> exceptionList; /** * 異常信息格式化工具 */ private ExceptionMessageFormat formatter; /** * 鎖 */ private Lock lock; public ChildThreadException() { super(); initial(); } public ChildThreadException(String message) { super(message); initial(); } public ChildThreadException(String message, StackTraceElement[] stackTrace) { this(message); setStackTrace(stackTrace); } private void initial() { exceptionList = new ArrayList<Exception>(); lock = new ReentrantLock(); formatter = ExceptionMsgFormatFactory.getInstance().getFormatter(ExceptionMsgFormatFactory.STACK_TRACE); } /** * 子線程是否有異常 * @return */ public boolean hasException() { return exceptionList.size() > 0; } /** * 添加子線程的異常 * @param e */ public void addException(Exception e) { try { lock.lock(); e.setStackTrace(e.getStackTrace()); exceptionList.add(e); } finally { lock.unlock(); } } /** * 獲取子線程的異常列表 * @return */ public List<Exception> getExceptionList() { return exceptionList; } /** * 清空子線程的異常列表 */ public void clearExceptionList() { exceptionList.clear(); } /** * 獲取全部子線程異常的堆棧跟蹤信息 * @return */ public String getAllStackTraceMessage() { StringBuffer sb = new StringBuffer(); for (Exception e : exceptionList) { sb.append(e.getClass().getName()); sb.append(": "); sb.append(e.getMessage()); sb.append("\n"); sb.append(formatter.formate(e)); } return sb.toString(); } /** * 打印全部子線程的異常的堆棧跟蹤信息 */ public void printAllStackTrace() { printAllStackTrace(System.err); } /** * 打印全部子線程的異常的堆棧跟蹤信息 * @param s */ public void printAllStackTrace(PrintStream s) { for (Exception e : exceptionList) { e.printStackTrace(s); } } }
·有沒有問題試一下才知道,寫了個類來測試:TestCase 爲並行任務子線程,resultMap爲並行任務共同完成的結果集。假設resultMap由5部分組成,main方法中啓動5個子線程分別完成一個部分,等5個子線程處理完後,main方法將結果resultMap打印出來。ide
package com.zyj.thread.test; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import com.zyj.exception.ChildThreadException; import com.zyj.thread.MultiThreadHandler; import com.zyj.thread.parallel.MultiParallelThreadHandler; import com.zyj.thread.parallel.ParallelTaskWithThreadPool; public class TestCase implements Runnable { private String name; private Map<String, Object> result; public TestCase(String name, Map<String, Object> result) { this.name = name; this.result = result; } @Override public void run() { // 模擬線程執行1000ms try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } // 模擬線程1和線程3拋出異常 // if(name.equals("1") || name.equals("3")) // throw new RuntimeException(name + ": throw exception"); result.put(name, "complete part " + name + "!"); } public static void main(String[] args) { System.out.println("main begin \t================="); Map<String, Object> resultMap = new HashMap<String, Object>(8, 1); MultiThreadHandler handler = new MultiParallelThreadHandler(); // ExecutorService service = Executors.newFixedThreadPool(3); // MultiThreadHandler handler = new ParallelTaskWithThreadPool(service); TestCase task = null; // 啓動5個子線程做爲要處理的並行任務,共同完成結果集resultMap for(int i=1; i<=5 ; i++){ task = new TestCase("" + i, resultMap); handler.addTask(task); } try { handler.run(); } catch (ChildThreadException e) { System.out.println(e.getAllStackTraceMessage()); } System.out.println(resultMap); // service.shutdown(); System.out.println("main end \t================="); } }
運行main方法,測試結果以下函數
main begin ================= {3=complete part 3!, 2=complete part 2!, 1=complete part 1!, 5=complete part 5!, 4=complete part 4!} main end =================
將模擬線程1和線程3拋出異常的註釋打開,測試結果以下工具
紅色的打印是子線程中捕獲異常打印的堆棧跟蹤信息,黑色的異常信息是主線程main方法中打印的,這說明主線程可以監視到子線程的出錯,以便採起對應的處理。因爲線程1和線程3出現了異常,未能完成任務,因此打印的resultMap只有第二、四、5三個部分完成。測試
爲了便於擴展,我把MultiParallelThreadHandler類中的invoke方法和childExceptionHandler方法定義爲protected類型。invoke方法中是具體的線程執行,childExceptionHandler方法是子線程拋出異常後的處理,能夠去繼承,重寫爲本身想要的,好比我想用線程池去運行子線程,就能夠去繼承並重寫invoke方法,獲得下面的這個類
package com.zyj.thread.parallel; import java.util.concurrent.ExecutorService; /** * 使用線程池運行並行任務 * @author zengyuanjun * */ public class ParallelTaskWithThreadPool extends MultiParallelThreadHandler { private ExecutorService service; public ParallelTaskWithThreadPool() { } public ParallelTaskWithThreadPool(ExecutorService service) { this.service = service; } public ExecutorService getService() { return service; } public void setService(ExecutorService service) { this.service = service; } /** * 使用線程池運行 */ @Override protected void invoke(Runnable command) { if(null != service){ service.execute(command); }else{ super.invoke(command); } } }
測試就在上面的測試類中,只不過被註釋掉了,測試結果是同樣的,就很少說了。