1
- # Parallel filtering and smoothing for a lgssm.
2
- # This implementation is adapted from the work of Adrien Correnflos in,
3
- # https://github.com/EEA-sensors/sequential-parallelization-examples/
1
+ '''
2
+ Parallel filtering and smoothing for a lgssm.
3
+
4
+ This implementation is adapted from the work of Adrien Correnflos:
5
+ https://github.com/EEA-sensors/sequential-parallelization-examples/
6
+
7
+ Note that in the original implementation, the initial state distribution
8
+ applies to t=0, and the first emission occurs at time `t=1` (i.e. after
9
+ the initial state has been transformed by the dynamics), whereas here,
10
+ the first emission occurs at time `t=0` and is produced directly by the
11
+ untransformed initial state (see below).
12
+
13
+ Sarkka et al.
14
+
15
+ F₀,Q₀ F₁,Q₁ F₂,Q₂
16
+ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
17
+ | | |
18
+ | H₁,R₁ | H₂,R₂ | H₃,R₃
19
+ | | |
20
+ Y₁ Y₂ Y₃
21
+
22
+ Dynamax
23
+
24
+ F₀,Q₀ F₁,Q₁ F₂,Q₂
25
+ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
26
+ | | | |
27
+ | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃
28
+ | | | |
29
+ Y₀ Y₁ Y₂ Y₃
30
+
31
+ '''
32
+
4
33
import jax .numpy as jnp
5
- import jax .scipy as jsc
6
34
from jax import vmap , lax
7
35
from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
8
36
from jaxtyping import Array , Float
37
+ from typing import NamedTuple
38
+ from dynamax .types import PRNGKey
39
+ from functools import partial
9
40
10
- from dynamax .utils .utils import psd_solve
41
+ from jax .scipy .linalg import cho_solve , cho_factor
42
+ from dynamax .utils .utils import symmetrize
11
43
from dynamax .linear_gaussian_ssm import PosteriorGSSMFiltered , PosteriorGSSMSmoothed , ParamsLGSSM
12
44
45
+
13
46
def _get_params (x , dim , t ):
14
47
if callable (x ):
15
48
return x (t )
16
49
elif x .ndim == dim + 1 :
17
50
return x [t ]
18
51
else :
19
52
return x
53
+
54
+ #---------------------------------------------------------------------------#
55
+ # Filtering #
56
+ #---------------------------------------------------------------------------#
20
57
21
- def _make_associative_filtering_elements (params , emissions ):
58
+ class FilterMessage (NamedTuple ):
59
+ """
60
+ Filtering associative scan elements.
61
+
62
+ Attributes:
63
+ A: P(z_j | y_{i:j}, z_{i-1}) weights.
64
+ b: P(z_j | y_{i:j}, z_{i-1}) bias.
65
+ C: P(z_j | y_{i:j}, z_{i-1}) covariance.
66
+ J: P(z_{i-1} | y_{i:j}) covariance.
67
+ eta: P(z_{i-1} | y_{i:j}) mean.
68
+ """
69
+ A : Float [Array , "ntime state_dim state_dim" ]
70
+ b : Float [Array , "ntime state_dim" ]
71
+ C : Float [Array , "ntime state_dim state_dim" ]
72
+ J : Float [Array , "ntime state_dim state_dim" ]
73
+ eta : Float [Array , "ntime state_dim" ]
74
+ logZ : Float [Array , "ntime" ]
75
+
76
+
77
+ def _initialize_filtering_messages (params , emissions ):
22
78
"""Preprocess observations to construct input for filtering assocative scan."""
23
79
24
- def _first_filtering_element (params , y ):
25
- F = _get_params (params .dynamics .weights , 2 , 0 )
80
+ def _first_message (params , y ):
26
81
H = _get_params (params .emissions .weights , 2 , 0 )
27
- Q = _get_params (params .dynamics .cov , 2 , 0 )
28
82
R = _get_params (params .emissions .cov , 2 , 0 )
83
+ d = _get_params (params .emissions .bias , 1 , 0 )
84
+ m = params .initial .mean
85
+ P = params .initial .cov
29
86
30
- S = H @ Q @ H .T + R
31
- CF , low = jsc .linalg .cho_factor (S )
87
+ S = H @ P @ H .T + R
88
+ CF , low = cho_factor (S )
89
+ K = cho_solve ((CF , low ), H @ P ).T
32
90
33
- m1 = params .initial .mean
34
- P1 = params .initial .cov
35
- S1 = H @ P1 @ H .T + R
36
- K1 = psd_solve (S1 , H @ P1 ).T
37
-
38
- A = jnp .zeros_like (F )
39
- b = m1 + K1 @ (y - H @ m1 )
40
- C = P1 - K1 @ S1 @ K1 .T
41
-
42
- eta = F .T @ H .T @ jsc .linalg .cho_solve ((CF , low ), y )
43
- J = F .T @ H .T @ jsc .linalg .cho_solve ((CF , low ), H @ F )
44
-
45
- logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = H @ P1 @ H .T + R ).log_prob (y )
91
+ A = jnp .zeros_like (P )
92
+ b = m + K @ (y - H @ m - d )
93
+ C = symmetrize (P - K @ S @ K .T )
94
+ eta = jnp .zeros_like (b )
95
+ J = jnp .eye (len (b ))
46
96
97
+ logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = H @ P @ H .T + R ).log_prob (y )
47
98
return A , b , C , J , eta , logZ
48
99
49
100
50
- def _generic_filtering_element (params , y , t ):
101
+ @partial (vmap , in_axes = (None , 0 , 0 ))
102
+ def _generic_message (params , y , t ):
51
103
F = _get_params (params .dynamics .weights , 2 , t )
52
- H = _get_params (params .emissions .weights , 2 , t + 1 )
53
104
Q = _get_params (params .dynamics .cov , 2 , t )
105
+ b = _get_params (params .dynamics .bias , 1 , t )
106
+ H = _get_params (params .emissions .weights , 2 , t + 1 )
54
107
R = _get_params (params .emissions .cov , 2 , t + 1 )
108
+ d = _get_params (params .emissions .bias , 1 , t + 1 )
55
109
56
110
S = H @ Q @ H .T + R
57
- CF , low = jsc .linalg .cho_factor (S )
58
- K = jsc .linalg .cho_solve ((CF , low ), H @ Q ).T
59
- A = F - K @ H @ F
60
- b = K @ y
61
- C = Q - K @ H @ Q
111
+ CF , low = cho_factor (S )
112
+ K = cho_solve ((CF , low ), H @ Q ).T
62
113
63
- eta = F .T @ H .T @ jsc . linalg . cho_solve ((CF , low ), y )
64
- J = F .T @ H .T @ jsc . linalg . cho_solve ((CF , low ), H @ F )
114
+ eta = F .T @ H .T @ cho_solve ((CF , low ), y - H @ b - d )
115
+ J = symmetrize ( F .T @ H .T @ cho_solve ((CF , low ), H @ F ) )
65
116
66
- logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = S ).log_prob (y )
117
+ A = F - K @ H @ F
118
+ b = b + K @ (y - H @ b - d )
119
+ C = symmetrize (Q - K @ H @ Q )
67
120
121
+ logZ = - MVN (loc = jnp .zeros_like (y ), covariance_matrix = S ).log_prob (y )
68
122
return A , b , C , J , eta , logZ
69
123
70
- first_elems = _first_filtering_element (params , emissions [0 ])
71
- generic_elems = vmap (_generic_filtering_element , (None , 0 , 0 ))(params , emissions [1 :], jnp .arange (len (emissions )- 1 ))
72
- combined_elems = tuple (jnp .concatenate ((first_elm [None ,...], gen_elm ))
73
- for first_elm , gen_elm in zip (first_elems , generic_elems ))
74
- return combined_elems
124
+
125
+ A0 , b0 , C0 , J0 , eta0 , logZ0 = _first_message (params , emissions [0 ])
126
+ At , bt , Ct , Jt , etat , logZt = _generic_message (params , emissions [1 :], jnp .arange (len (emissions )- 1 ))
127
+
128
+ return FilterMessage (
129
+ A = jnp .concatenate ([A0 [None ], At ]),
130
+ b = jnp .concatenate ([b0 [None ], bt ]),
131
+ C = jnp .concatenate ([C0 [None ], Ct ]),
132
+ J = jnp .concatenate ([J0 [None ], Jt ]),
133
+ eta = jnp .concatenate ([eta0 [None ], etat ]),
134
+ logZ = jnp .concatenate ([logZ0 [None ], logZt ])
135
+ )
136
+
137
+
75
138
76
139
def lgssm_filter (
77
140
params : ParamsLGSSM ,
@@ -83,71 +146,81 @@ def lgssm_filter(
83
146
84
147
Note: This function does not yet handle `inputs` to the system.
85
148
"""
86
- #TODO: Add input handling.
87
- initial_elements = _make_associative_filtering_elements (params , emissions )
88
-
89
149
@vmap
90
- def filtering_operator (elem1 , elem2 ):
150
+ def _operator (elem1 , elem2 ):
91
151
A1 , b1 , C1 , J1 , eta1 , logZ1 = elem1
92
152
A2 , b2 , C2 , J2 , eta2 , logZ2 = elem2
93
- dim = A1 .shape [0 ]
94
- I = jnp .eye (dim )
153
+ I = jnp .eye (A1 .shape [0 ])
95
154
96
155
I_C1J2 = I + C1 @ J2
97
- temp = jsc .linalg .solve (I_C1J2 .T , A2 .T ).T
156
+ temp = jnp .linalg .solve (I_C1J2 .T , A2 .T ).T
98
157
A = temp @ A1
99
158
b = temp @ (b1 + C1 @ eta2 ) + b2
100
- C = temp @ C1 @ A2 .T + C2
159
+ C = symmetrize ( temp @ C1 @ A2 .T + C2 )
101
160
102
161
I_J2C1 = I + J2 @ C1
103
- temp = jsc .linalg .solve (I_J2C1 .T , A1 ).T
104
-
162
+ temp = jnp .linalg .solve (I_J2C1 .T , A1 ).T
105
163
eta = temp @ (eta2 - J2 @ b1 ) + eta1
106
- J = temp @ J2 @ A1 + J1
107
-
108
- # mu = jsc.linalg.solve(J2, eta2)
109
- # t2 = - eta2 @ mu + (b1 - mu) @ jsc.linalg.solve(I_J2C1, (J2 @ b1 - eta2))
164
+ J = symmetrize (temp @ J2 @ A1 + J1 )
110
165
111
166
mu = jnp .linalg .solve (C1 , b1 )
112
167
t1 = (b1 @ mu - (eta2 + mu ) @ jnp .linalg .solve (I_C1J2 , C1 @ eta2 + b1 ))
113
-
114
168
logZ = (logZ1 + logZ2 + 0.5 * jnp .linalg .slogdet (I_C1J2 )[1 ] + 0.5 * t1 )
169
+ return FilterMessage (A , b , C , J , eta , logZ )
170
+
171
+ initial_messages = _initialize_filtering_messages (params , emissions )
172
+ final_messages = lax .associative_scan (_operator , initial_messages )
173
+
174
+ return PosteriorGSSMFiltered (
175
+ filtered_means = final_messages .b ,
176
+ filtered_covariances = final_messages .C ,
177
+ marginal_loglik = - final_messages .logZ [- 1 ])
115
178
116
- return A , b , C , J , eta , logZ
117
179
118
- _ , filtered_means , filtered_covs , _ , _ , logZ = lax . associative_scan (
119
- filtering_operator , initial_elements
120
- )
180
+ #---------------------------------------------------------------------------#
181
+ # Smoothing #
182
+ #---------------------------------------------------------------------------#
121
183
122
- return PosteriorGSSMFiltered (marginal_loglik = - logZ [- 1 ],
123
- filtered_means = filtered_means , filtered_covariances = filtered_covs )
184
+ class SmoothMessage (NamedTuple ):
185
+ """
186
+ Smoothing associative scan elements.
124
187
188
+ Attributes:
189
+ E: P(z_i | y_{1:j}, z_{j+1}) weights.
190
+ g: P(z_i | y_{1:j}, z_{j+1}) bias.
191
+ L: P(z_i | y_{1:j}, z_{j+1}) covariance.
192
+ """
193
+ E : Float [Array , "ntime state_dim state_dim" ]
194
+ g : Float [Array , "ntime state_dim" ]
195
+ L : Float [Array , "ntime state_dim state_dim" ]
125
196
126
197
127
- def _make_associative_smoothing_elements (params , filtered_means , filtered_covariances ):
198
+ def _initialize_smoothing_messages (params , filtered_means , filtered_covariances ):
128
199
"""Preprocess filtering output to construct input for smoothing assocative scan."""
129
200
130
- def _last_smoothing_element (m , P ):
201
+ def _last_message (m , P ):
131
202
return jnp .zeros_like (P ), m , P
132
203
133
- def _generic_smoothing_element (params , m , P , t ):
204
+ @partial (vmap , in_axes = (None , 0 , 0 , 0 ))
205
+ def _generic_message (params , m , P , t ):
134
206
F = _get_params (params .dynamics .weights , 2 , t )
135
207
Q = _get_params (params .dynamics .cov , 2 , t )
208
+ b = _get_params (params .dynamics .bias , 1 , t )
136
209
137
- Pp = F @ P @ F .T + Q
138
-
139
- E = psd_solve (Pp , F @ P ).T
140
- g = m - E @ F @ m
141
- L = P - E @ Pp @ E .T
210
+ CF , low = cho_factor (F @ P @ F .T + Q )
211
+ E = cho_solve ((CF , low ), F @ P ).T
212
+ g = m - E @ (F @ m + b )
213
+ L = symmetrize (P - E @ F @ P )
142
214
return E , g , L
143
-
144
- last_elems = _last_smoothing_element (filtered_means [- 1 ], filtered_covariances [- 1 ])
145
- generic_elems = vmap (_generic_smoothing_element , (None , 0 , 0 , 0 ))(
146
- params , filtered_means [:- 1 ], filtered_covariances [:- 1 ], jnp .arange (len (filtered_covariances )- 1 )
147
- )
148
- combined_elems = tuple (jnp .append (gen_elm , last_elm [None ,:], axis = 0 )
149
- for gen_elm , last_elm in zip (generic_elems , last_elems ))
150
- return combined_elems
215
+
216
+ En , gn , Ln = _last_message (filtered_means [- 1 ], filtered_covariances [- 1 ])
217
+ Et , gt , Lt = _generic_message (params , filtered_means [:- 1 ], filtered_covariances [:- 1 ], jnp .arange (len (filtered_means )- 1 ))
218
+
219
+ return SmoothMessage (
220
+ E = jnp .concatenate ([Et , En [None ]]),
221
+ g = jnp .concatenate ([gt , gn [None ]]),
222
+ L = jnp .concatenate ([Lt , Ln [None ]])
223
+ )
151
224
152
225
153
226
def lgssm_smoother (
@@ -163,26 +236,78 @@ def lgssm_smoother(
163
236
filtered_posterior = lgssm_filter (params , emissions )
164
237
filtered_means = filtered_posterior .filtered_means
165
238
filtered_covs = filtered_posterior .filtered_covariances
166
- initial_elements = _make_associative_smoothing_elements (params , filtered_means , filtered_covs )
167
-
239
+
168
240
@vmap
169
- def smoothing_operator (elem1 , elem2 ):
241
+ def _operator (elem1 , elem2 ):
170
242
E1 , g1 , L1 = elem1
171
243
E2 , g2 , L2 = elem2
172
-
173
244
E = E2 @ E1
174
245
g = E2 @ g1 + g2
175
- L = E2 @ L1 @ E2 .T + L2
176
-
246
+ L = symmetrize (E2 @ L1 @ E2 .T + L2 )
177
247
return E , g , L
178
248
179
- _ , smoothed_means , smoothed_covs , * _ = lax . associative_scan (
180
- smoothing_operator , initial_elements , reverse = True
181
- )
249
+ initial_messages = _initialize_smoothing_messages ( params , filtered_means , filtered_covs )
250
+ final_messages = lax . associative_scan ( _operator , initial_messages , reverse = True )
251
+
182
252
return PosteriorGSSMSmoothed (
183
253
marginal_loglik = filtered_posterior .marginal_loglik ,
184
254
filtered_means = filtered_means ,
185
255
filtered_covariances = filtered_covs ,
186
- smoothed_means = smoothed_means ,
187
- smoothed_covariances = smoothed_covs
256
+ smoothed_means = final_messages . g ,
257
+ smoothed_covariances = final_messages . L
188
258
)
259
+
260
+
261
+ #---------------------------------------------------------------------------#
262
+ # Sampling #
263
+ #---------------------------------------------------------------------------#
264
+
265
+ class SampleMessage (NamedTuple ):
266
+ """
267
+ Sampling associative scan elements.
268
+
269
+ Attributes:
270
+ E: z_i ~ z_{j+1} weights.
271
+ h: z_i ~ z_{j+1} bias.
272
+ """
273
+ E : Float [Array , "ntime state_dim state_dim" ]
274
+ h : Float [Array , "ntime state_dim" ]
275
+
276
+
277
+ def _initialize_sampling_messages (key , params , filtered_means , filtered_covariances ):
278
+ """A parallel version of the lgssm sampling algorithm.
279
+
280
+ Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
281
+ the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`.
282
+ """
283
+ E , g , L = _initialize_smoothing_messages (params , filtered_means , filtered_covariances )
284
+ return SampleMessage (E = E , h = MVN (g , L ).sample (seed = key ))
285
+
286
+
287
+ def lgssm_posterior_sample (
288
+ key : PRNGKey ,
289
+ params : ParamsLGSSM ,
290
+ emissions : Float [Array , "ntime emission_dim" ]
291
+ ) -> Float [Array , "ntime state_dim" ]:
292
+ """A parallel version of the lgssm sampling algorithm.
293
+
294
+ See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
295
+
296
+ Note: This function does not yet handle `inputs` to the system.
297
+ """
298
+ filtered_posterior = lgssm_filter (params , emissions )
299
+ filtered_means = filtered_posterior .filtered_means
300
+ filtered_covs = filtered_posterior .filtered_covariances
301
+
302
+ @vmap
303
+ def _operator (elem1 , elem2 ):
304
+ E1 , h1 = elem1
305
+ E2 , h2 = elem2
306
+
307
+ E = E2 @ E1
308
+ h = E2 @ h1 + h2
309
+ return E , h
310
+
311
+ initial_messages = _initialize_sampling_messages (key , params , filtered_means , filtered_covs )
312
+ _ , samples = lax .associative_scan (_operator , initial_messages , reverse = True )
313
+ return samples
0 commit comments