ニューラルネットワークの逆誤差伝搬の重みの微分の計算

本記事では以下の命題を証明する.

命題.
\(W\in\mathbb{R}^{n_1\times n_3}\)\(X\in\mathbb{R}^{n_3\times n_2}\)\(B\in\mathbb{R}^{n_1\times n_2}\) とし,\(G:=WX+B\) とする. このとき,(偏微分可能な)スカラー値関数 \(f(G)\)に対して, \[\begin{align} \frac{\partial f}{\partial W} &= \frac{\partial f}{\partial G}X^\top,\tag{1}\\ \frac{\partial f}{\partial X} &= W^\top\frac{\partial f}{\partial G} \tag{2}\\ \frac{\partial f}{\partial B} &= \frac{\partial f}{\partial G}.\tag{3} \end{align}\]

なお,スカラー値関数に対する行列の偏微分の定義はQiitaの解説記事等を参照されたい.

証明.
はじめに \((1)\) を証明する. \(f\) に対する \(W\)\((i,j)\) 要素での偏微分を考えると \[ \frac{\partial f}{\partial w_{i,j}} = \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial w_{i,j}} \tag{4} \] である. ところが,\(G=WX+B\) であることから,\(G\)\(i\) 行目の要素にしか \(w_{i,j}\) は含まれない. 従って \((4)\) の右辺について \[ \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial w_{i,j}} = \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}}\frac{\partial g_{i,l}}{\partial w_{i,j}} \] が成り立つ. ここで,\(g_{i,l} = b_{i,l}+\sum_{m=1}^{n_3}w_{i,m}x_{m,l}\) を用いると \[\begin{align} \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}}\frac{\partial g_{i,l}}{\partial w_{i,j}} &= \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}} x_{j,l}\\ &= \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}} (X^\top)_{l,j}\\ &= \left(\frac{\partial f}{\partial G} X^\top\right)_{i,j} \tag{5} \end{align}\] \((4)\)\((5)\) を代入することで,\((1)\) が得られる.

\((2)\) 及び \((3)\) の証明の手順は \((1)\) とほとんど同じであるため,式変形のみを示す. \((2)\)\[\begin{align} \frac{\partial f}{\partial x_{i,j}} &= \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial x_{i,j}}\\ &=\sum_{k=1}^{n_1} \frac{\partial f}{\partial g_{k,j}}\frac{\partial g_{k,j}}{\partial x_{i,j}}\\ &= \sum_{k=1}^{n_1} \frac{\partial f}{\partial g_{k,j}} w_{k,i}\\ &= \sum_{k=1}^{n_1} \left(\frac{\partial f}{\partial G}\right)_{k,j} \left(W^\top\right)_{i,k}\\ &= \left(W^\top\frac{\partial f}{\partial G}\right)_{i,j} {}_{} \end{align}\] より成り立つ. \((3)\)\[\begin{align} \frac{\partial f}{\partial b_{i,j}} &= \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial b_{i,j}}\\ &= \frac{\partial f}{\partial g_{i,j}}\underbrace{\frac{\partial g_{i,j}}{\partial b_{i,j}}}_{1}\\ &= \left(\frac{\partial f}{\partial G}\right)_{i,j} {}_{} \end{align}\] より成り立つ.

コメント