機械学習基礎理論独習

誤りがあればご指摘いただけると幸いです。数式が整うまで少し時間かかります。リンクフリーです。

勉強ログです。リンクフリーです
目次へ戻る

ポアソン混合モデルにおける変分推論

はじめに

本記事は、「ベイズ推論による機械学習入門」という書籍を参考に書いたので、
今までの記事とは異なり、潜在変数を\bf Sと書いております。

本記事でも、「平均場近似の変分推論といえば、この式!」という以下の式を使います。

\begin{eqnarray}
\ln q(z_i)=\langle\ln p({\mathcal D},{\bf Z})\rangle_{q({\bf Z}_{\backslash i})}+{\rm const}\tag{1}\\
\end{eqnarray}

(1)については、こちらで解説しています。

潜在変数{\bf S}、混合比率{\boldsymbol\pi}については、本記事では詳しくは説明しませんので、わからない方は
混合ガウス分布の最尤推定変分混合ガウス分布の記事を参考にしてください。

全体の流れ

変分推論の記事はどうしても長くなってしまう為、全体の流れを書きます。

1. まず観測データ{\mathcal D}と観測されていない未知の変数{\bf Z}に関して、同時分布を構築する。
2. 事後分布p({\bf Z}|{\mathcal D})を解析的あるいは近似的に求める。

これ実は、ベイズ統計全般に言える事なのでですが、すごくすごく大事なことなので書いておきました。
今回ももちろんこの流れに沿っています。

今回のモデルでは、 {\mathcal D}={\bf X},\ {\bf Z}=\{{\bf S},{\boldsymbol\lambda},{\boldsymbol\pi}\} となります。
2で、「解析的あるいは近似的に」とありますが、今回は「近似的に」求めます。

ポアソン混合モデル

ポアソン混合モデルは、多峰性の離散非負データ(図1の左のヒストグラム参照)を学習する際に用いられます。

1
f:id:olj611:20211007181954p:plain:h200

ここでは、1次元データに対するポアソン混合モデルを考えます。

まず、ポアソン混合モデルの確率変数を書き下します。

観測データを {\bf X}=\{x_1,\ldots,x_N\} とします。
潜在変数を {\bf S}=\{{\bf s}_1,\ldots,{\bf s}_N\}, {\bf s}_n=\{s_{n1},\ldots,s_{nK}\},s_{nk}\in\{0,1\},\sum_{k=1}^Ks_{nk}=1 とします。
混合比率を {\boldsymbol\pi}=\{\pi_1,\ldots,\pi_K\},\pi_k\geqslant0,\sum_{k=1}^K\pi_k=1 とします。
ポアソン分布のパラメータを {\boldsymbol\lambda}=\{\lambda_1,\ldots,\lambda_K\} とします。

次に、確率分布を書き下します。

p(x_n|\lambda_k)ポアソン分布を採用します。

\begin{eqnarray}
p(x_n|\lambda_k)={\rm Poi}(x_n|\lambda_k)=\frac{\lambda_k^{x_n}}{x_n!}\exp(-\lambda_k)\tag{2}
\end{eqnarray}

よって、混合分布における条件付き分布\ p(x_n|{\bf s}_n,{\boldsymbol\lambda})\ は以下のようになります。

\begin{eqnarray}
p(x_n|{\bf s}_n,{\boldsymbol\lambda})=\prod_{k=1}^K{\rm Poi}(x_n|\lambda_k)^{s_{nk}}=\prod_{k=1}^K{\rm Poi}(x_n|\lambda_k)^{s_{nk}}\tag{3}
\end{eqnarray}

観測値x_nはそれぞれ独立だと仮定すると、以下が成り立ちます。

\begin{eqnarray}
p({\bf X}|{\bf S},{\boldsymbol\lambda})=\prod_{n=1}^Np(x_n|{\bf s}_n,{\boldsymbol\lambda})=\prod_{n=1}^N\prod_{k=1}^K{\rm Poi}(x_n|\lambda_k)^{s_{nk}}\tag{4}
\end{eqnarray}

{\boldsymbol\lambda}=\{\lambda_1,\ldots,\lambda_K\}の事前分布は、ポアソン分布の共役事前分布であるガンマ分布とします。

\begin{eqnarray}
p(\lambda_k)={\rm Gam}(\lambda_k|a,b)=C_{\rm G}(a,b)\lambda_k^{a-1}\exp(-b\lambda_k)\tag{5}
\end{eqnarray}

(5)C_{\rm G}(a,b)はガンマ分布の正規化項です。
\lambda_kはそれぞれ独立とします。

\begin{eqnarray}
p({\boldsymbol\lambda})=\prod_{k=1}^Kp(\lambda_k)=\prod_{k=1}^K{\rm Gam}(\lambda_k|a,b)\tag{6}
\end{eqnarray}

{\bf s}_nは1 of K表現を用いているので、p({\bf s}_n|{\boldsymbol\pi})はカテゴリ分布とします。

\begin{eqnarray}
p({\bf s}_n|{\boldsymbol\pi})={\rm Cat}({\bf s}_n|{\boldsymbol\pi})=\prod_{k=1}^K\pi_k^{s_{nk}}\tag{7}
\end{eqnarray}

x_n{\bf s}_nはペアであり、x_nがそれぞれ独立と仮定しているので、{\bf s}_nはそれぞれ独立です。

\begin{eqnarray}
p({\bf S}|{\boldsymbol\pi})=\prod_{n=1}^Np({\bf s}_n|{\boldsymbol\pi})=\prod_{n=1}^N{\rm Cat}({\bf s}_n|{\boldsymbol\pi})\tag{8}
\end{eqnarray}

混合比率 {\boldsymbol\pi}=\{\pi_1,\ldots,\pi_K\} の事前分布ですが、
{\boldsymbol\pi}はカテゴリ分布のパラメータなので、共役事前分布のディリクレ分布を採用します。

\begin{eqnarray}
p({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})=C_{\rm D}({\boldsymbol\alpha})\prod_{k=1}^K\pi_k^{\alpha_k-1}\tag{9}\\
\end{eqnarray}

(9)で、{\boldsymbol\alpha}=\{\alpha_1,\ldots,\alpha_K\}はハイパーパラメータであり、固定です。
また、C_{\rm D}({\boldsymbol\alpha})はディレクレ分布の正規化定数です。

グラフィカルモデルは、以下のようになります。

2
f:id:olj611:20211016125056p:plain:h200

同時分布は、以下の式で表せます。

\begin{eqnarray}
p({\bf X},{\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})&=&p({\bf X}|{\bf S},{\boldsymbol\lambda})p({\bf S}|{\boldsymbol\pi})p({\boldsymbol\lambda})p({\boldsymbol\pi})\tag{10}
\end{eqnarray}

変分推論を適用

準備が整いましたので、変分推論を適用します。

潜在変数とパラメータの事後分布p({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi}|{\bf X})を以下のように近似します。

\begin{eqnarray}
p({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi}|{\bf X})\simeq q({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})=q({\bf S})q({\boldsymbol\lambda},{\boldsymbol\pi})\tag{11}
\end{eqnarray}

このように潜在変数とパラメータを分けて近似する手続きを変分EMアルゴリズムということがあります。
また、q({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})を近似事後分布、または変分事後分布と呼びます。

\ln q({\bf S})の導出

\ln q({\bf S})に式(1)を当てはめるため、式(1)において、z_i={\bf S},{\bf Z}_{\backslash i}=\{{\boldsymbol\lambda},{\boldsymbol\pi}\}とします。

\begin{eqnarray}
\ln q({\bf S})&=&\langle p({\bf X},{\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})p({\bf S}|{\boldsymbol\pi})p({\boldsymbol\lambda})p({\boldsymbol\pi})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+\langle \ln p({\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+\langle \ln p({\boldsymbol\pi})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}+{\rm const}\tag{12}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda})}+\langle\ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\boldsymbol\pi})}+{\rm const}\tag{13}\\
&=&\left\langle\sum_{n=1}^N \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\right\rangle_{q({\boldsymbol\lambda})}+\left\langle\sum_{n=1}^N\ln p({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\boldsymbol\pi})}+{\rm const}\\
&=&\sum_{n=1}^N\left\langle \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\right\rangle_{q({\boldsymbol\lambda})}+\sum_{n=1}^N\left\langle\ln p({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\boldsymbol\pi})}+{\rm const}\tag{14}\\
&=&\sum_{n=1}^N\left(\left\langle \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\right\rangle_{q({\boldsymbol\lambda})}+\left\langle\ln p({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\boldsymbol\pi})}\right)+{\rm const}\tag{15}\\
\end{eqnarray}

(12)で、\bf Sを含まない項は{\rm const}にまとめました。
(13)で、\langle\cdot\rangle_{q({\boldsymbol\lambda},{\boldsymbol\pi})}\langle\cdot\rangle_{q({\boldsymbol\lambda})}\langle\cdot\rangle_{q({\boldsymbol\pi})}になっていますが、周辺化したためです。
(14)で、期待値の線形性を使って、\sum\langle\cdot\rangleの外に出しました。

(15)より、\ln q({\bf S})N個の和で表現されているため、q({\bf S})N個の積で表現されます。
よって、q({\bf S})は各点ごとの独立な分布q({\bf s}_1)\cdots q({\bf s}_N)に分解されることが分かります。

したがって、式(15)より、\ln q({\bf s}_n)は以下のように書けます。

\begin{eqnarray}
\ln q({\bf s}_n)&=&\langle \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda})}+\langle\ln p({\bf s}_n|{\boldsymbol\pi})\rangle_{q({\boldsymbol\pi})}+{\rm const}\tag{16}
\end{eqnarray}

(16)\langle \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda})}を変形します。

\begin{eqnarray}
\langle \ln p(x_n|{\bf s}_n,{\boldsymbol\lambda})\rangle_{q({\boldsymbol\lambda})}&=&\left\langle\sum_{k=1}^Ks_{nk}\ln{\rm Poi}(x_n|\lambda_k)\right\rangle_{q({\boldsymbol\lambda})}\\
&=&\sum_{k=1}^Ks_{nk}\left\langle \ln{\rm Poi}(x_n|\lambda_k)\right\rangle_{q({\boldsymbol\lambda})}\tag{17}\\
&=&\sum_{k=1}^Ks_{nk}\left\langle \ln{\rm Poi}(x_n|\lambda_k)\right\rangle_{q({\lambda_k})}\tag{18}\\
&=&\sum_{k=1}^Ks_{nk}\left\langle\ln\frac{\lambda_k^{x_n}}{x_n!}\exp(-\lambda_k)\right\rangle_{q({\lambda_k})}\\
&=&\sum_{k=1}^Ks_{nk}\left\langle x_n\ln\lambda_k-\ln x_n!-\lambda_k\right\rangle_{q({\lambda_k})}\\
&=&\sum_{k=1}^Ks_{nk}\left(x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\langle\ln x_n!\rangle_{q({\lambda_k})}-\langle\lambda_k\rangle_{q({\lambda_k})}\right)\\
&=&\sum_{k=1}^Ks_{nk}\left(x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\ln x_n!-\langle\lambda_k\rangle_{q({\lambda_k})}\right)\\
&=&\sum_{k=1}^Ks_{nk}\left(x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\langle\lambda_k\rangle_{q({\lambda_k})}\right)+{\rm const}\tag{19}
\end{eqnarray}

(17)で、期待値の線形性を使って、\sum,s_{nk}\langle\cdot\rangleの外に出しました。
(18)で、\langle\cdot\rangle_{q({\boldsymbol\lambda})}\langle\cdot\rangle_{q({\lambda_k})}になっていますが、\lambda_{j\not=k}で周辺化されたためです。
(19)で、\sum_{k=1}^Ks_{nk}\ln x_n!=\ln x_n!であるため、-\ln x_n!\sumの外に出し、
-\ln x_n!{\bf s}_nとは無関係である為、{\rm const}にまとめました。

(16)\langle\ln p({\bf s}_n|{\boldsymbol\pi})\rangle_{q({\boldsymbol\pi})}を変形します。

\begin{eqnarray}
\langle\ln p({\bf s}_n|{\boldsymbol\pi})\rangle_{q({\boldsymbol\pi})}&=&\left\langle{\rm Cat}({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\boldsymbol\pi})}\\
&=&\left\langle\ln\prod_{k=1}^K\pi_k^{s_{nk}}\right\rangle_{q({\boldsymbol\pi})}\\
&=&\left\langle\sum_{k=1}^Ks_{nk}\ln\pi_k\right\rangle_{q({\boldsymbol\pi})}\\
&=&\sum_{k=1}^Ks_{nk}\left\langle\ln\pi_k\right\rangle_{q({\boldsymbol\pi})}\tag{20}
\end{eqnarray}

(20)で、期待値の線形性を使って、\sum,s_{nk}\langle\cdot\rangleの外に出しました。
(20)\langle\cdot\rangle_{q({\boldsymbol\pi})}\langle\cdot\rangle_{q(\pi_k)}のように変形していない理由ですが、
\sum_{k=1}^K\pi_k=1という制約のため、分解して考えにくい為であると思います。
また、後でわかるのですがq({\boldsymbol\pi})はディレクレ分布であり、成分の分布での期待値は普通考えない為だと思われます。

(19),(20)を式(16)に代入します。

\begin{eqnarray}
\ln q({\bf s}_n)&=&\sum_{k=1}^Ks_{nk}\left(x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\langle\lambda_k\rangle_{q({\lambda_k})}\right)+\sum_{k=1}^Ks_{nk}\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}+{\rm const}\\
&=&\sum_{k=1}^Ks_{nk}\left(x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\langle\lambda_k\rangle_{q({\lambda_k})}+\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}\right)+{\rm const}\tag{21}
\end{eqnarray}

(21)において、

\begin{eqnarray}
\ln\rho_{nk}=x_n\langle\ln\lambda_k\rangle_{q({\lambda_k})}-\langle\lambda_k\rangle_{q({\lambda_k})}+\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}\tag{22}
\end{eqnarray}

とおきます。

(22)を式(21)に代入します。

\begin{eqnarray}
&&\ln q({\bf s}_n)=\sum_{k=1}^Ks_{nk}\ln\rho_{nk}+{\rm const}\tag{23}\\
&&\Leftrightarrow q({\bf s}_n)=C\prod_{k=1}^K\rho_{nk}^{s_{nk}}\tag{24}\\
&&\Leftrightarrow q({\bf s}_n)=\prod_{k=1}^K\eta_{nk}^{s_{nk}}\tag{25}\\
&&\Leftrightarrow q({\bf s}_n)={\rm Cat}({\bf s}_n|{\boldsymbol\eta}_n)\tag{26}
\end{eqnarray}

(24)Cは式(23){\rm const}です。
(24)から式(25)の変形はやや面倒なので、最後に説明します。
(26)\eta_{nk}=\dfrac{\rho_{nk}}{\displaystyle\sum_{j=1}^K\rho_{nj}},\ {\boldsymbol\eta}_n=\{\eta_{n1},\ldots,\eta_{nK}\}とおきました。

(26)より、近似分布q({\bf s}_n)はパラメータ{\boldsymbol\eta}_nを持つカテゴリ分布になることが分かります。
近似分布q({\lambda_k}),q({\boldsymbol\pi})がまだ明らかになっていないので、
(22)\langle\ln\lambda_k\rangle_{q({\lambda_k})},\langle\lambda_k\rangle_{q({\lambda_k})},\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}の計算は後回しにします。

\ln q({\boldsymbol\lambda},{\boldsymbol\pi})の導出

\ln q({\boldsymbol\lambda},{\boldsymbol\pi})に式(1)を当てはめるため、式(1)において、z_i=\{{\boldsymbol\lambda},{\boldsymbol\pi}\},{\bf Z}_{\backslash i}={\bf S}とします。

\begin{eqnarray}
\ln q({\boldsymbol\lambda},{\boldsymbol\pi})&=&\langle p({\bf X},{\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})\rangle_{q({\bf S})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})p({\bf S}|{\boldsymbol\pi})p({\boldsymbol\lambda})p({\boldsymbol\pi})\rangle_{q({\bf S})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\langle \ln p({\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\boldsymbol\pi})\rangle_{q({\bf S})}+{\rm const}\\
&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+\ln p({\boldsymbol\pi})+{\rm const}\tag{27}
\end{eqnarray}

(27)で、\ln p({\boldsymbol\lambda})\ln p({\boldsymbol\pi})\bf Sに無関係なので、\langle\ln p({\boldsymbol\lambda})\rangle_{q({\bf S})}=\ln p({\boldsymbol\lambda}),\langle\ln p({\boldsymbol\pi})\rangle_{q({\bf S})}=\ln p({\boldsymbol\pi})としています。
また、式(27)は項が{\boldsymbol\lambda}{\boldsymbol\pi}で分かれているので、近似分布において独立、
すなわち、q({\boldsymbol\lambda},{\boldsymbol\pi})=q({\boldsymbol\lambda})q({\boldsymbol\pi})になることを意味しています。

(27)\ln q({\boldsymbol\lambda})について整理します。

\begin{eqnarray}
&&\ln q({\boldsymbol\lambda})q({\boldsymbol\pi})=\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+\ln p({\boldsymbol\pi})+{\rm const}\\
&&\Leftrightarrow\ln q({\boldsymbol\lambda})+\ln q({\boldsymbol\pi})=\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+\ln p({\boldsymbol\pi})+{\rm const}\\
&&\Leftrightarrow\ln q({\boldsymbol\lambda})=\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+{\rm const}\tag{28}
\end{eqnarray}

(28)で、{\boldsymbol\lambda}に関係ない項は{\rm const}.にまとめています。

(27)\ln q({\boldsymbol\pi})について整理します。

\begin{eqnarray}
&&\ln q({\boldsymbol\lambda})q({\boldsymbol\pi})=\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+\ln p({\boldsymbol\pi})+{\rm const}\\
&&\Leftrightarrow\ln q({\boldsymbol\lambda})+\ln q({\boldsymbol\pi})=\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+\ln p({\boldsymbol\pi})+{\rm const}\\
&&\Leftrightarrow\ln q({\boldsymbol\pi})=\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\pi})+{\rm const}\tag{29}
\end{eqnarray}

(29)で、{\boldsymbol\pi}に関係ない項は{\rm const}.にまとめています。

以下で、\ln q({\boldsymbol\lambda})\ln q({\boldsymbol\pi})を別々に計算していきます。

\ln q({\boldsymbol\lambda})の導出

(28)より、\ln q({\boldsymbol\lambda})を計算していきます。

\begin{eqnarray}
\ln q({\boldsymbol\lambda})&=&\langle \ln p({\bf X}|{\bf S},{\boldsymbol\lambda})\rangle_{q({\bf S})}+\ln p({\boldsymbol\lambda})+{\rm const}\\
&=&\left\langle\sum_{n=1}^N\sum_{k=1}^Ks_{nk}\ln{\rm Poi}(x_n|\lambda_k)\right\rangle_{q({\bf S})}+\left\langle\sum_{k=1}^K\ln{\rm Gam}(\lambda_k|a,b)\right\rangle_{q({\bf S})}+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\langle s_{nk}\rangle_{q({\bf S})}\ln{\rm Poi}(x_n|\lambda_k)+\sum_{k=1}^K\ln{\rm Gam}(\lambda_k|a,b)+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\langle s_{nk}\rangle_{q({\bf s}_n)}\ln{\rm Poi}(x_n|\lambda_k)+\sum_{k=1}^K\ln{\rm Gam}(\lambda_k|a,b)+{\rm const}\tag{30}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}\ln{\rm Poi}(x_n|\lambda_k)+\ln{\rm Gam}(\lambda_k|a,b)\right)+{\rm const}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}\ln\frac{\lambda_k^{x_n}}{x_n!}\exp(-\lambda_k)+\ln C_{\rm G}(a,b)\lambda_k^{a-1}\exp(-b\lambda_k)\right)+{\rm const}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}\left(x_n\ln\lambda_k-\ln x_n!-\lambda_k\right)+\ln C_{\rm G}(a,b)+(a-1)\ln\lambda_k-b\lambda_k\right)+{\rm const}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}\left(x_n\ln\lambda_k-\lambda_k\right)+(a-1)\ln\lambda_k-b\lambda_k\right)+{\rm const}\\
&=&\sum_{k=1}^K\left(\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}x_n+a-1\right)\ln\lambda_k-\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}+b\right)\lambda_k\right)+{\rm const}\tag{31}
\end{eqnarray}

(30)で、\langle\cdot\rangle_{q({\bf S})}\langle\cdot\rangle_{q({\bf s}_n)}とし、\langle\cdot\rangle_{q({s_{nk}})}としていないのは、
\sum_{k=1}^Ks_{nk}=1という制約のため、分解して考えにくい為であると思います。
また、q({\bf s}_n)はカテゴリ分布であり、成分の分布での期待値は普通考えない為だと思われます。

(31)において、

\begin{eqnarray}
\hat{a}_k=\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}x_n+a\tag{32}
\end{eqnarray}

\begin{eqnarray}
\hat{b}_k=\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}+b\tag{33}
\end{eqnarray}

のようにおきます。

(32),(33)を式(31)に代入します。

\begin{eqnarray}
&&\ln q({\boldsymbol\lambda})=\sum_{k=1}^K\left(\left(\hat{a}_k-1\right)\ln\lambda_k-\hat{b}_k\lambda_k\right)+{\rm const}\tag{34}\\
&&\Leftrightarrow q({\boldsymbol\lambda})=\prod_{k=1}^KC_{\rm G}(\hat{a}_k,\hat{b}_k)\lambda_k^{\hat{a}_k-1}\exp(-\hat{b}_k\lambda_k)\tag{35}\\
&&\Leftrightarrow q({\boldsymbol\lambda})=\prod_{k=1}^K{\rm Gam}(\lambda_k|\hat{a}_k,\hat{b}_k)\tag{36}\\
&&\Leftrightarrow q(\lambda_k)={\rm Gam}(\lambda_k|\hat{a}_k,\hat{b}_k)\tag{37}
\end{eqnarray}

(35)C_{\rm G}(\hat{a}_k,\hat{b}_k)q(\lambda_k)の正規化項であり、式(34){\rm const}です。
(37)は、式(36)が積の形で書けているので、q(\lambda_k)がそれぞれ独立であることを利用しました。
(37)より、近似分布q(\lambda_k)はパラメータ\hat{a}_k,\hat{b}_kを持つガンマ分布になることが分かります。
(32),(33)\langle s_{nk}\rangle_{q({\bf s}_n)}は後で計算します。

\ln q({\boldsymbol\pi})の導出

(29)より、\ln q({\boldsymbol\lambda})を計算していきます。

\begin{eqnarray}
\ln q({\boldsymbol\pi})&=&\langle \ln p({\bf S}|{\boldsymbol\pi})\rangle_{q({\bf S})}+\ln p({\boldsymbol\pi})+{\rm const}\\
&=&\left\langle\sum_{n=1}^N\sum_{k=1}^K\ln{\rm Cat}({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\bf S})}+\ln{\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\left\langle\ln{\rm Cat}({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\bf S})}+\ln{\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\left\langle\ln{\rm Cat}({\bf s}_n|{\boldsymbol\pi})\right\rangle_{q({\bf s}_n)}+\ln{\rm Dir}({\boldsymbol\pi}|{\boldsymbol\alpha})+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\langle s_{nk}\ln\pi_k\rangle_{q({\bf s}_n)}+\ln C_{\rm D}({\boldsymbol\alpha})\prod_{k=1}^K\pi_k^{\alpha_k-1}+{\rm const}\\
&=&\sum_{n=1}^N\sum_{k=1}^K\langle s_{nk}\rangle_{q({\bf s}_n)}\ln\pi_k+\ln C_{\rm D}({\boldsymbol\alpha})+(\alpha_k-1)\sum_{k=1}^K\ln\pi_k+{\rm const}\tag{38}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}\ln\pi_k+(\alpha_k-1)\ln\pi_k\right)+{\rm const}\\
&=&\sum_{k=1}^K\left(\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}+\alpha_k-1\right)\ln\pi_k+{\rm const}\tag{39}
\end{eqnarray}

(38)で、\langle\cdot\rangle_{q({\bf s}_n)}\langle\cdot\rangle_{q({s_{nk}})}としていないのは、
\sum_{k=1}^Ks_{nk}=1という制約のため、分解して考えにくい為であると思います。
また、q({\bf s}_n)はカテゴリ分布であり、成分の分布での期待値は普通考えない為だと思われます。

(39)において、

\begin{eqnarray}
\hat{\alpha}_k=\sum_{n=1}^N\langle s_{nk}\rangle_{q({\bf s}_n)}+\alpha_k\tag{40}
\end{eqnarray}

とおきます。

(39)を式(38)に代入します。

\begin{eqnarray}
&&\ln q({\boldsymbol\pi})=\sum_{k=1}^K(\hat{\alpha}_k-1)\ln\pi_k+{\rm const}\tag{41}\\
&&\Leftrightarrow q({\boldsymbol\pi})=C_{\rm D}(\hat{\boldsymbol\alpha})\prod_{k=1}^K\pi_k^{\hat{\alpha}_k-1}\tag{42}\\
&&\Leftrightarrow q({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\hat{\boldsymbol\alpha}})\tag{43}
\end{eqnarray}

(42)で、\hat{\boldsymbol\alpha}=\{\hat{\alpha}_1,\ldots,\hat{\alpha}_K\}とおきました。
(42)C_{\rm D}(\hat{\boldsymbol\alpha})q({\boldsymbol\pi})の正規化項であり、式(41){\rm const}です。
(43)より、近似分布q({\boldsymbol\pi})はパラメータ\hat{\boldsymbol\alpha}を持つディレクレ分布になることが分かります。
(40)\langle s_{nk}\rangle_{q({\bf s}_n)}は後で計算します。

期待値の計算

全ての近似分布q({\bf s}_n),q(\lambda_k),q({\boldsymbol\pi})が明らかになったので、まだ計算していない期待値を計算します。

(26)より、q({\bf s}_n)={\rm Cat}({\bf s}_n|{\boldsymbol\eta}_n)なので、

\begin{eqnarray}
\langle s_{nk}\rangle_{q({\bf s}_n)}=\eta_{nk}\tag{44}
\end{eqnarray}

です。

(37)より、q(\lambda_k)={\rm Gam}(\lambda_k|\hat{a}_k,\hat{b}_k)なので、

\begin{eqnarray}
\langle\lambda_k\rangle_{q({\lambda_k})}=\frac{\hat{a}_k}{\hat{b}_k}\tag{45}
\end{eqnarray}

\begin{eqnarray}
\langle\ln\lambda_k\rangle_{q({\lambda_k})}=\psi(\hat{a}_k)-\ln\hat{b}_k\tag{46}
\end{eqnarray}

です。

(43)より、q({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\hat{\boldsymbol\alpha}})なので、

\begin{eqnarray}
\langle\ln\pi_k\rangle_{q({\boldsymbol\pi})}=\psi(\hat{\alpha}_k)-\psi\left(\sum_{j=1}^K\hat{\alpha}_j\right)\tag{47}
\end{eqnarray}

です。

(44),(45),(46),(47)を含むパラメータに代入していきます。

(22)に、式(45),(46),(47)を代入します。

\begin{eqnarray}
&&\ln\rho_{nk}=x_n\left(\psi(\hat{a}_k)-\ln\hat{b}_k\right)-\frac{\hat{a}_k}{\hat{b}_k}+\psi(\hat{\alpha}_k)-\psi\left(\sum_{j=1}^K\hat{\alpha}_j\right)\\
&&\Leftrightarrow \rho_{nk}=\exp\left(x_n\left(\psi(\hat{a}_k)-\ln\hat{b}_k\right)-\frac{\hat{a}_k}{\hat{b}_k}+\psi(\hat{\alpha}_k)-\psi\left(\sum_{j=1}^K\hat{\alpha}_j\right)\right)\tag{48}
\end{eqnarray}

(32)に、式(44)を代入します。

\begin{eqnarray}
\hat{a}_k=\sum_{n=1}^N\eta_{nk}x_n+a\tag{49}
\end{eqnarray}

(33)に、式(44)を代入します。

\begin{eqnarray}
\hat{b}_k=\sum_{n=1}^N\eta_{nk}+b\tag{50}
\end{eqnarray}

(40)に、式(44)を代入します。

\begin{eqnarray}
\hat{\alpha}_k=\sum_{n=1}^N\eta_{nk}+\alpha_k\tag{51}
\end{eqnarray}

近似分布q({\bf s}_n),q(\lambda_k),q({\boldsymbol\pi})のパラメータの更新式が求まりましたので、
以下の繰り返し法で求めます。

アルゴリズム

1.初期化

{\alpha_k}\ (k=1,\ldots,K),a,bを初期化します。
\hat{\alpha}_k,\hat{a}_k,\hat{b}_k\ (k=1,\ldots,K)を以下のように初期化します。

\begin{eqnarray}
&&\hat{\alpha}_k\leftarrow{\alpha_k}\tag{52}\\
&&\hat{a}_k\leftarrow a_k\tag{53}\\
&&\hat{b}_k\leftarrow b_k\tag{54}
\end{eqnarray}

2.\eta_{nk}の更新 (q({\bf s}_n)={\rm Cat}({\bf s}_n|{\boldsymbol\eta}_n)の更新)

\eta_{nk}\ (n=1,\ldots,N,\ k=1,\ldots,K)を更新します。

\begin{eqnarray}
\eta_{nk}\leftarrow\dfrac{\rho_{nk}}{\displaystyle\sum_{j=1}^K\rho_{nj}}\tag{55}
\end{eqnarray}

ただし、

\begin{eqnarray}
\rho_{nk}=\exp\left(x_n\left(\psi(\hat{a}_k)-\ln\hat{b}_k\right)-\frac{\hat{a}_k}{\hat{b}_k}+\psi(\hat{\alpha}_k)-\psi\left(\sum_{j=1}^K\hat{\alpha}_j\right)\right)\tag{56}
\end{eqnarray}

とします。

3.\hat{a}_k,\hat{b}_kの更新 (q(\lambda_k)={\rm Gam}(\lambda_k|\hat{a}_k,\hat{b}_k)の更新)

\hat{a}_k,\hat{b}_k\ (k=1,\ldots,K)を更新します。

\begin{eqnarray}
\hat{a}_k\leftarrow\sum_{n=1}^N\eta_{nk}x_n+a\tag{57}
\end{eqnarray}

\begin{eqnarray}
\hat{b}_k\leftarrow\sum_{n=1}^N\eta_{nk}+b\tag{58}
\end{eqnarray}

4.\hat{\alpha}_kの更新 (q({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\hat{\boldsymbol\alpha}})の更新)

\hat{\alpha}_k\ (k=1,\ldots,K)を更新します。

\begin{eqnarray}
\hat{\alpha}_k\leftarrow\sum_{n=1}^N\eta_{nk}+\alpha_k\tag{59}
\end{eqnarray}

5.終了条件
対数尤度を再計算し、前回との差分があらかじめ設定していた収束条件を満たしていなければ2に戻り、満たしていれば終了します。

\eta_{nk}を初期化し、2と3と4の処理を入れ替えても同じです。
※対数尤度の代わりに、繰り返し回数を決めて、それを終了条件とするのもよいと思います。
※対数尤度の代わりに、変分下限(変分下界)を計算してもよいと思います。
※対数尤度\ln p({\bf X}|{\bf S},{\boldsymbol\lambda})計算時の{\bf S},{\boldsymbol\lambda}は、q({\bf s}_n),q(\lambda_k)の平均やモードなどを使えばよいと思います。

(24)から式(25)の変形について

この式変形は混合分布の変分推論では必ず出てきますし、重要な式変形だと思うので、丁寧に説明していきます。

もう一度、式(24),(25)を書き出してみます。(式(60)が式(24)に式(61)が式(25)に相当します。)

\begin{eqnarray}
&&q({\bf s}_n)=C\prod_{k=1}^K\rho_{nk}^{s_{nk}}\tag{60}\\
&&\Leftrightarrow q({\bf s}_n)=\prod_{k=1}^K\eta_{nk}^{s_{nk}}\tag{61}\\
\end{eqnarray}

(60)Cq({\bf s}_n)の正規化項なので、C^{-1}\displaystyle\prod_{k=1}^K\rho_{nk}^{s_{nk}}を全ての{\bf s}_nについて足し合わせたものです。
まずは、C^{-1}について計算します。

\begin{eqnarray}
C^{-1}&=&\sum_{{\bf s}_n}\prod_{k=1}^K\rho_{nk}^{s_{nk}}\tag{62}\\
&=&\rho_{n1}^1\cdot\rho_{n2}^0\cdots\rho_{nK}^0\ \ \ \ (s_{n1}=1,s_{n,k\not=1}=0)\\
&+&\rho_{n1}^0\cdot\rho_{n2}^1\cdots\rho_{nK}^0\ \ \ \ (s_{n2}=1,s_{n,k\not=2}=0)\\
&&\vdots\\
&+&\rho_{n1}^0\cdot\rho_{n2}^0\cdots\rho_{nK}^1\ \ \ \ (s_{nK}=1,s_{n,k\not=K}=0)\tag{63}\\
&=&\rho_{n1}\cdot 1\cdots 1\ \ \ \ (s_{n1}=1,s_{n,k\not=1}=0)\\
&+&1\cdot\rho_{n2}^1\cdots 1\ \ \ \ (s_{n2}=1,s_{n,k\not=2}=0)\\
&&\vdots\\
&+&1\cdot1\cdots\rho_{nK}^1\ \ \ \ (s_{nK}=1,s_{n,k\not=K}=0)\tag{64}\\
&=&\rho_{n1}+\rho_{n2}+\cdots+\rho_{nK}\\
&=&\sum_{k=1}^K\rho_{nk}\tag{65}
\end{eqnarray}

(62)から式(63)への式変形が一番のポイントだと思います。
s_{nk}にはs_{nk}\in\{0,1\}かつ\displaystyle\sum_{k=1}^Ks_{nk}=1という制約があります。
これを分かりやすくいうと、s_{nk}0または1の値をとり、s_{nk}=1ならs_{n,j\not=k}=0(kでないjにおいてs_{nj}=0)ということです。
例えば、K=3の場合、
s_{n1}=1,s_{n2}=0,s_{n3}=0 (s_{n1}のみが1の場合)、
s_{n1}=0,s_{n2}=1,s_{n3}=0 (s_{n2}のみが1の場合)、
s_{n1}=0,s_{n2}=0,s_{n3}=1 (s_{n3}のみが1の場合) の3つの場合があるということです。
\displaystyle\sum_{{\bf s}_n}は全ての{\bf s}_nについて足し合わせなさい、ということなので、
\displaystyle\sum_{{\bf s}_n}=(s_{n1}のみが1の場合)+(s_{n2}のみが1の場合)+\cdots+(s_{nK}のみが1の場合) と書けるので、
(62)から式(63)への式変形が実現します。

(63)から式(64)への変形は、スカラー0乗が1であることさえ知っていれば分かります。

次に、求めた式(65)を式(60)に代入します。

\begin{eqnarray}
q({\bf s}_n)&=&\frac{\displaystyle\prod_{k=1}^K\rho_{nk}^{s_{nk}}}{C^{-1}}\\
&=&\frac{\displaystyle\prod_{k=1}^K\rho_{nk}^{s_{nk}}}{\displaystyle\sum_{j=1}^K\rho_{nj}}\tag{67}\\
&=&\prod_{k=1}^K\left(\frac{\rho_{nk}}{\sum_{j=1}^K\rho_{nj}}\right)^{s_{nk}}\tag{68}\\
&=&\prod_{k=1}^K\eta_{nk}^{s_{nk}}\tag{69}
\end{eqnarray}

(67)から式(68)の変形ですが、式(67)において、s_{nk}=1のとき\dfrac{\rho_{nk}}{\displaystyle\sum_{j=1}^K\rho_{nj}}なので、式(68)のようにまとめて()^{s_{nk}}と書けます。
\eta_{nk}=\dfrac{\rho_{nk}}{\displaystyle\sum_{j=1}^K\rho_{nj}}とおいているので、式(68)に代入すれば、式(69)が導けます。

以上より、式(24)から式(25)の変形ができることが確認できました。

まとめ

最後に今回の記事の一連の流れを確認しておきます。
式番号は既に登場したものに'をつけます。

まずは、同時分布を次のように設計しました。

\begin{eqnarray}
p({\bf X},{\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})&=&p({\bf X}|{\bf S},{\boldsymbol\lambda})p({\bf S}|{\boldsymbol\pi})p({\boldsymbol\lambda})p({\boldsymbol\pi})\tag{10'}
\end{eqnarray}

次に、パラメータの事後分布を以下のように仮定しました。

\begin{eqnarray}
p({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi}|{\bf X})\simeq q({\bf S},{\boldsymbol\lambda},{\boldsymbol\pi})=q({\bf S})q({\boldsymbol\lambda},{\boldsymbol\pi})\tag{11'}
\end{eqnarray}

近似分布q({\bf S})を計算すると、q({\bf S})=\displaystyle\prod_{n=1}^Nq({\bf s}_n)であることが分かり、

\begin{eqnarray}
q({\bf s}_n)={\rm Cat}({\bf s}_n|{\boldsymbol\eta}_n)\tag{26'}
\end{eqnarray}

であることが分かりました。

近似分布q({\boldsymbol\lambda},{\boldsymbol\pi})を計算すると、q({\boldsymbol\lambda},{\boldsymbol\pi})=q({\boldsymbol\lambda})q({\boldsymbol\pi}) であることが分かりました。

近似分布q({\boldsymbol\lambda})を計算すると、q({\boldsymbol\lambda})=\displaystyle\prod_{k=1}^Kq(\lambda_k)であることが分かり、

\begin{eqnarray}
q(\lambda_k)={\rm Gam}(\lambda_k|\hat{a}_k,\hat{b}_k)\tag{37'}
\end{eqnarray}

であることが分かりました。

近似分布q({\boldsymbol\pi})を計算すると、

\begin{eqnarray}
q({\boldsymbol\pi})={\rm Dir}({\boldsymbol\pi}|{\hat{\boldsymbol\alpha}})\tag{43'}
\end{eqnarray}

であることが分かりました。

近似分布q({\bf S}),q({\boldsymbol\lambda}),q({\boldsymbol\pi})のパラメータは相互に依存しているため、
解析的に解くのは難しいですが、上に書いた繰り返し法のアルゴリズムで解くことができます。

偉人の名言

f:id:olj611:20211013133731p:plain:h300
統計とは、街灯の柱と酒を飲むようなもの。
ウィンストン・チャーチル

参考文献

ベイズ推論による機械学習入門 p133-p137

動画

なし

目次へ戻る