Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is this GLM interface ok? #35

Open
JockLawrie opened this issue Oct 3, 2022 · 0 comments
Open

Is this GLM interface ok? #35

JockLawrie opened this issue Oct 3, 2022 · 0 comments

Comments

@JockLawrie
Copy link

Hi there,

I have a project that utilises 100s of models of different types, such GLMs, boosted trees, etc.
This package looks like it can provide a unified interface for training them and making predictions.
Below is my first attempt at implementing the interface for GLMs, together with a simple usage script.
Note that the predict function returns a distribution rather than a point prediction.
Is this implementation on track? Have I missed/misunderstood anything?

Cheers,
Jock

using Distributions
using GLM
using LinearAlgebra: dot
using Models

# Types
struct GLMTemplate{R <: UnivariateDistribution, L <: Link, V} <: Template
    responsedist::R
    link::L
    kwargs::Dict{Symbol, V}
end

GLMTemplate(responsedist, link) = GLMTemplate(responsedist, link, Dict{Symbol, Int}())

struct GLMModel{R <: UnivariateDistribution, L <: Link, S} <: Model
    responsedist::R
    link::L
    coef::Vector{typeof(0.0)}
    vcov::Matrix{typeof(0.0)}
    scale::S  # Nothing or Float64
end

estimate_type(::GLMTemplate) = DistributionEstimate
estimate_type(::GLMModel)    = DistributionEstimate
output_type(::GLMTemplate) = SingleOutput
output_type(::GLMModel)    = SingleOutput
predict_input_type(::GLMTemplate) = PointPredictInput
predict_input_type(::GLMModel)    = PointPredictInput

# Fit
function fit(template::GLMTemplate, y, X)
    model = glm(X, y, template.responsedist, template.link; template.kwargs...)
    constructmodel(template, model)
end

function fit(template::GLMTemplate, y, X, wts)
    model = glm(X, y, template.responsedist, template.link; wts=wts, template.kwargs...)
    constructmodel(template, model)
end

function constructmodel(template::GLMTemplate, fittedmodel)
    _coef  = GLM.coef(fittedmodel)
    _vcov  = GLM.vcov(fittedmodel)
    _scale = scaleparameter(template.responsedist, GLM.dispersion(fittedmodel))
    GLMModel(template.responsedist, template.link, _coef, _vcov, _scale)
end

scaleparameter(d, disp) = nothing
scaleparameter(d::Normal, disp) = disp
scaleparameter(d::Gamma,  disp) = 1.0 / disp

# Predict
function predict(model::GLMModel, x::AbstractVector)
    eta = dot(model.coef, x)
    mu  = GLM.linkinv(model.link, eta)
    _predict(model.responsedist, mu, model.scale)
end

function predict(model::GLMModel, X::AbstractMatrix)
    n  = size(X, 1)
    p1 = predict(model, view(X, 1, :))
    result = Vector{typeof(p1)}(undef, n)
    @inbounds result[1] = p1
    for i = 2:n
        @inbounds result[i] = predict(model, view(X, i, :))
    end
    result
end

_predict(d, mu, s::Nothing) = typeof(d)(mu)
_predict(d, mu, s) = typeof(d)(mu, s)

And a simple script:

# Data
N = 1000;
k = 3; # Number of predictors including the intercept
X = rand(N, k);
fill!(X[:, 1], 1.0);
btrue = collect(1.0:k);
strue = 1.0;
y = X*btrue .+ strue .* randn(N);
w = 1.0 .+ rand(N);
m = N / sum(w);
w .*= m;

# Template
template = GLMTemplate(Normal(), IdentityLink())

# Fit unweighted
model0 = glm(X, y, template.responsedist, template.link)
model  = fit(template, y, X)
model.coef == coef(model0)
model.vcov == vcov(model0)

# Fit weighted
model0 = glm(X, y, template.responsedist, template.link; wts=w)
model  = fit(template, y, X, w)
model.coef == coef(model0)
model.vcov == vcov(model0)

# Predict a single observation
Xnew = X[1, :]
predict(model,  Xnew)

# Predict several observations
Xnew = X[1:5, :]
GLM.predict(model0, Xnew)
predict(model,  Xnew)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant