From aec8e83dc8e7183008b6f989adeb27c8ef31e67d Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 13 Mar 2015 11:56:48 -0700 Subject: added doc + test case for CrossEntropyCriterion --- test.lua | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 6174db2..3661d03 100644 --- a/test.lua +++ b/test.lua @@ -723,19 +723,21 @@ local function criterionJacobianTest1D(cri, input, target) local dfdx = cri:backward(input, target) -- for each input perturbation, do central difference local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx) - for i=1,input:size(1) do + local input_s = input:storage() + local centraldiff_dfdx_s = centraldiff_dfdx:storage() + for i=1,input:nElement() do -- f(xi + h) - input[i] = input[i] + eps + input_s[i] = input_s[i] + eps local fx1 = cri:forward(input, target) -- f(xi - h) - input[i] = input[i] - 2*eps + input_s[i] = input_s[i] - 2*eps local fx2 = cri:forward(input, target) -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h local cdfx = (fx1 - fx2) / (2*eps) -- store f' in appropriate place - centraldiff_dfdx[i] = cdfx + centraldiff_dfdx_s[i] = cdfx -- reset input[i] - input[i] = input[i] + eps + input_s[i] = input_s[i] + eps end -- compare centraldiff_dfdx with :backward() @@ -804,6 +806,31 @@ function nntest.ClassNLLCriterion() criterionJacobianTest1D(cri, input, target) end +function nntest.CrossEntropyCriterion() + -- stochastic + local numLabels = math.random(5, 10) + local input = torch.zeros(numLabels) + local target = torch.random(1, numLabels) + + local cri = nn.CrossEntropyCriterion() + criterionJacobianTest1D(cri, input, target) + + -- batch + local numLabels = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels) + local target = torch.Tensor(bsz):random(1, numLabels) + + local cri = nn.CrossEntropyCriterion() + criterionJacobianTest1D(cri, input, target) + + -- with weights + local weights = torch.rand(numLabels) + weights = weights / weights:sum() + cri = nn.CrossEntropyCriterion(weights) + criterionJacobianTest1D(cri, input, target) +end + function nntest.LogSigmoid() local ini = math.random(3,5) local inj = math.random(3,5) -- cgit v1.2.3