機械学習基礎理論独習

誤りがあればご指摘いただけると幸いです。数式が整うまで少し時間かかります。リンクフリーです。

勉強ログです。リンクフリーです
目次へ戻る

変分混合ガウス分布

混合ガウス分布モデルの確率変数

今回登場する確率変数の紹介です。
観測データを {\bf X}=\{{\bf x}_1,\ldots,{\bf x}_N\} とします。
EMアルゴリズム同様、潜在変数 {\bf Z}=\{{\bf z}_1,\ldots,{\bf z}_N\}, {\bf z}_n=\{z_{n1},\ldots,z_{nK}\},z_{nk}\in\{0,1\},\sum_{k=1}^Kz_{nk}=1 を潜り込ませます。
混合比率 {\boldsymbol\pi}=\{\pi_1,\ldots,\pi_K\},\pi_k\geqslant0,\sum_{k=1}^K\pi_k=1
ガウス分布の平均{\boldsymbol\mu}=\{{\boldsymbol\mu}_1,\ldots,{\boldsymbol\mu}_K\}
ガウス分布の精度 {{\boldsymbol\Lambda}}=\{{\boldsymbol\Lambda}_1,\ldots,{\boldsymbol\Lambda}_K\}です。
精度行列は分散共分散行列の逆行列のことです。計算が少し楽になるので、こちらを使います。

変分推論を適用する

やりたいことは、{\bf X}から混合ガウス分布の潜在変数とパラメータを推定することです。
そのためには、潜在変数とパラメータの事後分布 p({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda}|{\bf X})に近い分布q({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})を見つければよさそうです。

\begin{eqnarray}
p({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda}|{\bf X})=\frac{p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})}{p({\bf X})}\propto p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})\tag{1}
\end{eqnarray}

(1)より、潜在変数とパラメータの事後分布は同時分布に比例するので
同時分布 p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})に近い分布q({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})を考えることにします。

ここでq({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})が以下のように分解できると仮定します。

\begin{eqnarray}
q({\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})=q({\bf Z})q({\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})\tag{2}
\end{eqnarray}

(2)のように潜在変数とパラメータの分布を分けて近似する手続きを、特に変分EMアルゴリズムと呼ぶ場合があります。

\begin{eqnarray}
\ln q(z_i)=\langle\ln p({\mathcal D},{\bf Z})\rangle_{q({\bf Z}_{\backslash i})}+{\rm const.}\tag{3}\\
\end{eqnarray}

前回の記事で導いた変分推論の公式(3)を用いると

\begin{eqnarray}
\ln q({\bf Z})=\langle\ln p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})\rangle_{q({\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})}+{\rm const.}\tag{4}\\
\ln q({\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})=\langle\ln p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})\rangle_{q({\bf Z})}+{\rm const.}\tag{5}\\
\end{eqnarray}

以降では、
(4),(5)に共通な同時分布p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})を定める
(4),(5)を変形して、 q({\bf Z}),q({\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})の更新式を求める
の2つをやっていきます。

同時分布

変分推論を用いた混合ガウス分布モデルのグラフィカルモデルは、以下の図1です。
※図1では、ハイパーパラメータ{\boldsymbol\alpha}_0,{\bf m}_0,\beta_0,{\bf W}_0,\nu_0は省略しています。

図1
f:id:olj611:20210221023705p:plain

このグラフィカルモデルから同時分布を書き起こしてみます。

\begin{eqnarray}
p({\bf X},{\bf Z},{\boldsymbol\pi},{\boldsymbol\mu},{\bf\Lambda})=p({\bf X}|{\bf Z},{\boldsymbol\mu},{\bf\Lambda})p({\bf Z}|{\boldsymbol\pi})p({\boldsymbol\pi})p({\boldsymbol\mu}|{\bf\Lambda})p({\bf\Lambda})\tag{6}
\end{eqnarray}

グラフィカルモデルでは{\bf x_n},{\bf z}_nはプレート内にありますが、まとめて{\bf X},{\bf Z}として書きだしました。
(6)の右辺の因子をそれぞれ書き出してみます。

\begin{eqnarray}
&&p({\bf Z}|{\boldsymbol\pi})=\prod_{n=1}^N\prod_{k=1}^K\pi_k^{z_{nk}}\tag{7}\\
&&p({\bf X|}{\bf Z},{\boldsymbol\mu},{\bf\Lambda})=\prod_{n=1}^N\prod_{k=1}^K\mathcal{N}({\bf x}_n|{\boldsymbol\mu}_k,{\bf\Lambda}^{-1}_k)^{z_{nk}}\tag{8}\\
&&p({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha}_0)=C({\boldsymbol\alpha}_0)\prod_{k=1}^K\pi_k^{\alpha_0-1}\tag{9}\\
&&p({\boldsymbol\mu},{\bf\Lambda})=p({\boldsymbol\mu}|{\bf\Lambda})p({\bf\Lambda})=\prod_{k=1}^K\mathcal{N}({\boldsymbol\mu}_k|{\bf m}_0,(\beta_0{\bf\Lambda}_k)^{-1})\mathcal{W}({\bf\Lambda}_k|{\bf W}_0,\nu_0)\tag{10}
\end{eqnarray}

(9)ですが、\boldsymbol\pi はカテゴリ分布のパラメータなので、共役事前分布のディレクレ分布を採用しています。C({\boldsymbol\alpha}_0) は正規化定数です。
{\boldsymbol\alpha}_0 の要素は全て \alpha_0 としています。

(10)ですが、{\boldsymbol\mu},{\bf\Lambda}は多次元ガウス分布の平均と精度なので、共役事前分布のガウスウィシャート分布を採用しています。

q({\bf Z})の式変形

PRML演習問題10.12(標準)より、q({\bf Z})は以下のようになります。

\begin{eqnarray}
q({\bf Z})=\prod_{n=1}^N\prod_{k=1}^Kr_{nk}^{z_{nk}}\tag{11}\\
\end{eqnarray}

(11) r_{nk} は以下のようにおきました。

\begin{eqnarray}
r_{nk}=\frac{\rho_{nk}}{\displaystyle\sum_{j=1}^K\rho_{nj}}\tag{12}
\end{eqnarray}

(12)\rho_{nk}は以下のようにおきました。

\begin{eqnarray}
\ln\rho_{nk}=\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}+\frac{1}{2}\langle\ln|{\bf\Lambda}_k|\rangle_{q({\bf\Lambda}_k)}-\frac{D}{2}\ln2\pi-\frac{1}{2}\langle({\bf x}_n-{\boldsymbol\mu}_k)^\top{\bf\Lambda}_k({\bf x}_n-{\boldsymbol\mu}_k)\rangle_{q({\boldsymbol\mu}_k,{\bf\Lambda}_k)}\tag{13}
\end{eqnarray}

ここで  z_{nk} q({\bf Z}) に対する期待値は以下のように計算できます。

\begin{eqnarray}
\langle z_{nk}\rangle_{q({\bf Z})}&=&\langle z_{nk}\rangle_{q({\bf z}_n)}\\
&=&\sum_{{\bf z}_n}z_{nk}q({\bf z}_n)\\
&=&\sum_{{\bf z}_n}z_{nk}\prod_{k'=1}^Kr_{nk'}^{z_{nk'}}\\
&=&1\cdot r_{nk}^1\\
&=&r_{nk}\tag{14}\\
\end{eqnarray}

q({\boldsymbol\pi},{\boldsymbol\mu},{\boldsymbol\Lambda})の式変形

この後の数式を見やすくする為に、3つの統計量を定義します。

\begin{eqnarray}
&&N_k=\sum_{n=1}^Nr_{nk}\tag{15}\\
&&\bar{\bf x}_k=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}{\bf x}_n\tag{16}\\
&&{\bf S}_k=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}({\bf x}_n-\bar{\bf x}_k)({\bf x}_n-\bar{\bf x}_k)\top\tag{17}\\
\end{eqnarray}

PRML演習問題 10.13(標準) wwwより、q({\boldsymbol\pi},{\boldsymbol\mu},{\bf\Lambda})が以下のように分解されることが分かります。

\begin{eqnarray}
q({\boldsymbol\pi},{\boldsymbol\mu},{\bf\Lambda})=q({\boldsymbol\pi})\prod_{k=1}^Kq({\boldsymbol\mu}_k,{\bf\Lambda}_k)\tag{18}
\end{eqnarray}

これは変分事後分布 q({\boldsymbol\pi},{\boldsymbol\mu},{\bf\Lambda})q({\boldsymbol\pi})q({\boldsymbol\mu},{\bf\Lambda}) と分解されることを意味します。
さらに、{\boldsymbol\mu} 及び \bf\Lambda を含む項は {\boldsymbol\mu}_k{\bf\Lambda}_k を含む項の k についての積からなり、次のように分解されます。

PRML演習問題 10.13(標準) wwwより、 q({\boldsymbol\pi}) はディレクレ分布であることが分かります。ディレクレ分布のパラメータを {\boldsymbol\alpha}=(\alpha_0,\ldots,\alpha_K)^\top とおきます。

\begin{eqnarray}
q({\boldsymbol\pi})&=&{\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})\\
&=&C_{\rm D}({\boldsymbol\alpha})\prod_{k=1}^K\pi_k^{\alpha_k-1}\tag{19}
\end{eqnarray}

(19)で、\alpha_kを以下のようにおきました。

\begin{eqnarray}
\alpha_k=\alpha_0+N_k\tag{20}
\end{eqnarray}

PRML演習問題 10.13(標準) wwwより、q({\boldsymbol\mu}_k,{\bf\Lambda}_k)ガウス-ウィシャート分布であることが分かります。

\begin{eqnarray}
q({\boldsymbol\mu}_k,{\bf\Lambda}_k)=\mathcal{N}({\boldsymbol\mu}_k|{\bf m}_k,(\beta_k{\bf\Lambda}_k)^{-1})\mathcal{W}({\bf\Lambda}_k|{\bf W}_k,\nu_k)\tag{21}
\end{eqnarray}

(21)\beta_k,{\bf m}_k,\nu_k,{\bf W}_kは以下のようにおきました。

\begin{eqnarray}
&&\beta_k=\beta_0+N_k\tag{22}\\
&&{\bf m}_k=\frac{1}{\beta_k}(\beta_0{\bf m}_0+N_k\bar{\bf x}_k)\tag{23}\\
&&\nu_k=\nu_0+N_k\tag{24}\\
&&{\bf W}_k^{-1}={\bf W}_0^{-1}+N_kS_k+\frac{\beta_0N_k}{\beta_0+N_k}(\bar{\bf x}_k-{\bf m}_0)(\bar{\bf x}_k-{\bf m}_0)^\top\tag{25}\\
\end{eqnarray}

\rho_{nk}の計算

q({\boldsymbol\pi}),q({\boldsymbol\mu}_k,{\bf\Lambda}_k)が分かったので、\rho_{nk}を計算していきます。

その前に、計算で使う期待値を記しておきます。

ディレクレ分布 {\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})=C_{\rm D}({\boldsymbol\alpha})\displaystyle\prod_{k=1}^K\pi_k^{\alpha_k-1}\ln\pi_kの期待値は以下のようになります。

\begin{eqnarray}
\langle\ln\pi_k\rangle=\psi(\alpha_k)-\psi\left(\sum_{i=1}^K\alpha_k\right)\tag{26}
\end{eqnarray}

\psi(\cdot)はディガンマ関数です。

ウィシャート分布 \mathcal{W}({\bf\Lambda}|{\bf W},\nu)=C_{\mathcal W}|{\bf\Lambda}|^{\frac{\nu-D-1}{2}}\exp\left(-\dfrac{1}{2}{\rm Tr}({\bf W}^{-1}{\bf\Lambda})\right)\ln|{\bf\Lambda}|の期待値は以下のようになります。

\begin{eqnarray}
\langle\ln|{\bf\Lambda}|\rangle=\sum_{i=1}^D\psi\left(\frac{\nu+1-i}{2}\right)+D\ln2+\ln|{\bf W}|\tag{27}\\
\end{eqnarray}

(26),(27)は、これまで計算してきた変分ガウス分布の変数とは関係なく一般的なものなので、ご注意ください。

\rho_{nk}には期待値の項が3つあるので、別々に計算していきます。

まず、\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}を計算します。
これは式(26)をそのまま適用すればよいので、以下のようになります。

\begin{eqnarray}
\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}&=&\psi(\alpha_k)-\psi\left(\sum_{i=1}^K\alpha_k\right)\\
&=&\psi(\alpha_k)-\psi(\hat{\alpha})\tag{28}
\end{eqnarray}

\sum_{k=1}^K\alpha_k\psi(\hat{\alpha})とおきました。

次に、\langle\ln|{\bf\Lambda}_k|\rangle_{q({\bf\Lambda}_k)}を計算します。
これは式(26)をそのまま適用すればよいので、以下のようになります。

\begin{eqnarray}
\langle\ln|{\bf\Lambda}_k|\rangle_{q({\bf\Lambda}_k)}=\sum_{i=1}^D\psi\left(\frac{\nu_k+1-i}{2}\right)+D\ln2+\ln|{\bf W}_k|\tag{29}\\
\end{eqnarray}

最後に、\langle({\bf x}_n-{\boldsymbol\mu}_k)^\top{\bf\Lambda}_k({\bf x}_n-{\boldsymbol\mu}_k)\rangle_{q({\boldsymbol\mu}_k,{\bf\Lambda}_k)}PRML演習問題 10.14(標準)より、以下のようになります。

\begin{eqnarray}
\langle({\bf x}_n-{\boldsymbol\mu}_k)^\top{\bf\Lambda}_k({\bf x}_n-{\boldsymbol\mu}_k)\rangle_{q({\boldsymbol\mu}_k,{\bf\Lambda}_k)}
&=&D\beta_k^{-1}+\nu_k({\bf x}_n-{\bf m}_k)^\top{\bf W}_k({\bf x}_n-{\bf m}_k)\tag{30}\\
\end{eqnarray}

ここで、式(28),(29)を以下のようにおきます。

\begin{eqnarray}
&&\ln\tilde{\pi}_k=\psi(\alpha_k)-\psi(\hat{\alpha})\tag{31}\\
&&\ln\tilde{\bf\Lambda}_k=\sum_{i=1}^D\psi\left(\frac{\nu_k+1-i}{2}\right)+D\ln2+\ln|{\bf W}_k|\tag{32}
\end{eqnarray}

(30),(31),(32)を式(13)へ代入します。

\begin{eqnarray}
\ln\rho_{nk}&=&\ln\tilde{\pi}_k+\frac{1}{2}\ln\tilde{\bf\Lambda}_k-\frac{D}{2}\ln2\pi-\frac{1}{2}\left(\beta_k^{-1}D+\nu_k({\bf x}_n-{\bf m}_k)^\top{\bf W}_k({\bf x}_n-{\bf m}_k)\right)\tag{33}\\
\end{eqnarray}

(33)の対数を外します。

\begin{eqnarray}
\rho_{nk}&=&\tilde{\pi}_k\tilde{\bf\Lambda}_k^{\frac{1}{2}}\exp{\left(-\frac{D}{2\beta_k}-\frac{\nu_k}{2}({\bf x}_n-{\bf m}_k)^\top{\bf W}_k({\bf x}_n-{\bf m}_k)\right)}\exp\left(-\frac{D}{2}\ln2\pi\right)\tag{34}\\
\end{eqnarray}

ところで\rho_{nk}r_{nk}の計算のために必要で、\exp\left(-\frac{D}{2}\ln2\pi\right)は定数であり、計算時に打ち消しあうので、式(34)は以下のようになります。
\ln\tilde{\bf\Lambda}_kにもD\ln2という定数があるが、参考書に倣いそのままにしておきます。

\begin{eqnarray}
\rho_{nk}=\tilde{\pi}_k\tilde{\bf\Lambda}_k^{\frac{1}{2}}\exp{\left(-\frac{D}{2\beta_k}-\frac{\nu_k}{2}({\bf x}_n-{\bf m}_k)^\top{\bf W}_k({\bf x}_n-{\bf m}_k)\right)}\tag{35}\\
\end{eqnarray}

(35)は本来比例\proptoで書くべきかもしれませんがイコールにしています。

アルゴリズム

0.要件

D次元の\{{\bf x}_1,\ldots,{\bf x}_N\}が与えられています。

1.初期化

K,\alpha_0,\beta_0,{\bf m}_0,\nu_0,{\bf W}_0を初期化します。

\alpha_k,\beta_k,{\bf m}_k,\nu_k,{\bf W}_kを以下のように初期化します。

\begin{eqnarray}
&&\alpha_k\leftarrow\alpha_0\tag{36}\\
&&\beta_k\leftarrow\beta_0\tag{37}\\
&&{\bf m}_k\leftarrow{\bf m}_0\tag{38}\\
&&\nu_k\leftarrow\nu_0\tag{39}\\
&&{\bf W}_k\leftarrow{\bf W}_0\tag{40}\\
\end{eqnarray}

2.変分Eステップ

\rho_{nk},r_{nk}を更新します。

\begin{eqnarray}
&&r_{nk}\leftarrow\frac{\rho_{nk}}{\sum_{j=1}^K\rho_{nj}}\tag{41}
\end{eqnarray}

ただし、

\begin{eqnarray}
&&\rho_{nk}=\tilde{\pi}_k\tilde{\bf\Lambda}_k^{\frac{1}{2}}\exp{\left(-\frac{D}{2\beta_k}-\frac{\nu_k}{2}({\bf x}_n-{\bf m}_k)^\top{\bf W}_k({\bf x}_n-{\bf m}_k)\right)}\tag{42}\\
&&\ln\tilde{\pi}_k=\psi(\alpha_k)-\psi(\hat{\alpha})\tag{43}\\
&&\psi(\hat{\alpha})=\sum_{k=1}^K\alpha_k\tag{44}\\
&&\ln\tilde{\bf\Lambda}_k=\sum_{i=1}^D\psi(\frac{\nu_k+1-i}{2})+D\ln2+\ln|{\bf W}_k|\tag{45}
\end{eqnarray}

とします。

3.変分Mステップ

\alpha_k,\beta_k,{\bf m}_k,\nu_k,{\bf W}_kを更新します。

\begin{eqnarray}
&&\alpha_k\leftarrow\alpha_0+N_k\tag{46}\\
&&\beta_k\leftarrow\beta_0+N_k\tag{47}\\
&&{\bf m}_k\leftarrow\frac{1}{\beta_k}(\beta_0{\bf m}_0+N_k\bar{\bf x}_k)\tag{48}\\
&&\nu_k\leftarrow\nu_0+N_k\tag{49}\\
&&{\bf W}_k^{-1}\leftarrow{\bf W}_0^{-1}+N_kS_k+\frac{\beta_0N_k}{\beta_0+N_k}(\bar{\bf x}_k-{\bf m}_0)(\bar{\bf x}_k-{\bf m}_0)^\top\tag{50}\\
\end{eqnarray}

ただし、

\begin{eqnarray}
&&N_k=\sum_{n=1}^Nr_{nk}\tag{51}\\
&&\bar{\bf x}_k=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}{\bf x}_n\tag{52}\\
&&{\bf S}_k=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}({\bf x}_n-\bar{\bf x}_k)({\bf x}_n-\bar{\bf x}_k)\top\tag{53}\\
\end{eqnarray}

とします。

4.収束確認
対数尤度を再計算し、前回との差分があらかじめ設定していた収束条件を満たしていなければ2に戻り、満たしていれば終了します。

※収束確認は変分下限でもよいと思います。

偉人の名言

f:id:olj611:20210303202222p:plain
笑われて、笑われて、強くなる。
太宰治

参考文献

パターン認識機械学習 下巻
パターン認識機械学習の学習
ベイズ推論による機械学習入門

動画

目次へ戻る