Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variational Inference #24

Open
brandonwillard opened this issue Oct 1, 2020 · 0 comments
Open

Variational Inference #24

brandonwillard opened this issue Oct 1, 2020 · 0 comments
Labels
enhancement New feature or request

Comments

@brandonwillard
Copy link
Contributor

brandonwillard commented Oct 1, 2020

Currently, our HMM models are not compatible with the variational inference (VI) implementation in PyMC3, because they use discrete variables (e.g. the state sequences). If we marginalize the state sequences (i.e. integrate them out), we could have a version of our models that is compatible with VI. That marginalized form could consist of the forward-backward probabilities used in tandem with a weight-based mixture Distribution—like the built-in Mixture class.

The following is an example of such a marginalized model, but this one only uses the forward-pass probabilities:

X_tt = shared(X_df.values, name="X", borrow=True)
y_tt = shared(y_df.values.squeeze(), name="y_t", borrow=True)
n_tt = shared(n_df.values, name='n_t', borrow=True)

# Construct an initial beta value obeying
# n_df * sp.special.expit(X_df.dot(beta_mu_0)) == np.mean(y_df[y_df > 0]).squeeze()
beta_mu_0 = np.zeros(X_df.shape[1])
beta_mu_0[0] = sp.special.logit(np.mean(y_df[y_df > 0]).squeeze() / n_df.mean())

alpha_g, beta_g = 1.0, 1.0
a_xi_0, b_xi_0 = 1.0, 1.0
a_xi_1, b_xi_1 = 1.0, 1.0

with pm.Model() as poisson_model:

    # Horseshoe with uniform global scale
    tau_rv = pm.Uniform('tau', 1./X_tt.get_value().shape[0], 1.)
    lambda_rv = pm.HalfCauchy('lambda', tau_rv, shape=(X_tt.get_value().shape[1], ))
    beta_rv = pm.Normal('beta',
                        mu=beta_mu_0, tau=1.0/lambda_rv,
                        shape=(X_df.shape[1], ))

    xi_0_rv = pm.Dirichlet('xi_0', np.r_[a_xi_0, b_xi_0])
    xi_1_rv = pm.Dirichlet('xi_1', np.r_[a_xi_1, b_xi_1])

    Gamma_tt = tt.stack([xi_0_rv, xi_1_rv])
    Gamma_rv = pm.Deterministic('Gamma', Gamma_tt)

    # gamma_0_rv = pm.Dirichlet('pi_0', np.r_[alpha_g, beta_g])
    gamma_0_tt = compute_steady_state(Gamma_tt)

    p_tt = tt.nnet.sigmoid(X_tt.dot(beta_rv))
    p_rv = pm.Deterministic('p_t', p_tt)

    log_lik_t_tt = tt.stack([pm.Constant.dist(0).logp(y_tt),
                             pm.Poisson.dist(n_tt * p_tt).logp(y_tt)]).T

    # Compute forward probabilities
    def log_alpha_step(y_t, log_lik_t, log_alpha_tm1, log_Gamma_t):
        log_alpha_t = logdotexp(log_Gamma_t, log_alpha_tm1)
        log_alpha_t_normed = log_lik_t + log_alpha_t
        log_alpha_t_normed = log_alpha_t_normed - pm.math.logsumexp(log_alpha_t_normed)
        return log_alpha_t_normed

    log_alpha_0_tt = tt.log(gamma_0_tt) + log_lik_t_tt[0]
    log_alpha_0_tt = log_alpha_0_tt - pm.math.logsumexp(log_alpha_0_tt)
    log_alpha_t_tt, _ = theano.scan(fn=log_alpha_step,
                                    sequences=[y_tt, log_lik_t_tt],
                                    non_sequences=[tt.log(Gamma_tt)],
                                    outputs_info=[{"initial": log_alpha_0_tt, "taps": [-1]}],
                                    strict=True,
                                    name='log_alpha_t')

    alpha_t_tt = tt.exp(log_alpha_t_tt)

    V_0_rv = pm.Constant.dist(0, shape=y_tt.get_value().shape[0])
    V_1_rv = pm.Poisson.dist(n_tt * p_tt, shape=y_tt.get_value().shape[0])

    Y_rv = pm.Mixture('Y_t', w=alpha_t_tt,
                      comp_dists=[V_0_rv, V_1_rv],
                      observed=y_tt)

The essential difference is—of course—the alpha_t_tt computation and its use as the weights parameter in Mixture. The alpha_* correspond to the similarly named terms in the classical Baum-Welch algorithm. Use of the marginal state probabilities (i.e. the Baum-Welch gamma values) is more appropriate, though.

@brandonwillard brandonwillard added the enhancement New feature or request label Oct 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant