AePPL `Switch`-defined mixtures
Maybe a new functionality to AePPL `Switches` would be to play multiplayer games on it 🤔
Lately, I have been trying to work on issues 76 and 77 of AePPL in which we would like to extend the library’s mixture functionality. See PR 154 in AePPL.
Mixture modelling in AePPL
The overarching goal of AePPL is to retrieve correct log-probability functions of data-generating models. Diving a little bit more into the details, every data-generating model induces a hierarchical graph which can be build using Aesara’s symbolic mathematics toolbox. For instance, in PR 19, mixture models constructed via at.stack
or at.join
are currently supported:
import aesara.tensor as at
from aeppl import joint_logprob
srng = at.random.RandomStream(seed=2320)
I_rv = srng.bernoulli(0.5, name="I")
X1_rv = srng.normal(loc=-5, scale=0.1, name="X1")
X2_rv = srng.normal(loc=5, scale=0.1, name="X2")
Z1_rv = at.stack([X1_rv, X2_rv])[I_rv]
z_vv = Z1_rv.clone()
i_vv = I_rv.clone()
logp = joint_logprob({Z1_rv: z1_vv, I_rv: i_vv})
Effectively, we can only retrieve the log-probability of the appropriate mixture component if provided numerical values for value variables z_vv
and i_vv
. These log-probabilities are unmarginalized, that is that AePPL retrieves the log-probability of X1_rv
or X2_rv
at the value z_vv
depending on the index i_vv
.
Switch
mixtures
The Switch
Op
is an operator that take in three arguments: the index variable and both components of the mixture model. The index variable served as if the components were in an ifelse
condition. With a condition that is a dichotomous random variable and both branches that are stochastic as well, i.e. MeasurableVariable
s, this Switch
subgraph would be a mixture model and can be replaced by a MixtureRV
node. Thanks to the indexing functionality provided in expand_indices
that imitates NumPy’s advanced indexing logic, adding a graph rewrite for switch
and ifelse
mixtures is not very difficult.
Univariate components
Using the same I_rv
, X1_rv
, X2_rv
and corresponding value variables defined above, a Switch
mixture can be defined as followed:
srng = at.random.RandomStream(seed=2320)
I_rv = srng.bernoulli(0.5, name="I")
X1_rv = srng.normal(loc=-5, scale=0.1, name="X1")
X2_rv = srng.normal(loc=5, scale=0.1, name="X2")
Z2_rv = at.switch(I_rv, X1_rv, X2_rv)
z_vv = Z1_rv.clone()
i_vv = I_rv.clone()
logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
In AePPL, graphs rewrites via the node_rewriter
decorator (previously known as local_optimizer
since PR 1054 of Aesara to identify Elemwise
nodes whose scalar operator is a Switch
. Here, I_rv
, X1_rv
and X2_rv
are not provided any fancy size
arguments, so no indexing operations need to be involved.
Multi-dimensional inputs
For Switch
es, both indices and components can be non-scalars. However, for IfElse
mixtures whose logic is very similar, conditions/indices can only be scalar-valued. The identification of which elements in the two inputs as components are selected via indexing is non-trivial; I refer to the following NumPy examples as to what AePPL/Aesara is expected to yield. In the end, the AePPL mixture logic for subgraphs defined by MakeVector
and Join
, Op
s that combine two tensors, need to align with the Switch
/ifelse
correspondant indexing operation.
comp1 = np.arange(1, 13).reshape(3, 4)
comp2 = -comp1
np.where(
[0, 1, 0, 0],
comp1,
comp2
)
# array([[ -1, 2, -3, -4],
# [ -5, 6, -7, -8],
# [ -9, 10, -11, -12]])
The example above illustrates the expected behaviour when the index is a vector and components a 2D matrix. These should work with arbitrarily defined arrays.
Future Work
Future work entails:
- As of now, finish the
IfElse
mixture subgraph PRs. - Extend
MixtureRV
s defined byat.stack
s to retrieve their appropriate log-likelihood. - Continue work on (Truncated) Dirichlet Processes for our experimental package (
pymc-experimental
), but that’s taken a halt in progress…