Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2015-08-07 01:26:00 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2015-08-07 01:26:00 +0300
commitf3180c13b72eea810804d5d83c154ddba187a5c9 (patch)
treee84f2f50476dac7fbfa71a3833e8f96314a2e067 /test
parent1c65ebf62d9096c98bc7a108685b48e87ec4b341 (diff)
added LogSoftMax test
Diffstat (limited to 'test')
-rw-r--r--test/test.lua67
1 files changed, 66 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua
index 71d97d9..5c8b31d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -505,7 +505,7 @@ local function nonlinSingle(nonlin)
'error on state (backward) ')
end
-function nonlinBatch(nonlin)
+local function nonlinBatch(nonlin)
local bs = math.random(1,32)
local from = math.random(1,32)
local outi = math.random(1,64)
@@ -659,6 +659,71 @@ function cudnntest.SoftMax_batch()
precision_backward, 'error on state (backward) ')
end
+
+function cudnntest.LogSoftMax_single()
+ local sz = math.random(1,64)
+ local input = torch.randn(sz):cuda()
+ local gradOutput = torch.randn(sz):cuda()
+
+ local sconv = nn.LogSoftMax():cuda()
+ local groundtruth = sconv:forward(input)
+ local groundgrad = sconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ local gconv = cudnn.LogSoftMax():cuda()
+ local _ = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
+ local resgrad = gconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ local error = rescuda:float() - groundtruth:float()
+ local errmax = error:abs():max()
+ mytester:assertlt(errmax, precision_forward,
+ 'error on state (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ errmax = error:abs():max()
+ mytester:assertlt(errmax, precision_backward,
+ 'error on state (backward) ')
+end
+
+function cudnntest.LogSoftMax_batch()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = outi
+ local inj = outj
+ local input = torch.randn(bs,from,inj,ini):cuda()
+ local gradOutput = torch.randn(bs,from,outj,outi):cuda()
+
+ local sconv = nn.LogSoftMax():cuda()
+ local groundtruth = sconv:forward(input:view(bs,-1))
+ local groundgrad = sconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ local gconv = cudnn.LogSoftMax():cuda()
+ local rescuda = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
+ local resgrad = gconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
+ mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error on state (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error on state (backward) ')
+end
+
function cudnntest.functional_SpatialBias()
local bs = math.random(1,32)
local from = math.random(1,32)