diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-05-12 19:27:54 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-05-12 19:27:54 +0400 |
commit | 900e8522fd2cd759803babde7bf8bae43daacaf3 (patch) | |
tree | f05f952cf8b8250ed10864fb0dce424564ac0202 | |
parent | 3b537ffa5a60c772439787afba418255ca83791a (diff) |
fixing BCECriterion, adding test, closing https://github.com/torch/nn/pull/10
-rw-r--r-- | BCECriterion.lua | 4 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 7 |
3 files changed, 10 insertions, 2 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua index 2d339e9..71fcc3a 100644 --- a/BCECriterion.lua +++ b/BCECriterion.lua @@ -11,7 +11,7 @@ function BCECriterion:updateOutput(input, target) self.output:add(torch.add(-input,1):log():cmul(torch.add(-target,1))) - if self.sizeAverage + if self.sizeAverage then self.output:div(target:size(1)) end @@ -23,7 +23,7 @@ function BCECriterion:updateGradInput(input, target) self.gradInput = torch.cdiv(target,input) self.gradInput:add(-1,torch.cdiv(torch.add(-target,1),torch.add(-input,1))) - if self.sizeAverage + if self.sizeAverage then self.gradInput:div(target:size(1)) end @@ -99,6 +99,7 @@ include('MultiMarginCriterion.lua') include('MultiLabelMarginCriterion.lua') include('L1Cost.lua') include('WeightedMSECriterion.lua') +include('BCECriterion.lua') include('StochasticGradient.lua') diff --git a/test/test.lua b/test/test.lua index 7eb2d44..e3b6a50 100644 --- a/test/test.lua +++ b/test/test.lua @@ -443,6 +443,13 @@ function nntest.WeightedMSECriterion() criterionJacobianTest(cri, input, target) end +function nntest.BCECriterion() + local input = torch.rand(100) + local target = input:clone():add(torch.rand(100)) + local cri = nn.BCECriterion() + criterionJacobianTest(cri, input, target) +end + function nntest.LogSigmoid() local ini = math.random(10,20) local inj = math.random(10,20) |