Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2014-07-02 00:07:47 +0400
committerClement Farabet <clement.farabet@gmail.com>2014-07-02 00:07:47 +0400
commit7af29c79899774af91e806531cbe420631d63246 (patch)
treefed9c3302f498c3b105f48246a89eb08f66f4593 /BCECriterion.lua
parente7a59498e2a7a2f155d4905ee5fbd70f263b6fea (diff)
rewrote BCECriterion for CUDA compat
Diffstat (limited to 'BCECriterion.lua')
-rw-r--r--BCECriterion.lua37
1 files changed, 31 insertions, 6 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua
index 71fcc3a..be9082e 100644
--- a/BCECriterion.lua
+++ b/BCECriterion.lua
@@ -7,21 +7,46 @@ end
function BCECriterion:updateOutput(input, target)
-- log(input) * target + log(1 - input) * (1 - target)
- self.output = torch.log(input):cmul(target)
+
+ self.term1 = self.term1 or input.new()
+ self.term2 = self.term2 or input.new()
+ self.term3 = self.term3 or input.new()
+
+ self.term1:resizeAs(input)
+ self.term2:resizeAs(input)
+ self.term3:resizeAs(input)
+
+ self.term1:fill(1):add(-1,target)
+ self.term2:fill(1):add(-1,input):log():cmul(self.term1)
- self.output:add(torch.add(-input,1):log():cmul(torch.add(-target,1)))
+ self.term3:copy(input):log():cmul(target)
+ self.term3:add(self.term2)
if self.sizeAverage then
- self.output:div(target:size(1))
+ self.term3:div(target:size(1))
end
- return self.output:sum()
+ return self.term3:sum()
end
function BCECriterion:updateGradInput(input, target)
-- target / input - (1 - target) / (1 - input)
- self.gradInput = torch.cdiv(target,input)
- self.gradInput:add(-1,torch.cdiv(torch.add(-target,1),torch.add(-input,1)))
+
+ self.term1 = self.term1 or input.new()
+ self.term2 = self.term2 or input.new()
+
+ self.term1:resizeAs(input)
+ self.term2:resizeAs(input)
+
+ self.term1:fill(1):add(-1,target)
+ self.term2:fill(1):add(-1,input)
+
+ self.term1:cdiv(self.term2)
+
+ self.gradInput:resizeAs(input)
+ self.gradInput:copy(target):cdiv(input)
+
+ self.gradInput:add(-1,self.term1)
if self.sizeAverage then
self.gradInput:div(target:size(1))