From 7af29c79899774af91e806531cbe420631d63246 Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Tue, 1 Jul 2014 16:07:47 -0400 Subject: rewrote BCECriterion for CUDA compat --- BCECriterion.lua | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) (limited to 'BCECriterion.lua') diff --git a/BCECriterion.lua b/BCECriterion.lua index 71fcc3a..be9082e 100644 --- a/BCECriterion.lua +++ b/BCECriterion.lua @@ -7,21 +7,46 @@ end function BCECriterion:updateOutput(input, target) -- log(input) * target + log(1 - input) * (1 - target) - self.output = torch.log(input):cmul(target) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + self.term3 = self.term3 or input.new() + + self.term1:resizeAs(input) + self.term2:resizeAs(input) + self.term3:resizeAs(input) + + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input):log():cmul(self.term1) - self.output:add(torch.add(-input,1):log():cmul(torch.add(-target,1))) + self.term3:copy(input):log():cmul(target) + self.term3:add(self.term2) if self.sizeAverage then - self.output:div(target:size(1)) + self.term3:div(target:size(1)) end - return self.output:sum() + return self.term3:sum() end function BCECriterion:updateGradInput(input, target) -- target / input - (1 - target) / (1 - input) - self.gradInput = torch.cdiv(target,input) - self.gradInput:add(-1,torch.cdiv(torch.add(-target,1),torch.add(-input,1))) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + + self.term1:resizeAs(input) + self.term2:resizeAs(input) + + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input) + + self.term1:cdiv(self.term2) + + self.gradInput:resizeAs(input) + self.gradInput:copy(target):cdiv(input) + + self.gradInput:add(-1,self.term1) if self.sizeAverage then self.gradInput:div(target:size(1)) -- cgit v1.2.3