Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-28 23:31:53 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-28 23:55:38 +0400
commit91793f499ab03a2d89cfa0148711e4ce02ab31ba (patch)
tree311aeb793a448dc87b67f6e1dd1d95ec75b4305a /test
parent2f8a13181e33e05261077c8108e39edc18bf0f3e (diff)
adding SoftMax and SpatialSoftMax bindings
Diffstat (limited to 'test')
-rw-r--r--test/test.lua49
1 files changed, 43 insertions, 6 deletions
diff --git a/test/test.lua b/test/test.lua
index 78a4fdc..6965219 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -130,11 +130,11 @@ function cudnntest.ReLU()
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,from,outj,outi):cuda()
- local sconv = nn.ReLU(ki,kj,si,sj):cuda()
+ local sconv = nn.ReLU():cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
- local gconv = cudnn.ReLU(ki,kj,si,sj):cuda()
+ local gconv = cudnn.ReLU():cuda()
local rescuda = gconv:forward(input)
-- serialize and deserialize
@@ -164,11 +164,11 @@ function cudnntest.Tanh()
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,from,outj,outi):cuda()
- local sconv = nn.Tanh(ki,kj,si,sj):cuda()
+ local sconv = nn.Tanh():cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
- local gconv = cudnn.Tanh(ki,kj,si,sj):cuda()
+ local gconv = cudnn.Tanh():cuda()
local rescuda = gconv:forward(input)
-- serialize and deserialize
@@ -198,11 +198,11 @@ function cudnntest.Sigmoid()
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,from,outj,outi):cuda()
- local sconv = nn.Tanh(ki,kj,si,sj):cuda()
+ local sconv = nn.Sigmoid():cuda()
local groundtruth = sconv:forward(input)
local groundgrad = sconv:backward(input, gradOutput)
cutorch.synchronize()
- local gconv = cudnn.Tanh(ki,kj,si,sj):cuda()
+ local gconv = cudnn.Sigmoid():cuda()
local rescuda = gconv:forward(input)
-- serialize and deserialize
@@ -218,6 +218,43 @@ function cudnntest.Sigmoid()
mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
end
+function cudnntest.SoftMax()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local ki = math.random(2,4)
+ local kj = math.random(2,4)
+ local si = ki
+ local sj = kj
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = outi
+ local inj = outj
+ local input = torch.randn(bs,from,inj,ini):cuda()
+ local gradOutput = torch.randn(bs,from,outj,outi):cuda()
+
+ local sconv = nn.SoftMax():cuda()
+ local groundtruth = sconv:forward(input:view(bs,-1))
+ local groundgrad = sconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ local gconv = cudnn.SoftMax():cuda()
+ local rescuda = gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local rescuda = gconv:forward(input)
+ local resgrad = gconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error on state (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error on state (backward) ')
+end
+
+
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())