Skip to content

Commit e5d98e2

Browse files
committed
Initial commit.
0 parents  commit e5d98e2

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

BPTS.lua

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
--[[--
2+
In backpropagation through structure (BPTS) [see: "Learning task-dependent distributed representations by backpropagation
3+
through structure" (Goller and Kuchler, 1996)] a tree is created from the input, with a single encoder is cloned per node.
4+
Forward propagation is carried out by taking the input of the leafs and concatenating them in the parent
5+
node (similar to RAAM). Thus for the tree:
6+
7+
A
8+
/ \
9+
B C
10+
11+
A's input is the output of B and C concatenated together. The error is calculated for the top node (given
12+
a criterion), with backpropagation carried out by splitting the error across the children.
13+
--]]--
14+
15+
local BPTS = torch.class("BPTS")
16+
17+
require 'nn'
18+
require 'Tree'
19+
20+
21+
function BPTS.createNode(encoder, children)
22+
local seq = nn.Sequential()
23+
24+
-- set children as input to this node
25+
local parallel = nn.ParallelTable()
26+
for _, child in ipairs(children) do
27+
parallel:add(child)
28+
end
29+
seq:add(parallel)
30+
31+
-- concatenate their values
32+
seq:add(nn.JoinTable(1))
33+
34+
local sharedEncoder = encoder:clone("weight", "bias")
35+
seq:add(sharedEncoder)
36+
return seq
37+
end
38+
39+
function BPTS.createFromTree(tree, encoder)
40+
if tree:isLeaf() then
41+
-- leafs' inputs are their values
42+
return nn.Identity(), tree.value
43+
else
44+
local children = {}
45+
local values = {}
46+
for _, child in ipairs(tree.children) do
47+
mod, value = BPTS.createFromTree(child, encoder)
48+
table.insert(children, mod)
49+
table.insert(values, value)
50+
end
51+
52+
return BPTS.createNode(encoder, children), values
53+
end
54+
end

testBPTS.lua

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
require 'BPTS'
2+
3+
-- create an arbritrary dictionary
4+
function createDictionary(size)
5+
local dict = {}
6+
dict["A"] = torch.rand(size)
7+
dict["B"] = torch.rand(size)
8+
dict["C"] = torch.rand(size)
9+
dict["D"] = torch.rand(size)
10+
dict["E"] = torch.rand(size)
11+
dict["F"] = torch.rand(size)
12+
return dict
13+
end
14+
15+
function createEncoder(encodingSize)
16+
local encoder = nn.Sequential()
17+
encoder:add(nn.Linear(2*encodingSize, encodingSize))
18+
encoder:add(nn.Tanh())
19+
return encoder
20+
end
21+
22+
function main()
23+
torch.manualSeed(42)
24+
local encodingSize = 5
25+
26+
-- parse a manually create tree
27+
local tree = Tree.parse("(root A (childA (childB B C) D))", createDictionary(encodingSize))
28+
29+
-- create the encoder to be used at each leaf
30+
local encoder = createEncoder(encodingSize)
31+
32+
-- create the network using the tree
33+
local bpts, input = BPTS.createFromTree(tree, encoder)
34+
local output = torch.Tensor{1, 0, 0, 0, 0}
35+
local criterion = nn.MSECriterion()
36+
37+
-- do a basic functionality test
38+
criterion:forward(bpts:forward(input), output)
39+
bpts:zeroGradParameters()
40+
bpts:backward(input, criterion:backward(bpts.output, output))
41+
42+
-- however, no simple way to test the gradient (below will fail)
43+
-- local err = Jacobian.testJacobian(bpts, input)
44+
-- print("error: ", err)
45+
end
46+
47+
48+
main()

0 commit comments

Comments
 (0)