I’m in the middle of producing an R package implementing a statistical model which finds a maximum-likelihood set of parameters \(\theta\) using Expectation Maximisation (EM). Implementing EM can be tricky, since bugs can have non-obvious consequences, such as converging to the wrong set of parameters or not converging at all. However, EM turns out to be very easy to write unit tests for.
In maximising the log posterior, EM iterates between two steps:
Find the expected value of the complete log likelihood: \[ Q(\theta | \theta^{(t)}) = \mathbb{E}_{\mathbf{Z} | \mathbf{X}, \theta^{(t)}}\left[ \log p(\mathbf{X}, \mathbf{Z} | \theta) \right] \]
Maximise this with respect to the parameters: \[ \theta^{(t+1)} = \underset{\theta}{\mathrm{argmax}} \, Q(\theta | \theta^{(t)}) \]
Typically (2) is done in closed form. This means taking some derivatives and doing some linear algebra by hand, then coding up the solution as the maximisation-step. That leaves a lot of room for mistakes! However, we can make sure that our M-step is correct by writing unit tests for the following:
- After maximization, the value of \(Q\) increases: \(Q(\theta^{(t+1)} | \theta^{(t)}) > Q(\theta^{(t)} | \theta^{(t)})\)
- The numerical gradient of \(Q(\theta^{(t+1)} | \theta^{(t)})\) with respect to \(\theta^{(t+1)}\) is a vector of zeroes
- The numerical Hessian of \(Q(\theta^{(t+1)} | \theta^{(t)})\) is negative definite
Similarly we can confirm the entire EM procedure is working by iterating until convergence, and checking numerically that we have reached a zero-gradient point of the likelihood (again with negative-definite Hessian). There is almost certainly an existing package for finding numerical gradients/Hessians in your language of choice; I use the numDeriv
R package.
For many models, step (1) reduces to finding expectations of the latent variables \(\mathbf{Z}\). Often this expectation reduces to a particularly simple form for certain values of \(\theta\). I have found it useful to check that my function computing expectations agrees with these simplifying solutions, and write these up as more unit tests. Often a number of simpler forms exist which involve setting different components of \(\theta\) to 0 or \(\mathit{I}\), and this can help diagnose which elements of \(\theta\) are problematic.