keras自定義網絡層

在深度學習領域,Keras是一個高度封裝的庫並被普遍應用,能夠經過調用其內置網絡模塊(各類網絡層)實現針對性的模型結構;當所須要的網絡層功能不被包含時,則須要經過自定義網絡層或模型實現。網絡

如何在keras框架下自定義層,基本「套路」以下。框架

通常地,keras中的網絡層是一個類,因此自定義層即編寫一個類,更爲重要的是這個類(即自定義層)須要繼承Layer父類,並且須要實現如下四種方法:

  1. __init __ (self, output_dim, **kwargs)

這個方法是用來初始化並自定義自定義層所需的屬性,好比output_dim;
此外,該方法須要執行super().__init __(**kwargs),這行代碼是執行Layer類中的初始化函數;
當執行上述代碼就沒有必要去管input_shape,weights,trainable等關鍵字參數,由於父類(Layer)的初始化函數實現了它們與layer實例的綁定。函數

  1. build(self, input_shape)

這個方法是用來建立層的權重;
在該方法中,根據以前的繼承,經過Layer類的add_weight方法來自定義並添加一個權重矩陣,這個方法須要input_shape參數;
該方法必須設self.built = True,目的是爲了保證這個層的權重定義函數build被執行過了;
在built函數中,須要說明這個權重各方面的屬性,好比shape、初始化方式以及可訓練性等信息。學習

  1. call(self, x)

這個方法是用來編寫層的功能邏輯;
在該方法中,須要關注傳入call的第一個參數:輸入張量x;x只能是一種形式變量,不能是具體的變量,即它不能被定義;
這個call函數就是該層的計算邏輯,當建立好這個層實例後,該實例能夠執行call函數;
可見,這個層的核心應該是一段符號式的輸入張量到輸出張量的計算過程。ui

  1. compute_output_shape(self, input_shape)

這個方法是用來保證輸出shape是正確的;
這裏重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出的shape符合實際;
父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,因此須要重寫該方法。blog

示例

結合官方文檔的例子,給出以下一個自定義層的代碼:
繼承

使用自定義層,就如同使用keras內置網絡層同樣,以下圖所示:(另外,本例使用kears內置的激活函數層ReLU承接自定義層的輸出,從而避免將激活函數的功能加入到自定義層中)
文檔

相關文章
相關標籤/搜索