ConjugateComputationVI
This is an implementation of [1].
It utilises the AbstractGPs.jl interface, so should play nicely with any AbstractGP
, including those from Stheno.jl and TemporalGPs.jl.
No attempt has been made to make this implementation work for anything other than Gaussian processes.
Example Usage
Approximate inference and learning in a GP under an Exponential likelihood.
This is primarily handled using the build_latent_gp
function, which produces a LatentGP
specifying this model when provided with kernel parameters.
ParameterHandling.jl is used to handle
the book-keeping associated with the model parameters.
using AbstractGPs
using ConjugateComputationVI
using Distributions
using Optim
using ParameterHandling
using Plots
using Random
using RDatasets
using StatsFuns
using Zygote
using ConjugateComputationVI: GaussHermiteQuadrature, UnivariateFactorisedLikelihood
# Specify the model parameters.
θ_init = (scale=positive(1.9), stretch=positive(0.8));
θ_init_flat, unflatten = ParameterHandling.flatten(θ_init);
# Specify the model.
# A core requirement of this package is that you are able to provide a function mapping
# from your model parameters to a `LatentGP`.
function build_latent_gp(θ::AbstractVector{<:Real})
return build_latent_gp(ParameterHandling.value(unflatten(θ)))
end
function build_latent_gp(θ::NamedTuple)
gp = GP(θ.scale * SEKernel() ∘ ScaleTransform(θ.stretch))
lik = UnivariateFactorisedLikelihood(f -> Exponential(exp(f)))
return LatentGP(gp, lik, 1e-9)
end
# Specify inputs and generate some synthetic outputs.
x = range(-5.0, 5.0; length=100);
y = rand(build_latent_gp(θ_init_flat)(x)).y;
# Attempt to recover the kernel parameters used when generating the data.
# Add some noise to the initialisation to make this more interesting.
# We specify that the reconstruction term in the ELBO is to be approximated using
# Gauss-Hermite quadrature with 10 points.
f_approx_post, results_summary = ConjugateComputationVI.optimize_elbo(
build_latent_gp,
GaussHermiteQuadrature(10),
x,
y,
θ_init_flat + randn(length(θ_init_flat)),
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
linesearch = Optim.LineSearches.BackTracking(),
),
Optim.Options(
show_trace = true,
iterations=25,
f_calls_limit=50,
),
);
# Compute approx. posterior CIs using Monte Carlo.
function approx_post_95_CI(x::AbstractVector, N::Int)
samples = map(marginals(f_approx_post(x, 1e-6))) do latent_marginal
f = rand(latent_marginal, N)
return rand.(Exponential.(exp.(f)))
end
return quantile.(samples, Ref((0.025, 0.5, 0.975)))
end
x_pr = range(-6.0, 6.0; length=250);
qs = approx_post_95_CI(x_pr, 10_000);
# Plot the predictions.
p1 = plot(
x_pr, getindex.(qs, 1);
linealpha=0,
fillrange=getindex.(qs, 3),
label="95% CI",
fillalpha=0.3,
);
scatter!(p1, x, y; markersize=2, label="Observations");
p2 = plot(
f_approx_post(x_pr, 1e-6);
ribbon_scale=3, color=:blue, label="approx posterior latent",
);
sampleplot!(f_approx_post(x_pr, 1e-6), 10; color=:blue);
plot(p1, p2; layout=(2, 1))
See the examples directory for more.
Limitations
This approximation does not presently play nicely with pseudo-point approximations. It could be extended to do so, in line with [2].
References
[1] - Khan, Mohammad, and Wu Lin. "Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models." Artificial Intelligence and Statistics. PMLR, 2017.
[2] - Adam, Vincent et al. "Dual parameterization of sparse variational gaussian processes." NeurIPS, 2021. link