diff options
-rw-r--r-- | BCECriterion.lua | 37 |
1 files changed, 31 insertions, 6 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua index 71fcc3a..be9082e 100644 --- a/BCECriterion.lua +++ b/BCECriterion.lua @@ -7,21 +7,46 @@ end function BCECriterion:updateOutput(input, target) -- log(input) * target + log(1 - input) * (1 - target) - self.output = torch.log(input):cmul(target) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + self.term3 = self.term3 or input.new() + + self.term1:resizeAs(input) + self.term2:resizeAs(input) + self.term3:resizeAs(input) + + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input):log():cmul(self.term1) - self.output:add(torch.add(-input,1):log():cmul(torch.add(-target,1))) + self.term3:copy(input):log():cmul(target) + self.term3:add(self.term2) if self.sizeAverage then - self.output:div(target:size(1)) + self.term3:div(target:size(1)) end - return self.output:sum() + return self.term3:sum() end function BCECriterion:updateGradInput(input, target) -- target / input - (1 - target) / (1 - input) - self.gradInput = torch.cdiv(target,input) - self.gradInput:add(-1,torch.cdiv(torch.add(-target,1),torch.add(-input,1))) + + self.term1 = self.term1 or input.new() + self.term2 = self.term2 or input.new() + + self.term1:resizeAs(input) + self.term2:resizeAs(input) + + self.term1:fill(1):add(-1,target) + self.term2:fill(1):add(-1,input) + + self.term1:cdiv(self.term2) + + self.gradInput:resizeAs(input) + self.gradInput:copy(target):cdiv(input) + + self.gradInput:add(-1,self.term1) if self.sizeAverage then self.gradInput:div(target:size(1)) |