Skip to content

Commit 90e8cd6

Browse files
Merge pull request #28 from GitHub-HongweiZhang/fix-deepfm
fix deepFM
2 parents 7aa01ec + 3e4d9ba commit 90e8cd6

File tree

4 files changed

+23
-17
lines changed

4 files changed

+23
-17
lines changed

prediction_flow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.3'
1+
__version__ = '0.1.4'

prediction_flow/pytorch/deepfm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,23 @@ def forward(self, x):
134134
for feature in self.features.number_features:
135135
embeddings.append(
136136
self.embeddings[feature.name](
137-
x[feature.name].view(-1, 1)))
137+
x[feature.name].view(-1, 1)).unsqueeze(1))
138138
for feature in self.features.category_features:
139139
embeddings.append(
140-
self.embeddings[feature.name](x[feature.name]))
140+
self.embeddings[feature.name](x[feature.name]).unsqueeze(1))
141141
for feature in self.features.sequence_features:
142142
embeddings.append(
143143
self._sequence_poolings[feature.name](
144-
self.embeddings[feature.name](x[feature.name])))
144+
self.embeddings[feature.name](x[feature.name])).unsqueeze(1))
145145

146146
emb_concat = None
147147
if embeddings:
148148
emb_concat = torch.cat(embeddings, dim=1)
149-
150-
# fm
151-
if self.fm:
152-
final_layer_inputs.append(self.fm(emb_concat))
149+
b, f, e = emb_concat.size()
150+
# fm
151+
if self.fm:
152+
final_layer_inputs.append(self.fm(emb_concat))
153+
emb_concat = emb_concat.view(b, f * e)
153154

154155
# deep
155156
if self.mlp:

prediction_flow/pytorch/nn/fm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
class FM(nn.Module):
1414
"""FM layer.
1515
"""
16-
def __init__(self):
16+
def __init__(self, reduce_sum=True):
1717
super(FM, self).__init__()
18+
self.reduce_sum = reduce_sum
1819

1920
def forward(self, x):
20-
sum_squared = torch.pow(torch.sum(x, dim=1), 2).unsqueeze(1)
21-
squared_sum = torch.sum(torch.pow(x, 2), dim=1).unsqueeze(1)
22-
second_order = 0.5 * (sum_squared - squared_sum)
23-
return second_order
21+
sum_squared = torch.pow(torch.sum(x, dim=1), 2)
22+
squared_sum = torch.sum(torch.pow(x, 2), dim=1)
23+
second_order = sum_squared - squared_sum
24+
if self.reduce_sum:
25+
output = torch.sum(second_order, dim=1, keepdim=True)
26+
return 0.5 * output

prediction_flow/pytorch/nn/tests/test_fm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
def test_fm():
88
fm = FM()
99

10-
x = torch.as_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
10+
x = torch.as_tensor(
11+
[[[1.0, 1.0, 1.0], [1.0, 2.0, 3.0]],
12+
[[1.0, 1.0, 1.0], [4.0, 5.0, 6.0]]])
1113
actual = fm(x)
1214

13-
# 11.0 = 1 * 2 + 1 * 3 + 2 * 3
14-
# 77.0 = 4 * 5 + 4 * 6 + 5 * 6
15+
# 6.0 = 1 * 1 + 1 * 2 + 1 * 3
16+
# 15.0 = 1 * 4 + 1 * 5 + 1 * 6
1517
np.testing.assert_array_almost_equal(
16-
actual.numpy(), np.array([[11.0], [74.0]], dtype=np.float))
18+
actual.numpy(), np.array([[6.0], [15.0]], dtype=np.float))

0 commit comments

Comments
 (0)