diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2014-07-02 06:22:49 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2014-07-02 06:22:49 +0400 |
commit | 6a1102f026998871500157ddef3b003434aab91d (patch) | |
tree | 6ab1316751bdb33e37f443f0021d307072e05cd4 /BCECriterion.lua | |
parent | 7af29c79899774af91e806531cbe420631d63246 (diff) |
BCECriterion was incorrect.
Fixed:
* forward/backward: inverted sign to produce neg ll (not pos)
* averaging over all elements in target, to properly support batches
* added epsilon to protect against 0s (divs and logs)
Diffstat (limited to 'BCECriterion.lua')
-rw-r--r-- | BCECriterion.lua | 89 |
1 files changed, 50 insertions, 39 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua index be9082e..8e366c8 100644 --- a/BCECriterion.lua +++ b/BCECriterion.lua @@ -1,56 +1,67 @@ local BCECriterion, parent = torch.class('nn.BCECriterion', 'nn.Criterion') +local eps = 1e-12 + function BCECriterion:__init() parent.__init(self) self.sizeAverage = true end function BCECriterion:updateOutput(input, target) - -- log(input) * target + log(1 - input) * (1 - target) + -- log(input) * target + log(1 - input) * (1 - target) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + self.term3 = self.term3 or input.new() - self.term1 = self.term1 or input.new() - self.term2 = self.term2 or input.new() - self.term3 = self.term3 or input.new() + self.term1:resizeAs(input) + self.term2:resizeAs(input) + self.term3:resizeAs(input) - self.term1:resizeAs(input) - self.term2:resizeAs(input) - self.term3:resizeAs(input) + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input):add(eps):log():cmul(self.term1) - self.term1:fill(1):add(-1,target) - self.term2:fill(1):add(-1,input):log():cmul(self.term1) - - self.term3:copy(input):log():cmul(target) - self.term3:add(self.term2) + self.term3:copy(input):add(eps):log():cmul(target) + self.term3:add(self.term2) - if self.sizeAverage then - self.term3:div(target:size(1)) - end + if self.sizeAverage then + self.term3:div(target:nElement()) + end - return self.term3:sum() + self.output = - self.term3:sum() + + return self.output end function BCECriterion:updateGradInput(input, target) - -- target / input - (1 - target) / (1 - input) - - self.term1 = self.term1 or input.new() - self.term2 = self.term2 or input.new() - - self.term1:resizeAs(input) - self.term2:resizeAs(input) - - self.term1:fill(1):add(-1,target) - self.term2:fill(1):add(-1,input) - - self.term1:cdiv(self.term2) - - self.gradInput:resizeAs(input) - self.gradInput:copy(target):cdiv(input) - - self.gradInput:add(-1,self.term1) - - if self.sizeAverage then - self.gradInput:div(target:size(1)) - end - - return self.gradInput + -- target / input - (1 - target) / (1 - input) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + self.term3 = self.term3 or input.new() + + self.term1:resizeAs(input) + self.term2:resizeAs(input) + self.term3:resizeAs(input) + + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input) + + self.term2:add(eps) + self.term1:cdiv(self.term2) + + self.term3:copy(input):add(eps) + + self.gradInput:resizeAs(input) + self.gradInput:copy(target):cdiv(self.term3) + + self.gradInput:add(-1,self.term1) + + if self.sizeAverage then + self.gradInput:div(target:nElement()) + end + + self.gradInput:mul(-1) + + return self.gradInput end |