diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2013-07-10 18:46:25 +0400 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2013-07-10 18:46:25 +0400 |
commit | 42402eb04e47e0bf872726c1188ca0fdd199f224 (patch) | |
tree | 4db43b6814a73a0f88da310f259a7f36f2ae9553 | |
parent | 938b09af83f1b6b506888d2374a5e7eae82f9e71 (diff) |
correct the updateGradInput logic of module.
We want to make sure that if a node has only a single child, then we pass whwtever the
output of updateGradInput is, else, we select an pass to tthe corresponding child.
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | gmodule.lua | 2 |
2 files changed, 2 insertions, 2 deletions
@@ -87,7 +87,7 @@ graph.dot(gmod.fg,'Big MLP') m:add(nn.SplitTable(1)) m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) input = nn.Identity()() - input1,input2 = m(input,2) + input1,input2 = m(input):split(2) m3 = nn.JoinTable(1)({input1,input2}) g = nn.gModule({input},{m3}) diff --git a/gmodule.lua b/gmodule.lua index b1b1e79..4b4f5f1 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -226,7 +226,7 @@ function gModule:updateGradInput(input,gradOutput) child.data.gradOutput = child.data.gradOutput or {} local mapindex = node.data.mapindex[child.data] local gi - if istable(gradInput) and istable(input) then + if #node.children ~= 1 then --istable(gradInput) and istable(input) then gi = gradInput[mapindex] else gi = gradInput |