pytorch代碼倉庫html
pytorch在19年11月份的時候合入了這部分剪枝的代碼。pytorch提供一些直接可用的api,用戶只須要傳入須要剪枝的module實例和須要剪枝的參數名字,系統自動幫助完成剪枝操做,看起來接口挺簡單。好比 def random_structured(module, name, amount, dim)git
pytorch支持的幾種類型的剪枝策略:

詳細分析
- pytorch提供了一個剪枝的抽象基類‘‘class BasePruningMethod(ABC)’,全部剪枝策略都須要繼承該基類,並重載部分函數就能夠了
- 通常狀況下須要重載__init__和compute_mask方法,__call__, apply_mask, apply, prune和remove不須要重載,例如官方提供的RandomUnstructured剪枝方法


- 剪枝的API接口,能夠看到支持用戶自定義的剪枝mask,接口爲custom_from_mask

- API的實現,使用classmethod的方法,剪枝策略的實例化在框架內部完成,不須要用戶實例化
-
剪枝的大隻過程:github
- 根據用戶選擇的剪枝API生成對應的策略實例,此時會判斷須要作剪枝操做的module上是否已經掛有前向回調函數,沒有則生成新的,有了就在老的上面添加,而且生成PruningContainer。從這裏能夠看出,對於同一個module使用多個剪枝策略時,pytorch經過PruningContainer來對剪枝策略進行管理。PruningContainer自己也是繼承自BasePruningMethod。同時設置前向計算的回調,便於後續訓練時調用。
- 接着根據用戶輸入的module和name,找到對應的參數tensor。若是是第一次剪枝,那麼須要生成_orig結尾的tensor,而後刪除原始的module上的tensor。如name爲bias,那麼生成bias_orig存起來,而後刪除module.bias屬性。
- 獲取defaultmask,而後調用method.computemask生成當前策略的mask值。生成的mask會被存在特定的緩存module.register_buffer(name + "_mask", mask)。這裏的compute_mask多是兩種狀況:若是隻有一個策略,那麼調用的時候對應剪枝策略的compute_mask方法,若是一個module有多個剪枝策略組合,那麼調用的應該是PruningContainer的compute_mask

4. 執行剪枝,保存剪枝結果到module的屬性,註冊訓練時的剪枝回調函數,剪枝完成。新的mask應用在orig的tensor上面生成新的tensor保存的對應的name屬性

pytorch還提供各種一個remove接口,目的是把以前的剪枝結果持久化,具體操做就是刪除以前生成的跟剪枝相關的緩存或者是回調hook接口,設置被剪枝的name參數(如bias)爲最後一次訓練的值。
api
-
本身寫一個剪枝策略接口也是能夠的:
緩存
- 先寫一個剪枝策略類繼承BasePruningMethod
- 而後重載基類的compute_mask方法,寫本身的計算mask方法
官方完整教程在這裏app