Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-05-12 19:27:54 +0400
committerSoumith Chintala <soumith@gmail.com>2014-05-12 19:27:54 +0400
commit900e8522fd2cd759803babde7bf8bae43daacaf3 (patch)
treef05f952cf8b8250ed10864fb0dce424564ac0202
parent3b537ffa5a60c772439787afba418255ca83791a (diff)
fixing BCECriterion, adding test, closing https://github.com/torch/nn/pull/10
-rw-r--r--BCECriterion.lua4
-rw-r--r--init.lua1
-rw-r--r--test/test.lua7
3 files changed, 10 insertions, 2 deletions
diff --git a/BCECriterion.lua b/BCECriterion.lua
index 2d339e9..71fcc3a 100644
--- a/BCECriterion.lua
+++ b/BCECriterion.lua
@@ -11,7 +11,7 @@ function BCECriterion:updateOutput(input, target)
self.output:add(torch.add(-input,1):log():cmul(torch.add(-target,1)))
- if self.sizeAverage
+ if self.sizeAverage then
self.output:div(target:size(1))
end
@@ -23,7 +23,7 @@ function BCECriterion:updateGradInput(input, target)
self.gradInput = torch.cdiv(target,input)
self.gradInput:add(-1,torch.cdiv(torch.add(-target,1),torch.add(-input,1)))
- if self.sizeAverage
+ if self.sizeAverage then
self.gradInput:div(target:size(1))
end
diff --git a/init.lua b/init.lua
index 7071298..1fba70a 100644
--- a/init.lua
+++ b/init.lua
@@ -99,6 +99,7 @@ include('MultiMarginCriterion.lua')
include('MultiLabelMarginCriterion.lua')
include('L1Cost.lua')
include('WeightedMSECriterion.lua')
+include('BCECriterion.lua')
include('StochasticGradient.lua')
diff --git a/test/test.lua b/test/test.lua
index 7eb2d44..e3b6a50 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -443,6 +443,13 @@ function nntest.WeightedMSECriterion()
criterionJacobianTest(cri, input, target)
end
+function nntest.BCECriterion()
+ local input = torch.rand(100)
+ local target = input:clone():add(torch.rand(100))
+ local cri = nn.BCECriterion()
+ criterionJacobianTest(cri, input, target)
+end
+
function nntest.LogSigmoid()
local ini = math.random(10,20)
local inj = math.random(10,20)