Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbhi Agg <abhi@abhitopia.com>2016-04-28 18:39:05 +0300
committerSoumith Chintala <soumith@gmail.com>2016-04-28 18:39:05 +0300
commitc131490dc13ad52db212a8ffd64407e9eaee8e33 (patch)
treea20c8ef9420b4b2539cdf04d1dd36d55aa02d06b
parent59deaaac17d20db73e693c532715456024d0dd1b (diff)
Fix to make compatible with MarginRankingCriterion (#108)
* Fix to make compatible with MarginRankingCriterion
-rw-r--r--ModuleFromCriterion.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua
index faca485..8717ca5 100644
--- a/ModuleFromCriterion.lua
+++ b/ModuleFromCriterion.lua
@@ -24,7 +24,19 @@ end
function ModuleFromCriterion:updateGradInput(input, gradOutput)
local prediction, target = unpack(input)
local gradPrediction = self.criterion:updateGradInput(prediction, target)
- self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
+ if type(gradPrediction) == 'table' then
+ if type(self.gradInput[1]) ~= 'table' then
+ self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10)
+ for i=1, #gradPrediction do
+ self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside.
+ end
+ end
+ for i=1, #gradPrediction do
+ self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1])
+ end
+ else
+ self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
+ end
self.gradInput[2]:resizeAs(target):zero()
return self.gradInput
end