/
interpolate_spline.py
291 lines (227 loc) · 11 KB
/
interpolate_spline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Polyharmonic spline interpolation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
EPSILON = 0.0000000001
def _cross_squared_distance_matrix(x, y):
"""Pairwise squared distance between two (batch) matrices' rows (2nd dim).
Computes the pairwise distances between rows of x and rows of y
Args:
x: [batch_size, n, d] float `Tensor`
y: [batch_size, m, d] float `Tensor`
Returns:
squared_dists: [batch_size, n, m] float `Tensor`, where
squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
"""
x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2)
y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2)
# Expand so that we can broadcast.
x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1)
x_y_transpose = math_ops.matmul(x, y, adjoint_b=True)
# squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
return squared_dists
def _pairwise_squared_distance_matrix(x):
"""Pairwise squared distance among a (batch) matrix's rows (2nd dim).
This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x)
Args:
x: `[batch_size, n, d]` float `Tensor`
Returns:
squared_dists: `[batch_size, n, n]` float `Tensor`, where
squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
"""
x_x_transpose = math_ops.matmul(x, x, adjoint_b=True)
x_norm_squared = array_ops.matrix_diag_part(x_x_transpose)
x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
# squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose(
x_norm_squared_tile, [0, 2, 1])
return squared_dists
def _solve_interpolation(train_points, train_values, order,
regularization_weight):
"""Solve for interpolation coefficients.
Computes the coefficients of the polyharmonic interpolant for the 'training'
data defined by (train_points, train_values) using the kernel phi.
Args:
train_points: `[b, n, d]` interpolation centers
train_values: `[b, n, k]` function values
order: order of the interpolation
regularization_weight: weight to place on smoothness regularization term
Returns:
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
"""
b, n, d = train_points.get_shape().as_list()
_, _, k = train_values.get_shape().as_list()
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
# To account for python style guidelines we use
# matrix_a for A and matrix_b for B.
c = train_points
f = train_values
# Next, construct the linear system.
with ops.name_scope('construct_linear_system'):
matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
if regularization_weight > 0:
batch_identity_matrix = np.expand_dims(np.eye(n), 0)
batch_identity_matrix = constant_op.constant(
batch_identity_matrix, dtype=train_points.dtype)
matrix_a += regularization_weight * batch_identity_matrix
# Append ones to the feature values for the bias term in the linear model.
ones = array_ops.ones([b, n, 1], train_points.dtype)
matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1]
# [b, n + d + 1, n]
left_block = array_ops.concat(
[matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1)
num_b_cols = matrix_b.get_shape()[2] # d + 1
lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
right_block = array_ops.concat([matrix_b, lhs_zeros],
1) # [b, n + d + 1, d + 1]
lhs = array_ops.concat([left_block, right_block],
2) # [b, n + d + 1, n + d + 1]
rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype)
rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k]
# Then, solve the linear system and unpack the results.
with ops.name_scope('solve_linear_system'):
w_v = linalg_ops.matrix_solve(lhs, rhs)
w = w_v[:, :n, :]
v = w_v[:, n:, :]
return w, v
def _apply_interpolation(query_points, train_points, w, v, order):
"""Apply polyharmonic interpolation model to data.
Given coefficients w and v for the interpolation model, we evaluate
interpolated function values at query_points.
Args:
query_points: `[b, m, d]` x values to evaluate the interpolation at
train_points: `[b, n, d]` x values that act as the interpolation centers
( the c variables in the wikipedia article)
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
order: order of the interpolation
Returns:
Polyharmonic interpolation evaluated at points defined in query_points.
"""
batch_size = train_points.get_shape()[0].value
num_query_points = query_points.get_shape()[1].value
# First, compute the contribution from the rbf term.
pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
phi_pairwise_dists = _phi(pairwise_dists, order)
rbf_term = math_ops.matmul(phi_pairwise_dists, w)
# Then, compute the contribution from the linear term.
# Pad query_points with ones, for the bias term in the linear model.
query_points_pad = array_ops.concat([
query_points,
array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
], 2)
linear_term = math_ops.matmul(query_points_pad, v)
return rbf_term + linear_term
def _phi(r, order):
"""Coordinate-wise nonlinearity used to define the order of the interpolation.
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
Args:
r: input op
order: interpolation order
Returns:
phi_k evaluated coordinate-wise on r, for k = r
"""
# using EPSILON prevents log(0), sqrt0), etc.
# sqrt(0) is well-defined, but its gradient is not
with ops.name_scope('phi'):
if order == 1:
r = math_ops.maximum(r, EPSILON)
r = math_ops.sqrt(r)
return r
elif order == 2:
return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON))
elif order == 4:
return 0.5 * math_ops.square(r) * math_ops.log(
math_ops.maximum(r, EPSILON))
elif order % 2 == 0:
r = math_ops.maximum(r, EPSILON)
return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r)
else:
r = math_ops.maximum(r, EPSILON)
return math_ops.pow(r, 0.5 * order)
def interpolate_spline(train_points,
train_values,
query_points,
order,
regularization_weight=0.0,
name='interpolate_spline'):
r"""Interpolate signal using polyharmonic interpolation.
The interpolant has the form
$$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
This is a sum of two terms: (1) a weighted sum of radial basis function (RBF)
terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias.
The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v
by appending 1 as a final dimension to x. The coefficients w and v are
estimated such that the interpolant exactly fits the value of the function at
the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the
vector w sums to 0. With these constraints, the coefficients can be obtained
by solving a linear system.
\\(\phi\\) is an RBF, parametrized by an interpolation
order. Using order=2 produces the well-known thin-plate spline.
We also provide the option to perform regularized interpolation. Here, the
interpolant is selected to trade off between the squared loss on the training
data and a certain measure of its curvature
([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
Using a regularization weight greater than zero has the effect that the
interpolant will no longer exactly fit the training data. However, it may be
less vulnerable to overfitting, particularly for high-order interpolation.
Note the interpolation procedure is differentiable with respect to all inputs
besides the order parameter.
Args:
train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
locations. These do not need to be regularly-spaced.
train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values
evaluated at train_points.
query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
where we will output the interpolant's values.
order: order of the interpolation. Common values are 1 for
\\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline),
or 3 for \\(\phi(r) = r^3\\).
regularization_weight: weight placed on the regularization term.
This will depend substantially on the problem, and it should always be
tuned. For many problems, it is reasonable to use no regularization.
If using a non-zero value, we recommend a small value like 0.001.
name: name prefix for ops created by this function
Returns:
`[b, m, k]` float `Tensor` of query values. We use train_points and
train_values to perform polyharmonic interpolation. The query values are
the values of the interpolant evaluated at the locations specified in
query_points.
"""
with ops.name_scope(name):
train_points = ops.convert_to_tensor(train_points)
train_values = ops.convert_to_tensor(train_values)
query_points = ops.convert_to_tensor(query_points)
# First, fit the spline to the observed data.
with ops.name_scope('solve'):
w, v = _solve_interpolation(train_points, train_values, order,
regularization_weight)
# Then, evaluate the spline at the query locations.
with ops.name_scope('predict'):
query_values = _apply_interpolation(query_points, train_points, w, v,
order)
return query_values