diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-03-13 21:56:48 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-03-13 21:56:48 +0300 |
commit | aec8e83dc8e7183008b6f989adeb27c8ef31e67d (patch) | |
tree | b3970fb968e8d142b163d6c0a85c1cb9b83424ac /test.lua | |
parent | cef0cb0cba92f88d8bf076e6fb4088503a2b7845 (diff) |
added doc + test case for CrossEntropyCriterion
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 37 |
1 files changed, 32 insertions, 5 deletions
@@ -723,19 +723,21 @@ local function criterionJacobianTest1D(cri, input, target) local dfdx = cri:backward(input, target) -- for each input perturbation, do central difference local centraldiff_dfdx = torch.Tensor():resizeAs(dfdx) - for i=1,input:size(1) do + local input_s = input:storage() + local centraldiff_dfdx_s = centraldiff_dfdx:storage() + for i=1,input:nElement() do -- f(xi + h) - input[i] = input[i] + eps + input_s[i] = input_s[i] + eps local fx1 = cri:forward(input, target) -- f(xi - h) - input[i] = input[i] - 2*eps + input_s[i] = input_s[i] - 2*eps local fx2 = cri:forward(input, target) -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h local cdfx = (fx1 - fx2) / (2*eps) -- store f' in appropriate place - centraldiff_dfdx[i] = cdfx + centraldiff_dfdx_s[i] = cdfx -- reset input[i] - input[i] = input[i] + eps + input_s[i] = input_s[i] + eps end -- compare centraldiff_dfdx with :backward() @@ -804,6 +806,31 @@ function nntest.ClassNLLCriterion() criterionJacobianTest1D(cri, input, target) end +function nntest.CrossEntropyCriterion() + -- stochastic + local numLabels = math.random(5, 10) + local input = torch.zeros(numLabels) + local target = torch.random(1, numLabels) + + local cri = nn.CrossEntropyCriterion() + criterionJacobianTest1D(cri, input, target) + + -- batch + local numLabels = math.random(5,10) + local bsz = math.random(3, 7) + local input = torch.zeros(bsz, numLabels) + local target = torch.Tensor(bsz):random(1, numLabels) + + local cri = nn.CrossEntropyCriterion() + criterionJacobianTest1D(cri, input, target) + + -- with weights + local weights = torch.rand(numLabels) + weights = weights / weights:sum() + cri = nn.CrossEntropyCriterion(weights) + criterionJacobianTest1D(cri, input, target) +end + function nntest.LogSigmoid() local ini = math.random(3,5) local inj = math.random(3,5) |