Parameter inference for a simple SIR model



Mathieu Besançon made a nice blog post on Chris’ DifferentialEquations ecosystem, .

The notebook is online at .

One of the examples is simple model tracking individuals in a population with three states $S$, $I$, $R$.

Let $u = (u_1, u_2, u_2)$ be the vector of individuals in these states.

In the SIR interpretation of the model, where $I$, $S$ and $R$ stand for susceptible, infected, and recovered, the $I,S \mapsto 2I$ is understood as infection and $I \mapsto R$ as recovery (with subsequent immunity) of the individual.

Intepreting these transitions as continuous processes, the population can be modeled by a ordinary differential equation

$$ \frac{d}{dt} u(t) = \begin{bmatrix} -\alpha u_1(t)u_2(t)\ \alpha u_1(t)u_2(t) - \beta u_2(t)\ \beta u_2(t)\end{bmatrix} $$

In the following I show how to estimate the model parameters $\alpha$ and $\beta$ from observations of a noisy continuous version of the model.

I am using my package Bridge and StaticArray’s and some more.

using Bridge
using Plots
using LaTeXStrings
using StaticArrays

Stochastic model

The dynamics of the population is given by a stochastic differential equation

$$ d X_t = b(t, X_t)\,dt + \sigma(t, X_t)\, d W_t\qquad (\star) $$

with drift vector $b$ and dispersion coefficient $\sigma$

$$ b(t, u) = \begin{bmatrix} -\alpha u_1u_2\ \alpha u_1u_2 - \beta u_2\ \beta u_2\end{bmatrix} \qquad \sigma(t, u)= \begin{bmatrix} \sigma_1 u_1 u_2 & 0
-\sigma_1u_1u_2& -\sigma_2 u_2
0& \sigma_2 u_2 \end{bmatrix}. $$

with $\alpha = 0.8, \beta = 3.0, \sigma_1 = 0.07, \sigma_2 = 0.4$.

Because $\sigma$ is a $3x2$ matrix, there is a 2 dimensional noise process $W$ acting on all three components.

This equation models a closed system, because of the entries of $b$ and the columns of $\sigma$ add up to $0$. The population $u_1 + u_2 + u_3$ stays constant. See Mathieu’s explanations.

In Julia, using Bridge and StaticArrays, this SDE model reads as follows.

import Bridge: b, σ, a

struct CIR <: ContinuousTimeProcess{SVector{3,Float64}}
P = CIR(0.8, 3.0, 0.07, 0.4)

b(t, u, P::CIR) = @SVector [-P.α*u[1]*u[2], P.α*u[1]*u[2] - P.β*u[2], P.β*u[2]]
σ(t, u, P::CIR) = @SMatrix Float64[
     P.σ1*u[1]*u[2]      0.0 
    -P.σ1*u[1]*u[2]  -P.σ2*u[2]
                0.0   P.σ2*u[2]
σ (generic function with 20 methods)

Solving the system

The interpretation of eq. $(\star)$ is that the population follows ordinary differential equation $$ d u_t = b(t, u_t) dt $$ above, but with superimposed Gaussian noise such that

$$ X_{t+\Delta} \approx X_t + b(t, X_t) d\Delta + \sigma(t, X_t) \sqrt{\Delta} Z $$

where $Z$ is a two-dimensional standard normal vector. $\sqrt{\Delta}$ is the scaling of the noise typical for stochastic differential equations. This is the Euler-Maruyama scheme implemented in Bridge.

The following lines simulate and plot a 2d Brownian motion (Wiener process) $W$ with Bridge.

u₀ = @SVector [49.0, 1.0, 0.0]
tt = 0.0:0.001:1.0

W = sample(tt, Wiener{SVector{2,Float64}}())

plot(, first.(W.yy), label=L"W_1(t)")
plot!(, last.(W.yy), label=L"W_2(t)")


0.0 0.2 0.4 0.6 0.8 1.0 -1.5 -1.0 -0.5 0.0 t

Solve the SDE starting in $u_0$ on the grid $t_i = 0.001i$ using the Euler-Maruyama approximation.

X = solve(Bridge.EulerMaruyama(), u₀, W, P);
plot(, Bridge.mat(X.yy)', label=[L"X_1",L"X_2",L"X_3"])

0.0 0.2 0.4 0.6 0.8 1.0 0 10 20 30 40 t

@time solve!(Bridge.EulerMaruyama(), X, u₀, W, P);
  0.000138 seconds (4 allocations: 160 bytes)

*Quite fast. Try to beat that or ask Chris to appreciate. ;-) *

Parameter inference

Assume that we observe the path shown in the picture and want to now the intensities $\alpha$ and $\beta$ of the transitions and also the noise parameters $\sigma_1$, $\sigma_2$. In the Bayesian paradigm, we assume that there are Gaussian priors on $\alpha$ and $\beta$.

A nice feature of the model is that $\alpha$ and $\beta$ enter as multiplicative constants into the drift $b$. Therefore the posterior of the drift parameters given $\sigma_1$ and $\sigma_2$ is conjugate Gaussian.

The following function samples from the posterior of these parameters. It is written to be generally useful. If your drift is say

b(t, u, ::Example) = m*u + b

then define

paramgrad(t, u, ::Example) = u
paramintercept(t, u, ::Example) = b
paramgrad(t, u, ::CIR) = @SMatrix [-u[1]*u[2] 0.; u[1]*u[2] -u[2]; 0. u[2]]
paramintercept(t, u, ::CIR) = @SVector [0.0, 0.0, 0.0]
    conjugate_posterior(Y, Ξ, P)

Sample the posterior distribution of the conjugate drift parameters from path `Y`,
prior precision matrix `Ξ` under model `P` with non-conjugate parameters in `P` fixed.
function conjugate_posterior(Y, Ξ, P)
    t, y =[1], Y.yy[1]
    ϕ = paramgrad(t, y, P)
    mu = zero(ϕ[1, :])
    G = zero(mu*mu')

    for i in 1:length(Y)-1
        ϕ = paramgrad(t, y, P)
        Gϕ = SMatrix{3,2,Float64}(pinv(Matrix(a(t, y, P)))*ϕ) # a is sigma*sigma'. Todo: smoothing like this is very slow
        zi = ϕ'*Gϕ
        t2, y2 =[i + 1], Y.yy[i + 1]
        dy = y2 - y 
        ds = t2 - t
        mu = mu +'*(dy - paramintercept(t, y, P)*ds)
        t, y = t2, y2
        G = G +  zi*ds
    WW = G + Ξ

    WL = chol(Hermitian(WW))'
    th° = WL'\(randn(typeof(mu))+WL\mu)

Let us visualise that.

Π = [conjugate_posterior(X, 0.1*I, P) for i in 1:1000]
scatter(first.(Π), last.(Π), color="blue", markersize = 0.3, label="posterior samples")
scatter!([Ptrue.α], [Ptrue.β], color="red", label="truth")

0.6 0.7 0.8 0.9 1.0 1.5 2.0 2.5 3.0 3.5 α β posterior samples truth

Estimating $\sigma_1$, $\sigma_2$

Let’s now find estimators for the other two parameters. Let’s be pragmatic and just derive estimation equations assuming the drift is known.

From equation $(\star)$ and

$$ \sigma(t, u) = \begin{bmatrix} \sigma_1 u_1 u_2 & 0
-\sigma_1u_1u_2& -\sigma_2 u_2
0& \sigma_2 u_2 \end{bmatrix} $$

we find

$$ \frac{M^1_t}{X_t^1 X_t^2 \sqrt{\Delta t}} \sim N(0, \sigma_1^2) \quad\frac{M^3_t}{ X_t^2 \sqrt{\Delta t}} \sim N(0, \sigma_2^2) $$

where $X_t^1$, $X_t^2$, etc. denote the first, second, etc. component of $X$ and

$$ Mt = X{t+\Delta} - X_t - b(t, X_t) \Delta t $$

second(x) = getindex(x, 2)
M = [X.yy[i]-X.yy[i-1]-b([i], X.yy[i], P)*([i][i-1]) for i in 2:length(X)]

σ1est = std(first.(M)./sqrt.(diff([1:end-1])./second.(X.yy[1:end-1]))
@show Ptrue.σ1
@show σ1est;
Ptrue.σ1 = 0.07
σ1est = 0.07015579020497378
σ2est = std(last.(M)./sqrt.(diff([1:end-1]))
@show Ptrue.σ2
@show σ2est;
Ptrue.σ2 = 0.4
σ2est = 0.39761699891570274

MCMC for the parameters

One of the nice features of Julia is that direct implementations of MCMC algorithm are performant and quite easy to deal with.

Because we can sample the drift parameters from the posterior, but only derived estimation equations for $\sigma_1$, $\sigma_2$ this is a sampler with one Gibbs step and one EM-type step.

iterations = 1000
P = CIR(2.1, 5.0, 0.17, 0.1)
params = zeros(4, iterations)
M = zeros(SVector{3,Float64}, length(tt)-1,)
for iter in 1:iterations
    # step 1: sample conjugate parameter α and β conditional on σ1, σ2
    α, β = conjugate_posterior(X, 0.1*I, P)
    P = CIR(α, β, P.σ1, P.σ2)
    # step 2: estimate σ1, σ2

    for i in 2:length(X)
        M[i-1] = (X.yy[i]-X.yy[i-1]-b([i], X.yy[i], P)*([i][i-1])) 

    σ1 = std(first.(M)[1:end]./sqrt.(diff(tt))./first.(X.yy[1:end-1])./second.(X.yy[1:end-1]))
    σ2 = std(last.(M)./sqrt.(diff([1:end-1]))
    P = CIR(P.α, P.β, σ1, σ2)
    # record sample
    params[:, iter] = [α, β, σ1, σ2]
p1 = plot(params[1, :], linewidth=0.2, label = "posterior samples")
plot!(Ptrue.α*ones(iterations), label = LaTeXString("true \\alpha"))
plot!(Bridge.runmean(params[1, :]), label = "running mean")
p2 = plot(params[2, :], linewidth=0.2, label = "posterior samples")
plot!(Ptrue.β*ones(iterations), label = LaTeXString("true \\beta"))
plot!(Bridge.runmean(params[2, :]), label = "running mean")

200 400 600 800 1000 0.6 0.7 0.8 0.9 1.0 1.1 posterior samples true α running mean 200 400 600 800 1000 2.0 2.5 3.0 3.5 posterior samples true β running mean

p1 = plot(params[3, :], linewidth=0.2, label = "EM samples")
plot!(Ptrue.σ1*ones(iterations), label = LaTeXString("true \\sigma_1"))
plot!(Bridge.runmean(params[3, :]), label = "running mean")
p2 = plot(params[4, :], linewidth=0.2, label = "EM samples")
plot!(Ptrue.σ2*ones(iterations), label = LaTeXString("true \\sigma_2"))
plot!(Bridge.runmean(params[4, :]), label = "running mean")

200 400 600 800 1000 0.067 0.068 0.069 0.070 0.071 0.072 0.073 EM samples true σ 1 running mean 200 400 600 800 1000 0.385 0.390 0.395 0.400 EM samples true σ 2 running mean