Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2014-07-02 06:22:49 +0400
committerClement Farabet <clement.farabet@gmail.com>2014-07-02 06:22:49 +0400
commit6a1102f026998871500157ddef3b003434aab91d (patch)
tree6ab1316751bdb33e37f443f0021d307072e05cd4 /BCECriterion.lua
parent7af29c79899774af91e806531cbe420631d63246 (diff)
BCECriterion was incorrect.
Fixed: * forward/backward: inverted sign to produce neg ll (not pos) * averaging over all elements in target, to properly support batches * added epsilon to protect against 0s (divs and logs)
Diffstat (limited to 'BCECriterion.lua')
-rw-r--r--BCECriterion.lua89
1 files changed, 50 insertions, 39 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua
index be9082e..8e366c8 100644
--- a/BCECriterion.lua
+++ b/BCECriterion.lua
@@ -1,56 +1,67 @@
local BCECriterion, parent = torch.class('nn.BCECriterion', 'nn.Criterion')
+local eps = 1e-12
+
function BCECriterion:__init()
parent.__init(self)
self.sizeAverage = true
end
function BCECriterion:updateOutput(input, target)
- -- log(input) * target + log(1 - input) * (1 - target)
+ -- log(input) * target + log(1 - input) * (1 - target)
+
+ self.term1 = self.term1 or input.new()
+ self.term2 = self.term2 or input.new()
+ self.term3 = self.term3 or input.new()
- self.term1 = self.term1 or input.new()
- self.term2 = self.term2 or input.new()
- self.term3 = self.term3 or input.new()
+ self.term1:resizeAs(input)
+ self.term2:resizeAs(input)
+ self.term3:resizeAs(input)
- self.term1:resizeAs(input)
- self.term2:resizeAs(input)
- self.term3:resizeAs(input)
+ self.term1:fill(1):add(-1,target)
+ self.term2:fill(1):add(-1,input):add(eps):log():cmul(self.term1)
- self.term1:fill(1):add(-1,target)
- self.term2:fill(1):add(-1,input):log():cmul(self.term1)
-
- self.term3:copy(input):log():cmul(target)
- self.term3:add(self.term2)
+ self.term3:copy(input):add(eps):log():cmul(target)
+ self.term3:add(self.term2)
- if self.sizeAverage then
- self.term3:div(target:size(1))
- end
+ if self.sizeAverage then
+ self.term3:div(target:nElement())
+ end
- return self.term3:sum()
+ self.output = - self.term3:sum()
+
+ return self.output
end
function BCECriterion:updateGradInput(input, target)
- -- target / input - (1 - target) / (1 - input)
-
- self.term1 = self.term1 or input.new()
- self.term2 = self.term2 or input.new()
-
- self.term1:resizeAs(input)
- self.term2:resizeAs(input)
-
- self.term1:fill(1):add(-1,target)
- self.term2:fill(1):add(-1,input)
-
- self.term1:cdiv(self.term2)
-
- self.gradInput:resizeAs(input)
- self.gradInput:copy(target):cdiv(input)
-
- self.gradInput:add(-1,self.term1)
-
- if self.sizeAverage then
- self.gradInput:div(target:size(1))
- end
-
- return self.gradInput
+ -- target / input - (1 - target) / (1 - input)
+
+ self.term1 = self.term1 or input.new()
+ self.term2 = self.term2 or input.new()
+ self.term3 = self.term3 or input.new()
+
+ self.term1:resizeAs(input)
+ self.term2:resizeAs(input)
+ self.term3:resizeAs(input)
+
+ self.term1:fill(1):add(-1,target)
+ self.term2:fill(1):add(-1,input)
+
+ self.term2:add(eps)
+ self.term1:cdiv(self.term2)
+
+ self.term3:copy(input):add(eps)
+
+ self.gradInput:resizeAs(input)
+ self.gradInput:copy(target):cdiv(self.term3)
+
+ self.gradInput:add(-1,self.term1)
+
+ if self.sizeAverage then
+ self.gradInput:div(target:nElement())
+ end
+
+ self.gradInput:mul(-1)
+
+ return self.gradInput
end