Pyro
PyroはUberが作った確率的プログラミングのためのPythonライブラリ。PyTorchをベースにしており、PyTorchで実装したニューラルネットを組み込んだ深層生成モデルの開発ができる。最近、シングルセルデータやマルチオミクスデータを対象とした深層生成モデルが多数報告されており、自分でも実装&既存のモデルをカスタマイズできるようになりたいと考えていたので勉強してみた。Pyroの公式ドキュメントは充実しており、読んでいるだけで面白く勉強になるのだが、試している中で疑問に思ったことが多々あったので、メモを残していこうと思う。
Stanのような歴史の長い確率的プログラミング言語と比較すると、日本語で書かれている資料は少ない。多くの日本語の情報は「ベイズ線形回帰をPyroで実装してみた」みたいな内容が多いので、このシリーズでは、より複雑なモデルを実装することにトライしていこうと思う。
本記事では以下のページが参考になった。
- General
- 変分推論
実装の流れ
Step 1: ModelとGuideの定義
Pyroでは、はじめにModelとGuideを関数として実装する。Modelにはモデルと各パラメータの事前分布を、Guideには変分推論 (後述) で用いる近似事後分布を定義する。
Pyroのチュートリアルで紹介されている線形回帰の事例をもとに実装の流れをさらっていく。この事例では国土の地形の凹凸の多さ(TRI: Terrain Ruggedness Index)とGDPの関係を線形回帰で調べており、アフリカ以外の地域ではTRIが高い (地形の凹凸が高い) ほどGDPが低いのに対して、アフリカではそれと逆の傾向があることを示している。ここでは以下のモデルに対してベイズ推論を行う。変数はアフリカ地域であれば1が入るバイナリ変数である。
このケースであれば各パラメータの事後分布は解析的に求めることができるが、あえて変分推論で近似事後分布を求める。ここでは変分ベイズでよく用いられる平均場近似 (後述) という近似法を使う。事後分布を近似する、3つの独立した分布 () を探索する。以下のコードでは分布にはそれぞれ正規分布を仮定している。
このときModelとGuide (近似事後分布) はPyroで以下のように実装できる。Model/Guideを構成する要素 (Primitives) については後述する。
import torch from torch.distributions import constraints import pyro import pyro.distributions as dist def model(is_cont_africa, ruggedness, log_gdp): a = pyro.sample("a", dist.Normal(0., 10.)) b_a = pyro.sample("bA", dist.Normal(0., 1.)) b_r = pyro.sample("bR", dist.Normal(0., 1.)) b_ar = pyro.sample("bAR", dist.Normal(0., 1.)) sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness with pyro.plate("data", len(ruggedness)): pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp) def guide(is_cont_africa, ruggedness, log_gdp): a_loc = pyro.param('a_loc', torch.tensor(0.)) a_scale = pyro.param('a_scale', torch.tensor(1.), constraint=constraints.positive) sigma_loc = pyro.param('sigma_loc', torch.tensor(1.), constraint=constraints.positive) weights_loc = pyro.param('weights_loc', torch.randn(3)) weights_scale = pyro.param('weights_scale', torch.ones(3), constraint=constraints.positive) a = pyro.sample("a", dist.Normal(a_loc, a_scale)) b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0])) b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1])) b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2])) sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
変分推論 (Variational inference)
変分推論については多くの書籍・ウェブサイトで説明されているので、言及する必要はないかもしれないが、自分の中で腑に落ちた説明を以下に残しておく。
N個のパラメータ を含むモデルを考える。複雑なモデルである場合、パラメータの事後分布は解析的に求めることができないことが多い。*1 このとき変分推論ではパラメータ () の事後分布を近似する近似事後分布を探索する。以下のKLダイバージェンスが最小となる分布を探す。
ここでは平均場近似による近似がよく使われる。平均場近似では各パラメータの分布が互いに独立していて、以下が成り立つ条件を仮定する。
式(1)の最適解を探索するには以下の変分下限(ELBO) を最大化するを探索すれば良い。
この根拠は以下のように説明できる。は以下のように変形できる。左辺はデータに対して一定の値なので、を最大化することは、KLダイバージェンスを最小化することと等しいと考えられる。
Primitives
Pyroでのモデルの構築に使われる基本単位 (Primitivesと呼ばれる) についてまとめる。
pyro.param
- パラメータを定義する際に用いる関数
- 引数
- name: パラメータの名前 (ParamStoreDictでの管理に使われる)
- init_tensor: 初期値
- constraint: 制約条件 (e.g. 正値)
- パラメータはグローバルで管理され
ParamStoreDict
と呼ばれる領域に保存される。*2pyro.get_param_store()
でParamStoreDict
の内容をdict型で取得できるpyro.clear_param_store()
でParamStoreDict
の内容を消去できる。学習後のパラメータを消去したい際に用いる。
ParamStoreDict
のkeyにはpyro.param
の第一引数"name"の情報が使われるので、パラメータを宣言する際は、他のパラメータと同じ"name"を使わないようにしなければならない。
a_loc = pyro.param('a_loc', torch.tensor(0.)) d = pyro.get_param_store() d["a_loc"]
tensor(0., requires_grad=True)
パラメータは以下の方法でも取得できる。
pyro.param("a_loc").item()
pyro.sample
- 確率変数を定義する際に用いる関数
- 引数
- name: 確率変数の名前。これをもとにModelの分布とGuideの近似分布を対応させる
- fn: 分布 (
pyro.distributions
のクラスが用いられる) - obs: 実測値 ELBOの計算に使われる (default: None)
- 返値
- obsが空(None)の場合: 分布
fn
からサンプリングした値を返す - obsが空でない場合: obsに指定した値が返される
- obsが空(None)の場合: 分布
a = pyro.sample("a", dist.Normal(0., 1.)) a
tensor(0.5161)
pyro.plate
- 繰り返し(グラフィカルモデルにおけるプレート)を定義する関数
- e.g. N個の遺伝子の発現量を1つずつ同じ分布でモデル化する場合
- pyro.plate内部の確率変数は(特定の条件下で)独立とみなされる
- 例えば以下のコードでは特定のmean/sigmaにおいてobs内の各変数は独立といえる
- 並列化により計算を高速化できる
- 引数
- name: プレートの名前
- size: 繰り返しのサイズ
- subsample_size: ミニバッチのサイズ (default: size)
- dim: plateをどの次元に対応させるか (default: -1)
dim=-1
の場合はbatch_shape (後述) のうち一番後ろ (後ろから1番目) をPlateに対応させる
mean = pyro.sample("mean", dist.Normal(0., 1.)) sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) with pyro.plate("data", 5) as idx: y = pyro.sample("y", dist.Normal(mean, sigma)) y.shape
torch.Size([5])
batch_shapeとevent_shape
pyro.distribution
モジュールの分布クラス (e.g. dist.Normal
) は、パラメータに複数の値を与えたり、メソッドexpand
を使うことで、複数の分布を同時に定義することができる。このとき、メソッドsample
や、上述のpyro.sample
により、複数の分布からサンプリングされた値をtorch.tensor型のオブジェクトとして取得することができる。例えば、正規分布 (dist.Normal
) のパラメータ () に4つの値を与えてると、4つの正規分布が同時に定義され、各分布からサンプリングされた値をサイズ4のテンソルで取得することができる。
means = torch.tensor([0.0, 1.0, 2.0, 3.0]) sds = torch.tensor([1.0, 1.0, 1.0, 1.0]) d = dist.Normal(means, sds) x = pyro.sample("x", d) x
tensor([0.2222, 1.5773, 0.2363, 3.2264])
これに関連する重要な概念として、分布クラスはbatch_shape
とevent_shape
という属性を持つ。batch_shape
は分布の数を表し、event_shape
は各分布から出力される値の次元数を表す。分布クラスからサンプリングされた値の次元は「batch_shape + event_shape」の形になる (event_shapeが末尾の次元に対応する)。
以下の2つのケースは似ているが、異なるbatch_shape
とevent_shape
となる。①4つの正規分布を作る場合は、batch_shapeが[4]に対して、event_shapeは (=[1])となる。②の4次元正規分布を作る場合は、①と同様に4次元の値を返す分布が作られるが、1つの分布として扱われるため、batch_shapeが、event_shapeが[4]となる。
① 4つの正規分布を作る場合
means = torch.tensor([0.0, 1.0, 2.0, 3.0]) sds = torch.tensor([1.0, 1.0, 1.0, 1.0]) d = dist.Normal(means, sds) d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([4]), torch.Size([]), torch.Size([4]))
② 1つの4次元正規分布を作る場合
means = torch.tensor([0.0, 1.0, 2.0, 3.0]) sds = torch.eye(4) d = dist.MultivariateNormal(means, sds) d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([]), torch.Size([4]), torch.Size([4]))
to_event
分布クラスのメソッドto_event
を使うと、複数の独立した確率分布を、一つの多次元確率分布として扱うことができるようになる。
① 3x4の独立したベルヌーイ分布
d = dist.Bernoulli(0.5 * torch.ones(3,4)) d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([3, 4]), torch.Size([]), torch.Size([3, 4]))
② 3つの独立した4次元ベルヌーイ分布
d = dist.Bernoulli(0.5 * torch.ones(3,4)).to_event(1) # 末尾1次元を非独立として扱う d.batch_shape, d.event_shape, d.sample().shape
(torch.Size([3]), torch.Size([4]), torch.Size([3, 4]))
pyro.plateの引数dimについて
上述の通りpyro.plateの引数dim
はplateがどのbatch_shape
に対応するかを指定するのに使われる。例えば、以下のコードでは、event_shapeが[3, 4]のベルヌーイ分布が5つ (batch_shape=5) 作られる。サンプリングされる値の次元は「batch_shape + event_shape」の形になるので、この場合は[5, 3, 4]次元のテンソルが出力される。
with pyro.plate("a", 5, dim=-1): d = pyro.sample("d", dist.Bernoulli(0.5 * torch.ones(3,4)).to_event(2)) d.shape
torch.Size([5, 3, 4])
Step 2: 学習
ModelとGuideを実装したあとはパラメータの推論を行う。Pyroの変分推論では、Adamなどの確率的な最適化アルゴリズムを使って、事後分布にできるだけ近い近似分布を探索する (SVI: Stochasticなvariational inference)。はじめに、pyro.optim
モジュールのクラスを使い最適化のアルゴリズム・条件 (e.g. learning rate) を設定し、SVIクラスでModel/Guide・Optimizer・損失関数 (この場合はELBO) を指定する。つづいてsvi.step
でパラメータの更新を行う。svi.step
はModel/Guideの引数となるデータを受け取り、パラメータを更新し、返り値として損失 (ELBO) を出力する。以下の事例では、for文で5000回パラメータの更新を行っている。
from pyro.infer import SVI, Trace_ELBO from pyro.optim import Adam pyro.clear_param_store() # パラメータのリセット optimizer = Adam({"lr": .05}) svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) num_iters = 5000 for i in range(num_iters): elbo = svi.step(is_cont_africa, ruggedness, log_gdp) if i % 500 == 0: print("Elbo loss: {}".format(elbo))
Step 3: 事後分布の取得
pyro.infer.Predictive
を使うと以下のように事後分布から各パラメータをサンプリングできる。
from pyro.infer import Predictive params = ["a", "bA", "bR", "bAR"] pred_model = Predictive(model, guide, num_samples=2000, return_sites=params) pred_sample = pred_model(is_cont_africa, ruggedness, log_gdp) pred_sample["a"]
tensor([[9.1086], [9.1260], [9.1357], ..., [9.2670], [9.0693], [9.0717]])
*1:https://ntacoffee.com/variational-inference/
*2:どうしてこのようなデザインになったのだろうか