Python 中 lru_cache 的使用和實現

在計算機軟件領域,緩存(Cache)指的是將部分數據存儲在內存中,以便下次可以更快地訪問這些數據,這也是一個典型的用空間換時間的例子。通常用於緩存的內存空間是固定的,當有更多的數據須要緩存的時候,須要將已緩存的部分數據清除後再將新的緩存數據放進去。須要清除哪些數據,就涉及到了緩存置換的策略,LRU(Least Recently Used,最近最少使用)是很常見的一個,也是 Python 中提供的緩存置換策略。python

下面咱們經過一個簡單的示例來看 Python 中的 lru_cache 是如何使用的。git

def factorial(n):
    print(f"計算 {n} 的階乘")
    return 1 if n <= 1 else n * factorial(n - 1)

a = factorial(5)
print(f'5! = {a}')
b = factorial(3)
print(f'3! = {b}')

上面的代碼中定義了函數 factorial,經過遞歸的方式計算 n 的階乘,而且在函數調用的時候打印出 n 的值。而後分別計算 5 和 3 的階乘,並打印結果。運行上面的代碼,輸出以下程序員

計算 5 的階乘
計算 4 的階乘
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
5! = 120
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
3! = 6

能夠看到,factorial(3) 的結果在計算 factorial(5) 的時候已經被計算過了,可是後面又被重複計算了。爲了不這種重複計算,咱們能夠在定義函數 factorial 的時候加上 lru_cache 裝飾器,以下所示github

import functools
# 注意 lru_cache 後的一對括號,證實這是帶參數的裝飾器
@functools.lru_cache()
def factorial(n):
    print(f"計算 {n} 的階乘")
    return 1 if n <= 1 else n * factorial(n - 1)

從新運行代碼,輸入以下算法

計算 5 的階乘
計算 4 的階乘
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
5! = 120
3! = 6

能夠看到,此次在調用 factorial(3) 的時候沒有打印相應的輸出,也就是說 factorial(3) 是直接從緩存讀取的結果,證實緩存生效了。緩存

被 lru_cache 修飾的函數在被相同參數調用的時候,後續的調用都是直接從緩存讀結果,而不用真正執行函數。下面咱們深刻源碼,看看 Python 內部是怎麼實現 lru_cache 的。寫做時 Python 最新發行版是 3.9,因此這裏使用的是 Python 3.9 的源碼,而且保留了源碼中的註釋。數據結構

def lru_cache(maxsize=128, typed=False):
    """Least-recently-used cache decorator.
    If *maxsize* is set to None, the LRU features are disabled and the cache
    can grow without bound.
    If *typed* is True, arguments of different types will be cached separately.
    For example, f(3.0) and f(3) will be treated as distinct calls with
    distinct results.
    Arguments to the cached function must be hashable.
    View the cache statistics named tuple (hits, misses, maxsize, currsize)
    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
    Access the underlying function with f.__wrapped__.
    See:  http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
    """

    # Users should only access the lru_cache through its public API:
    #       cache_info, cache_clear, and f.__wrapped__
    # The internals of the lru_cache are encapsulated for thread safety and
    # to allow the implementation to change (including a possible C version).
    
    if isinstance(maxsize, int):
        # Negative maxsize is treated as 0
        if maxsize < 0:
            maxsize = 0
    elif callable(maxsize) and isinstance(typed, bool):
        # The user_function was passed in directly via the maxsize argument
        user_function, maxsize = maxsize, 128
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    elif maxsize is not None:
        raise TypeError(
            'Expected first argument to be an integer, a callable, or None')
    
    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    
    return decorating_function

這段代碼中有以下幾個關鍵點app

  • 關鍵字參數函數

    maxsize 表示緩存容量,若是爲 None 表示容量不設限, typed 表示是否區分參數類型,註釋中也給出瞭解釋,若是 typed == True,那麼 f(3)f(3.0) 會被認爲是不一樣的函數調用。源碼分析

  • 第 24 行的條件分支

    若是 lru_cache 的第一個參數是可調用的,直接返回 wrapper,也就是把 lru_cache 當作不帶參數的裝飾器,這是 Python 3.8 纔有的特性,也就是說在 Python 3.8 及以後的版本中咱們能夠用下面的方式使用 lru_cache,多是爲了防止程序員在使用 lru_cache 的時候忘記加括號。

    import functools
    # 注意 lru_cache 後面沒有括號,
    # 證實這是將其當作不帶參數的裝飾器
    @functools.lru_cache
    def factorial(n):
        print(f"計算 {n} 的階乘")
        return 1 if n <= 1 else n * factorial(n - 1)

    注意,Python 3.8 以前的版本運行上面代碼會報錯:TypeError: Expected maxsize to be an integer or None。

lru_cache 的具體邏輯是在 _lru_cache_wrapper 函數中實現的,仍是同樣,列出源碼,保留註釋。

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # Move the link to the front of the circular queue
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we're
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot
                else:
                    # Put result in a new link at the front of the queue.
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # Use the cache_len bound method instead of the len() function
                    # which could potentially be wrapped in an lru_cache itself.
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

函數開始的地方 2~14 行定義了一些關鍵變量,

  • hitsmisses 分別表示緩存命中和沒有命中的次數
  • root 雙向循環鏈表的頭結點,每一個節點保存前向指針、後向指針、key 和 key 對應的 result,其中 key 爲 _make_key 函數根據參數結算出來的字符串,result 爲被修飾的函數在給定的參數下返回的結果。注意,root 是不保存數據 key 和 result 的。
  • cache 是真正保存緩存數據的地方,類型爲 dict。cache 中的 key 也是 _make_key 函數根據參數結算出來的字符串,value 保存的是 key 對應的雙向循環鏈表中的節點。

接下來根據 maxsize 不一樣,定義不一樣的 wrapper

  • maxsize == 0,其實也就是沒有緩存,那麼每次函數調用都不會命中,而且沒有命中的次數 misses 加 1。

  • maxsize is None,不限制緩存大小,若是函數調用不命中,將沒有命中次數 misses 加 1,不然將命中次數 hits 加 1。

  • 限制緩存的大小,那麼須要根據 LRU 算法來更新 cache,也就是 42~97 行的代碼。

    • 若是緩存命中 key,那麼將命中節點移到雙向循環鏈表的結尾,而且返回結果(47~58 行)

      這裏經過字典加雙向循環鏈表的組合數據結構,實現了用 O(1) 的時間複雜度刪除給定的節點。

    • 若是沒有命中,而且緩存滿了,那麼須要將最久沒有使用的節點(root 的下一個節點)刪除,而且將新的節點添加到鏈表結尾。在實現中有一個優化,直接將當前的 root 的 key 和 result 替換成新的值,將 root 的下一個節點置爲新的 root,這樣獲得的雙向循環鏈表結構跟刪除 root 的下一個節點而且將新節點加到鏈表結尾是同樣的,可是避免了刪除和添加節點的操做(68~88 行)

    • 若是沒有命中,而且緩存沒滿,那麼直接將新節點添加到雙向循環鏈表的結尾(root[PREV],這裏我認爲是結尾,可是代碼註釋中寫的是開頭)(89~96 行)

最後給 wrapper 添加兩個屬性函數 cache_infocache_clearcache_info 顯示當前緩存的命中狀況的統計數據,cache_clear 用於清空緩存。對於上面階乘相關的代碼,若是在最後執行 factorial.cache_info(),會輸出

CacheInfo(hits=1, misses=5, maxsize=128, currsize=5)

第一次執行 factorial(5) 的時候都沒命中,因此 misses = 5,第二次執行 factorial(3) 的時候,緩存命中,因此 hits = 1。

最後須要說明的是,對於有多個關鍵字參數的函數,若是兩次調用函數關鍵字參數傳入的順序不一樣,會被認爲是不一樣的調用,不會命中緩存。另外,被 lru_cache 裝飾的函數不能包含可變類型參數如 list,由於它們不支持 hash。

總結一下,這篇文章首先簡介了一下緩存的概念,而後展現了在 Python 中 lru_cache 的使用方法,最後經過源碼分析了 Python 中 lru_cache 的實現細節。

相關文章
相關標籤/搜索