Skip to content

Commit fa4723f

Browse files
authored
Merge pull request #324 from calebweinreb/lgssm_parallel_inference_with_bias
Refactor of LGSSM parallel inference
2 parents d7f283e + effaba3 commit fa4723f

File tree

3 files changed

+312
-140
lines changed

3 files changed

+312
-140
lines changed

dynamax/linear_gaussian_ssm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717

1818
from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter
1919
from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother
20+
from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample
2021

2122
from dynamax.linear_gaussian_ssm.models import LinearGaussianConjugateSSM, LinearGaussianSSM
Lines changed: 212 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,140 @@
1-
# Parallel filtering and smoothing for a lgssm.
2-
# This implementation is adapted from the work of Adrien Correnflos in,
3-
# https://github.com/EEA-sensors/sequential-parallelization-examples/
1+
'''
2+
Parallel filtering and smoothing for a lgssm.
3+
4+
This implementation is adapted from the work of Adrien Correnflos:
5+
https://github.com/EEA-sensors/sequential-parallelization-examples/
6+
7+
Note that in the original implementation, the initial state distribution
8+
applies to t=0, and the first emission occurs at time `t=1` (i.e. after
9+
the initial state has been transformed by the dynamics), whereas here,
10+
the first emission occurs at time `t=0` and is produced directly by the
11+
untransformed initial state (see below).
12+
13+
Sarkka et al.
14+
15+
F₀,Q₀ F₁,Q₁ F₂,Q₂
16+
Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
17+
| | |
18+
| H₁,R₁ | H₂,R₂ | H₃,R₃
19+
| | |
20+
Y₁ Y₂ Y₃
21+
22+
Dynamax
23+
24+
F₀,Q₀ F₁,Q₁ F₂,Q₂
25+
Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
26+
| | | |
27+
| H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃
28+
| | | |
29+
Y₀ Y₁ Y₂ Y₃
30+
31+
'''
32+
433
import jax.numpy as jnp
5-
import jax.scipy as jsc
634
from jax import vmap, lax
735
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
836
from jaxtyping import Array, Float
37+
from typing import NamedTuple
38+
from dynamax.types import PRNGKey
39+
from functools import partial
940

10-
from dynamax.utils.utils import psd_solve
41+
from jax.scipy.linalg import cho_solve, cho_factor
42+
from dynamax.utils.utils import symmetrize
1143
from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM
1244

45+
1346
def _get_params(x, dim, t):
1447
if callable(x):
1548
return x(t)
1649
elif x.ndim == dim + 1:
1750
return x[t]
1851
else:
1952
return x
53+
54+
#---------------------------------------------------------------------------#
55+
# Filtering #
56+
#---------------------------------------------------------------------------#
2057

21-
def _make_associative_filtering_elements(params, emissions):
58+
class FilterMessage(NamedTuple):
59+
"""
60+
Filtering associative scan elements.
61+
62+
Attributes:
63+
A: P(z_j | y_{i:j}, z_{i-1}) weights.
64+
b: P(z_j | y_{i:j}, z_{i-1}) bias.
65+
C: P(z_j | y_{i:j}, z_{i-1}) covariance.
66+
J: P(z_{i-1} | y_{i:j}) covariance.
67+
eta: P(z_{i-1} | y_{i:j}) mean.
68+
"""
69+
A: Float[Array, "ntime state_dim state_dim"]
70+
b: Float[Array, "ntime state_dim"]
71+
C: Float[Array, "ntime state_dim state_dim"]
72+
J: Float[Array, "ntime state_dim state_dim"]
73+
eta: Float[Array, "ntime state_dim"]
74+
logZ: Float[Array, "ntime"]
75+
76+
77+
def _initialize_filtering_messages(params, emissions):
2278
"""Preprocess observations to construct input for filtering assocative scan."""
2379

24-
def _first_filtering_element(params, y):
25-
F = _get_params(params.dynamics.weights, 2, 0)
80+
def _first_message(params, y):
2681
H = _get_params(params.emissions.weights, 2, 0)
27-
Q = _get_params(params.dynamics.cov, 2, 0)
2882
R = _get_params(params.emissions.cov, 2, 0)
83+
d = _get_params(params.emissions.bias, 1, 0)
84+
m = params.initial.mean
85+
P = params.initial.cov
2986

30-
S = H @ Q @ H.T + R
31-
CF, low = jsc.linalg.cho_factor(S)
87+
S = H @ P @ H.T + R
88+
CF, low = cho_factor(S)
89+
K = cho_solve((CF, low), H @ P).T
3290

33-
m1 = params.initial.mean
34-
P1 = params.initial.cov
35-
S1 = H @ P1 @ H.T + R
36-
K1 = psd_solve(S1, H @ P1).T
37-
38-
A = jnp.zeros_like(F)
39-
b = m1 + K1 @ (y - H @ m1)
40-
C = P1 - K1 @ S1 @ K1.T
41-
42-
eta = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), y)
43-
J = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), H @ F)
44-
45-
logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=H @ P1 @ H.T + R).log_prob(y)
91+
A = jnp.zeros_like(P)
92+
b = m + K @ (y - H @ m - d)
93+
C = symmetrize(P - K @ S @ K.T)
94+
eta = jnp.zeros_like(b)
95+
J = jnp.eye(len(b))
4696

97+
logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=H @ P @ H.T + R).log_prob(y)
4798
return A, b, C, J, eta, logZ
4899

49100

50-
def _generic_filtering_element(params, y, t):
101+
@partial(vmap, in_axes=(None, 0, 0))
102+
def _generic_message(params, y, t):
51103
F = _get_params(params.dynamics.weights, 2, t)
52-
H = _get_params(params.emissions.weights, 2, t+1)
53104
Q = _get_params(params.dynamics.cov, 2, t)
105+
b = _get_params(params.dynamics.bias, 1, t)
106+
H = _get_params(params.emissions.weights, 2, t+1)
54107
R = _get_params(params.emissions.cov, 2, t+1)
108+
d = _get_params(params.emissions.bias, 1, t+1)
55109

56110
S = H @ Q @ H.T + R
57-
CF, low = jsc.linalg.cho_factor(S)
58-
K = jsc.linalg.cho_solve((CF, low), H @ Q).T
59-
A = F - K @ H @ F
60-
b = K @ y
61-
C = Q - K @ H @ Q
111+
CF, low = cho_factor(S)
112+
K = cho_solve((CF, low), H @ Q).T
62113

63-
eta = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), y)
64-
J = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), H @ F)
114+
eta = F.T @ H.T @ cho_solve((CF, low), y - H @ b - d)
115+
J = symmetrize(F.T @ H.T @ cho_solve((CF, low), H @ F))
65116

66-
logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=S).log_prob(y)
117+
A = F - K @ H @ F
118+
b = b + K @ (y - H @ b - d)
119+
C = symmetrize(Q - K @ H @ Q)
67120

121+
logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=S).log_prob(y)
68122
return A, b, C, J, eta, logZ
69123

70-
first_elems = _first_filtering_element(params, emissions[0])
71-
generic_elems = vmap(_generic_filtering_element, (None, 0, 0))(params, emissions[1:], jnp.arange(len(emissions)-1))
72-
combined_elems = tuple(jnp.concatenate((first_elm[None,...], gen_elm))
73-
for first_elm, gen_elm in zip(first_elems, generic_elems))
74-
return combined_elems
124+
125+
A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0])
126+
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1))
127+
128+
return FilterMessage(
129+
A=jnp.concatenate([A0[None], At]),
130+
b=jnp.concatenate([b0[None], bt]),
131+
C=jnp.concatenate([C0[None], Ct]),
132+
J=jnp.concatenate([J0[None], Jt]),
133+
eta=jnp.concatenate([eta0[None], etat]),
134+
logZ=jnp.concatenate([logZ0[None], logZt])
135+
)
136+
137+
75138

76139
def lgssm_filter(
77140
params: ParamsLGSSM,
@@ -83,71 +146,81 @@ def lgssm_filter(
83146
84147
Note: This function does not yet handle `inputs` to the system.
85148
"""
86-
#TODO: Add input handling.
87-
initial_elements = _make_associative_filtering_elements(params, emissions)
88-
89149
@vmap
90-
def filtering_operator(elem1, elem2):
150+
def _operator(elem1, elem2):
91151
A1, b1, C1, J1, eta1, logZ1 = elem1
92152
A2, b2, C2, J2, eta2, logZ2 = elem2
93-
dim = A1.shape[0]
94-
I = jnp.eye(dim)
153+
I = jnp.eye(A1.shape[0])
95154

96155
I_C1J2 = I + C1 @ J2
97-
temp = jsc.linalg.solve(I_C1J2.T, A2.T).T
156+
temp = jnp.linalg.solve(I_C1J2.T, A2.T).T
98157
A = temp @ A1
99158
b = temp @ (b1 + C1 @ eta2) + b2
100-
C = temp @ C1 @ A2.T + C2
159+
C = symmetrize(temp @ C1 @ A2.T + C2)
101160

102161
I_J2C1 = I + J2 @ C1
103-
temp = jsc.linalg.solve(I_J2C1.T, A1).T
104-
162+
temp = jnp.linalg.solve(I_J2C1.T, A1).T
105163
eta = temp @ (eta2 - J2 @ b1) + eta1
106-
J = temp @ J2 @ A1 + J1
107-
108-
# mu = jsc.linalg.solve(J2, eta2)
109-
# t2 = - eta2 @ mu + (b1 - mu) @ jsc.linalg.solve(I_J2C1, (J2 @ b1 - eta2))
164+
J = symmetrize(temp @ J2 @ A1 + J1)
110165

111166
mu = jnp.linalg.solve(C1, b1)
112167
t1 = (b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1))
113-
114168
logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1)
169+
return FilterMessage(A, b, C, J, eta, logZ)
170+
171+
initial_messages = _initialize_filtering_messages(params, emissions)
172+
final_messages = lax.associative_scan(_operator, initial_messages)
173+
174+
return PosteriorGSSMFiltered(
175+
filtered_means=final_messages.b,
176+
filtered_covariances=final_messages.C,
177+
marginal_loglik=-final_messages.logZ[-1])
115178

116-
return A, b, C, J, eta, logZ
117179

118-
_, filtered_means, filtered_covs, _, _, logZ = lax.associative_scan(
119-
filtering_operator, initial_elements
120-
)
180+
#---------------------------------------------------------------------------#
181+
# Smoothing #
182+
#---------------------------------------------------------------------------#
121183

122-
return PosteriorGSSMFiltered(marginal_loglik=-logZ[-1],
123-
filtered_means=filtered_means, filtered_covariances=filtered_covs)
184+
class SmoothMessage(NamedTuple):
185+
"""
186+
Smoothing associative scan elements.
124187
188+
Attributes:
189+
E: P(z_i | y_{1:j}, z_{j+1}) weights.
190+
g: P(z_i | y_{1:j}, z_{j+1}) bias.
191+
L: P(z_i | y_{1:j}, z_{j+1}) covariance.
192+
"""
193+
E: Float[Array, "ntime state_dim state_dim"]
194+
g: Float[Array, "ntime state_dim"]
195+
L: Float[Array, "ntime state_dim state_dim"]
125196

126197

127-
def _make_associative_smoothing_elements(params, filtered_means, filtered_covariances):
198+
def _initialize_smoothing_messages(params, filtered_means, filtered_covariances):
128199
"""Preprocess filtering output to construct input for smoothing assocative scan."""
129200

130-
def _last_smoothing_element(m, P):
201+
def _last_message(m, P):
131202
return jnp.zeros_like(P), m, P
132203

133-
def _generic_smoothing_element(params, m, P, t):
204+
@partial(vmap, in_axes=(None, 0, 0, 0))
205+
def _generic_message(params, m, P, t):
134206
F = _get_params(params.dynamics.weights, 2, t)
135207
Q = _get_params(params.dynamics.cov, 2, t)
208+
b = _get_params(params.dynamics.bias, 1, t)
136209

137-
Pp = F @ P @ F.T + Q
138-
139-
E = psd_solve(Pp, F @ P).T
140-
g = m - E @ F @ m
141-
L = P - E @ Pp @ E.T
210+
CF, low = cho_factor(F @ P @ F.T + Q)
211+
E = cho_solve((CF, low), F @ P).T
212+
g = m - E @ (F @ m + b)
213+
L = symmetrize(P - E @ F @ P)
142214
return E, g, L
143-
144-
last_elems = _last_smoothing_element(filtered_means[-1], filtered_covariances[-1])
145-
generic_elems = vmap(_generic_smoothing_element, (None, 0, 0, 0))(
146-
params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_covariances)-1)
147-
)
148-
combined_elems = tuple(jnp.append(gen_elm, last_elm[None,:], axis=0)
149-
for gen_elm, last_elm in zip(generic_elems, last_elems))
150-
return combined_elems
215+
216+
En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1])
217+
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1))
218+
219+
return SmoothMessage(
220+
E=jnp.concatenate([Et, En[None]]),
221+
g=jnp.concatenate([gt, gn[None]]),
222+
L=jnp.concatenate([Lt, Ln[None]])
223+
)
151224

152225

153226
def lgssm_smoother(
@@ -163,26 +236,78 @@ def lgssm_smoother(
163236
filtered_posterior = lgssm_filter(params, emissions)
164237
filtered_means = filtered_posterior.filtered_means
165238
filtered_covs = filtered_posterior.filtered_covariances
166-
initial_elements = _make_associative_smoothing_elements(params, filtered_means, filtered_covs)
167-
239+
168240
@vmap
169-
def smoothing_operator(elem1, elem2):
241+
def _operator(elem1, elem2):
170242
E1, g1, L1 = elem1
171243
E2, g2, L2 = elem2
172-
173244
E = E2 @ E1
174245
g = E2 @ g1 + g2
175-
L = E2 @ L1 @ E2.T + L2
176-
246+
L = symmetrize(E2 @ L1 @ E2.T + L2)
177247
return E, g, L
178248

179-
_, smoothed_means, smoothed_covs, *_ = lax.associative_scan(
180-
smoothing_operator, initial_elements, reverse=True
181-
)
249+
initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs)
250+
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)
251+
182252
return PosteriorGSSMSmoothed(
183253
marginal_loglik=filtered_posterior.marginal_loglik,
184254
filtered_means=filtered_means,
185255
filtered_covariances=filtered_covs,
186-
smoothed_means=smoothed_means,
187-
smoothed_covariances=smoothed_covs
256+
smoothed_means=final_messages.g,
257+
smoothed_covariances=final_messages.L
188258
)
259+
260+
261+
#---------------------------------------------------------------------------#
262+
# Sampling #
263+
#---------------------------------------------------------------------------#
264+
265+
class SampleMessage(NamedTuple):
266+
"""
267+
Sampling associative scan elements.
268+
269+
Attributes:
270+
E: z_i ~ z_{j+1} weights.
271+
h: z_i ~ z_{j+1} bias.
272+
"""
273+
E: Float[Array, "ntime state_dim state_dim"]
274+
h: Float[Array, "ntime state_dim"]
275+
276+
277+
def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances):
278+
"""A parallel version of the lgssm sampling algorithm.
279+
280+
Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
281+
the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`.
282+
"""
283+
E, g, L = _initialize_smoothing_messages(params, filtered_means, filtered_covariances)
284+
return SampleMessage(E=E, h=MVN(g, L).sample(seed=key))
285+
286+
287+
def lgssm_posterior_sample(
288+
key: PRNGKey,
289+
params: ParamsLGSSM,
290+
emissions: Float[Array, "ntime emission_dim"]
291+
) -> Float[Array, "ntime state_dim"]:
292+
"""A parallel version of the lgssm sampling algorithm.
293+
294+
See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
295+
296+
Note: This function does not yet handle `inputs` to the system.
297+
"""
298+
filtered_posterior = lgssm_filter(params, emissions)
299+
filtered_means = filtered_posterior.filtered_means
300+
filtered_covs = filtered_posterior.filtered_covariances
301+
302+
@vmap
303+
def _operator(elem1, elem2):
304+
E1, h1 = elem1
305+
E2, h2 = elem2
306+
307+
E = E2 @ E1
308+
h = E2 @ h1 + h2
309+
return E, h
310+
311+
initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs)
312+
_, samples = lax.associative_scan(_operator, initial_messages, reverse=True)
313+
return samples

0 commit comments

Comments
 (0)