遵照 TDD 實現一個精簡版的 HashMap

前言

上一篇文章筆者解讀了 HashMap 的源碼,正好趁熱打鐵,今天筆者抽了些時間經過 TDD 實現了一個精簡版的 HashMap,經筆者測試,正常狀況下效率略微遜於 HashMap。java

預設計

public class SimpleHashMap<K, V> {
    public V put(K key, V value);   
    public V get(K key);   
    public V remove(K key); 
    public boolean containsKey(K key); 
    public int size();
    public Iterator<V> values();
    public void forEach(Consumer<? super K> action);
}
複製代碼

Tasking

  • 無參構建 SimpleHashMap
  • 構造函數初始化 initial capacity
  • 構造函數初始化 initial capacity 和 load factor
  • initial capacity 默認使用 16
  • load factor 默認使用 0.75f
  • 初始化的 resize 門檻爲 initial capacity * load factor
  • 再次 resize 門檻爲 threshold = threshold << 1
  • 增長 put 接口
    • 計算 hash 值
    • 增長 hash table 用於保存數據節點.
    • 若是 hash table 的容量爲 0 或者 hash table 的容量超過門檻,則設置新的 resize 門檻,並擴容和 rehash。
    • hash table 的下標爲 hash & (capacity -1)
    • 擴容時須要把舊的 hash table 的數據轉移到新的 hash table
    • 轉移數據到新的 hash table 以前須要 rehash,rehash = entry.hash & (new_capacity -1)
    • 若是 hash 衝突,使用鏈表存儲
    • 若是同一個 hash 衝突超過 8 次,使用紅黑樹存儲
  • 增長 size 接口
    • 增長全局的 size 成員變量.
    • put 接口調用成功,則 size += 1.
    • remove 接口調用成功,則 size -= 1.
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 containsKey 接口
    • 經過 key 計算 hash
    • 經過 hash 計算 index
    • 經過 index 檢索 key,檢索到return true,不然 return false,
    • 考慮 hash table 爲 null.
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 get 接口
    • 經過 key 計算 hash
    • 經過 hash 計算 index
    • 經過 index 檢索 bucket
    • 若是 bucket 存在多個數據節點,則須要判斷 key 的值和引用是否相等.
    • 若是相等返回對應的 value,不然返回 null.
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 remove 接口
    • 經過 key 計算 hash
    • 經過 hash 計算 index
    • 經過 index 檢索 bucket
    • 若是相等則將對應的 bucket 置 null,並返回對應的 value,不然返回 null,
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 values 接口
    • 每次 put 成功時保存 list 中到
    • 每次 put 替換成功時,須要替換 list 中對應的 value
    • 每次 remove 成功時從 list 中到刪除
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 forEach 接口
    • 遍歷 hash table
    • 若是存在 bucket,則經過 action.apply(key)
    • 考慮鏈表
    • 考慮紅黑樹
  • 增長 fail-fast
    • 增長 modCount 成員變量用於統計變動次數
    • 迭代先後須要驗證 modCount 先後是否一致
    • 若是 modCount 先後是否一致須要拋出 ConcurrentModificationException.
  • 增長 rb tree 保存 hash 衝突超過 8 次的數據節點.

測試覆蓋率

測試代碼

**
 * @author lyning
 */
public class SimpleHashMapTest {

    private SimpleHashMap<Integer, Integer> map;

    @BeforeEach
    public void setUp() throws Exception {
        // given
        this.map = new SimpleHashMap<>();
    }

    /************ size test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call size() " +
            "then return 0")
    public void size1() {
        // when
        int size = map.size();
        // then
        assertThat(size).isZero();
    }

    @Test
    @DisplayName("given multiple entries(contains duplicate key) " +
            "when call size() " +
            "then return correct size")
    public void size2() {
        // given
        SimpleHashMap<Integer, Integer> map = new SimpleHashMap<>();
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(3, 5);
        map.put(4, 4);
        map.put(5, 5);
        map.remove(1);
        map.remove(2);
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict) " +
            "when call size() " +
            "then return correct size")
    public void size3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        map.remove(new HashConflict(5));
        map.remove(new HashConflict(3));
        // when
        int size = map.size();
        // then
        assertThat(size).isEqualTo(3);
    }
    /************ size test end **********/


    /************ put test start **********/
    @Test
    @DisplayName("given empty entries " +
            "when put one entry " +
            "then return size 1")
    public void put1() {
        // when
        map.put(1, 1);
        // then
        assertThat(map.size()).isOne();
    }

    @Test
    @DisplayName("given empty entries " +
            "when put two entries(duplicate key) " +
            "then return size 1")
    public void put2() {
        // when
        map.put(1, 1);
        map.put(1, 2);
        // then
        assertThat(map.size()).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put three entries " +
            "then return size 3")
    public void put3() {
        // when
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        // then
        assertThat(map.size()).isEqualTo(3);
    }

    @Test
    @DisplayName("should return value " +
            "when call put")
    public void put4() {
        // when
        Integer value = map.put(1, 1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given empty entries " +
            "when put multiples entries(hash conflict) " +
            "then")
    public void put5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        // when
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 5, 4, 5));
    }

    @Test
    @DisplayName("should auto grow " +
            "when capacity exceed threshold")
    public void put6() {
        // given default threshold = 8
        // when
        for (int i = 1; i <= 20; i++) {
            map.put(i, i);
        }
        // then
        assertThat(map.size()).isEqualTo(20);
        assertThat(map.get(20)).isEqualTo(20);
    }
    /************ put test end **********/

    /************ get test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when get by null key" +
            "then return null")
    public void get1() {
        // when
        Integer value = map.get(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given empty entries" +
            "when get value by not exist key" +
            "then return null")
    public void get2() {
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value by not exist key" +
            "then return null")
    public void get3() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when get value" +
            "then return value")
    public void get4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.get(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by hash conflict key" +
            "then return value")
    public void get5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(3), 4);
        map.put(new HashConflict(3), 5);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(5);
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when get value by not exist hash conflict key" +
            "then return null")
    public void get6() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.get(new HashConflict(6));
        // then
        assertThat(value).isNull();
    }
    /************ get test end **********/


    /************ remove test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when remove by null key" +
            "then return null")
    public void remove1() {
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by null key" +
            "then return null")
    public void remove2() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(null);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given entry" +
            "when remove by key" +
            "then return value")
    public void remove3() {
        // given
        map.put(1, 1);
        // when
        int value = map.remove(1);
        // then
        assertThat(value).isEqualTo(1);
    }

    @Test
    @DisplayName("given entry" +
            "when remove by not exist key" +
            "then return null")
    public void remove4() {
        // given
        map.put(1, 1);
        // when
        Integer value = map.remove(2);
        // then
        assertThat(value).isNull();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when remove by hash conflict key" +
            "then return value")
    public void remove5() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        Integer value = map.remove(new HashConflict(3));
        // then
        assertThat(value).isEqualTo(3);
        assertThat(Lists.newArrayList(map.values())).isEqualTo(Lists.list(1, 2, 4, 5));
    }
    /************ remove test end **********/


    /************ values test start **********/
    @Test
    @DisplayName("given empty entries" +
            "when call values" +
            "then return empty values")
    public void values1() {
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values).isEmpty();
    }

    @Test
    @DisplayName("given multiple entries" +
            "when call values" +
            "then return all values")
    public void values2() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(3, 4);
        map.put(4, 4);
        map.remove(4);
        // when
        Iterable<Integer> values = map.values();
        // then
        assertThat(values.spliterator().estimateSize()).isEqualTo(3);
        assertThat(Lists.newArrayList(values)).isEqualTo(Lists.list(1, 2, 4));
    }
    /************ values test end **********/


    /************ containsKey test start **********/
    @Test
    @DisplayName("given entry" +
            "when key exist" +
            "then return true")
    public void contains_key1() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(1);
        // then
        assertThat(result).isTrue();
    }

    @Test
    @DisplayName("given entry" +
            "when key not exist" +
            "then return false")
    public void containsKey2() {
        // given
        map.put(1, 1);
        // when
        boolean result = map.containsKey(2);
        // then
        assertThat(result).isFalse();
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call containsKey" +
            "then return correct result")
    public void containsKey3() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // then
        assertThat(map.containsKey(new HashConflict(3))).isTrue();
        assertThat(map.containsKey(new HashConflict(5))).isTrue();
        assertThat(map.containsKey(new HashConflict(6))).isFalse();
    }
    /************ containsKey test end **********/


    /************ forEach test start **********/
    @Test
    @DisplayName("given multiple entries" +
            "when call forEach" +
            "then pass")
    public void forEach1() {
        // given
        map.put(1, 1);
        map.put(2, 2);
        map.put(3, 3);
        map.put(4, 4);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4));
    }

    @Test
    @DisplayName("given multiple entries(hash conflict)" +
            "when call forEach" +
            "then pass")
    public void forEach2() {
        // given
        SimpleHashMap<HashConflict, Integer> map = new SimpleHashMap<>();
        map.put(new HashConflict(1), 1);
        map.put(new HashConflict(2), 2);
        map.put(new HashConflict(3), 3);
        map.put(new HashConflict(4), 4);
        map.put(new HashConflict(5), 5);
        // when
        List<Integer> results = new ArrayList<>();
        map.forEach((key) -> results.add(map.get(key)));
        // then
        assertThat(results).isEqualTo(Lists.list(1, 2, 3, 4, 5));
    }

    /************ forEach test end **********/

    class HashConflict {
        private int field;

        HashConflict(int field) {
            this.field = field;
        }

        @Override
        public int hashCode() {
            return this.field <= 8 ? 1 : this.field;
        }

        @Override
        public boolean equals(Object obj) {
            return ((HashConflict) obj).field == this.field;
        }
    }
}
複製代碼

SimpleHashMap 源碼

/** * @author lyning */
public class SimpleHashMap<K, V> {
    private static final int DEFAULT_INITIAL_CAPACITY = 16;
    private static final float DEFAULT_LOAD_FACTOR = 0.75f;
    private int size;
    private Bucket<K, V>[] table;
    private int threshold;

    public boolean containsKey(K key) {
        int hash = this.hash(key);
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];
        return bucket != null
                && bucket.lookup(key) != null;
    }

    public void forEach(Consumer<K> action) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                action.accept(bucket.key);
                bucket = bucket.next;
            }
        }
    }

    public V get(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.getVal(index, key);
    }

    public V put(K key, V value) {
        if (this.tableEmpty() || this.nearByThreshold()) {
            this.resize();
        }
        int hash = this.hash(key);
        return this.putVal(key, value, hash);
    }

    public V remove(K key) {
        if (this.tableEmpty()) {
            return null;
        }
        int hash = this.hash(key);
        int index = this.index(hash);
        return this.removeVal(index, key);
    }

    public int size() {
        return this.size;
    }

    public Iterable<V> values() {
        if (this.tableEmpty()) {
            return new ArrayList<>();
        }
        List<V> collections = new ArrayList<>();
        this.collectValues(collections);
        return collections;
    }

    private void collectValues(List<V> collections) {
        for (Bucket<K, V> bucket : this.table) {
            while (bucket != null) {
                collections.add(bucket.value);
                bucket = bucket.next;
            }
        }
    }

    private Bucket<K, V> findBucket(int index) {
        return this.table[index];
    }

    private V getVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        if (Objects.isNull(bucket) || Objects.isNull(bucket = bucket.lookup(key))) {
            return null;
        }
        return bucket.value;
    }

    private void grow(int newCap) {
        if (this.tableEmpty()) {
            this.initTable(newCap);
            return;
        }
        this.table = this.rebuildTable(newCap);
    }

    private int hash(K key) {
        int hashcode;
        return key == null
                ? 0
                : (hashcode = key.hashCode()) ^ (hashcode >>> 16);
    }

    private int index(int hash) {
        return hash & (this.table.length - 1);
    }

    private void initTable(int newCap) {
        this.table = new Bucket[newCap];
    }

    private boolean nearByThreshold() {
        return this.size + 1 >= this.threshold;
    }

    private V putVal(K key, V value, int hash) {
        int index = this.index(hash);
        Bucket<K, V> bucket = this.table[index];

        if (Objects.isNull(bucket)) {
            this.table[index] = new Bucket<>(hash, key, value);
        } else {
            Bucket<K, V> indexBucket = bucket.lookup(key);
            if (indexBucket != null) {
                indexBucket.value = value;
                return value;
            }
            bucket.putLast(new Bucket<>(hash, key, value));
        }
        this.size += 1;
        return value;
    }

    private Bucket<K, V>[] rebuildTable(int newCap) {
        Bucket<K, V>[] oldTable = this.table;
        Bucket<K, V>[] newTable = new Bucket[newCap];
        for (Bucket<K, V> bucket : oldTable) {
            if (bucket != null) {
                int index = this.index(bucket.hash);
                newTable[index] = bucket;
            }
        }
        return newTable;
    }

    private V removeVal(int index, K key) {
        Bucket<K, V> bucket = this.findBucket(index);
        Bucket<K, V> prev = null;
        while (bucket != null) {
            if (bucket.matchKey(key)) {
                if (Objects.isNull(prev)) {
                    this.table[index] = null;
                } else {
                    prev.next = bucket.next;
                }
                this.size -= 1;
                return bucket.value;
            }
            prev = bucket;
            bucket = bucket.next;
        }
        return null;
    }

    private void resize() {
        int oldCap = this.tableCapacity();
        int newCap = 0;
        if (oldCap == 0) {
            oldCap = DEFAULT_INITIAL_CAPACITY;
            this.threshold = (int) (DEFAULT_INITIAL_CAPACITY * DEFAULT_LOAD_FACTOR);
        } else {
            newCap = oldCap << 1;
            this.threshold = this.threshold << 1;
        }

        if (newCap == 0) {
            newCap = oldCap;
        }
        this.grow(newCap);
    }

    private int tableCapacity() {
        return Objects.isNull(this.table) ? 0 : this.table.length;
    }

    private boolean tableEmpty() {
        return Objects.isNull(this.table);
    }

    static class Bucket<K, V> {
        Bucket<K, V> next;
        int hash;
        K key;
        V value;

        public Bucket(int hash, K key, V value) {
            this.hash = hash;
            this.key = key;
            this.value = value;
        }

        public Bucket<K, V> lookup(K key) {
            Bucket<K, V> bucket = this;
            while (bucket != null) {
                if (bucket.matchKey(key)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
            return null;
        }

        public boolean matchKey(K key) {
            return this.key == key || this.key.equals(key);
        }

        public void putLast(Bucket<K, V> bucket) {
            this.last().next = bucket;
        }

        private Bucket last() {
            Bucket<K, V> bucket = this;
            while (true) {
                if (Objects.isNull(bucket.next)) {
                    return bucket;
                }
                bucket = bucket.next;
            }
        }
    }
}
複製代碼

總結

其中最難的應屬紅黑樹,真的是極其複雜,筆者用了一個小時還沒能理解其中要領,索性使用鏈表替代了,等有時間再靜下心來把未完成的任務消滅掉。git

理解問題,Tasking,TDD(包含重構),這是筆者最近一直在遵照的規則,但願能夠給您給來一點感悟。github

源碼

相關文章
相關標籤/搜索