Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ModuleFromCriterion.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua
index faca485..8717ca5 100644
--- a/ModuleFromCriterion.lua
+++ b/ModuleFromCriterion.lua
@@ -24,7 +24,19 @@ end
function ModuleFromCriterion:updateGradInput(input, gradOutput)
local prediction, target = unpack(input)
local gradPrediction = self.criterion:updateGradInput(prediction, target)
- self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
+ if type(gradPrediction) == 'table' then
+ if type(self.gradInput[1]) ~= 'table' then
+ self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10)
+ for i=1, #gradPrediction do
+ self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside.
+ end
+ end
+ for i=1, #gradPrediction do
+ self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1])
+ end
+ else
+ self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
+ end
self.gradInput[2]:resizeAs(target):zero()
return self.gradInput
end