[TOC]算法
1、概述
-
Meta Learning = Learn to learn網絡
讓機器去學習如何進行學習:使用一系列的任務來訓練模型,模型根據在這些任務上汲取的經驗,成爲了一個強大的學習者,可以更快的學習新任務。app
-
Meta Learning VS Lifelong Learning框架
- 終身學習:着眼於用同一個模型去學習不一樣的任務。
- 元學習:不一樣任務使用不一樣的模型,元學習者積累經驗後,在新任務上訓練的更快更好。
-
Meta Learning VS Machine Learning機器學習
- 機器學習:核心是經過人爲設計的學習算法(Learning Algorithm),利用訓練數據訓練獲得一個函數f,這個函數能夠用於新數據的預測分類。
- 元學習:讓機器本身學習找出最優的學習算法。根據提供的訓練數據找到一個能夠找到函數f的函數F的能力。
2、元學習的實現框架
-
定義一系列的學習算法函數
不一樣的網絡結構、參數初始化策略、參數更新策略決定了不一樣學習算法。性能
-
定義學習算法函數F的評價標準學習
綜合考慮學習算法F針對不一樣任務產生的函數f在進行測試時獲得的損失。測試
-
選取最好的學習算法F*=argminL(F)spa
最佳學習算法通常能夠經過梯度降低方法來肯定。
3、元學習的訓練數據
-
機器學習
機器學習的訓練數據和測試數據來自同一分佈的數據集。
-
元學習:
-
元學習的訓練數據是由一個個的訓練任務構成的,一個訓練任務對應一個傳統的機器學習的應用實例。
- 須要大批次數據的訓練任務顯然難以進行元學習訓練,所以常規的元學習的訓練任務通常是Few Shot Learning類型的任務,即經過少許數據就能構建一個任務,進行快速的學習與訓練。 > - 考慮到運算性能,現階段的元學習常常是與Few Shot Learning綁定在一塊兒。
-
訓練數據分爲訓練任務集和測試任務集。
-
任務集中的每個任務的訓練數據即傳統的機器學習應用實例中的訓練數據集和測試數據集,不過爲了區分訓練任務(Training Set)和測試任務(Testing Test) ,這裏將它們命名爲支持集(Support Set)和查詢集(Query Set).
-
4、元學習的Benchmarks
-
Omniglot數據集
-
組成
-
整個數據集由1623個符號(Characters)組成;
-
每一個符號有20個樣例(Examples),每一個樣例由不一樣的人書寫.
-
-
使用:結合Few-shot Learning中的N-ways K-shot分類問題
- 對於每個訓練任務和測試任務,樣本數據分爲N個類,每一個類提供K個樣本。
- 整個字符集分爲訓練字符集(Training Set or Support Set)和測試字符集(Testing Set or Query Set)
- 訓練任務:從訓練字符集中抽取N個類的字符,每種字符抽取K個樣本,組成一個訓練任務的訓練數據
- 測試任務:從測試字符集中抽取N個類的字符,每種字符抽取K個樣本,組成一個測試任務的訓練數據
-
-
MAML
Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.
-
損失函數 (Loss Function):$L(\Phi)= \sum_{n=1}^N{l^n(\hat{\theta}^n)}$
- $\hat{\theta}^n$:第n個任務中學習到的模型參數,取決於參數$\Phi$。
- $l^n(\hat{\theta}^n)$:第n個任務在其測試集上獲得的損失。
-
損失函數最小化:使用梯度降低(Gradient Descent)
$$\Phi\leftarrow\Phi-\eta\nabla_\Phi{L(\Phi)}$$
-
只考慮一次訓練以後對初始化參數的梯度更新。
- 只取進行一次梯度更新後的參數做爲當前任務的最佳參數。
- 上式求出的是元學習模型的通用參數,下式求出的是每一個任務的最佳參數。
- $L(\Phi)$和$\hat{\theta}$用於元學習模型的參數更新。
- 既能加快模型的適應速度,在必定程度上還能減輕過擬合。、
$$\hat{\theta}=\Phi-\epsilon\nabla_\Phi{l(\Phi)}$$
-
總體執行流程:
- 將每個訓練任務和測試任務的模型參數初始化:$\Phi_0$
- 對每個任務執行一次梯度更新獲得新的模型參數:$\hat{\theta}$
- 綜合考慮全部訓練任務在$\hat{\theta}$下的損失:$L(\Phi)$
- 對$L(\Phi)$執行梯度更新,獲得最優的元學習模型的參數:$\Phi$
- 將該$\Phi$用於測試任務,檢驗更新效果。
-
二階微分與一階近似(數學推導):
- 訓練過程的參數更新公式以下:
$$ \Phi\leftarrow\Phi-\eta\nabla_\Phi{L(\Phi)} \ L(\Phi)= \sum_{n=1}^N{l^n(\hat{\theta}^n)} \ \hat{\theta}=\Phi-\epsilon\nabla_\Phi{l(\Phi)} $$
-
$ \nabla_\Phi{L(\Phi)} $的計算 $$ \nabla_\Phi{L(\Phi)}=\nabla_\Phi{\sum_{n=1}^{N}l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_\Phi{l^n(\hat{\theta}^n)} \ $$
-
其中$\nabla_\Phi{l^n(\hat{\theta}^n)}$爲: $$ \nabla_\Phi{l(\hat{\theta})}=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \Phi_1\ \partial l(\hat{\theta})/\partial \Phi_2\ \vdots\ \partial l(\hat{\theta})/\partial \Phi_i\ \end{matrix} \right| $$
$\Phi_i$表示模型的各個參數(Weight),$\Phi_i$決定當前任務的$\hat{\theta}$的第j個參數$\hat{\theta}_j$,從而影響$l(\hat{\theta})$
-
根據三者之間的關係:$\Phi_i \rightarrow \hat{\theta}_j \rightarrow l(\hat{\theta})$,有: $$ \frac{\partial l(\hat{\theta})}{\partial \Phi_i}=\sum_j\sum_i {\frac{\partial l(\hat{\theta})}{\partial \hat{\theta}_j}\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}} $$
-
又由於根據參數更新公式(3),取$\hat{theta}$的第j維爲例,有: $$ \hat{\theta}_j=\Phi_j-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j} $$
-
求$\hat{\theta}_j$對$\Phi_j$的偏導,有: $$ \frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}= \begin{cases} -\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i},i \neq j\ 1-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i},i = j \end{cases} $$
將該式代回到$\frac{\partial l(\hat{\theta})}{\partial \Phi_i}$中便可求出$\nabla_\Phi{L(\Phi)}$。
但實際上該式存在二次微分的計算,會極大的影響運算效率。
-
做者用一次微分來近似代替二次微分的結果: $$ \frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}= \begin{cases} -\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i} \approx{0} ,i \neq j\ 1-\epsilon \frac{\partial l(\Phi)}{\partial \Phi_j \partial \Phi_i} \approx{1},i = j \end{cases} $$
-
因此 $$ \frac{\partial l(\hat{\theta})}{\partial \Phi_i}=\sum_j\sum_i {\frac{\partial l(\hat{\theta})}{\partial \hat{\theta}_j}\frac{\partial \hat{\theta}_j}{\partial \hat{\Phi}_i}} \approx \frac{\partial l(\hat{\theta})}{\partial \hat{\theta}i}\ \nabla\Phi{l(\hat{\theta})}=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \Phi_1\ \partial l(\hat{\theta})/\partial \Phi_2\ \vdots\ \partial l(\hat{\theta})/\partial \Phi_i\ \end{matrix} \right|=\left| \begin{matrix} \partial l(\hat{\theta})/\partial \hat{\theta}_1\ \partial l(\hat{\theta})/\partial \hat{\theta}_2\ \vdots\ \partial l(\hat{\theta})/\partial \hat{\theta}i\ \end{matrix} \right|=\nabla\hat{\theta}{l(\hat{\theta})} $$
-
-
-
因此$ \nabla_\Phi{L(\Phi)} $能夠化爲: $$ \nabla_\Phi{L(\Phi)}=\nabla_\Phi{\sum_{n=1}^{N}l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_\Phi{l^n(\hat{\theta}^n)}=\sum_{n=1}^{N}\nabla_{\hat{\theta}^n}{l^n(\hat{\theta}^n)} $$
經過將二階微分近似爲一階微分,提高運算效率的同時對模型預測的準確率沒有太大的影響。
-
-
Reptile
Nichol A, Achiam J, Schulman J. On first-order meta-learning algorithms[J]. arXiv preprint arXiv:1803.02999, 2018.
-
基本思想
-
基於MAML進行改善,對參數更新次數不加限制。
-
-
Reptile VS Pretraining VS MAML
-