diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-14 08:31:33 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-14 08:31:33 +0400 |
commit | 937f52bdd070f59db90720c5993b28abd44f6aac (patch) | |
tree | 310f10457ad5237f8f79e2634fb2b0908729071b /SuperCriterion.lua | |
parent | d6da0e6b62f24d0581255751d3ee22e6f3765035 (diff) |
SuperCriterion accepts multiple targets.
Diffstat (limited to 'SuperCriterion.lua')
-rw-r--r-- | SuperCriterion.lua | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/SuperCriterion.lua b/SuperCriterion.lua index ce429e2..c73d716 100644 --- a/SuperCriterion.lua +++ b/SuperCriterion.lua @@ -15,17 +15,31 @@ end function SuperCriterion:forward(input, target) self.output = 0 - for i,criterion in ipairs(self.criterions) do - self.output = self.output + self.weights[i]*criterion:forward(input[i],target) + if type(target) == 'table' then + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:forward(input[i],target[i]) + end + else + for i,criterion in ipairs(self.criterions) do + self.output = self.output + self.weights[i]*criterion:forward(input[i],target) + end end return self.output end function SuperCriterion:backward(input, target) - for i,criterion in ipairs(self.criterions) do - self.gradInput[i] = torch.Tensor() or self.gradInput[i] - self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target) ) + if type(target) == 'table' then + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = torch.Tensor() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target[i])) + end + else + for i,criterion in ipairs(self.criterions) do + self.gradInput[i] = torch.Tensor() or self.gradInput[i] + self.gradInput[i]:resizeAs(input[i]):zero() + self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target)) + end end return self.gradInput end |