機械学習基礎理論独習

誤りがあればご指摘いただけると幸いです。数式が整うまで少し時間かかります。リンクフリーです。

勉強ログです。リンクフリーです
目次へ戻る

1変数ガウス分布の変分推論

はじめに

本記事では、以下の式を使います。

\begin{eqnarray}
\ln q(z_i)=\langle\ln p({\mathcal D},{\bf Z})\rangle_{q({\bf Z}_{\backslash i})}+{\rm const.}\tag{1}\\
\end{eqnarray}

(1)については、こちらで解説しています。
また、1次元ガウス分布の平均と分散の事後分布の記事を読んでおくと、本記事の理解の助けになるかもしれません。

1変数ガウス分布

1変数xについてのガウス分布を用いて、分解による変分近似の例を示します。

ガウス分布から独立に発生したと仮定する観測値xのデータ集合{\mathcal D}=\{x_1,\ldots,x_N\}が与えられたとします。
この時、もともとのガウス分布の平均\muと精度\tauの事後分布を求めてみます。
精度\tauとは、分散\sigma^2逆数で、\tau=(\sigma^2)^{-1}と表せます。

尤度関数は、以下の式です。

\begin{eqnarray}
p({\mathcal D}|\mu,\tau)&=&\prod_{n=1}^Np(x_n|\mu,\tau)\\
&=&\prod_{n=1}^N\left(\frac{\tau}{2\pi}\right)^{1/2}\exp\left(-\frac{\tau}{2}(x_n-\mu)^2\right)\\
&=&\left(\frac{\tau}{2\pi}\right)^{N/2}\exp\left(-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2\right)\tag{2}
\end{eqnarray}

\mu,\tauに関する共役事前分布を導入します。

\begin{eqnarray}
p(\mu|\tau)&=&{\mathcal N}(\mu|\mu_0,(\lambda_0\tau)^{-1})\\
&=&\left(\frac{\lambda_0\tau}{2\pi}\right)^{1/2}\exp\left(-\frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\right)\tag{3}\\
\end{eqnarray}

\begin{eqnarray}
p(\tau)&=&{\rm Gam}(\tau|a_0,b_0)\\
&=&\frac{1}{\Gamma(a_0)}b_0^{a_0}\tau^{a_0-1}\exp(-b_0\tau)\tag{4}
\end{eqnarray}

同時分布は、以下の式です。

\begin{eqnarray}
p({\mathcal D}|\mu,\tau)=p({\mathcal D}|\mu,\tau)p(\mu|\tau)p(\tau)\tag{5}
\end{eqnarray}

グラフィカルモデルは、以下の図1です。

図1
f:id:olj611:20211005085823p:plain:w250

変分推論の適用

以上の式から事後分布p(\mu,\tau|{\mathcal D})を求めると、ガウス-ガンマ分布となるのですが、
本記事は、変分推論の記事なので、以下のように事後分布を分解した変分近似を考えることにします。

\begin{eqnarray}
p(\mu,\tau|{\mathcal D})\simeq q(\mu,\tau)=q_\mu(\mu)q_\tau(\tau)\tag{6}
\end{eqnarray}

以下の計算で使う\ln p({\mathcal D}|\mu,\tau),\ln p(\mu|\tau),\ln p(\tau)を先に計算しておきます。

\begin{eqnarray}
\ln p({\mathcal D}|\mu,\tau)&=&\ln\left(\left(\frac{\tau}{2\pi}\right)^{N/2}\exp\left(-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2\right)\right)\\
&=&\frac{N}{2}\ln\tau-\frac{N}{2}\ln(2\pi)-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2\tag{7}
\end{eqnarray}

\begin{eqnarray}
\ln p(\mu|\tau)&=&\ln\left(\left(\frac{\lambda_0\tau}{2\pi}\right)^{1/2}\exp\left(-\frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\right)\right)\\
&=&\frac{1}{2}\ln\lambda_0\tau-\frac{1}{2}\ln(2\pi)-\frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\tag{8}
\end{eqnarray}

\begin{eqnarray}
\ln p(\tau)&=&\ln\left(\frac{1}{\Gamma(a_0)}b_0^{a_0}\tau^{a_0-1}\exp(-b_0\tau)\right)\\
&=&-\ln\Gamma(a_0)+a_0\ln b_0+(a_0-1)\ln\tau-b_0\tau\tag{9}
\end{eqnarray}

\ln q_\mu^\star(\mu)は、式(1)より、以下の式で表されます。

\begin{eqnarray}
\ln q_\mu^\star(\mu)&=&\langle\ln p({\mathcal D},\mu,\tau)\rangle_{q_\tau(\tau)}+{\rm const}\\
&=&\langle\ln p({\mathcal D}|\mu,\tau)p(\mu|\tau)p(\tau)\rangle_{q_\tau(\tau)}+{\rm const}\\
&=&\langle\ln p({\mathcal D}|\mu,\tau)+ \ln p(\mu|\tau) + \ln p(\tau)\rangle_{q_\tau(\tau)}+{\rm const}\\
&=&\langle\ln p({\mathcal D}|\mu,\tau)+ \ln p(\mu|\tau)\rangle_{q_\tau(\tau)}+{\rm const}\tag{10}\\
&=&\left\langle \frac{N}{2}\ln\tau-\frac{N}{2}\ln(2\pi)-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2+\frac{1}{2}\ln\lambda_0\tau-\frac{1}{2}\ln(2\pi)-\frac{\lambda_0\tau}{2}(x_n-\mu)^2\right\rangle_{q_\tau(\tau)}+{\rm const}\\
&=&\left\langle -\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2- \frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\right\rangle_{q_\tau(\tau)}+{\rm const}\tag{11}\\
&=&-\frac{\langle\tau\rangle_{q_\tau(\tau)}}{2}\left(\lambda_0(\mu-\mu_0)^2+\sum_{n=1}^N(x_n-\mu)^2\right)+{\rm const}\tag{12}
\end{eqnarray}

(10),(11)\muに関係のない項は{\rm const}にまとめています。
(12)q_\mu^\star(\mu)ガウス分布なので、q_\mu(\mu)は以下の式で表せます。

\begin{eqnarray}
q_\mu(\mu)={\mathcal N}(\mu)|\mu_N,\lambda_N^{-1})\tag{13}
\end{eqnarray}

\begin{eqnarray}
\mu_N=\frac{\lambda_0\mu_0+N\overline{x}}{\lambda_0+N}\tag{14}
\end{eqnarray}

\begin{eqnarray}
\lambda_N=(\lambda_0+N)\langle\tau\rangle_{q_\tau(\tau)}\tag{15}
\end{eqnarray}

(14)で、\overline{x}=\dfrac{1}{N}\displaystyle\sum_{n=1}^Nx_nとおきました。
(14),(15)の導出については、PRML演習問題 10.7(標準)を参照してください。
※式(15)の因数\langle\tau\rangle_{q_\tau(\tau)}ですが、この時点では、q_\tau(\tau)がどのような分布か明らかでない為、計算できません。

\ln q_\tau^\star(\tau)は、式(1)より、以下の式で表されます。

\begin{eqnarray}
\ln q_\tau^\star(\tau)&=&\langle\ln p({\mathcal D},\mu,\tau)\rangle_{q_\mu(\mu)}+{\rm const}\\
&=&\langle\ln p({\mathcal D}|\mu,\tau)p(\mu|\tau)p(\tau)\rangle_{q_\mu(\mu)}+{\rm const}\\
&=&\langle\ln p({\mathcal D}|\mu,\tau)+ \ln p(\mu|\tau) + \ln p(\tau)\rangle_{q_\mu(\mu)}+{\rm const}\\
&=&\Biggl\langle \frac{N}{2}\ln\tau-\frac{N}{2}\ln(2\pi)-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2+\frac{1}{2}\ln\lambda_0\tau-\frac{1}{2}\ln(2\pi)-\frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\\
&&-\ln\Gamma(a_0)+a_0\ln b_0+(a_0-1)\ln\tau-b_0\tau\Biggl\rangle_{q_\mu(\mu)}+{\rm const}\\
&=&(a_0-1)\ln\tau-b_0\tau+\frac{N+1}{2}\ln\tau+\left\langle-\frac{\tau}{2}\sum_{n=1}^N(x_n-\mu)^2-\frac{\lambda_0\tau}{2}(\mu-\mu_0)^2\right\rangle_{q_\mu(\mu)}+{\rm const}\tag{16}\\
&=&(a_0-1)\ln\tau-b_0\tau+\frac{N+1}{2}\ln\tau-\frac{\tau}{2}\left\langle\sum_{n=1}^N(x_n-\mu)^2+\lambda_0(\mu-\mu_0)^2\right\rangle_{q_\mu(\mu)}+{\rm const}\tag{17}
\end{eqnarray}

(16)\tauに関係のない項は{\rm const}にまとめています。
(17)はガンマ分布なので、q_\tau(\tau)は以下の式で表せます。

\begin{eqnarray}
q_\tau(\tau)={\rm Gam}(\tau|a_N,b_N)\tag{18}\\
\end{eqnarray}

\begin{eqnarray}
a_N=a_0+\frac{N+1}{2}\tag{19}
\end{eqnarray}

\begin{eqnarray}
b_N=b_0+\frac{1}{2}\left\langle\sum_{n=1}^N(x_n-\mu)^2+\lambda_0(\mu-\mu_0)^2\right\rangle_{q_\mu(\mu)}\tag{20}
\end{eqnarray}

(19),(20)の導出については、PRML演習問題 10.7(標準)を参照してください。

q_\mu(\mu)q_\tau(\tau)の分布が判明したので、それぞれのパラメータ\mu_N,\lambda_N,a_N,b_Nを求めればよいことになります。
(14),(19)より、\mu_N,a_Nは固定なので、\lambda_N,b_Nを求めればよいことになります。

(20)を変形します。

\begin{eqnarray}
b_N&=&b_0+\frac{1}{2}\left\langle\sum_{n=1}^N(x_n-\mu)^2+\lambda_0(\mu-\mu_0)^2\right\rangle_{q_\mu(\mu)}\\
&=&b_0+\frac{1}{2}\left\langle\sum_{n=1}^N(x_n^2-2x_n\mu+\mu^2)+\lambda_0(\mu^2-2\mu_0\mu+\mu_0^2)\right\rangle_{q_\mu(\mu)}\\
&=&b_0+\frac{1}{2}\left\langle N\overline{x^2}-2N\overline{x}\mu+N\mu^2+\lambda_0\mu^2-2\lambda_0\mu_0\mu+\lambda_0\mu_0^2\right\rangle_{q_\mu(\mu)}\\
&=&b_0+\frac{1}{2}(N\overline{x^2}+\lambda_0\mu_0^2)+\frac{1}{2}(N+\lambda_0)\langle\mu^2\rangle_{q_\mu(\mu)}-(N\overline{x}+\lambda_0\mu_0)\langle\mu\rangle_{q_\mu(\mu)}\tag{21}
\end{eqnarray}

(21)で、\overline{x^2}=\dfrac{1}{N}\displaystyle\sum_{n=1}^Nx_n^2とおきました。

(15)より、\lambda_Nq_\tau(\tau)に依存しており、式(21)より、b_Nq_\mu(\mu)に依存しています。
相互に依存しており、解析的に解くのが難しい為、以下の繰り返し法で解きます。

なお、q_\mu(\mu)ガウス分布であり、q_\tau(\tau)がガンマ分布であるので、
(15)\langle\tau\rangle_{q_\tau(\tau)}、式(21)\langle\mu^2\rangle_{q_\mu(\mu)},\langle\mu\rangle_{q_\mu(\mu)}は以下のようになります。

\begin{eqnarray}
\langle\tau\rangle_{q_\tau(\tau)}=\frac{a_N}{b_N}\tag{22}
\end{eqnarray}

\begin{eqnarray}
\langle\mu^2\rangle_{q_\mu(\mu)}=\mu_N^2+\frac{1}{\lambda_N}\tag{23}
\end{eqnarray}

\begin{eqnarray}
\langle\mu\rangle_{q_\mu(\mu)}=\mu_N\tag{24}
\end{eqnarray}

(22)は、ガンマ分布{\rm Gam}(\tau|a,b)の平均\langle\tau\rangle\langle\tau\rangle=\dfrac{a}{b}であることを用いました。
(23)は、一般に成り立つ{\rm var}[x]=\langle x^2\rangle-\langle x\rangle^2を用いました。

アルゴリズム

1.初期化

\mu_0,\lambda_0,a_0,b_0を初期化します。
\mu_N,a_N,b_Nを以下のように初期化します。

\begin{eqnarray}
&&\mu_N\leftarrow\frac{\lambda_0\mu_0+N\overline{x}}{\lambda_0+N}\tag{25}\\
&&a_N\leftarrow a_0+\frac{N+1}{2}\tag{26}\\
&&b_N\leftarrow b_0\tag{27}
\end{eqnarray}

2.\lambda_Nの更新

\lambda_Nを更新します。

\begin{eqnarray}
\lambda_N\leftarrow(\lambda_0+N)\langle\tau\rangle_{q_\tau(\tau)}\tag{28}
\end{eqnarray}

ただし、

\begin{eqnarray}
\langle\tau\rangle_{q_\tau(\tau)}=\frac{a_N}{b_N}\tag{29}
\end{eqnarray}

とします。

3.b_Nの更新

b_Nを更新します。

\begin{eqnarray}
b_N\leftarrow b_0+\frac{1}{2}(N\overline{x^2}+\lambda_0\mu_0^2)+\frac{1}{2}(N+\lambda_0)\langle\mu^2\rangle_{q_\mu(\mu)}-(N\overline{x}+\lambda_0\mu_0)\langle\mu\rangle_{q_\mu(\mu)}\tag{30}
\end{eqnarray}

ただし、

\begin{eqnarray}
\langle\mu^2\rangle_{q_\mu(\mu)}=\mu_N^2+\frac{1}{\lambda_N}\tag{31}
\end{eqnarray}

\begin{eqnarray}
\langle\mu\rangle_{q_\mu(\mu)}=\mu_N\tag{32}
\end{eqnarray}

\begin{eqnarray}
\overline{x^2}=\frac{1}{N}\displaystyle\sum_{n=1}^Nx_n^2\tag{33}
\end{eqnarray}

\begin{eqnarray}
\overline{x}=\dfrac{1}{N}\displaystyle\sum_{n=1}^Nx_n\tag{34}
\end{eqnarray}

とします。

4.終了条件
対数尤度を再計算し、前回との差分があらかじめ設定していた収束条件を満たしていなければ2に戻り、満たしていれば終了します。

\lambda_Nを初期化し、2と3の処理を入れ替えても同じです。
※対数尤度の代わりに、繰り返し回数を決めて、それを終了条件とするのもよいと思います。
※対数尤度の代わりに、変分下限(変分下界)を計算してもよいと思います。
※対数尤度計算時の\mu,\tauは、q_\mu(\mu),q_\tau(\tau)の平均やモードなどを使えばよいと思います。

おまけ - 解析的に解く

\mu_0=a_0=b_0=\lambda_0=0 の時は、解析的に解くことも可能です。
(導出については、PRML演習問題 10.9(標準)を参照してください。)

\begin{eqnarray}
\frac{1}{\langle\tau\rangle_{q(\tau)}}&=&\frac{1}{N}\sum_{n=1}^N(x_n-\overline{x})^2\tag{35}
\end{eqnarray}

偉人の名言

f:id:olj611:20211006064642p:plain:h300
物真似から出発して、独創にまでのびていくのが、
我々日本人のすぐれた性質であり、たくましい能力でもあるのです。
野口英世

参考文献

パターン認識機械学習 下巻 p184-p185

動画

なし

目次へ戻る