ニューラルネットワークの逆誤差伝搬の重みの微分の計算
本記事では以下の命題を証明する.
命題.
\(W\in\mathbb{R}^{n_1\times n_3}\),\(X\in\mathbb{R}^{n_3\times n_2}\),\(B\in\mathbb{R}^{n_1\times n_2}\) とし,\(G:=WX+B\) とする. このとき,(偏微分可能な)スカラー値関数 \(f(G)\)に対して, \[\begin{align}
\frac{\partial f}{\partial W} &= \frac{\partial f}{\partial G}X^\top,\tag{1}\\
\frac{\partial f}{\partial X} &= W^\top\frac{\partial f}{\partial G} \tag{2}\\
\frac{\partial f}{\partial B} &= \frac{\partial f}{\partial G}.\tag{3}
\end{align}\]
なお,スカラー値関数に対する行列の偏微分の定義はQiitaの解説記事等を参照されたい.
証明.
はじめに \((1)\) を証明する. \(f\) に対する \(W\) の \((i,j)\) 要素での偏微分を考えると \[
\frac{\partial f}{\partial w_{i,j}}
= \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial w_{i,j}} \tag{4}
\] である. ところが,\(G=WX+B\) であることから,\(G\) の \(i\) 行目の要素にしか \(w_{i,j}\) は含まれない. 従って \((4)\) の右辺について \[
\sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial w_{i,j}}
= \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}}\frac{\partial g_{i,l}}{\partial w_{i,j}}
\] が成り立つ. ここで,\(g_{i,l} = b_{i,l}+\sum_{m=1}^{n_3}w_{i,m}x_{m,l}\) を用いると \[\begin{align}
\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}}\frac{\partial g_{i,l}}{\partial w_{i,j}}
&= \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}} x_{j,l}\\
&= \sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{i,l}} (X^\top)_{l,j}\\
&= \left(\frac{\partial f}{\partial G} X^\top\right)_{i,j} \tag{5}
\end{align}\] \((4)\) に \((5)\) を代入することで,\((1)\) が得られる.
\((2)\) 及び \((3)\) の証明の手順は \((1)\) とほとんど同じであるため,式変形のみを示す. \((2)\) は \[\begin{align} \frac{\partial f}{\partial x_{i,j}} &= \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial x_{i,j}}\\ &=\sum_{k=1}^{n_1} \frac{\partial f}{\partial g_{k,j}}\frac{\partial g_{k,j}}{\partial x_{i,j}}\\ &= \sum_{k=1}^{n_1} \frac{\partial f}{\partial g_{k,j}} w_{k,i}\\ &= \sum_{k=1}^{n_1} \left(\frac{\partial f}{\partial G}\right)_{k,j} \left(W^\top\right)_{i,k}\\ &= \left(W^\top\frac{\partial f}{\partial G}\right)_{i,j} {}_{} \end{align}\] より成り立つ. \((3)\) は \[\begin{align} \frac{\partial f}{\partial b_{i,j}} &= \sum_{k=1}^{n_1}\sum_{l=1}^{n_2} \frac{\partial f}{\partial g_{k,l}}\frac{\partial g_{k,l}}{\partial b_{i,j}}\\ &= \frac{\partial f}{\partial g_{i,j}}\underbrace{\frac{\partial g_{i,j}}{\partial b_{i,j}}}_{1}\\ &= \left(\frac{\partial f}{\partial G}\right)_{i,j} {}_{} \end{align}\] より成り立つ.
コメント
コメントを投稿