理解與使用Treiber Stack

背景

最近在不少JDK源碼中都看到了Treiber stack這個單詞。html

  • 好比CompletableFuture中的:
volatile Completion stack;    // Top of Treiber stack of dependent actions
  • 好比FutureTask中的:
/** Treiber stack of waiting threads */
private volatile WaitNode waiters;
  • 好比Phaser中的:
/**
 * Wait nodes for Treiber stack representing wait queue
 */
static final class QNode implements ForkJoinPool.ManagedBlocker {
    final Phaser phaser;
    final int phase;
    final boolean interruptible;
    final boolean timed;
    boolean wasInterrupted;
    long nanos;
    final long deadline;
    volatile Thread thread; // nulled to cancel wait
    QNode next;
  • 還好比ForkJoinPool中的描述:
* Bits and masks for field ctl, packed with 4 16 bit subfields:
 * AC: Number of active running workers minus target parallelism
 * TC: Number of total workers minus target parallelism
 * SS: version count and status of top waiting thread
 * ID: poolIndex of top of Treiber stack of waiters

感受這種名詞出現的頻率有點高,須要深刻了解一下。java

名稱由來

Treiber Stack在 R. Kent Treiber在1986年的論文Systems Programming: Coping with Parallelism中首次出現。它是一種無鎖併發棧,其無鎖的特性是基於CAS原子操做實現的。node

CompletableFuture源碼實現

CompletableFuture的Treiber stack實現感受有點複雜,由於有其餘邏輯摻雜,代碼不容易閱讀,其實抽象來看,Treiber stack首先是個單向鏈表,鏈表頭部即棧頂元素,在入棧和出現過程當中,須要對棧頂元素進行CAS控制,防止多線程狀況下數據錯亂。多線程

// Either the result or boxed AltResult
volatile Object result;
// Top of Treiber stack of dependent actions(Treiber stack棧頂元素)
volatile Completion stack;

/** Returns true if successfully pushed c onto stack. */
final boolean tryPushStack(Completion c) {
    Completion h = stack;
    lazySetNext(c, h);
    return UNSAFE.compareAndSwapObject(this, STACK, h, c);
}

/** Unconditionally pushes c onto stack, retrying if necessary. */
final void pushStack(Completion c) {
    do {} while (!tryPushStack(c));
}

簡單來看,入棧的步驟以下:併發

  • 嘗試入棧,利用CAS將新的節點做爲棧頂元素,新節點next賦值爲舊棧頂元素
  • 嘗試入棧成功,即結束;入棧失敗,繼續重試上面的操做

FutureTask實現

FutureTask用了Treiber Stack來維護等待任務完成的線程,在FutureTask的任務完成/取消/異常後在finishCompletion鉤子方法中會喚醒棧中等待的線程。dom

Treiber Stack抽象實現

入棧

void push(Node new) {
  do {
  } while(!tryPush(new)) // 嘗試入棧
}

boolean tryPush(node) {
    Node oldHead = head;
    node.next = oldHead; // 新節點next賦值爲舊棧頂元素
    return CAS(oldHead, node); // 利用CAS將新的節點做爲棧頂元素
}

出棧

對於出棧,要作的工做就是將原來的棧頂節點移除,等待垃圾回收;新棧頂元素CAS爲第一個子元素。僞代碼:ide

E pop() {
    Node<E> oldHead;
    Node<E> newHead;
    do {
        oldHead = top.get();
        // 判斷棧是否爲空,爲空直接返回
        if (oldHead == null)
            return null;
        newHead = oldHead.next;
    } while (!CAS(oldHead, newHead));
    // 舊的節點刪掉next引用,等待gc
    oldHead.item = null;
    return oldHead.item;
}

示例

import sun.misc.Unsafe;

import java.lang.reflect.Field;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * 基於Unsafe實現TreiberStack
 * @author Charles
 */
public class TreiberStack<E> {
    private volatile Node<E> head;

    public void push(E item) {
        Objects.requireNonNull(item);
        Node<E> newHead = new Node<>(item);
        Node<E> oldHead;
        int count = 0;
        do {
            oldHead = head;
            count++;
        } while (!tryPush(oldHead, newHead, count));
        newHead.next = oldHead;
    }

    private boolean tryPush(Node<E> oldHead, Node<E> newHead, int count) {
        boolean isSuccess = UNSAFE.compareAndSwapObject(this, HEAD, oldHead, newHead);
        System.out.println(currentThreadName() + " try push [" + count + "]," +
                " oldHead = " + getValue(oldHead) +
                " newHead = " + getValue(newHead) +
                " isSuccess = " + isSuccess);
        return isSuccess;
    }

    public E pop() {
        Node<E> oldHead;
        Node<E> newHead;
        do {
            oldHead = head;
            System.out.println(currentThreadName() + " do pop:" +
                    " oldHead = " + getValue(oldHead) +
                    " newHead = " + Optional.ofNullable(head).map(s -> s.next.item).orElse(null));
            if (oldHead == null) {
                return null;
            }
            newHead = oldHead.next;
        } while (!tryPop(oldHead, newHead));
        oldHead.next = null;
        return oldHead.item;
    }

    private boolean tryPop(Node<E> oldHead, Node<E> newHead) {
        boolean isSuccess = UNSAFE.compareAndSwapObject(this, HEAD, oldHead, newHead);
        System.out.println(currentThreadName() + " try pop:" +
                " oldHead = " + getValue(oldHead) +
                " currentHead = " + getValue(head) +
                " newHead = " + getValue(newHead) +
                " isSuccess: " + isSuccess);
        return isSuccess;
    }

    private E getValue(Node<E> n) {
        return Optional.ofNullable(n).map(t -> t.item).orElse(null);
    }

    private static class Node<E> {
        E item;
        Node<E> next;

        Node(E item) {
            this.item = item;
        }
    }

    // Unsafe mechanics
    private static final sun.misc.Unsafe UNSAFE;
    private static final long HEAD;
    private static final long NEXT;

    static {
        try {
            Field getUnsafe = sun.misc.Unsafe.class.getDeclaredField("theUnsafe");
            getUnsafe.setAccessible(true);
            UNSAFE = (Unsafe) getUnsafe.get(null);

            Class<?> k = TreiberStack.class;
            HEAD = UNSAFE.objectFieldOffset(k.getDeclaredField("head"));
            NEXT = UNSAFE.objectFieldOffset(TreiberStack.Node.class.getDeclaredField("next"));
        } catch (Exception x) {
            throw new Error(x);
        }
    }

    private static class RandomValue {
        private final Integer value;

        public RandomValue() {
            this.value = new Random().nextInt(Integer.MAX_VALUE);
        }

        public Integer getValue() {
            return value;
        }

        @Override
        public String toString() {
            return value.toString();
        }
    }

    private static String currentThreadName() {
        return System.nanoTime() + " / " + Thread.currentThread().getName();
    }

    public static void main(String[] args) throws InterruptedException {
        TreiberStack<RandomValue> ts = new TreiberStack<>();
        ExecutorService es = Executors.newFixedThreadPool(10);
        Thread.sleep(2000);
        for (int i = 0; i < 5; i++) {
            es.submit(() -> ts.push(new RandomValue()));
        }
        for (int i = 0; i < 50; i++) {
            es.submit((Runnable) ts::pop);
        }
    }
}

參考

Wiki Treiber Stack
Treiber Stack介紹
Treiber stack設計ui

相關文章
相關標籤/搜索