Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-09-15 20:32:36 +0300
committersoumith <soumith@fb.com>2015-09-15 20:32:36 +0300
commit4d5c3db15efc87fe4220fc06486a8d7be759dcc2 (patch)
tree87548ebc6a6c6113d952569d1ab72ccf6052ebb0 /test
parent97f41c48602a345344bb5f76e73e4b2fbf7eb679 (diff)
whitespace cleanups, fixing logsoftmax test
Diffstat (limited to 'test')
-rw-r--r--test/test.lua14
1 files changed, 5 insertions, 9 deletions
diff --git a/test/test.lua b/test/test.lua
index 4062425..8c22ece 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -692,15 +692,11 @@ end
function cudnntest.LogSoftMax_batch()
local bs = math.random(1,32)
local from = math.random(1,32)
- local outi = math.random(1,64)
- local outj = math.random(1,64)
- local ini = outi
- local inj = outj
- local input = torch.randn(bs,from,inj,ini):cuda()
- local gradOutput = torch.randn(bs,from,outj,outi):cuda()
+ local input = torch.randn(bs,from):cuda()
+ local gradOutput = torch.randn(bs,from):cuda()
local sconv = nn.LogSoftMax():cuda()
- local groundtruth = sconv:forward(input:view(bs,-1))
+ local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
local gconv = cudnn.LogSoftMax():cuda()
@@ -713,8 +709,8 @@ function cudnntest.LogSoftMax_batch()
local rescuda = gconv:forward(input)
local resgrad = gconv:backward(input, gradOutput)
cutorch.synchronize()
- mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
- mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
+ mytester:asserteq(rescuda:dim(), 2, 'error in dimension')
+ mytester:asserteq(resgrad:dim(), 2, 'error in dimension')
local error = rescuda:float() - groundtruth:float()
mytester:assertlt(error:abs():max(),