"""
This module provides a few lightweight wrappers around probabilistic
programming primitives from Numpyro.
"""
import functools
from typing import Dict, List
import jax
import jax.numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.primitives import Messenger
import pandas as pd
from tqdm.autonotebook import tqdm
# Random numbers
_RNG_KEY = jax.random.PRNGKey(0)
def onetime_rng_key():
global _RNG_KEY
current_key, _RNG_KEY = jax.random.split(_RNG_KEY, 2)
return current_key
# Automatic naming of sampling sites
class autoname(Messenger):
"""
If multiple sampling sites have the same name, automatically append a number and
increment it by 1 for each repeated occurence.
"""
def __enter__(self):
self._names = set()
super(autoname, self).__enter__()
def _increment_name(self, name, label):
while (name, label) in self._names:
try:
base, count_str = name.rsplit("__", maxsplit=1)
count = int(count_str) + 1
except ValueError:
base, count = name, 1
name = f"{base}__{count}"
return name
def process_message(self, msg):
if msg["type"] == "sample":
msg["name"] = self._increment_name(msg["name"], "sample")
def postprocess_message(self, msg):
if msg["type"] == "sample":
self._names.add((msg["name"], "sample"))
# Sampling from probability distributions
def sample(dist: dist.Distribution, name: str = None, **kwargs):
"""
Sample from a primitive distribution
:param dist: A Pyro distribution
:param name: Name to assign to this sampling site in the execution trace
:return: A sample from the distribution
"""
# If a value isn't explicitly named, generate an automatic name,
# relying on autoname handler for uniqueness.
if not name:
name = "_v"
# The rng key provided below is only used when no Numpyro seed handler
# is provided. This happens when we sample from distributions outside
# an inference context.
return numpyro.sample(name, dist, rng_key=onetime_rng_key(), **kwargs)
# Conditioning
def condition(cond: bool, name: str = None):
if not name:
name = "_c"
return numpyro.factor(name, 0 if cond else np.NINF)
# Record deterministic values in trace
[docs]def tag(value, name: str):
return numpyro.deterministic(name, value)
# Automatically record model return value in trace
def tag_output(model):
def wrapped():
value = model()
if value is not None:
tag(value, "output")
return value
return wrapped
# Memoization
memoized_functions = [] # FIXME: global state
def mem(func):
func = functools.lru_cache(None)(func)
memoized_functions.append(func)
return func
def clear_mem():
for func in memoized_functions:
func.cache_clear()
def handle_mem(model):
def wrapped(*args, **kwargs):
clear_mem()
return model(*args, **kwargs)
return wrapped
# Main inference function
def is_singleton_array(value):
return isinstance(value, np.DeviceArray) and value.size in ((1,), 1)
def is_factor(entry):
return (
entry.get("is_observed")
and entry.get("fn")
and isinstance(entry["fn"], numpyro.distributions.Unit)
)
def factor_score(entry):
return entry["fn"].log_factor
[docs]def run(model, num_samples=5000, ignore_untagged=True, rng_seed=0) -> pd.DataFrame:
"""
Run model forward, record samples for variables. Return dataframe
with one row for each execution.
"""
model = numpyro.handlers.trace(handle_mem(tag_output(autoname(model))))
with numpyro.handlers.seed(rng_seed=rng_seed):
samples: List[Dict[str, float]] = []
progress_bar = tqdm(total=num_samples)
progress_bar.update(0)
i = 0
while i < num_samples:
sample: Dict[str, float] = {}
trace = model.get_trace()
reject = False
for name in trace.keys():
entry = trace[name]
if entry["type"] in ("sample", "deterministic"):
if is_factor(entry):
score = factor_score(entry)
if score == np.NINF:
reject = True
break
elif score == 0:
pass
else:
raise NotImplementedError(
f"Weighted factors - got score {score}"
)
else:
if ignore_untagged and name.startswith("_"):
continue
value = entry["value"]
if is_singleton_array(value):
value = value.item() # FIXME
sample[name] = value
if reject:
continue
samples.append(sample)
i += 1
progress_bar.update(1)
progress_bar.close()
return pd.DataFrame(samples) # type: ignore