AePPL `Switch`-defined mixtures

Maybe a new functionality to AePPL `Switches` would be to play multiplayer games on it 🤔

Lately, I have been trying to work on issues 76 and 77 of AePPL in which we would like to extend the library’s mixture functionality. See PR 154 in AePPL.

Mixture modelling in AePPL

The overarching goal of AePPL is to retrieve correct log-probability functions of data-generating models. Diving a little bit more into the details, every data-generating model induces a hierarchical graph which can be build using Aesara’s symbolic mathematics toolbox. For instance, in PR 19, mixture models constructed via at.stack or at.join are currently supported:

import aesara.tensor as at
from aeppl import joint_logprob

srng = at.random.RandomStream(seed=2320)

I_rv = srng.bernoulli(0.5, name="I")
X1_rv = srng.normal(loc=-5, scale=0.1, name="X1")
X2_rv = srng.normal(loc=5, scale=0.1, name="X2")

Z1_rv = at.stack([X1_rv, X2_rv])[I_rv]

z_vv = Z1_rv.clone()
i_vv = I_rv.clone()

logp = joint_logprob({Z1_rv: z1_vv, I_rv: i_vv})

Effectively, we can only retrieve the log-probability of the appropriate mixture component if provided numerical values for value variables z_vv and i_vv. These log-probabilities are unmarginalized, that is that AePPL retrieves the log-probability of X1_rv or X2_rv at the value z_vv depending on the index i_vv.

Switch mixtures

The Switch Op is an operator that take in three arguments: the index variable and both components of the mixture model. The index variable served as if the components were in an ifelse condition. With a condition that is a dichotomous random variable and both branches that are stochastic as well, i.e. MeasurableVariables, this Switch subgraph would be a mixture model and can be replaced by a MixtureRV node. Thanks to the indexing functionality provided in expand_indices that imitates NumPy’s advanced indexing logic, adding a graph rewrite for switch and ifelse mixtures is not very difficult.

Univariate components

Using the same I_rv, X1_rv, X2_rv and corresponding value variables defined above, a Switch mixture can be defined as followed:

srng = at.random.RandomStream(seed=2320)

I_rv = srng.bernoulli(0.5, name="I")
X1_rv = srng.normal(loc=-5, scale=0.1, name="X1")
X2_rv = srng.normal(loc=5, scale=0.1, name="X2")

Z2_rv = at.switch(I_rv, X1_rv, X2_rv)

z_vv = Z1_rv.clone()
i_vv = I_rv.clone()

logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})

In AePPL, graphs rewrites via the node_rewriter decorator (previously known as local_optimizer since PR 1054 of Aesara to identify Elemwise nodes whose scalar operator is a Switch. Here, I_rv, X1_rv and X2_rv are not provided any fancy size arguments, so no indexing operations need to be involved.

Multi-dimensional inputs

For Switches, both indices and components can be non-scalars. However, for IfElse mixtures whose logic is very similar, conditions/indices can only be scalar-valued. The identification of which elements in the two inputs as components are selected via indexing is non-trivial; I refer to the following NumPy examples as to what AePPL/Aesara is expected to yield. In the end, the AePPL mixture logic for subgraphs defined by MakeVector and Join, Ops that combine two tensors, need to align with the Switch/ifelse correspondant indexing operation.

comp1 = np.arange(1, 13).reshape(3, 4)
comp2 = -comp1

np.where(
    [0, 1, 0, 0],
    comp1,
    comp2
)
# array([[ -1,   2,  -3,  -4],
#        [ -5,   6,  -7,  -8],
#        [ -9,  10, -11, -12]])

The example above illustrates the expected behaviour when the index is a vector and components a 2D matrix. These should work with arbitrarily defined arrays.

Future Work

Future work entails:

  • As of now, finish the IfElse mixture subgraph PRs.
  • Extend MixtureRVs defined by at.stacks to retrieve their appropriate log-likelihood.
  • Continue work on (Truncated) Dirichlet Processes for our experimental package (pymc-experimental), but that’s taken a halt in progress…