三十分鐘理解計算圖上的微積分:Backpropagation,反向微分

神經網路的訓練演算法,目前基本上是以Backpropagation (BP) 反向傳播為主(加上一些變化),NN的訓練是在1986年被提出,但實際上,BP 已經在不同領域中被重複發明了數十次了(參見 Griewank (2010)[1])。更加一般性且與應用場景獨立的名稱叫做:反向微分 (reverse-mode differentiation)。本文是看了資料[2]中的介紹,寫的蠻好,自己記錄一下,方便理解。

從本質上看,BP 是一種快速求導的技術,可以作為一種不單單用在深度學習中並且可以勝任大量數值計算場景的基本的工具。

計算圖

必須先來講一講計算圖的概念,計算圖出現在Bengio 09年的《Learning Deep Architectures for AI》,
Bengio使用了有向圖結構來描述神經網路的計算:

這裡寫圖片描述

整張圖可看成三部分:輸入結點、輸出結點、從輸入到輸出的計算函式。上圖很容易理解,就是output=sin(a*x b) * x

計算圖上的導數

有向無環圖在電腦科學領域到處可見,特別是在函式式程式中。他們與依賴圖(dependency graph)或者呼叫圖(call graph)緊密相關。同樣他們也是大部分非常流行的深度學習框架背後的核心抽象。

下文以下面簡單的例子來描述:

這裡寫圖片描述

假設 a = 2, b = 1,最終表示式的值就是 6。
為了計算在這幅圖中的偏導數,我們需要 和式法則(sum rule )和 乘式法則(product rule):

這裡寫圖片描述

下面,在圖中每條邊上都有對應的導數了:
這裡寫圖片描述

那如果我們想知道哪些沒有直接相連的節點之間的影響關係呢?假設就看看 e 如何被 a 影響的。如果我們以 1 的速度改變 a,那麼 c 也是以 1 的速度在改變,導致 e 發生了 2 的速度在改變。因此 e 是以 1 * 2 的關於 a 變化的速度在變化。
而一般的規則就是對一個點到另一個點的所有的可能的路徑進行求和,每條路徑對應於該路徑中的所有邊的導數之積。因此,為了獲得 e 關於 b 的導數,就採用路徑求和:

這裡寫圖片描述

這個值就代表著 b 改變的速度通過 c 和 d 影響到 e 的速度。聰明的你應該可以想到,事情沒有那麼簡單吧?是的,上面例子比較簡單,在稍微複雜例子中,路徑求和法很容易產生路徑爆炸:

這裡寫圖片描述

在上面的圖中,從 X 到 Y 有三條路徑,從 Y 到 Z 也有三條。如果我們希望計算 dZ/dX,那麼就要對 3 * 3 = 9 條路徑進行求和了:

這裡寫圖片描述

該圖有 9 條路徑,但是在圖更加複雜的時候,路徑數量會指數級地增長。相比於粗暴地對所有的路徑進行求和,更好的方式是進行因式分解:

這裡寫圖片描述

有了這個因式分解,就出現了高效計算導數的可能——通過在每個節點上反向合併路徑而非顯式地對所有的路徑求和來大幅提升計算的速度。實際上,兩個演算法對每條邊的訪問都只有一次!

前向微分和反向微分

前向微分從圖的輸入開始,一步一步到達終點。在每個節點處,對輸入的路徑進行求和。每個這樣的路徑都表示輸入影響該節點的一個部分。通過將這些影響加起來,我們就得到了輸入影響該節點的全部,也就是關於輸入的導數。

這裡寫圖片描述

相對的,反向微分是從圖的輸出開始,反向一步一步抵達最開始輸入處。在每個節點處,會合了所有源於該節點的路徑。

這裡寫圖片描述

前向微分 跟蹤了輸入如何改變每個節點的情況。反向微分 則跟蹤了每個節點如何影響輸出的情況。也就是說,前向微分應用操作 d/dX 到每個節點,而反向微分應用操作 dZ/d 到每個節點。

讓我們重新看看剛開始的例子:
這裡寫圖片描述

我們可以從 b 往上使用前向微分。這樣獲得了每個節點關於 b 的導數。(寫在邊上的導數我們已經提前算高了,這些相對比較容易,只和一條邊的輸入輸出關係有關)

這裡寫圖片描述

我們已經計算得到了 de/db,輸出關於一個輸入 b 的導數。但是如果我們從 e 往回計算反向微分呢?這會得到 e 關於每個節點的導數:

這裡寫圖片描述

反向微分給出了 e 關於每個節點的導數,這裡的確是每一個節點。我們得到了 de/da 和 de/db,e 關於輸入 a 和 b 的導數。(當然中間節點都是包括的),前向微分給了我們輸出關於某一個輸入的導數,而反向微分則給出了所有的導數。

想象一個擁有百萬個輸入和一個輸出的函式。前向微分需要百萬次遍歷計算圖才能得到最終的導數,而反向微分僅僅需要遍歷一次就能得到所有的導數!速度極快!

訓練神經網路時,我們將衡量神經網路表現的代價函式看做是神經網路引數的函式。我們希望計算出代價函式關於所有引數的偏導數,從而進行梯度下降(gradient descent)。現在,常常會遇到百萬甚至千萬級的引數的神經網路。所以,反向微分,也就是 BP,在神經網路中發揮了關鍵作用!所以,其實BP的本質就是鏈式法則。

(有使用前向微分更加合理的場景麼?當然!因為反向微分得到一個輸出關於所有輸入的導數,前向微分得到了所有輸出關於一個輸入的導數。如果遇到了一個有多個輸出的函式,前向微分肯定更加快速)

BP 也是一種理解導數在模型中如何流動的工具。在推斷為何某些模型優化非常困難的過程中,BP 也是特別重要的。典型的例子就是在 Recurrent Neural Network 中理解 vanishing gradient 的原因。


有的時候,越是有效的演算法,原理往往越是簡單。


參考資料

[1] Who Invented the Reverse Mode of Differentiation?
[2] http://www.jianshu.com/p/0e9eea729476