diff options
author | Abhi Agg <abhi@abhitopia.com> | 2016-04-28 18:39:05 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-28 18:39:05 +0300 |
commit | c131490dc13ad52db212a8ffd64407e9eaee8e33 (patch) | |
tree | a20c8ef9420b4b2539cdf04d1dd36d55aa02d06b | |
parent | 59deaaac17d20db73e693c532715456024d0dd1b (diff) |
Fix to make compatible with MarginRankingCriterion (#108)
* Fix to make compatible with MarginRankingCriterion
-rw-r--r-- | ModuleFromCriterion.lua | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua index faca485..8717ca5 100644 --- a/ModuleFromCriterion.lua +++ b/ModuleFromCriterion.lua @@ -24,7 +24,19 @@ end function ModuleFromCriterion:updateGradInput(input, gradOutput) local prediction, target = unpack(input) local gradPrediction = self.criterion:updateGradInput(prediction, target) - self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1]) + if type(gradPrediction) == 'table' then + if type(self.gradInput[1]) ~= 'table' then + self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10) + for i=1, #gradPrediction do + self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside. + end + end + for i=1, #gradPrediction do + self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1]) + end + else + self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1]) + end self.gradInput[2]:resizeAs(target):zero() return self.gradInput end |