Scaled forward-backward algorithm + AutoDiff for optimizing the observation distributions

using Flux, Distributions, Zygote, ForwardDiff, Plots, StatsPlots, LinearAlgebra, Random
mutable struct HMM
    μ
    σ
    
    P
end
Flux.@functor HMM (μ, σ)

function HMM(states)
    μ = collect(range(-1,1, length=states))
    σ = zeros(states)
    
    P = softmax(Matrix(Diagonal(ones(states))),dims=2)
    
    return HMM(μ,σ,P)
end


function filter(m::HMM, y, p_t)
    
    μ = m.μ
    σ = exp.(m.σ)
    P = m.P
    
    y_t = y[1]
    
    dists_t = Normal.(μ,σ)
    pdfs = pdf.(dists_t, y_t)
    
    sumdist = (p_t.*pdfs)
    p_tt = sumdist./sum(sumdist)
    
    p_tp1 = P*p_tt
    
    if length(y)>1
        dists_tp1, p_t, p_ttp1 = filter(m,y[2:end],p_tp1)
        return vcat(dists_t, dists_tp1), hcat(p_t,p_tp1), hcat(p_tt, p_ttp1)
    else
        return dists_t, p_t, p_tt
    end
end

function forward_normalized(m::HMM, y, α_tm1)
    #https://github.com/mattjj/pyhsmm/blob/e6cfde5acb98401c2e727ca59a49ee0bfe86cf9d/pyhsmm/internals/hmm_states.py#L322
    
    μ = m.μ
    σ = exp.(m.σ)
    P = m.P
    
    y_t = y[1]
    qsum = P*α_tm1
    dists = Normal.(μ,σ)
    lpdfs = logpdf.(dists,y_t)
    
    lpdf_max = maximum(lpdfs)
 
    α_t = qsum[:] .* exp.(lpdfs .- lpdf_max)
    normalizer = sum(α_t)
    
    α_t_normed = α_t ./ normalizer
    logtot_t = log(normalizer) + lpdf_max
    
    if length(y)>1
        α_tp1_normed, logtot_tp1 = forward_normalized(m,y[2:end],α_t_normed)
        return hcat(α_t_normed, α_tp1_normed), logtot_t + logtot_tp1
    else
        return α_t_normed, logtot_t
    end
end


function backward_normalized(m::HMM, y, β_tp1)
    #https://github.com/mattjj/pyhsmm/blob/e6cfde5acb98401c2e727ca59a49ee0bfe86cf9d/pyhsmm/internals/hmm_states.py#L295
    
    μ = m.μ
    σ = exp.(m.σ)
    P = m.P
    
    y_t = y[end]
    
    dists = Normal.(μ,σ)
    lpdfs = logpdf.(dists,y_t)
    
    lpdf_max = maximum(lpdfs)
    
    β_t = transpose(P)*(β_tp1.*exp.(lpdfs.-lpdf_max))[:]
    normalizer = sum(β_t)
    
    β_t_normed = β_t./normalizer
    logtot_t = log(normalizer) + lpdf_max
    
    if length(y)>1
        β_tm1_normed, logtot_tm1 = backward_normalized(m, y[1:end-1], β_t_normed)
        return hcat(β_tm1_normed, β_t_normed), logtot_tm1 + logtot_t
    else
        return β_t_normed, logtot_t
    end
end



function likelihood(m::HMM,y,sps)
    μ = m.μ
    σ = exp.(m.σ)
    
    dists = Normal.(μ,σ)
    
    return mean(map(i->sum(sps[i].*logpdf.(dists,y[i])),1:length(y)))
end


function EM(m::HMM, y, p_0, n_iter = 50)
    
    for i in 1:n_iter
        α, logtot_α = forward_normalized(m,y,p_0)
        β, logtot_β = backward_normalized(m,y,p_0)
        
        αβ = α.*β        
        γ = αβ./sum(αβ,dims=1)
        
        
        sps = Flux.unstack(γ,dims=2)

        ps, f = Flux.destructure(m)
        
        for _ in 1:50
            grads = ForwardDiff.gradient(x -> -likelihood(f(x),y,sps), ps)
            ps .-= 0.001.*grads
        end
        
        newm = f(ps)
        m.μ = newm.μ
        m.σ = newm.σ
        
        Ps = Matrix(transpose(hcat([sum(γ[i:i, 1:end-1].*γ[:, 2:end],dims=2)[:] for i in 1:length(p_0)]...)))
        Ps./=sum(Ps,dims=1)
        
        m.P = Ps
        
        if i%50==0
            println(-likelihood(m,y,sps))
        end
    end
end
EM (generic function with 2 methods)
Random.seed!(321)
y = vcat([vcat(0.5 .*randn(25).+3, randn(25), 0.5 .*randn(25).-3) for _ in 1:10]...)

m = HMM(3)
ps, f = Flux.destructure(m)
([-1.0, 0.0, 1.0, 0.0, 0.0, 0.0], Restructure(HMM, ..., 6))
α, logtot_α = forward_normalized(m,y,ones(3)./3)
β, logtot_β = backward_normalized(m,y,ones(3)./3)

αβ = α.*β        
state_probs = αβ./sum(αβ,dims=1)


mean_pred = sum(m.μ .* state_probs,dims=1)[:]
std_pred = sqrt.(sum(exp.(m.σ) .* state_probs,dims=1)[:])

p1 = scatter(collect(1:length(y)), y, title = "Smoothing distribution before training", label = "Data",fmt=:png,
    c="blue", legend=:bottomleft, size = (1200,600))
plot!(p1, mean_pred, label = "Predicted mean + 2 stddevs", c="red", ribbon = 2 .* std_pred)
EM(m,y,ones(3)./3,750)
α, logtot_α = forward_normalized(m,y,ones(3)./3)
β, logtot_β = backward_normalized(m,y,ones(3)./3)

αβ = α.*β        
state_probs = αβ./sum(αβ,dims=1)


mean_pred = sum(m.μ .* state_probs,dims=1)[:]
std_pred = sqrt.(sum(exp.(m.σ) .* state_probs,dims=1)[:])

p2 = scatter(collect(1:length(y)), y, title = "Smoothing distribution after training", label = "Data",fmt=:png,
    c="blue", legend=:bottomleft, size = (1200,600))
plot!(p2, mean_pred, label = "Predicted mean + 2 stddevs", c="red", ribbon = 2 .* std_pred)
plot(p1,p2)