Parameter inference for a simple SIR model

Introduction

Mathieu Besançon made a nice blog post on Chris’ DifferentialEquations ecosystem, https://mbesancon.github.io/post/2017-12-14-diffeq-julia/ .

The notebook is online at https://gist.github.com/mschauer/9da19ea563d072cb683aba80ab61efb4 .

One of the examples is simple model tracking individuals in a population with three states $S$, $I$, $R$.

Let $u = (u_1, u_2, u_2)$ be the vector of individuals in these states.

• At any moment an individual of state $A$ can convince a individual of state $B$ to join the fraction $A$ $$S,I \mapsto 2I$$ This reaction happens proportional to the number of possible pairs $A, B$ which can match up, so proportional to $u_1 u_2$ with proportionality constant $\alpha$.

• At any moment an individual of state $B$ can attain state $R$ $$I \mapsto R$$ This reaction happens proportional to the number of possible individuals in state $B$, so proportional to $u_2$ with proportionality constant $\beta$.

In the SIR interpretation of the model, where $I$, $S$ and $R$ stand for susceptible, infected, and recovered, the $I,S \mapsto 2I$ is understood as infection and $I \mapsto R$ as recovery (with subsequent immunity) of the individual.

Intepreting these transitions as continuous processes, the population can be modeled by a ordinary differential equation

$$\frac{d}{dt} u(t) = \begin{bmatrix} -\alpha u_1(t)u_2(t)\ \alpha u_1(t)u_2(t) - \beta u_2(t)\ \beta u_2(t)\end{bmatrix}$$

In the following I show how to estimate the model parameters $\alpha$ and $\beta$ from observations of a noisy continuous version of the model.

I am using my package Bridge and StaticArray’s and some more.

using Bridge
using Plots
using LaTeXStrings
using StaticArrays

Stochastic model

The dynamics of the population is given by a stochastic differential equation

$$d X_t = b(t, X_t)\,dt + \sigma(t, X_t)\, d W_t\qquad (\star)$$

with drift vector $b$ and dispersion coefficient $\sigma$

$$b(t, u) = \begin{bmatrix} -\alpha u_1u_2\ \alpha u_1u_2 - \beta u_2\ \beta u_2\end{bmatrix} \qquad \sigma(t, u)= \begin{bmatrix} \sigma_1 u_1 u_2 & 0 -\sigma_1u_1u_2& -\sigma_2 u_2 0& \sigma_2 u_2 \end{bmatrix}.$$

with $\alpha = 0.8, \beta = 3.0, \sigma_1 = 0.07, \sigma_2 = 0.4$.

Because $\sigma$ is a $3x2$ matrix, there is a 2 dimensional noise process $W$ acting on all three components.

This equation models a closed system, because of the entries of $b$ and the columns of $\sigma$ add up to $0$. The population $u_1 + u_2 + u_3$ stays constant. See Mathieu’s explanations.

In Julia, using Bridge and StaticArrays, this SDE model reads as follows.

import Bridge: b, σ, a

struct CIR <: ContinuousTimeProcess{SVector{3,Float64}}
α::Float64
β::Float64
σ1::Float64
σ2::Float64
end
P = CIR(0.8, 3.0, 0.07, 0.4)

b(t, u, P::CIR) = @SVector [-P.α*u[1]*u[2], P.α*u[1]*u[2] - P.β*u[2], P.β*u[2]]
σ(t, u, P::CIR) = @SMatrix Float64[
P.σ1*u[1]*u[2]      0.0
-P.σ1*u[1]*u[2]  -P.σ2*u[2]
0.0   P.σ2*u[2]
]
σ (generic function with 20 methods)


Solving the system

The interpretation of eq. $(\star)$ is that the population follows ordinary differential equation $$d u_t = b(t, u_t) dt$$ above, but with superimposed Gaussian noise such that

$$X_{t+\Delta} \approx X_t + b(t, X_t) d\Delta + \sigma(t, X_t) \sqrt{\Delta} Z$$

where $Z$ is a two-dimensional standard normal vector. $\sqrt{\Delta}$ is the scaling of the noise typical for stochastic differential equations. This is the Euler-Maruyama scheme implemented in Bridge.

The following lines simulate and plot a 2d Brownian motion (Wiener process) $W$ with Bridge.

u₀ = @SVector [49.0, 1.0, 0.0]
tt = 0.0:0.001:1.0

W = sample(tt, Wiener{SVector{2,Float64}}())

plot(W.tt, first.(W.yy), label=L"W_1(t)")
plot!(W.tt, last.(W.yy), label=L"W_2(t)")

xlabel!("t")

Solve the SDE starting in $u_0$ on the grid $t_i = 0.001i$ using the Euler-Maruyama approximation.

X = solve(Bridge.EulerMaruyama(), u₀, W, P);
plot(X.tt, Bridge.mat(X.yy)', label=[L"X_1",L"X_2",L"X_3"])
xlabel!("t")

@time solve!(Bridge.EulerMaruyama(), X, u₀, W, P);
  0.000138 seconds (4 allocations: 160 bytes)


*Quite fast. Try to beat that or ask Chris to appreciate. ;-) *

Parameter inference

Assume that we observe the path shown in the picture and want to now the intensities $\alpha$ and $\beta$ of the transitions and also the noise parameters $\sigma_1$, $\sigma_2$. In the Bayesian paradigm, we assume that there are Gaussian priors on $\alpha$ and $\beta$.

A nice feature of the model is that $\alpha$ and $\beta$ enter as multiplicative constants into the drift $b$. Therefore the posterior of the drift parameters given $\sigma_1$ and $\sigma_2$ is conjugate Gaussian.

The following function samples from the posterior of these parameters. It is written to be generally useful. If your drift is say

b(t, u, ::Example) = m*u + b


then define

paramgrad(t, u, ::Example) = u
paramintercept(t, u, ::Example) = b

paramgrad(t, u, ::CIR) = @SMatrix [-u[1]*u[2] 0.; u[1]*u[2] -u[2]; 0. u[2]]
paramintercept(t, u, ::CIR) = @SVector [0.0, 0.0, 0.0]
"""
conjugate_posterior(Y, Ξ, P)

Sample the posterior distribution of the conjugate drift parameters from path Y,
prior precision matrix Ξ under model P with non-conjugate parameters in P fixed.
"""
function conjugate_posterior(Y, Ξ, P)
t, y = Y.tt[1], Y.yy[1]
ϕ = paramgrad(t, y, P)
mu = zero(ϕ[1, :])
G = zero(mu*mu')

for i in 1:length(Y)-1
ϕ = paramgrad(t, y, P)
Gϕ = SMatrix{3,2,Float64}(pinv(Matrix(a(t, y, P)))*ϕ) # a is sigma*sigma'. Todo: smoothing like this is very slow
zi = ϕ'*Gϕ
t2, y2 = Y.tt[i + 1], Y.yy[i + 1]
dy = y2 - y
ds = t2 - t
mu = mu + Gϕ'*(dy - paramintercept(t, y, P)*ds)
t, y = t2, y2
G = G +  zi*ds
end
WW = G + Ξ

WL = chol(Hermitian(WW))'
th° = WL'\(randn(typeof(mu))+WL\mu)
end        
conjugate_posterior


Let us visualise that.

Π = [conjugate_posterior(X, 0.1*I, P) for i in 1:1000]
scatter(first.(Π), last.(Π), color="blue", markersize = 0.3, label="posterior samples")
scatter!([Ptrue.α], [Ptrue.β], color="red", label="truth")
ylabel!(LaTeXString("\\beta"))
xlabel!(LaTeXString("\\alpha"))

Estimating $\sigma_1$, $\sigma_2$

Let’s now find estimators for the other two parameters. Let’s be pragmatic and just derive estimation equations assuming the drift is known.

From equation $(\star)$ and

$$\sigma(t, u) = \begin{bmatrix} \sigma_1 u_1 u_2 & 0 -\sigma_1u_1u_2& -\sigma_2 u_2 0& \sigma_2 u_2 \end{bmatrix}$$

we find

$$\frac{M^1_t}{X_t^1 X_t^2 \sqrt{\Delta t}} \sim N(0, \sigma_1^2) \quad\frac{M^3_t}{ X_t^2 \sqrt{\Delta t}} \sim N(0, \sigma_2^2)$$

where $X_t^1$, $X_t^2$, etc. denote the first, second, etc. component of $X$ and

$$Mt = X{t+\Delta} - X_t - b(t, X_t) \Delta t$$

second(x) = getindex(x, 2)
M = [X.yy[i]-X.yy[i-1]-b(X.tt[i], X.yy[i], P)*(X.tt[i]-X.tt[i-1]) for i in 2:length(X)]

σ1est = std(first.(M)./sqrt.(diff(X.tt))./first.(X.yy[1:end-1])./second.(X.yy[1:end-1]))
@show Ptrue.σ1
@show σ1est;
Ptrue.σ1 = 0.07
σ1est = 0.07015579020497378

σ2est = std(last.(M)./sqrt.(diff(X.tt))./second.(X.yy[1:end-1]))
@show Ptrue.σ2
@show σ2est;
Ptrue.σ2 = 0.4
σ2est = 0.39761699891570274


MCMC for the parameters

One of the nice features of Julia is that direct implementations of MCMC algorithm are performant and quite easy to deal with.

Because we can sample the drift parameters from the posterior, but only derived estimation equations for $\sigma_1$, $\sigma_2$ this is a sampler with one Gibbs step and one EM-type step.

iterations = 1000
P = CIR(2.1, 5.0, 0.17, 0.1)
params = zeros(4, iterations)
M = zeros(SVector{3,Float64}, length(tt)-1,)
for iter in 1:iterations

# step 1: sample conjugate parameter α and β conditional on σ1, σ2

α, β = conjugate_posterior(X, 0.1*I, P)

P = CIR(α, β, P.σ1, P.σ2)

# step 2: estimate σ1, σ2

for i in 2:length(X)
M[i-1] = (X.yy[i]-X.yy[i-1]-b(X.tt[i], X.yy[i], P)*(X.tt[i]-X.tt[i-1]))
end

σ1 = std(first.(M)[1:end]./sqrt.(diff(tt))./first.(X.yy[1:end-1])./second.(X.yy[1:end-1]))
σ2 = std(last.(M)./sqrt.(diff(X.tt))./second.(X.yy[1:end-1]))

P = CIR(P.α, P.β, σ1, σ2)

# record sample

params[:, iter] = [α, β, σ1, σ2]
end
p1 = plot(params[1, :], linewidth=0.2, label = "posterior samples")
plot!(Ptrue.α*ones(iterations), label = LaTeXString("true \\alpha"))
plot!(Bridge.runmean(params[1, :]), label = "running mean")
p2 = plot(params[2, :], linewidth=0.2, label = "posterior samples")
plot!(Ptrue.β*ones(iterations), label = LaTeXString("true \\beta"))
plot!(Bridge.runmean(params[2, :]), label = "running mean")
plot(p1,p2)

p1 = plot(params[3, :], linewidth=0.2, label = "EM samples")
plot!(Ptrue.σ1*ones(iterations), label = LaTeXString("true \\sigma_1"))
plot!(Bridge.runmean(params[3, :]), label = "running mean")
p2 = plot(params[4, :], linewidth=0.2, label = "EM samples")
plot!(Ptrue.σ2*ones(iterations), label = LaTeXString("true \\sigma_2"))
plot!(Bridge.runmean(params[4, :]), label = "running mean")
plot(p1,p2)

Reference

• van der Meulen, S., van Zanten: Reversible jump MCMC for nonparametric drift estimation for diffusion processes. Computational Statistics & Data Analysis 71, 2014, ISSN 0167-9473, https://doi.org/10.1016/j.csda.2013.03.002 .