機械学習基礎理論独習

誤りがあればご指摘いただけると幸いです。数式が整うまで少し時間かかります。リンクフリーです。

勉強ログです。リンクフリーです
目次へ戻る

確率的主成分分析の最尤推定 - 解析的に解く

はじめに

本記事は計算がややこしい箇所があるので、結果だけを見てもらってもいいかもしれません。

{\boldsymbol\mu} の最尤解

確率的主成分分析モデルの対数尤度は、

\begin{eqnarray}
\ln p({\bf X}|{\boldsymbol\mu},{\bf W},\sigma^2)&=&\sum_{n=1}^N\ln{\mathcal N}({\bf x}_n|{\boldsymbol\mu},{\bf C})\\
&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{1}{2}\sum_{n=1}^N({\bf x}_n-{\boldsymbol\mu})^\top{\bf C}^{-1}({\bf x}_n-{\boldsymbol\mu})\tag{1}
\end{eqnarray}

でした。
{\bf C} は次のように定義されます。

\begin{eqnarray}
{\bf C}={\bf W}{\bf W}^\top+\sigma^2{\bf I}\tag{2}
\end{eqnarray}

確率的主成分分析モデルの対数尤度 (1) をパラメータ {\boldsymbol\mu} に対して最大化すると、

\begin{eqnarray}
{\boldsymbol\mu}_{\rm ML}=\bar{\bf x}\tag{3}
\end{eqnarray}

となります。
ただし、\bar{\bf x} はデータベクトルの平均です。
{\boldsymbol\mu}_{\rm ML} の導出については、PRML演習問題 12.9(基本) をご覧ください。

(3) を式 (1) に代入します。

\begin{eqnarray}
\ln p({\bf X}|{\boldsymbol\mu},{\bf W},\sigma^2)&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{1}{2}\sum_{n=1}^N({\bf x}_n-\bar{\bf x})^\top{\bf C}^{-1}({\bf x}_n-\bar{\bf x})\\
&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{1}{2}\sum_{n=1}^N{\rm Tr}\left(({\bf x}_n-\bar{\bf x})^\top{\bf C}^{-1}({\bf x}_n-\bar{\bf x})\right)\\
&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{1}{2}\sum_{n=1}^N{\rm Tr}\left({\bf C}^{-1}({\bf x}_n-\bar{\bf x})({\bf x}_n-\bar{\bf x})^\top\right)\\
&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{N}{2}{\rm Tr}\left({\bf C}^{-1}\frac{1}{N}\sum_{n=1}^N({\bf x}_n-\bar{\bf x})({\bf x}_n-\bar{\bf x})^\top\right)\\
&=&-\frac{ND}{2}\ln(2\pi)-\frac{N}{2}\ln|{\bf C}|-\frac{N}{2}{\rm Tr}\left({\bf C}^{-1}{\bf S}\right)\\
&=&-\frac{N}{2}\left(D\ln(2\pi)+\ln|{\bf C}|+{\rm Tr}({\bf C}^{-1}{\bf S})\right)\tag{4}
\end{eqnarray}

{\bf W} の最尤解

対数尤度 (4) をパラメータ {\bf W}微分して、={\bf O} とおきます。

\begin{eqnarray}
&&\frac{\partial}{\partial{\bf W}}\ln p({\bf X}|{\boldsymbol\mu},{\bf W},\sigma^2)={\bf O}\\
&&\Leftrightarrow\frac{\partial}{\partial{\bf W}}\ln|{\bf C}|+\frac{\partial}{\partial{\bf W}}{\rm Tr}({\bf C}^{-1}{\bf S})={\bf O}\\
&&\Leftrightarrow2{\bf C}^{-1}{\bf W}-2{\bf C}^{-1}{\bf S}{\bf C}^{-1}{\bf W}={\bf O}\tag{5}\\
&&\Leftrightarrow{\bf C}^{-1}{\bf S}{\bf C}^{-1}{\bf W}={\bf C}^{-1}{\bf W}\\
&&\Leftrightarrow{\bf S}{\bf C}^{-1}{\bf W}={\bf W}\tag{6}
\end{eqnarray}

(5) の変形で用いた \dfrac{\partial}{\partial{\bf W}}\ln|{\bf C}|=2{\bf C}^{-1}{\bf W}, \dfrac{\partial}{\partial{\bf W}}{\rm Tr}({\bf C}^{-1}{\bf S})=-2{\bf C}^{-1}{\bf S}{\bf C}^{-1}{\bf W} については本記事の下の方で導出しております。

まず、{\bf W}={\bf O} の時、尤度関数は最小になります。
次に、{\bf C}={\bf S} の時、式 (2) より、以下が成り立ちます。

\begin{eqnarray}
{\bf W}{\bf W}^\top={\bf S}-\sigma^2{\bf I}\tag{7}
\end{eqnarray}

(7) を満たす {\bf W} は一般に次のように書けます。

\begin{eqnarray}
{\bf W}={\bf U}({\bf\Lambda}-\sigma^2{\bf I})^{1/2}{\bf R}\tag{8}
\end{eqnarray}

(8) が式 (7) を満たすか検算してみます。

\begin{eqnarray}
{\bf W}{\bf W}^\top&=&({\bf U}({\bf\Lambda}-\sigma^2{\bf I})^{1/2}{\bf R})({\bf U}({\bf\Lambda}-\sigma^2{\bf I})^{1/2}{\bf R})^\top\\
&=&{\bf U}({\bf\Lambda}-\sigma^2{\bf I})^{1/2}{\bf R}{\bf R}^\top({\bf\Lambda}-\sigma^2{\bf I})^{1/2}{\bf U}^\top\\
&=&{\bf U}({\bf\Lambda}-\sigma^2{\bf I}){\bf U}^\top\\
&=&{\bf U}{\bf\Lambda}{\bf U}^\top-{\bf U}\sigma^2{\bf I}{\bf U}^\top\\
&=&{\bf S}-\sigma^2{\bf I}\tag{9}
\end{eqnarray}

最後に、{\bf W}\not={\bf O},{\bf C}\not={\bf S} の時に、{\bf W} を求めます。
{\bf W}特異値分解します。

\begin{eqnarray}
{\bf W}={\bf U}{\bf L}{\bf \bf V}^\top\tag{10}
\end{eqnarray}

(10) で、
{\bf U} = ({\bf u}_1,\ldots,{\bf u}_M)D\times M の直交行列で、
{\bf L}={\rm diag}(l_1,\dots,l_M)M\times M の対角行列で、(l_j{\bf W} の特異値)
{\bf V}M\times M の直交行列です。

(10) を式 (2) に代入します。

\begin{eqnarray}
{\bf C}&=&({\bf U}{\bf L}{\bf V}^\top)({\bf U}{\bf L}{\bf V}^\top)^\top+\sigma^2{\bf I}\\
&=&{\bf U}{\bf L}{\bf V}^\top{\bf V}{\bf L}{\bf U}^\top+\sigma^2{\bf I}\\
&=&{\bf U}{\bf L}^2{\bf U}^\top+\sigma^2{\bf I}\\
&=&{\bf U}{\bf L}^2{\bf U}^\top+{\bf U}(\sigma^2{\bf I}){\bf U}^\top\\
&=&{\bf U}({\bf L}^2+\sigma^2{\bf I}){\bf U}^\top\tag{11}
\end{eqnarray}

(11) より、以下が成り立ちます。

\begin{eqnarray}
{\bf C}^{-1}={\bf U}({\bf L}^2+\sigma^2{\bf I})^{-1}{\bf U}^\top\tag{12}
\end{eqnarray}

(12) を式 (6) に代入します。

\begin{eqnarray}
&&{\bf S}{\bf U}({\bf L}^2+\sigma^2{\bf I})^{-1}{\bf U}^\top{\bf W}={\bf W}\\
&&\Leftrightarrow{\bf S}{\bf U}({\bf L}^2+\sigma^2{\bf I})^{-1}{\bf U}^\top={\bf I}\\
&&\Leftrightarrow{\bf S}{\bf U}({\bf L}^2+\sigma^2{\bf I})^{-1}={\bf U}\\
&&\Leftrightarrow{\bf S}{\bf U}={\bf U}({\bf L}^2+\sigma^2{\bf I})\\
&&\Leftrightarrow{\bf S}{\bf U}{\bf L}={\bf U}({\bf L}^2+\sigma^2{\bf I}){\bf L}\tag{13}
\end{eqnarray}

l_j\not=0 のとき、以下が成り立ちます。

\begin{eqnarray}
{\bf S}{\bf u}_j=(\sigma^2+l_j^2){\bf u}_j\tag{14}
\end{eqnarray}

(14) より、{\bf u}_j{\bf S}固有ベクトルです。
また、{\bf S}固有値 \lambda_j について以下の式が成り立ちます。

\begin{eqnarray}
&&\lambda_j=\sigma^2+l_j^2\\
&&\Leftrightarrow l_j=(\lambda_j-\sigma^2)^{1/2}\ (\because l_j\geqslant 0)\tag{15}
\end{eqnarray}

l_j=0 のときと合わせて、式 (10),(15) より、以下が成り立ちます。

\begin{eqnarray}
{\bf W}_{\rm ML}={\bf U}_M({\bf L}_M-\sigma^2{\bf I})^{1/2}{\bf R}\tag{16}
\end{eqnarray}

(16)
{\bf U}_MD\times M の行列で、その列ベクトルは {\bf S}固有ベクトルで、
{\bf L}_MM\times M の対角行列で、

\begin{eqnarray}
({\bf L}_M)_{jj}=
\left\{
    \begin{array}{l}
     \lambda_j,\ \ \ l_j\not=0 の時\ ({\bf u}_j に対応する固有値)\\
     \sigma^2,\ \ \ l_j=0 の時 
    \end{array}\tag{17}
  \right.
\end{eqnarray}

で、
{\bf R} は任意の M\times M の直交行列です。(式 (10){\bf V}^\top{\bf R} と書きなおしただけで、回転行列を連想しやすいように {\bf R} としたのだと思います。)

\sigma^2 の最尤解

{\bf W}={\bf W}_{\rm ML} とします。
このとき、式(2) より、

\begin{eqnarray}
{\bf C}&=&{\bf W}_{\rm ML}{\bf W}_{\rm ML}^\top+\sigma^2{\bf I}\\
&=&({\bf U}_M({\bf L}_M-\sigma^2{\bf I})^{1/2}{\bf R})({\bf U}_M({\bf L}_M-\sigma^2{\bf I})^{1/2}{\bf R})^\top+\sigma^2{\bf I}\\
&=&{\bf U}_M({\bf L}_M-\sigma^2{\bf I})^{1/2}{\bf R}{\bf R}^\top({\bf L}_M-\sigma^2{\bf I})^{1/2}{\bf U}_M^\top+\sigma^2{\bf I}\\
&=&{\bf U}_M({\bf L}_M-\sigma^2{\bf I}){\bf U}_M^\top+\sigma^2{\bf I}\\
&=&{\bf U}_M({\bf L}_M-\sigma^2{\bf I}){\bf U}_M^\top+{\bf U}_M(\sigma^2{\bf I}){\bf U}_M^\top\\
&=&{\bf U}_M{\bf L}_M{\bf U}_M^\top\tag{18}
\end{eqnarray}

となります。

(18)行列式の対数 \ln|{\bf C}| は以下のようになります。

\begin{eqnarray}
\ln|{\bf C}|&=&\ln\left(\prod_{j=1}^D({\bf L}_M)_{jj}\right)\\
&=&\ln\left(\left(\prod_{j=1}^{M'}\lambda_j\right)\left(\prod_{j=M'+1}^D\sigma^2\right)\right)\\
&=&\sum_{j=1}^{M'}\ln\lambda_j + \sum_{j=M'+1}^D\ln(\sigma^2)\\
&=&\sum_{j=1}^{M'}\ln\lambda_j + (D-M')\ln(\sigma^2)\tag{19}
\end{eqnarray}

(19)M'l_j\not=0 の特異値の数です。

{\rm Tr}({\bf C}^{-1}{\bf S}) を計算します。

\begin{eqnarray}
{\rm Tr}({\bf C}^{-1}{\bf S})&=&{\rm Tr}(({\bf U}_M{\bf L}_M{\bf U}_M^\top)^{-1}{\bf S})\\
&=&{\rm Tr}({\bf U}_M{\bf L}_M^{-1}{\bf U}_M^\top{\bf S})\\
&=&\underbrace{{\rm Tr}({\bf L}_M^{-1}{\bf U}_M^\top{\bf S}{\bf U}_M)}_{{\rm Tr}({\bf AB})={\rm Tr}({\bf BA})}\\
&=&\underbrace{{\rm Tr}({\bf L}_M^{-1}{\bf\Lambda})}_{{\rm Tr}({\bf AB})={\rm Tr}({\bf BA})}\\
&=&\sum_{j=1}^{M'}\frac{\lambda_j}{\lambda_j}+\sum_{j=M'+1}^{D}\frac{\lambda_j}{\sigma^2}\\
&=&\frac{1}{\sigma^2}\sum_{j=M'+1}^D\lambda_j+M'\tag{20}
\end{eqnarray}

対数尤度関数の式 (4) に式 (19),(20) を代入します。

\begin{eqnarray}
&&\ln p({\bf X}|{\boldsymbol\mu},{\bf W},\sigma^2)\\
&=&-\frac{N}{2}\left(D\ln(2\pi)+\sum_{j=1}^{M'}\ln\lambda_j + (D-M')\ln(\sigma^2)+\frac{1}{\sigma^2}\sum_{j=M'+1}^D\lambda_j+M'\right)\tag{21}
\end{eqnarray}

(21)\sigma^2 で計算して、=0 とおきます。

\begin{eqnarray}
&&\frac{\partial}{\partial\sigma^2}\ln p({\bf X}|{\boldsymbol\mu},{\bf W},\sigma^2)=0\\
&&\Leftrightarrow(D-M')\frac{\partial}{\partial\sigma^2}\ln(\sigma^2)+\frac{\partial}{\partial\sigma^2}\frac{1}{\sigma^2}\sum_{j=M'+1}^D\lambda_j=0\\
&&\Leftrightarrow(D-M')\frac{1}{\sigma^2}-\frac{1}{\sigma^4}\sum_{j=M'+1}^D\lambda_j=0\\
&&\Leftrightarrow\sigma^2=\frac{1}{D-M'}\sum_{j=M'+1}^D\lambda_j\tag{22}
\end{eqnarray}

ここから、まだいろいろな考察をしなくてはなりませんが、式変形については完了したので、一旦、本記事をアップします。
理解が深まり次第、本記事を更新する予定です。

\dfrac{\partial}{\partial{\bf W}}\ln|{\bf C}|=2{\bf C}^{-1}{\bf W} の導出

\begin{eqnarray}
\left(\frac{\partial}{\partial{\bf W}}\ln|{\bf C}|\right)_{ij}&=&\frac{\partial}{\partial W_{ij}}\ln|{\bf C}|\\
&=&\underbrace{{\rm Tr}\left({\bf C}^{-1}\frac{\partial}{\partial W_{ij}}{\bf C}\right)}_{\frac{\partial}{\partial x}\ln|{\bf A}|={\rm Tr}\left({\bf A}^{-1}\frac{\partial{\bf A}}{\partial x}\right)}\\
&=&{\rm Tr}\left({\bf C}^{-1}\frac{\partial}{\partial W_{ij}}({\bf W}{\bf W}^\top+\sigma^2{\bf I})\right)\\
&=&{\rm Tr}\left({\bf C}^{-1}\frac{\partial}{\partial W_{ij}}{\bf W}{\bf W}^\top\right)\\
&=&{\rm Tr}\Bigg({\bf C}^{-1}\underbrace{\left(\frac{\partial{\bf W}}{\partial W_{ij}}{\bf W}^\top+{\bf W}\frac{\partial{\bf W}^\top}{\partial W_{ij}}\right)}_{\frac{\partial}{\partial x}({\bf A}{\bf B}) = \frac{\partial{\bf A}}{\partial x}{\bf B}+{\bf A}\frac{\partial{\bf B}}{\partial x}}\Bigg)\\
&=&{\rm Tr}\left({\bf C}^{-1}\left({\bf J}_{ij}{\bf W}^\top+{\bf W}{\bf J}_{ji}\right)\right)\\
&=&\underbrace{{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)+{\rm Tr}\left({\bf C}^{-1}{\bf W}{\bf J}_{ji}\right)}_{{\rm Tr}({\bf A}+{\bf B})={\rm Tr}({\bf A})+{\rm Tr}({\bf B})}\\
&=&{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)+\underbrace{{\rm Tr}\left(({\bf C}^{-1}{\bf W}{\bf J}_{ji})^\top\right)}_{{\rm Tr}({\bf A})={\rm Tr}({\bf A}^\top)}\\
&=&{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)+{\rm Tr}\left({\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}\right)\\
&=&{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)+\underbrace{{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)}_{{\rm Tr}({\bf AB})={\rm Tr}({\bf BA})}\\
&=&2{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top\right)\\
&=&2\sum_{k,l,m}C_{kl}^{-1}({\bf J}_{ij})_{lm}W_{mk}^\top\\
&=&2\sum_{k,l,m}C_{kl}^{-1}({\bf J}_{ij})_{lm}W_{km}\\
&=&2\sum_{k,l,m}C_{kl}^{-1}\delta_{il}\delta_{jm}W_{km}\\
&=&2\sum_{k}C_{ki}^{-1}W_{kj}\\
&=&2\sum_{k}(C^{-1})_{ik}^\top W_{kj}\\
&=&2\sum_{k}C^{-1}_{ik} W_{kj}\\
&=&2\left({\bf C}^{-1}{\bf W}\right)_{ij}\\
\end{eqnarray}

\dfrac{\partial}{\partial{\bf W}}{\rm Tr}({\bf C}^{-1}{\bf S})=-2{\bf C}^{-1}{\bf S}{\bf C}^{-1}{\bf W} の導出

\begin{eqnarray}
\left(\frac{\partial}{\partial{\bf W}}{\rm Tr}({\bf C}^{-1}{\bf S})\right)_{ij}&=&\frac{\partial}{\partial W_{ij}}{\rm Tr}({\bf C}^{-1}{\bf S})\\
&=&{\rm Tr}\left(\frac{\partial{\bf C}^{-1}}{\partial W_{ij}}{\bf S}\right)\\
&=&{\rm Tr}\Bigg(\underbrace{-{\bf C}^{-1}\frac{\partial{\bf C}}{\partial W_{ij}}{\bf C}^{-1}}_{\frac{\partial}{\partial x}({\bf A}^{-1}) = -{\bf A}^{-1}\frac{\partial{\bf A}}{\partial x}{\bf A}^{-1}}{\bf S}\Bigg)\\
&=&-{\rm Tr}\left({\bf C}^{-1}\frac{\partial{\bf C}}{\partial W_{ij}}{\bf C}^{-1}{\bf S}\right)\\
&=&-{\rm Tr}\left({\bf C}^{-1}\frac{\partial}{\partial W_{ij}}({\bf W}{\bf W}^\top+\sigma^2{\bf I}){\bf C}^{-1}{\bf S}\right)\\
&=&-{\rm Tr}\left({\bf C}^{-1}\frac{\partial}{\partial W_{ij}}({\bf W}{\bf W}^\top){\bf C}^{-1}{\bf S}\right)\\
&=&-{\rm Tr}\Bigg({\bf C}^{-1}\underbrace{\left(\frac{\partial{\bf W}}{\partial W_{ij}}{\bf W}^\top+{\bf W}\frac{\partial{\bf W}^\top}{\partial W_{ij}}\right)}_{\frac{\partial}{\partial x}({\bf A}{\bf B}) = \frac{\partial{\bf A}}{\partial x}{\bf B}+{\bf A}\frac{\partial{\bf B}}{\partial x}}{\bf C}^{-1}{\bf S}\Bigg)\\
&=&-{\rm Tr}\left({\bf C}^{-1}\left({\bf J}_{ij}{\bf W}^\top+{\bf W}{\bf J}_{ji}\right){\bf C}^{-1}{\bf S}\right)\\
&=&\underbrace{-{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)-{\rm Tr}\left({\bf C}^{-1}{\bf W}{\bf J}_{ji}{\bf C}^{-1}{\bf S}\right)}_{{\rm Tr}({\bf A}+{\bf B})={\rm Tr}({\bf A})+{\rm Tr}({\bf B})}\\
&=&-{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)-\underbrace{{\rm Tr}\left(\left({\bf C}^{-1}{\bf W}{\bf J}_{ji}{\bf C}^{-1}{\bf S}\right)^\top\right)}_{{\rm Tr}({\bf A})={\rm Tr}({\bf A}^\top)}\\
&=&-{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)-{\rm Tr}\left({\bf S}{\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}\right)\\
&=&-{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)-\underbrace{{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)}_{{\rm Tr}({\bf AB})={\rm Tr}({\bf BA})}\\
&=&-2{\rm Tr}\left({\bf C}^{-1}{\bf J}_{ij}{\bf W}^\top{\bf C}^{-1}{\bf S}\right)\\
&=&-2\sum_{k,l,m,n,o}C_{kl}^{-1}({\bf J}_{ij})_{lm}W_{mn}^\top C_{no}^{-1}S_{ok}\\
&=&-2\sum_{k,l,m,n,o}C_{kl}^{-1}({\bf J}_{ij})_{lm}W_{nm} C_{no}^{-1}S_{ok}\\
&=&-2\sum_{k,l,m,n,o}C_{kl}^{-1}\delta_{il}\delta_{jm}W_{nm} C_{no}^{-1}S_{ok}\\
&=&-2\sum_{k,n,o}C_{ki}^{-1}W_{nj} C_{no}^{-1}S_{ok}\\
&=&-2\sum_{k,n,o}(C^{-1})_{ik}^\top S_{ko}^\top (C^{-1})_{on}^\top W_{nj} \\
&=&-2\sum_{k,n,o}C^{-1}_{ik} S_{ko} C^{-1}_{on} W_{nj} \\
&=&-2({\bf C}^{-1}{\bf S}{\bf C}^{-1}{\bf W})_{ij}\\
\end{eqnarray}

最後に

本記事は、Probabilistic Principal Component Analysisの行間を補ったものです。

参考文献

パターン認識機械学習 p290-p294

目次へ戻る