using Flux, ForwardDiff, Distributions, Plots, StatsPlots, Random, Zygote, LinearAlgebra, ChainRules, ChainRulesCore
Maximum Likelihood Estimation via Maximum-Likelihood
Simple latent variable model
\[X\sim\mathcal{N}(\mu_h,\sigma_h^2)\] \[Y\sim\mathcal{N}(\exp(\alpha_o\cdot X),\sigma_o^2)\]
struct LatentModel
mu_h
sigma_h
alpha_o
sigma_oend
@functor LatentModel
Flux.
LatentModel() = LatentModel(zeros(1,1),zeros(1,1),ones(1,1),zeros(1,1))
function Base.rand(m::LatentModel, N::Int)
= m.mu_h[1]
mu_h = exp(m.sigma_h[1])
sigma_h
= m.alpha_o[1]
alpha_o = exp(m.sigma_o[1])
sigma_o
= randn(N) .* sigma_h .+ mu_h
X = randn(N) .* sigma_o .+ exp.(alpha_o.*X)
Y
return Y
end
Base.rand(m::LatentModel) = rand(m,1)[1]
Specify model for a test case
\[X\sim\mathcal{N}(1,0.25)\] \[Y\sim\mathcal{N}(\exp(0.75\cdot X),0.25)\]
Random.seed!(123)
= LatentModel([1.0], [log(0.5)], [0.75], [log(0.5)])
true_model
= rand(true_model,50000) #to plot the density
Yfull = Yfull[1:150]
Y
histogram(Y,bins=20,normalize=true,alpha=0.5,label="Data",fmt=:png)
density!(Yfull,c=:red,lw=2, label="True density (apprx.)")
True density
\[p_\theta(y)=\int p_\theta(y|x)\frac{p_\theta(x)}{q(x)} q(x)dx\]
Approximated density
\[\hat{p}_\theta(y)=\frac{1}{M}\sum_{j=1}^M p_\theta(y|x_j)\frac{p_\theta(x_j)}{q(x_j)}\]
with \(x_j\) the proposal sample, drawn from \(q(x)\) with sample size \(M\).
Here:
\[q(x)=\mathcal{N}(x|0,4)\]
function particle_ll(m::LatentModel, y, M=1000)
= length(y)
N
= Normal(0,2) #q(x))
qdist = Normal(m.mu_h[1],exp(m.sigma_h[1])) #p(x)
pdist
= map(_->rand(qdist,M), 1:N)
ps
#one particle sample (1:M) per observation (1:N)
= map(i->Normal.(exp.(m.alpha_o[1].*ps[i]),exp(m.sigma_o[1])),1:N)
odists
#p(y_i) = 1/M sum_j^M[p(y_i|x_j)p(x_j)/q(x_j)] for i=1:N
= map(i->mean(map(od->pdf(od,Y[i]),odists[i]).*pdf.(pdist,ps[i])./pdf.(qdist,ps[i])),1:N)
ws
#1/N sum_i^N log(p(y_i)) (=avearage log-likelihood)
return mean(log.(ws))
end
particle_ll (generic function with 2 methods)
= LatentModel()
m = Flux.destructure(m)
pars, f
= rand(m,50000)
Yprefit
histogram(Y,bins=20,normalize=true,alpha=0.5,label="Data",xlim=(-1,9),fmt=:png)
density!(Yfull,c=:red,lw=2, label="True density (apprx.)")
density!(Yprefit, c=:green,lw=2, label="Model density before fit")
Random.seed!(123)
for i in 1:250
= []
gg
for i in 1:10
= ForwardDiff.gradient(p->-particle_ll(f(p),Y), pars)
g push!(gg,g)
end
= mean(gg)
grads
.-=0.025.*grads
pars
if i%25 ==0
println(particle_ll(f(pars),Y))
end
end
-1.698369499604494
-1.5859668597849523
-1.4757647900932243
-1.3924143414512473
-1.3310941879731057
-1.2977140375800214
-1.2845731201786137
-1.2669498497241503
-1.2622906444844957
-1.2637100740914606
= rand(f(pars),50000)
Ypostfit
histogram(Y,bins=20,normalize=true,alpha=0.5,label="Data",xlim=(-1,9),fmt=:png)
density!(Yfull,c=:red,lw=2, label="True density (apprx.)")
density!(Yprefit, c=:green,lw=2, label="Model density before fit")
density!(Ypostfit, c=:blue,lw=2, label="Fitted density")
Stochastic volatility
\[X_t\sim\mathcal{N}(\alpha_h\cdot X_{t-1},\sigma_h^2);\quad -1<\alpha<1\] \[Y_t\sim\mathcal{N}(0,exp(X_t/4)^2)\]
\[X_0=0\] (could also be fitted/trained)
tanh(0.5)
0.46211715726000974
struct StochasticVolatilityModel
alpha_h
sigma_hend
@functor StochasticVolatilityModel
Flux.
StochasticVolatilityModel() = StochasticVolatilityModel(zeros(1,1).+atanh(0.5),zeros(1,1))
function Base.rand(m::StochasticVolatilityModel, T::Int, X_0=0.0)
= tanh(m.alpha_h[1])
alpha_h = exp(m.sigma_h[1])
sigma_h
= [X_0]
X = []
Y
for t in 1:T
= randn() * sigma_h + alpha_h*X[end]
X_t = randn() * exp(X_t/4)
Y_t
push!(X,X_t)
push!(Y,Y_t)
end
return X[2:end],Y
end
Base.rand(m::StochasticVolatilityModel) = rand(m,1)[1]
Specify model for a test case
\[X_t\sim\mathcal{N}(0.9\cdot X_{t-1},0.1^2);\quad -1<\alpha<1\] \[Y_t\sim\mathcal{N}(0,exp(X_t)^2)\]
Random.seed!(123)
= StochasticVolatilityModel(atanh(0.95),0.1)
m
= rand(m,150)
X,Y
plot(Y,fmt=:png)
function particle_filter(m::StochasticVolatilityModel, y, M=1000)
= length(y)
T
= Normal(0,3) #q_0(x)
q0dist
= rand(q0dist,(M,1))
ps = [ones(M)./M]
ws
for t in 1:T
= Normal.(tanh(m.alpha_h[1]).*ps[:,t],exp(m.sigma_h[1]))
qdists = rand.(qdists)
ps_t = hcat(ps,ps_t[:,:])
ps = Normal.(0.0,exp.(ps_t./4))
odists
= pdf.(odists,y[t])
w_t
= w_t./sum(w_t)
w_t
= rand(Categorical(w_t),M)
a_t = ps[a_t,:]
ps end
return ps[:,2:end]
end
function particle_filter_ll(m::StochasticVolatilityModel, y, M=1000)
= length(y)
T
= Normal(0,3) #q_0(x)
q0dist
= rand(q0dist,(M,1))
ps = [ones(M)./M]
ws
for t in 1:T
= Normal.(tanh(m.alpha_h[1]).*ps[:,t],exp(m.sigma_h[1]))
qdists = rand.(qdists)
ps_t = hcat(ps,ps_t[:,:])
ps = Normal.(0.0,exp.(ps_t./4))
odists
= pdf.(odists,y[t])
w_t
= w_t./sum(w_t)
w_t
= rand(Categorical(w_t),M)
a_t = ps[a_t,:]
ps end
return mean(log.(mean(pdf.(Normal.(0.0,exp.(ps[:,2:end]./4)),transpose(y)),dims=1)))
end
particle_filter_ll (generic function with 2 methods)
Random.seed!(123)
= particle_filter(m,Y)
ps_true
= mean(exp.(ps_true./4),dims=1)[:]
filter_mean_true
plot(exp.(X./4),label="True Volatility",lw=2,fmt=:png)
plot!(filter_mean_true, label="Filter Mean True Model",lw=2)
= particle_filter(StochasticVolatilityModel(),Y)
ps_initial
= mean(exp.(ps_initial./4),dims=1)[:]
filter_mean_initial
plot!(filter_mean_initial, label="Filter Mean Initial Model",lw=2)
println(mean((X.-filter_mean_true).^2))
2.777623212726044
println(mean((X.-filter_mean_initial).^2))
4.1040430904059315
using FiniteDifferences
= Flux.destructure(StochasticVolatilityModel()) pars, f
([0.5493061443340549, 0.0], Restructure(StochasticVolatilityModel, ..., 2))
Random.seed!(123)
for _ in 1:50
= FiniteDifferences.grad(central_fdm(15,1),p->-mean([particle_filter_ll(f(p),Y,100) for _ in 1:5]),pars)[1]
gs .-= 0.025.*gs
pars end
Random.seed!(123)
= particle_filter(f(pars),Y)
ps
= mean(exp.(ps./4),dims=1)[:]
filter_mean_fitted
plot(exp.(X./4),label="True Volatility",lw=2,fmt=:png)
plot!(filter_mean_true, label="Filter Mean True Model",lw=2)
plot!(filter_mean_fitted, label="Filter Mean Fitted Model",lw=2)
println(mean((X.-filter_mean_fitted).^2)) #much better than the initial model
3.174579327614477
#could probably be improved with longer training duration pars
2-element Vector{Float64}:
0.7612290016866907
0.4596682029518793