diff options
Diffstat (limited to 'SuperCriterion.lua')
-rw-r--r-- | SuperCriterion.lua | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/SuperCriterion.lua b/SuperCriterion.lua index c73d716..983639a 100644 --- a/SuperCriterion.lua +++ b/SuperCriterion.lua @@ -13,32 +13,32 @@ function SuperCriterion:add(criterion, weight) table.insert(self.weights, weight) end -function SuperCriterion:forward(input, target) +function SuperCriterion:updateOutput(input, target) self.output = 0 if type(target) == 'table' then for i,criterion in ipairs(self.criterions) do - self.output = self.output + self.weights[i]*criterion:forward(input[i],target[i]) + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target[i]) end else for i,criterion in ipairs(self.criterions) do - self.output = self.output + self.weights[i]*criterion:forward(input[i],target) + self.output = self.output + self.weights[i]*criterion:updateOutput(input[i],target) end end return self.output end -function SuperCriterion:backward(input, target) +function SuperCriterion:updateGradInput(input, target) if type(target) == 'table' then for i,criterion in ipairs(self.criterions) do self.gradInput[i] = torch.Tensor() or self.gradInput[i] self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target[i])) + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target[i])) end else for i,criterion in ipairs(self.criterions) do self.gradInput[i] = torch.Tensor() or self.gradInput[i] self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:backward(input[i],target)) + self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) end end return self.gradInput |