diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-08-07 01:26:00 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-08-07 01:26:00 +0300 |
commit | f3180c13b72eea810804d5d83c154ddba187a5c9 (patch) | |
tree | e84f2f50476dac7fbfa71a3833e8f96314a2e067 /test | |
parent | 1c65ebf62d9096c98bc7a108685b48e87ec4b341 (diff) |
added LogSoftMax test
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 67 |
1 files changed, 66 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua index 71d97d9..5c8b31d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -505,7 +505,7 @@ local function nonlinSingle(nonlin) 'error on state (backward) ') end -function nonlinBatch(nonlin) +local function nonlinBatch(nonlin) local bs = math.random(1,32) local from = math.random(1,32) local outi = math.random(1,64) @@ -659,6 +659,71 @@ function cudnntest.SoftMax_batch() precision_backward, 'error on state (backward) ') end + +function cudnntest.LogSoftMax_single() + local sz = math.random(1,64) + local input = torch.randn(sz):cuda() + local gradOutput = torch.randn(sz):cuda() + + local sconv = nn.LogSoftMax():cuda() + local groundtruth = sconv:forward(input) + local groundgrad = sconv:backward(input, gradOutput) + cutorch.synchronize() + local gconv = cudnn.LogSoftMax():cuda() + local _ = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) + local resgrad = gconv:backward(input, gradOutput) + cutorch.synchronize() + local error = rescuda:float() - groundtruth:float() + local errmax = error:abs():max() + mytester:assertlt(errmax, precision_forward, + 'error on state (forward) ') + error = resgrad:float() - groundgrad:float() + errmax = error:abs():max() + mytester:assertlt(errmax, precision_backward, + 'error on state (backward) ') +end + +function cudnntest.LogSoftMax_batch() + local bs = math.random(1,32) + local from = math.random(1,32) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = outi + local inj = outj + local input = torch.randn(bs,from,inj,ini):cuda() + local gradOutput = torch.randn(bs,from,outj,outi):cuda() + + local sconv = nn.LogSoftMax():cuda() + local groundtruth = sconv:forward(input:view(bs,-1)) + local groundgrad = sconv:backward(input, gradOutput) + cutorch.synchronize() + local gconv = cudnn.LogSoftMax():cuda() + local rescuda = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) + local resgrad = gconv:backward(input, gradOutput) + cutorch.synchronize() + mytester:asserteq(rescuda:dim(), 4, 'error in dimension') + mytester:asserteq(resgrad:dim(), 4, 'error in dimension') + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error on state (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error on state (backward) ') +end + function cudnntest.functional_SpatialBias() local bs = math.random(1,32) local from = math.random(1,32) |