diff options
author | soumith <soumith@fb.com> | 2015-09-15 20:32:36 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-09-15 20:32:36 +0300 |
commit | 4d5c3db15efc87fe4220fc06486a8d7be759dcc2 (patch) | |
tree | 87548ebc6a6c6113d952569d1ab72ccf6052ebb0 /test | |
parent | 97f41c48602a345344bb5f76e73e4b2fbf7eb679 (diff) |
whitespace cleanups, fixing logsoftmax test
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/test/test.lua b/test/test.lua index 4062425..8c22ece 100644 --- a/test/test.lua +++ b/test/test.lua @@ -692,15 +692,11 @@ end function cudnntest.LogSoftMax_batch() local bs = math.random(1,32) local from = math.random(1,32) - local outi = math.random(1,64) - local outj = math.random(1,64) - local ini = outi - local inj = outj - local input = torch.randn(bs,from,inj,ini):cuda() - local gradOutput = torch.randn(bs,from,outj,outi):cuda() + local input = torch.randn(bs,from):cuda() + local gradOutput = torch.randn(bs,from):cuda() local sconv = nn.LogSoftMax():cuda() - local groundtruth = sconv:forward(input:view(bs,-1)) + local groundtruth = sconv:forward(input) local groundgrad = sconv:backward(input, gradOutput) cutorch.synchronize() local gconv = cudnn.LogSoftMax():cuda() @@ -713,8 +709,8 @@ function cudnntest.LogSoftMax_batch() local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() - mytester:asserteq(rescuda:dim(), 4, 'error in dimension') - mytester:asserteq(resgrad:dim(), 4, 'error in dimension') + mytester:asserteq(rescuda:dim(), 2, 'error in dimension') + mytester:asserteq(resgrad:dim(), 2, 'error in dimension') local error = rescuda:float() - groundtruth:float() mytester:assertlt(error:abs():max(), |