/
kmain.m
86 lines (70 loc) · 2.63 KB
/
kmain.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
%% Choose CVX solver
cvx_solver Mosek
%% Load wine quality data
data = csvread('Wine_Quality_Data.csv', 2);
X = data(:, 1:11);
Y = 2*data(:, 12) - 1;
Z = 2*data(:, 13) - 1;
%% Split data into training and validation sets
[n, p] = size(X);
inds = randperm(n);
tinds = inds(1:floor(0.8*n));
vinds = inds(floor(0.8*n)+1:end);
Xt = X(tinds,:);
Yt = Y(tinds,:);
Zt = Z(tinds,:);
Xv = X(vinds,:);
Yv = Y(vinds,:);
Zv = Z(vinds,:);
%% Choose fairness level
d = 0;
mu = 1e2;
%% Compute Gram matrix with kernel (y*x').^2
K = pdist2(X, X, @(x, y) (y*x').^2);
Kt = K(tinds,:); Kt = Kt(:,tinds);
Kv = K(vinds,:); Kv = Kv(:,tinds);
%% Compute (regular) SVM and its ROC curve
disp('Kernel SVM')
[alph, k, L] = ksvm(Xt, Yt, Kt);
[roc, sroc] = kroc(Xt, Yt, Kt, Kv, Yv, Zv, alph, k);
del = max(abs(sroc(:,1)-sroc(:,2)))
auc = trapz(roc(:,1), roc(:,2))
subplot(131);
plot(roc(:,1), roc(:,2), 'LineStyle', '-', 'Color', [0 0.4470 0.7410]);
hold on;
plot(sroc(:,1), sroc(:,2), 'LineStyle', '-.', 'Color', [0.8500 0.3250 0.0980]);
plot(linspace(0,1,10), linspace(0,1,10), 'LineStyle', '--', 'Color', [0.5 0.5 0.5]);
hold off;
axis square;
%% Compute (average) fair SVM and its ROC curve
disp('Kernel SVM with Linear Fairness Constraint')
spind = (Zt >= 0);
snind = (Zt < 0);
aveX = Yt.*(mean(Kt(:,spind),2) - mean(Kt(:,snind),2));
[alph, k] = solve_ksvm( Xt, Yt, Kt, L, aveX'/norm(aveX), d );
[roc, sroc] = kroc(Xt, Yt, Kt, Kv, Yv, Zv, alph, k);
del = max(abs(sroc(:,1)-sroc(:,2)))
auc = trapz(roc(:,1), roc(:,2))
subplot(132);
plot(roc(:,1), roc(:,2), 'LineStyle', '-', 'Color', [0 0.4470 0.7410]);
hold on;
plot(sroc(:,1), sroc(:,2), 'LineStyle', '--', 'Color', [0.8500 0.3250 0.0980]);
plot(linspace(0,1,10), linspace(0,1,10), 'LineStyle', ':', 'Color', [0 0 0]);
hold off;
axis square;
%% Compute (dc algorithm) fair SVM and its ROC curve
disp('Kernel SVM from Spectral Algorithm')
pSigma = Kt(:,spind); pSigma = pSigma*pSigma'/sum(spind) - Kt(:,spind)*eye(sum(spind))*Kt(:,spind)'/sum(spind)^2;
nSigma = Kt(:,snind); nSigma = nSigma*nSigma'/sum(snind) - Kt(:,snind)*eye(sum(snind))*Kt(:,snind)'/sum(snind)^2;
aveX = Yt.*(mean(Kt(:,spind),2) - mean(Kt(:,snind),2));
[alph, k] = solve_ksvm( Xt, Yt, Kt, L, aveX'/norm(aveX), d, pSigma-nSigma, mu );
[roc, sroc] = kroc(Xt, Yt, Kt, Kv, Yv, Zv, alph, k);
del = max(abs(sroc(:,1)-sroc(:,2)))
auc = trapz(roc(:,1), roc(:,2))
subplot(133);
plot(roc(:,1), roc(:,2), 'LineStyle', '-', 'Color', [0 0.4470 0.7410]);
hold on;
plot(sroc(:,1), sroc(:,2), 'LineStyle', '--', 'Color', [0.8500 0.3250 0.0980]);
plot(linspace(0,1,10), linspace(0,1,10), 'LineStyle', ':', 'Color', [0 0 0]);
hold off;
axis square;