Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
commit73ddbe49b8a58eca739e3f442bcb101b6aff0978 (patch)
treecb91414d9c918f3b5d6e865f882d0529f4f8e82d /ParallelCriterion.lua
parent3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff)
[Parallel,Multi]Criterion table inputs
Diffstat (limited to 'ParallelCriterion.lua')
-rw-r--r--ParallelCriterion.lua6
1 files changed, 3 insertions, 3 deletions
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua
index 95bd6cc..84d4ee1 100644
--- a/ParallelCriterion.lua
+++ b/ParallelCriterion.lua
@@ -25,11 +25,11 @@ function ParallelCriterion:updateOutput(input, target)
end
function ParallelCriterion:updateGradInput(input, target)
+ self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
+ nn.utils.recursiveFill(self.gradInput, 0)
for i,criterion in ipairs(self.criterions) do
local target = self.repeatTarget and target or target[i]
- self.gradInput[i] = self.gradInput[i] or input[i].new()
- self.gradInput[i]:resizeAs(input[i]):zero()
- self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target))
+ nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target))
end
return self.gradInput
end