-
Notifications
You must be signed in to change notification settings - Fork 10
/
mmd.py
75 lines (61 loc) · 2.51 KB
/
mmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright 2018 Alexander Matthews
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from IPython import embed
import numpy as np
from scipy.stats import multivariate_normal
import gpflow
def mmd(datasetA, datasetB, kernel ):
#here we use a biased but consistent estimator
#corresponding to equation 5 of
#Gretton et al 2012 JMLR.
#The bias will should be negligible with the large number of samples
#we use.
KAA = kernel.compute_K_symm(datasetA)
KAA_corrected = KAA - np.diag(np.diag(KAA))
KBB = kernel.compute_K_symm(datasetB)
KBB_corrected = KBB - np.diag(np.diag(KBB))
KAB = kernel.compute_K(datasetA,datasetB)
M = KAA.shape[0]
return np.sum( KAA_corrected/M/(M-1) + KBB_corrected/M/(M-1) - 2*KAB/M/M)
def test_mmd():
from matplotlib import pylab as plt
np.random.seed(1)
num_dim = 10
kern = gpflow.kernels.RBF(num_dim, lengthscales = np.ones(num_dim) )
#embed()
#kern.lengscales = np.ones( num_dim )
meanA = np.zeros(num_dim)
covA = np.eye(num_dim)
covB = covA
num_test = 30
num_repeats = 20
num_samples = 2000
betas = np.linspace(0., 1. , num_test )
mmd_squareds = np.zeros(( num_test, num_repeats ))
for repeat_index in range(num_repeats):
for beta, index in zip(betas,range(len(betas))):
meanB = np.ones_like( meanA )
meanB = beta*meanB/ np.sqrt( np.sum( meanB ** 2 ) )
samplesA = multivariate_normal.rvs( size = num_samples, mean = meanA, cov=covA )
samplesB = multivariate_normal.rvs( size = num_samples, mean = meanB, cov=covB )
mmd_squareds[index,repeat_index] = mmd(samplesA, samplesB, kern )
#stop
mean_mmd_squared = np.mean( mmd_squareds, axis = 1)
std_mmd_squared = np.std( mmd_squareds, axis = 1 ) / np.sqrt( num_repeats-1 )
plt.errorbar(betas,mean_mmd_squared, yerr = 2.*std_mmd_squared)
plt.figure()
#plt.errorbar(beta, np.sqrt( mean_mmd_squared )
embed()
if __name__ == '__main__':
test_mmd()