usingFlux, Distributions, Plots, StatsPlots, Zygote, ForwardDiff, LinearAlgebra, Random
#Durbin & Koopman - Time Series Analysis by State Space Methods, p.82 ff.struct StateSpaceModel Z Hl T R QlendFlux.@functor StateSpaceModelfunctionkalman_filter(m::StateSpaceModel, y, a_t, P_t) y_t = y[1] Z = m.Z H =transpose(m.Hl)*m.Hl T = m.T R = m.R Q =transpose(m.Ql)*m.Ql m_t = Z*a_t #observation mean given data until t-1 F_t =Z*P_t*transpose(Z) .+ H #observation cov given data until t-1 a_tt = a_t .+P_t*transpose(Z)*inv(F_t)*(y_t .- m_t) P_tt = P_t .-P_t*transpose(Z)*inv(F_t)*Z*P_t a_tp1 = T*a_tt P_tp1 =T*P_tt*transpose(T).+R*Q*transpose(R) dist_t =MvNormal(m_t[:],F_t)iflength(y)>1 dist_tp1, a_tp2, P_tp2 =kalman_filter(m,y[2:end],a_tp1,P_tp1)returnvcat(dist_t,dist_tp1), hcat(a_tp1,a_tp2), vcat([P_tp1],P_tp2)elsereturn dist_t, a_tp1, [P_tp1]endend#Durbin & Koopman - Time Series Analysis by State Space Methods, p.54 ff.struct ARMA r p q Φ Θ σ a_0 P_0lendFlux.@functor ARMA (Φ,Θ,σ)functionARMA(p,q) r =max(p,q+1) Φ =ones(r)./(2*max(p,1)) .*get_phi_zeros(r,p) Θ =ones(r-1)./(2*max(q,1)) .*get_theta_zeros(r,q) σ =ones(1,1) a_0 =zeros(r,1) P_0l =Matrix(Diagonal(ones(r)))returnARMA(r,p,q, Φ, Θ, σ, a_0, P_0l)endfunctionget_phi_zeros(r,p) out =zeros(r) out[1:p] .=1.return outendZygote.@nograd get_phi_zerosfunctionget_theta_zeros(r,q) out =zeros(r-1) out[1:q] .=1.return outendZygote.@nograd get_theta_zerosfunctionto_state_space(m::ARMA) r = m.r Φ = m.Φ .*get_phi_zeros(r,m.p) Θ = m.Θ .*get_theta_zeros(r,m.q) σ = m.σ Z =hcat(1.0, zeros(1,r-1)) Hl =zeros(1,1) T_right_slice =vcat(Matrix(Diagonal(ones(r-1))), zeros(1,r-1)) T =hcat(Φ,T_right_slice) R =vcat(1.0, Θ) Ql =ones(1,1).*σreturnStateSpaceModel(Z,Hl,T,R,Ql)endfunctionkalman_filter(m::ARMA, y) ys = Flux.unstack(y,1) sp =to_state_space(m) a_0 = m.a_0 P_0 =transpose(m.P_0l)*m.P_0lreturnkalman_filter(sp,ys,a_0,P_0)endfunctionllikelihood(m::ARMA, y) dists, _, _ =kalman_filter(m,y)returnmean(map(i->logpdf(dists[i],[y[i]]),1:length(y)))end
llikelihood (generic function with 1 method)
Random.seed!(321)data = [0.0,0.0]for t in1:200push!(data,-0.3*data[end]+0.4*data[end-1]+randn())enddata = data[3:end]plot(data,legend=:none, fmt=:png)