本文主要來展現一下簡版的work stealing線程池的實現。dom
Executors默認提供了幾個工廠方法ide
/** * Creates a thread pool that maintains enough threads to support * the given parallelism level, and may use multiple queues to * reduce contention. The parallelism level corresponds to the * maximum number of threads actively engaged in, or available to * engage in, task processing. The actual number of threads may * grow and shrink dynamically. A work-stealing pool makes no * guarantees about the order in which submitted tasks are * executed. * * @param parallelism the targeted parallelism level * @return the newly created thread pool * @throws IllegalArgumentException if {@code parallelism <= 0} * @since 1.8 */ public static ExecutorService newWorkStealingPool(int parallelism) { return new ForkJoinPool (parallelism, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true); } /** * Creates a work-stealing thread pool using all * {@link Runtime#availableProcessors available processors} * as its target parallelism level. * @return the newly created thread pool * @see #newWorkStealingPool(int) * @since 1.8 */ public static ExecutorService newWorkStealingPool() { return new ForkJoinPool (Runtime.getRuntime().availableProcessors(), ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true); }
ForkJoinPool主要用到的是雙端隊列,不過這裏咱們粗糙的實現的話,也能夠不用到deque。測試
public class WorkStealingChannel<T> { private static final Logger LOGGER = LoggerFactory.getLogger(WorkStealingChannel.class); BlockingDeque<T>[] managedQueues; AtomicLongMap<Integer> stat = AtomicLongMap.create(); public WorkStealingChannel() { int nCPU = Runtime.getRuntime().availableProcessors(); int queueCount = nCPU / 2 + 1; managedQueues = new LinkedBlockingDeque[queueCount]; for(int i=0;i<queueCount;i++){ managedQueues[i] = new LinkedBlockingDeque<T>(); } } public void put(T item) throws InterruptedException { int targetIndex = Math.abs(item.hashCode() % managedQueues.length); BlockingQueue<T> targetQueue = managedQueues[targetIndex]; targetQueue.put(item); } public T take() throws InterruptedException { int rdnIdx = ThreadLocalRandom.current().nextInt(managedQueues.length); int idx = rdnIdx; while (true){ idx = idx % managedQueues.length; T item = null; if(idx == rdnIdx){ item = managedQueues[idx].poll(); }else{ item = managedQueues[idx].pollLast(); } if(item != null){ LOGGER.info("take ele from queue {}",idx); stat.addAndGet(idx,1); return item; } idx++; if(idx == rdnIdx){ break; } } //走完一輪沒有,則隨機取一個等待 LOGGER.info("wait for queue:{}",rdnIdx); stat.addAndGet(rdnIdx,1); return managedQueues[rdnIdx].take(); } public AtomicLongMap<Integer> getStat() { return stat; } }
這裏根據cpu的數量創建了幾個deque,而後每次put的時候,根據hashcode取模放到對應的隊列。而後獲取的時候,先從隨機一個隊列取,沒有的話,再robbin round取其餘隊列的,尚未的話,則阻塞等待指定隊列的元素。線程
測試實例code
public class WorkStealingDemo { static final WorkStealingChannel<String> channel = new WorkStealingChannel<>(); static volatile boolean running = true; static class Producer extends Thread{ @Override public void run() { while(running){ try { channel.put(UUID.randomUUID().toString()); } catch (InterruptedException e) { e.printStackTrace(); } } } } static class Consumer extends Thread{ @Override public void run() { while(running){ try { String value = channel.take(); System.out.println(value); } catch (InterruptedException e) { e.printStackTrace(); } } } } public static void stop(){ running = false; System.out.println(channel.getStat()); } public static void main(String[] args) throws InterruptedException { int nCPU = Runtime.getRuntime().availableProcessors(); int consumerCount = nCPU / 2 + 1; for (int i = 0; i < nCPU; i++) { new Producer().start(); } for (int i = 0; i < consumerCount; i++) { new Consumer().start(); } Thread.sleep(30*1000); stop(); } }
輸出隊列
{0=660972, 1=660613, 2=661537, 3=659846, 4=659918}
從數據來看,仍是相對均勻的。ip