Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-03-13 21:56:48 +0300
committerRonan Collobert <ronan@collobert.com>2015-03-13 21:56:48 +0300
commitaec8e83dc8e7183008b6f989adeb27c8ef31e67d (patch)
treeb3970fb968e8d142b163d6c0a85c1cb9b83424ac /test.lua
parentcef0cb0cba92f88d8bf076e6fb4088503a2b7845 (diff)
added doc + test case for CrossEntropyCriterion
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua37
1 files changed, 32 insertions, 5 deletions
diff --git a/test.lua b/test.lua
index 6174db2..3661d03 100644
--- a/test.lua
+++ b/test.lua
@@ -723,19 +723,21 @@ local function criterionJacobianTest1D(cri, input, target)
local dfdx = cri:backward(input, target)
-- for each input perturbation, do central difference
local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx)
- for i=1,input:size(1) do
+ local input_s = input:storage()
+ local centraldiff_dfdx_s = centraldiff_dfdx:storage()
+ for i=1,input:nElement() do
-- f(xi + h)
- input[i] = input[i] + eps
+ input_s[i] = input_s[i] + eps
local fx1 = cri:forward(input, target)
-- f(xi - h)
- input[i] = input[i] - 2*eps
+ input_s[i] = input_s[i] - 2*eps
local fx2 = cri:forward(input, target)
-- f'(xi) = (f(xi + h) - f(xi - h)) / 2h
local cdfx = (fx1 - fx2) / (2*eps)
-- store f' in appropriate place
- centraldiff_dfdx[i] = cdfx
+ centraldiff_dfdx_s[i] = cdfx
-- reset input[i]
- input[i] = input[i] + eps
+ input_s[i] = input_s[i] + eps
end
-- compare centraldiff_dfdx with :backward()
@@ -804,6 +806,31 @@ function nntest.ClassNLLCriterion()
criterionJacobianTest1D(cri, input, target)
end
+function nntest.CrossEntropyCriterion()
+ -- stochastic
+ local numLabels = math.random(5, 10)
+ local input = torch.zeros(numLabels)
+ local target = torch.random(1, numLabels)
+
+ local cri = nn.CrossEntropyCriterion()
+ criterionJacobianTest1D(cri, input, target)
+
+ -- batch
+ local numLabels = math.random(5,10)
+ local bsz = math.random(3, 7)
+ local input = torch.zeros(bsz, numLabels)
+ local target = torch.Tensor(bsz):random(1, numLabels)
+
+ local cri = nn.CrossEntropyCriterion()
+ criterionJacobianTest1D(cri, input, target)
+
+ -- with weights
+ local weights = torch.rand(numLabels)
+ weights = weights / weights:sum()
+ cri = nn.CrossEntropyCriterion(weights)
+ criterionJacobianTest1D(cri, input, target)
+end
+
function nntest.LogSigmoid()
local ini = math.random(3,5)
local inj = math.random(3,5)