機械学習基礎理論独習

誤りがあればご指摘いただけると幸いです。数式が整うまで少し時間かかります。リンクフリーです。

勉強ログです。リンクフリーです
目次へ戻る

【Python実装】棄却サンプリング

本記事のアルゴリズムの記事はこちらです。

目標分布を以下のベータ分布とします。正規化項は除きます。

\begin{eqnarray}
&&p(x)=x^{a-1}(1-x)^{b-1},\ 0\leq x\leq1\\
&&a=10.2,\ b=5.8\tag{1}
\end{eqnarray}

提案分布は一様分布とします。

\begin{eqnarray}
q(x)=\frac{1}{1-0}=1\tag{2}
\end{eqnarray}

k=0.00013とします。
この時、kq(x)\geq p(x)を満たしています。

f:id:olj611:20210408165445p:plain:w400

###############################
#       棄却サンプリング
###############################
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import beta    # ベータ関数描画用

# 乱数を固定
np.random.seed(1)

# ベータ分布
a, b = 10.2, 5.8
x = np.linspace(0, 1, 100)  # 0 から 1 を 100 分割
#y = (x ** (a - 1)) * ((1 - x) ** (b - 1))   # 正規化定数を無視したベータ分布の確率密度関数の値
y = beta(a, b).pdf(x)   # 正規化定数で割ったベータ分布の確率密度関数の値

# 棄却サンプリング
N = 100000      # サンプリングする点の数(棄却込み)
samples = np.array([])    # サンプリングした点データ
K = 0.00013     # 提案分布が目標分布を覆うようにする

for n in range(N):
    xp = np.random.uniform(0, 1)
    yp = np.random.uniform(0, K)
    pdf = (xp ** (a - 1)) * ((1 - xp) ** (b - 1))
    if(yp <= pdf):
        samples = np.append(samples, xp)

# 統計量を比較 

# データ
print("サンプリングを試みた回数: ", N, "回")
print("サンプリング個数: ", len(samples), "個")
print()

# 事後平均値
mu = np.sum(samples) / len(samples);
print("サンプリング - 事後期待値: ", mu)
print("真の値    - 事後期待値: ", a / (a + b))
print()

# 事後分散
v = np.sum((samples - mu) ** 2 ) / len(samples)
print("サンプリング - 事後分散: ", v)
print("真の値    - 事後分散: ", (a * b) / (((a + b) ** 2) * (a + b + 1)))
print()

# 事後標準偏差
print("サンプリング - 事後標準偏差: ", v ** 0.5)
print("真の値    - 事後標準偏差: ", ((a * b) / (((a + b) ** 2) * (a + b + 1))) ** 0.5)
print()

# 事後確率最大値
maxIndex = np.argmax((samples ** (a - 1)) * ((1 - samples) ** (b - 1)))
print("サンプリング - 事後確率最大値: ", samples[maxIndex])
print("真の値    - 事後確率最大値: ", (a - 1) / (a + b - 2))
print()

# 事後中央値
sorted_samples = sorted(samples)
if(len(sorted_samples) % 2 != 0):# 奇数
    midIndex = int((len(sorted_samples) - 1) / 2)
    midX = sorted_samples[midIndex]
else:# 偶数
    midIndex = int(len(sorted_samples) / 2)
    midX = (sorted_samples[midIndex] + sorted_samples[midIndex - 1]) / 2
print("サンプリング - 事後中央値: ", midX)    # np.median(samples)と同じ
print("真の値(np)  - 事後中央値: ", beta(a, b).median())
print("真の値(wiki) - 事後中央値: ", (a - (1/3)) / (a + b - (2/3) ))

# 描画
plt.plot(x, y)  # ベータ分布を描画
plt.hist(samples, bins = 30, density = True, alpha = 0.5) # ヒストグラムを描画
plt.show()      # グラフを描画

f:id:olj611:20210408165829p:plain:w400
f:id:olj611:20210408165843p:plain:w400

統計量の計算

実行結果の説明をします。

p(x)=x^{10.2-1}(1-x)^{5.8-1}からサンプリングした点をx_1,\ldots,x_Nとします。
p(x)を事後分布とみなしているので、下の表の1番左の列は文字の先頭に「事後」が付いています。
事後分散には、標本分散を採用しています。事後標準偏差は省略しました。

サンプリングベータ分布(a=10.2,b=5.8)
事後期待値
\begin{eqnarray}
\mu&=&\frac{1}{N}\sum_{n=1}^Nx_n\\
&=&0.6375412953281164
\end{eqnarray}
\begin{eqnarray}
\frac{a}{a+b}=0.6375
\end{eqnarray}
事後分散
\begin{eqnarray}
\sigma^2&=&\frac{1}{N}\sum_{n=1}^N(x_n-\mu)^2\\
&=&0.013438176563902452\\
\end{eqnarray}
\begin{eqnarray}
\frac{ab}{(a+b)^2(a+b+1)}=0.01359375
\end{eqnarray}
事後確率最大値
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}\begin{eqnarray}
&&\argmax_{x_n}p(x_n)\\
&=&0.6571408069914879
\end{eqnarray}
\begin{eqnarray}
\frac{a-1}{a+b-2}=0.6571428571428571
\end{eqnarray}
事後中央値
\begin{eqnarray}
\left\{
    \begin{array}{l}
     x_{\frac{N-1}{2}} \hspace{55pt}Nが奇数\\
     \left(x_{\frac{N}{2}}+x_{\frac{N}{2}-1}\right)/2 \hspace{5pt}Nが偶数
    \end{array}
  \right.\\
=0.6437659596981459
\end{eqnarray}
\begin{eqnarray}
&&\approx\frac{a-1/3}{a+b-2/3}=0.643478260869565\\
&&ライブラリで計算 =0.6433608949743185\\
\end{eqnarray}

よく近似できていることが分かります。

参考文献

なし

偉人の名言

f:id:olj611:20210408173701p:plain:w300
勉強とは自分の無知を徐々に発見していくことである。
ウィル・デュラント

動画

目次へ戻る