Merged my first distribution class đźŽ‰
The first step to Dirichlet processes
Link to pull request: https://github.com/pymcdevs/pymc/pull/5200
After a long (two month) process, my PR on adding a distribution class has been merged! đź™‚ The whole process of taking the initiative to contribute to PyMC was rewarding, but definitely not easy.
Revisiting Dirichlet Processes
A random distribution \(G\) has a Dirichlet Process (DP) prior, denoted by \(G \sim \text{DP}(\alpha, G_0)\) if any finite partition \(\{A_i\}_{i=1}^n\) of the sample space \(\Theta\) follows a Dirichlet distribution as such:
\[\Big(G(A_1), \dots, G(A_n) \Big) \sim \text{Dir}\Big(\alpha G_0(A_1), \dots, \alpha G_0(A_n) \Big)\]where \(\alpha > 0\) is the concentration parameter and \(G_0\) is some base distribution, e.g. \(G_0 \equiv \mathcal{N}(0, 1)\). The higher the value of \(\alpha\), the smaller the weight values will be.
It can be initially difficult to understand why such a construction is useful, let alone nonparametric, but the idea is as followed. In Bayesian inference, we wish to perform inference by conditioning on observed data and looking at the posterior distribution of parameters of interest. However, by positing a DP prior on \(G\), we can effectively perform inference on the distribution \(G\) without positing any distributional assumption, hence rendering this construction nonparametric despite the need to specify \(G_0\). It is already worth mentioning that a DP prior poses some â€śstrongâ€ť restrictions, so a more common application of DPs are to posit them as priors in mixture modelling, but more on that at a later dateâ€¦
When it comes to sampling, there are many schemas with desirable properties that provide nice (conditional)posterior distributions (see Chinese Restaurant Process and Polya Urn). However, an inherent challenge to build a DP functionality to PyMC is to leverage its default sampling methods which are primarily gradientbased, i.e. Hamiltonian Monte Carlo (HMC) or some of its extensions if I remember correctly. As such, the most useful construction of a DP is to represent it as an infinite linear combination of weights obtained via a stickbreaking process and atoms, represented as Dirac delta distributions:
\[G = \sum_{h=1}^\infty w_h \delta_{m_h}\]where \(w_h = v_h \prod_{\ell < h} (1  v_\ell)\) where \(v_h \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)\) and \(m_h \stackrel{\text{i.i.d.}}{\sim} G_0\). This construction is particularly helpful since it is more easily vectorizable and plays into the strengths of HMC. Of course, as taught in undergraduate calculus, infinity is a concept, not a number, and our computers couldnâ€™t agree more. Instead of an infinite sum, we can express \(G\) using some large truncation parameter \(K\) and this finite approximation is justified by Ishwaran and James (2001):
\[G = \sum_{h=1}^K w_h \delta_{m_h} \, .\]From the same article, we have that the distribution of stickbreaking weights follows a generalized Dirichlet distribution, so that turns out to be the first step in implementing a submodule for DPs.
RandomVariable
classes in PyMC/Aesara
In PyMC, distributions are implemented as Distribution
classes with RandomVariable
instances rv_op
. These rv_op
have an rng_fn
method that is able to generate samples from the prior distribution and it serves as the basis of pm.sample_prior_predictive
.
class StickBreakingWeightsRV(RandomVariable):
name = "stick_breaking_weights"
ndim_supp = 1
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")
def make_node(self, rng, size, dtype, alpha, K):
alpha = at.as_tensor_variable(alpha)
K = at.as_tensor_variable(intX(K))
if alpha.ndim > 0:
raise ValueError("The concentration parameter needs to be a scalar.")
if K.ndim > 0:
raise ValueError("K must be a scalar.")
return super().make_node(rng, size, dtype, alpha, K)
def _infer_shape(self, size, dist_params, param_shapes=None):
alpha, K = dist_params
size = tuple(size)
return size + (K + 1,)
@classmethod
def rng_fn(cls, rng, alpha, K, size):
if K < 0:
raise ValueError("K needs to be positive.")
if size is None:
size = (K,)
elif isinstance(size, int):
size = (size,) + (K,)
else:
size = tuple(size) + (K,)
betas = rng.beta(1, alpha, size=size)
sticks = np.concatenate(
(
np.ones(shape=(size[:1] + (1,))),
np.cumprod(1  betas[..., :1], axis=1),
),
axis=1,
)
weights = sticks * betas
weights = np.concatenate(
(weights, 1  weights.sum(axis=1)[..., np.newaxis]),
axis=1,
)
return weights
First, the distinction between StickBreakingWeightsRV
and StickBreakingWeights
is important as the latter is the one that will be used under a pm.Model()
context manager. The rng_fn
essentially follows the mathematical stickbreaking construction in that \(K\) i.i.d. draws from a \(\text{Beta}(1, \alpha)\) are used to construct \(w_1, \dots, w_{K}\) with \(w_{K+1} = 1  \sum_{\ell=1}^K w_\ell\). An inherent challenge of this design is to abide by the existing OOP structure of the library since all the intricacies are not obvious. For instance, while size
is used to specify the dimension of the observations that we want to sample (also in StickBreakingWeights
below), we decided to provide the truncation parameter K
as an explicit argument rather than include as part of size
or shape
.
A distribution class for (truncated) stickbreaking weights
class StickBreakingWeights(Continuous):
# rv_op instance defined as: stickbreakingweights = StickBreakingWeights()
rv_op = stickbreakingweights
def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("transform", transforms.simplex)
return super().__new__(cls, name, *args, **kwargs)
@classmethod
def dist(cls, alpha, K, *args, **kwargs):
alpha = at.as_tensor_variable(floatX(alpha))
K = at.as_tensor_variable(intX(K))
assert_negative_support(alpha, "alpha", "StickBreakingWeights")
assert_negative_support(K, "K", "StickBreakingWeights")
return super().dist([alpha, K], **kwargs)
def get_moment(rv, size, alpha, K):
moment = (alpha / (1 + alpha)) ** at.arange(K)
moment *= 1 / (1 + alpha)
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=1)
if not rv_size_is_none(size):
moment_size = at.concatenate(
[
size,
[
K + 1,
],
]
)
moment = at.full(moment_size, moment)
return moment
def logp(value, alpha, K):
"""
Calculate logprobability of the distribution induced from the stickbreaking process
at specified value.
Parameters

value: numeric
Value for which logprobability is calculated.
Returns

TensorVariable
"""
logp = at.sum(
at.log(
at.cumsum(
value[..., ::1],
axis=1,
)
),
axis=1,
)
logp += K * betaln(1, alpha)
logp += alpha * at.log(value[..., 1])
logp = at.switch(
at.or_(
at.any(
at.and_(at.le(value, 0), at.ge(value, 1)),
axis=1,
),
at.or_(
at.bitwise_not(at.allclose(value.sum(1), 1)),
at.neq(value.shape[1], K + 1),
),
),
np.inf,
logp,
)
return check_parameters(
logp,
alpha > 0,
K > 0,
msg="alpha > 0, K > 0",
)
The length of this Distribution
class may seem a bit daunting, but here is a summary of each method.

__new__
: Here, we suggest the default transformation that would help sampling. Because each weight is between 0 and 1, the range of the transformed â€śweightâ€ť is not \(\mathbb{R}\) so the sampler cannot go into â€śdangerâ€ť territories. 
dist
: It is a bottleneck method which takes inputs that parametrize the distribution. At this point, inputs can be â€śanythingâ€ť, i.e. they have not necessarily been transformed intoaesara.tensor.var.TensorVariable
instances yet. 
get_moment
: An approach that PyMC has recently taken is to initialize samplers at the mean of each distribution (with no observations). For that purpose,get_moment
has been introduced and, given that stickbreaking weights are a product of linear transformations of i.i.d. Beta random variables, we have that:
for all \(h = 1, \dots, K\) and \(\mathbb{E}\left[w_{K+1}\right] = \left(\frac{\alpha}{1 + \alpha}\right)^{K}\).

logp
: While it may seem naive, this is the most important method as this is exactly what allowspm.sample()
to do its magic. Here, I provide the log distribution of a generalized Dirichlet distribution with \(b_h = 1\) and \(a_h = \alpha\) for all \(h\) with respect to the density provided in the Wiki article. Note that we assume thatvalue
can be of any dimension. Everything else in the method (whatâ€™s inat.switch
andcheck_parameters
) are to ensure that inputs and parameters are all valid with respect to the distributionâ€™s constraints; errors would be raised if we provided something likealpha = 2
, for instance.
A small comment that I added unit tests regarding this distribution: testing shapes, logp
, rng_fn
and their extensions to multidimensional samples. I do not talk about them here, but unit tests are quite important and, given that I didnâ€™t know what they were before, it was an interesting learning experience.
Some final comments
As I am pursuing a PhD, this experience of contributing to open source has been extremely rewarding and educational. It was not an easy process, but it was enjoyable, as I was able to learn a lot about the process of â€śproperlyâ€ť contributing code (in quotation marks because I am no software engineer and who am I to say that this is the proper method although this is the by far most structure that I have ever experienced).
The next steps are to properly simulate data from a DP to better understand how we should go about building an API for it. The ultimate goal is to have a working submodule and a nice class for DP Mixtures, which will probably be the most useful extension of DPs in practice.
Thanks
Firstly, a big thanks to you for reading this post! While this was a nice exercise for myself, itâ€™s nice to have people read on my work (and comment on it if you want!). I am proud to have coded everything, but none of this was possible without the help of these amazing people, who were always there to share their ideas and comments about my code:
Last but not least, a big thanks to the entire PyMC community for answering all my questions and all their ongoing efforts, whether they are noticeable or not đźš€