Aurora blog

バイオインフォ・バイオテクノロジーにまつわる情報

Pyroでベイズモデリング①:基本の確認

Pyro

PyroはUberが作った確率的プログラミングのためのPythonライブラリ。PyTorchをベースにしており、PyTorchで実装したニューラルネットを組み込んだ深層生成モデルの開発ができる。最近、シングルセルデータやマルチオミクスデータを対象とした深層生成モデルが多数報告されており、自分でも実装&既存のモデルをカスタマイズできるようになりたいと考えていたので勉強してみた。Pyroの公式ドキュメントは充実しており、読んでいるだけで面白く勉強になるのだが、試している中で疑問に思ったことが多々あったので、メモを残していこうと思う。

pyro.ai

Stanのような歴史の長い確率的プログラミング言語と比較すると、日本語で書かれている資料は少ない。多くの日本語の情報は「ベイズ線形回帰をPyroで実装してみた」みたいな内容が多いので、このシリーズでは、より複雑なモデルを実装することにトライしていこうと思う。

本記事では以下のページが参考になった。

実装の流れ

Step 1: ModelとGuideの定義

Pyroでは、はじめにModelとGuideを関数として実装する。Modelにはモデルと各パラメータの事前分布を、Guideには変分推論 (後述) で用いる近似事後分布を定義する。

Pyroのチュートリアルで紹介されている線形回帰の事例をもとに実装の流れをさらっていく。この事例では国土の地形の凹凸の多さ(TRI: Terrain Ruggedness Index)とGDPの関係を線形回帰で調べており、アフリカ以外の地域ではTRIが高い (地形の凹凸が高い) ほどGDPが低いのに対して、アフリカではそれと逆の傾向があることを示している。ここでは以下のモデルに対してベイズ推論を行う。変数 Africaはアフリカ地域であれば1が入るバイナリ変数である。

  •  log(GDP) \sim \rm{Normal}(\mu, \sigma)
    •  \mu = a + b_{A} * Africa + b_{R} * TRI + b_{AR} * Africa * TRI
      •  p(a) \sim \rm{Normal}(0, 10)
      •  p(b_{A}) \sim \rm{Normal}(0, 1)
    •  p(\sigma) \sim \rm{Uniform}(0, 10)

このケースであれば各パラメータの事後分布は解析的に求めることができるが、あえて変分推論で近似事後分布を求める。ここでは変分ベイズでよく用いられる平均場近似 (後述) という近似法を使う。事後分布 p(b_{A}, b_{R}, b_{AR} | X)を近似する、3つの独立した分布 ( q(b_{A}), q(b_{R}), q(b_{AR})) を探索する。以下のコードでは分布 q(b_{A}), q(b_{R}), q(b_{AR})にはそれぞれ正規分布を仮定している。

  •  p(b_{A}, b_{R}, b_{AR} | X) \approx q(b_{A}) q(b_{R}) q(b_{AR})

このときModelとGuide (近似事後分布) はPyroで以下のように実装できる。Model/Guideを構成する要素 (Primitives) については後述する。

import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist

def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.), constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.), constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', torch.randn(3))
    weights_scale = pyro.param('weights_scale', torch.ones(3), constraint=constraints.positive)
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))

変分推論 (Variational inference)

変分推論については多くの書籍・ウェブサイトで説明されているので、言及する必要はないかもしれないが、自分の中で腑に落ちた説明を以下に残しておく。

N個のパラメータ  Theta = (\theta_{1},...,\theta_{N}) を含むモデルを考える。複雑なモデルである場合、パラメータの事後分布 p(\Theta|X)は解析的に求めることができないことが多い。*1 このとき変分推論ではパラメータ ( \Theta = (\theta_{1},...,\theta_{N})) の事後分布 p(\Theta|X)を近似する近似事後分布 q(\Theta)を探索する。以下のKLダイバージェンスが最小となる分布 q(\Theta)を探す。

 q(\Theta) = \rm{argmin}_{q(\Theta)} \rm{KL} \lbrack q(\Theta) || p(\Theta|X) \rbrack \tag{1}

ここでは平均場近似による近似がよく使われる。平均場近似では各パラメータ \theta_{i}の分布が互いに独立していて、以下が成り立つ条件を仮定する。

 q(\Theta) = \Pi_{i} q(\theta_{i}) \tag{2}

式(1)の最適解を探索するには以下の変分下限(ELBO)  L \lbrack q(\Theta) \rbrack を最大化する q(\Theta)を探索すれば良い。

 L \lbrack q(\Theta) \rbrack = \int q(\Theta) \ln \frac{p(X, \Theta)}{q(\Theta)} d\Theta \tag{3}

この根拠は以下のように説明できる。 p(X)は以下のように変形できる。左辺 \ln p(X)はデータ Xに対して一定の値なので、\rm{L} \lbrack q(\Theta) \rbrackを最大化することは、KLダイバージェンスを最小化することと等しいと考えられる。

{
\begin{align}
\ln p(X) &= \int q(\Theta) \ln p(X) d\Theta \tag{4} \\\
&= \int q(\Theta) \ln \frac{p(X, \Theta)}{p(\Theta|X)} d\Theta \\\
&= \int q(\Theta) \ln \frac{q(\Theta)p(X, \Theta)}{q(\Theta)p(\Theta|X)} d\Theta \\\
&= \int q(\Theta) \ln \frac{p(X, \Theta)}{q(\Theta)} d\Theta - \int q(\Theta) \ln \frac{p(\Theta|X)}{q(\Theta)} d\Theta \\\
&= \rm{L} \lbrack q(\Theta) \rbrack + \rm{KL} \lbrack q(\Theta) || p(\Theta|X) \rbrack
\end{align}
}

Primitives

Pyroでのモデルの構築に使われる基本単位 (Primitivesと呼ばれる) についてまとめる。

pyro.param

  • パラメータを定義する際に用いる関数
  • 引数
    • name: パラメータの名前 (ParamStoreDictでの管理に使われる)
    • init_tensor: 初期値
    • constraint: 制約条件 (e.g. 正値)
  • パラメータはグローバルで管理されParamStoreDictと呼ばれる領域に保存される。*2
    • pyro.get_param_store()ParamStoreDictの内容をdict型で取得できる
    • pyro.clear_param_store()ParamStoreDictの内容を消去できる。学習後のパラメータを消去したい際に用いる。
  • ParamStoreDictのkeyにはpyro.paramの第一引数"name"の情報が使われるので、パラメータを宣言する際は、他のパラメータと同じ"name"を使わないようにしなければならない。
a_loc = pyro.param('a_loc', torch.tensor(0.))
d = pyro.get_param_store()
d["a_loc"]
tensor(0., requires_grad=True)

パラメータは以下の方法でも取得できる。

pyro.param("a_loc").item()

pyro.sample

  • 確率変数を定義する際に用いる関数
  • 引数
    • name: 確率変数の名前。これをもとにModelの分布とGuideの近似分布を対応させる
    • fn: 分布 (pyro.distributionsのクラスが用いられる)
    • obs: 実測値 ELBOの計算に使われる (default: None)
  • 返値
    • obsが空(None)の場合: 分布fnからサンプリングした値を返す
    • obsが空でない場合: obsに指定した値が返される
a = pyro.sample("a", dist.Normal(0., 1.))
a
tensor(0.5161)

pyro.plate

  • 繰り返し(グラフィカルモデルにおけるプレート)を定義する関数
    • e.g. N個の遺伝子の発現量を1つずつ同じ分布でモデル化する場合
    • pyro.plate内部の確率変数は(特定の条件下で)独立とみなされる
      • 例えば以下のコードでは特定のmean/sigmaにおいてobs内の各変数は独立といえる
      • 並列化により計算を高速化できる
  • 引数
    • name: プレートの名前
    • size: 繰り返しのサイズ
    • subsample_size: ミニバッチのサイズ (default: size)
    • dim: plateをどの次元に対応させるか (default: -1)
      • dim=-1の場合はbatch_shape (後述) のうち一番後ろ (後ろから1番目) をPlateに対応させる
mean = pyro.sample("mean", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))

with pyro.plate("data", 5) as idx:
    y = pyro.sample("y", dist.Normal(mean, sigma))
y.shape
torch.Size([5])

batch_shapeとevent_shape

pyro.distributionモジュールの分布クラス (e.g. dist.Normal) は、パラメータに複数の値を与えたり、メソッドexpandを使うことで、複数の分布を同時に定義することができる。このとき、メソッドsampleや、上述のpyro.sampleにより、複数の分布からサンプリングされた値をtorch.tensor型のオブジェクトとして取得することができる。例えば、正規分布 (dist.Normal) のパラメータ ( \mu, \sigma) に4つの値を与えてると、4つの正規分布が同時に定義され、各分布からサンプリングされた値をサイズ4のテンソルで取得することができる。

means = torch.tensor([0.0, 1.0, 2.0, 3.0])
sds = torch.tensor([1.0, 1.0, 1.0, 1.0])
d = dist.Normal(means, sds)
x = pyro.sample("x", d)
x
tensor([0.2222, 1.5773, 0.2363, 3.2264])

これに関連する重要な概念として、分布クラスはbatch_shapeevent_shapeという属性を持つ。batch_shapeは分布の数を表し、event_shapeは各分布から出力される値の次元数を表す。分布クラスからサンプリングされた値の次元は「batch_shape + event_shape」の形になる (event_shapeが末尾の次元に対応する)。

以下の2つのケースは似ているが、異なるbatch_shapeevent_shapeとなる。①4つの正規分布を作る場合は、batch_shapeが[4]に対して、event_shapeは (=[1])となる。②の4次元正規分布を作る場合は、①と同様に4次元の値を返す分布が作られるが、1つの分布として扱われるため、batch_shapeが、event_shapeが[4]となる。

① 4つの正規分布を作る場合

means = torch.tensor([0.0, 1.0, 2.0, 3.0])
sds = torch.tensor([1.0, 1.0, 1.0, 1.0])
d = dist.Normal(means, sds)
d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([4]), torch.Size([]), torch.Size([4]))

② 1つの4次元正規分布を作る場合

means = torch.tensor([0.0, 1.0, 2.0, 3.0])
sds = torch.eye(4)
d = dist.MultivariateNormal(means, sds)
d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([]), torch.Size([4]), torch.Size([4]))

to_event

分布クラスのメソッドto_eventを使うと、複数の独立した確率分布を、一つの多次元確率分布として扱うことができるようになる。

① 3x4の独立したベルヌーイ分布

d = dist.Bernoulli(0.5 * torch.ones(3,4))
d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([3, 4]), torch.Size([]), torch.Size([3, 4]))

② 3つの独立した4次元ベルヌーイ分布

d = dist.Bernoulli(0.5 * torch.ones(3,4)).to_event(1) # 末尾1次元を非独立として扱う
d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([3]), torch.Size([4]), torch.Size([3, 4]))

pyro.plateの引数dimについて

上述の通りpyro.plateの引数dimはplateがどのbatch_shapeに対応するかを指定するのに使われる。例えば、以下のコードでは、event_shapeが[3, 4]のベルヌーイ分布が5つ (batch_shape=5) 作られる。サンプリングされる値の次元は「batch_shape + event_shape」の形になるので、この場合は[5, 3, 4]次元のテンソルが出力される。

with pyro.plate("a", 5, dim=-1):
    d = pyro.sample("d", dist.Bernoulli(0.5 * torch.ones(3,4)).to_event(2))
d.shape
torch.Size([5, 3, 4])

Step 2: 学習

ModelとGuideを実装したあとはパラメータの推論を行う。Pyroの変分推論では、Adamなどの確率的な最適化アルゴリズムを使って、事後分布にできるだけ近い近似分布を探索する (SVI: Stochasticなvariational inference)。はじめに、pyro.optimモジュールのクラスを使い最適化のアルゴリズム・条件 (e.g. learning rate) を設定し、SVIクラスでModel/Guide・Optimizer・損失関数 (この場合はELBO) を指定する。つづいてsvi.stepでパラメータの更新を行う。svi.stepはModel/Guideの引数となるデータを受け取り、パラメータを更新し、返り値として損失 (ELBO) を出力する。以下の事例では、for文で5000回パラメータの更新を行っている。

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.clear_param_store() # パラメータのリセット

optimizer = Adam({"lr": .05})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

num_iters = 5000
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        print("Elbo loss: {}".format(elbo))

Step 3: 事後分布の取得

pyro.infer.Predictiveを使うと以下のように事後分布から各パラメータをサンプリングできる。

from pyro.infer import Predictive
params = ["a", "bA", "bR", "bAR"]
pred_model = Predictive(model, guide, num_samples=2000, return_sites=params)
pred_sample = pred_model(is_cont_africa, ruggedness, log_gdp)
pred_sample["a"]
tensor([[9.1086],
        [9.1260],
        [9.1357],
        ...,
        [9.2670],
        [9.0693],
        [9.0717]])

*1:https://ntacoffee.com/variational-inference/

*2:どうしてこのようなデザインになったのだろうか