-
Notifications
You must be signed in to change notification settings - Fork 13
/
multilinearmodel.cpp
84 lines (72 loc) · 1.85 KB
/
multilinearmodel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include "multilinearmodel.h"
MultilinearModel::MultilinearModel(const string &filename)
{
core.Read(filename);
UnfoldCoreTensor();
}
MultilinearModel MultilinearModel::project(const vector<int> &indices) const
{
//cout << "creating projected tensors..." << endl;
// create a projected version of the model
MultilinearModel newmodel;
newmodel.core.resize(core.layers(), core.rows(), indices.size() * 3);
for (int i = 0; i < core.layers(); i++) {
for (int j = 0; j < core.rows(); j++) {
for (int k = 0, idx = 0; k < indices.size(); k++, idx += 3) {
int vidx = indices[k] * 3;
newmodel.core(i, j, idx) = core(i, j, vidx);
newmodel.core(i, j, idx + 1) = core(i, j, vidx + 1);
newmodel.core(i, j, idx + 2) = core(i, j, vidx + 2);
}
}
}
newmodel.UnfoldCoreTensor();
return newmodel;
}
void MultilinearModel::UpdateTM0(const Tensor1 &w)
{
#if 0
tm0 = core.ModeProduct<0>(w);
#else
// tu0
// id0: | exp0 | exp1 | ... | expn |
// id1: | exp0 | exp1 | ... | expn |
// ...
// idn: | exp0 | exp1 | ... | expn |
auto tm0u = tu0.ModeProduct<0>(w);
tm0 = Tensor2::FoldByColumn(tm0u, core.rows(), core.cols());
#endif
}
void MultilinearModel::UpdateTM1(const Tensor1 &w)
{
#if 0
tm1 = core.ModeProduct<1>(w);
#else
// tu1
// exp0: | x0 | y0 | z0 | ..
// exp1:
// ...
// expn:
auto tm1u = tu1.ModeProduct<0>(w);
tm1 = Tensor2::FoldByRow(tm1u, core.layers(), core.cols());
#endif
}
void MultilinearModel::UpdateTMWithTM0(const Tensor1 &w)
{
tm = tm0.ModeProduct<0>(w);
}
void MultilinearModel::UpdateTMWithTM1(const Tensor1 &w)
{
tm = tm1.ModeProduct<0>(w);
}
void MultilinearModel::ApplyWeights(const Tensor1 &w0, const Tensor1 &w1)
{
UpdateTM0(w0);
UpdateTM1(w1);
UpdateTMWithTM0(w1);
}
void MultilinearModel::UnfoldCoreTensor()
{
tu0 = core.Unfold(0);
tu1 = core.Unfold(1);
}