Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-01-20 22:30:56 +0300
committerGregory Chanan <gchanan@fb.com>2017-01-20 22:30:56 +0300
commit9209e1701923c5a31609092b638b8443945c5d13 (patch)
tree24872456b5679e3104961cc5d779cbbd07a72724 /test.lua
parent5989f82800a640ed0f5613c8ef3e417c4502661d (diff)
Preserve old behavior of setting nll.sizeAverage in CrossEntropyCriterion.
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua14
1 files changed, 14 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index b3e1d16..b19d6b3 100644
--- a/test.lua
+++ b/test.lua
@@ -2116,6 +2116,20 @@ function nntest.CrossEntropyCriterion()
weights = weights / weights:sum()
cri = nn.CrossEntropyCriterion(weights)
criterionJacobianTest(cri, input, target)
+
+ -- verify nll.sizeAverage preservation
+ cri = nn.CrossEntropyCriterion(weights)
+ cri.nll.sizeAverage = false
+ criterionJacobianTest(cri, input, target)
+ mytester:eq(cri.nll.sizeAverage, false,
+ "ClassNLLCriterion.sizeAverage overwritten")
+
+ -- verify nll.sizeAverage propagation
+ cri = nn.CrossEntropyCriterion(weights)
+ cri.sizeAverage = false
+ criterionJacobianTest(cri, input, target)
+ mytester:eq(cri.nll.sizeAverage, false,
+ "ClassNLLCriterion.sizeAverage not propagated")
end
function nntest.LogSigmoid()