"""
This module provides samplers for probability distributions.
"""
import math
import jax.numpy as np
import numpyro.distributions as dist
from ergo.ppl import sample
def bernoulli(p=0.5, **kwargs):
return sample(dist.Bernoulli(probs=float(p)), **kwargs)
[docs]def normal(mean=0, stdev=1, **kwargs):
return sample(dist.Normal(mean, stdev), **kwargs)
[docs]def lognormal(loc=0, scale=1, **kwargs):
return sample(dist.LogNormal(loc, scale), **kwargs)
[docs]def halfnormal(stdev=1, **kwargs):
return sample(dist.HalfNormal(stdev), **kwargs)
[docs]def beta(a=1, b=1, **kwargs):
return sample(dist.Beta(a, b), **kwargs)
[docs]def categorical(ps, **kwargs):
return sample(Categorical(ps), **kwargs)
# Provide alternative parameterizations for primitive distributions
def Categorical(scores):
probs = scores / sum(scores)
return dist.Categorical(probs=probs)
def NormalFromInterval(low, high):
"""This assumes a centered 90% confidence interval, i.e. the left endpoint
marks 0.05% on the CDF, the right 0.95%."""
mean = (high + low) / 2
stdev = (high - mean) / 1.645
return dist.Normal(mean, stdev)
def HalfNormalFromInterval(high):
"""This assumes a 90% confidence interval starting at 0,
i.e. right endpoint marks 90% on the CDF"""
stdev = high / 1.645
return dist.HalfNormal(stdev)
def LogNormalFromInterval(low, high):
"""This assumes a centered 90% confidence interval, i.e. the left endpoint
marks 0.05% on the CDF, the right 0.95%."""
loghigh = math.log(high)
loglow = math.log(low)
mean = (loghigh + loglow) / 2
stdev = (loghigh - loglow) / (2 * 1.645)
return dist.LogNormal(mean, stdev)
def BetaFromHits(hits, total):
return dist.Beta(1 + hits, 1 + (total - hits))
# Alternative names and parameterizations for primitive distribution samplers
[docs]def normal_from_interval(low, high, **kwargs):
return sample(NormalFromInterval(low, high), **kwargs)
[docs]def lognormal_from_interval(low, high, **kwargs):
return sample(LogNormalFromInterval(low, high), **kwargs)
[docs]def halfnormal_from_interval(high, **kwargs):
return sample(HalfNormalFromInterval(high), **kwargs)
[docs]def beta_from_hits(hits, total, **kwargs):
return sample(BetaFromHits(hits, total), **kwargs)
[docs]def random_choice(options, ps=None):
if ps is None:
ps = np.full(len(options), 1 / len(options))
else:
ps = np.array(ps)
idx = sample(dist.Categorical(ps))
return options[idx]
[docs]def random_integer(min: int, max: int, **kwargs) -> int:
return int(math.floor(uniform(min, max, **kwargs).item()))
flip = bernoulli