diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-09-28 23:31:53 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-09-28 23:55:38 +0400 |
commit | 91793f499ab03a2d89cfa0148711e4ce02ab31ba (patch) | |
tree | 311aeb793a448dc87b67f6e1dd1d95ec75b4305a /test | |
parent | 2f8a13181e33e05261077c8108e39edc18bf0f3e (diff) |
adding SoftMax and SpatialSoftMax bindings
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 49 |
1 files changed, 43 insertions, 6 deletions
diff --git a/test/test.lua b/test/test.lua index 78a4fdc..6965219 100644 --- a/test/test.lua +++ b/test/test.lua @@ -130,11 +130,11 @@ function cudnntest.ReLU() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,from,outj,outi):cuda() - local sconv = nn.ReLU(ki,kj,si,sj):cuda() + local sconv = nn.ReLU():cuda() local groundtruth = sconv:forward(input) local groundgrad = sconv:backward(input, gradOutput) cutorch.synchronize() - local gconv = cudnn.ReLU(ki,kj,si,sj):cuda() + local gconv = cudnn.ReLU():cuda() local rescuda = gconv:forward(input) -- serialize and deserialize @@ -164,11 +164,11 @@ function cudnntest.Tanh() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,from,outj,outi):cuda() - local sconv = nn.Tanh(ki,kj,si,sj):cuda() + local sconv = nn.Tanh():cuda() local groundtruth = sconv:forward(input) local groundgrad = sconv:backward(input, gradOutput) cutorch.synchronize() - local gconv = cudnn.Tanh(ki,kj,si,sj):cuda() + local gconv = cudnn.Tanh():cuda() local rescuda = gconv:forward(input) -- serialize and deserialize @@ -198,11 +198,11 @@ function cudnntest.Sigmoid() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,from,outj,outi):cuda() - local sconv = nn.Tanh(ki,kj,si,sj):cuda() + local sconv = nn.Sigmoid():cuda() local groundtruth = sconv:forward(input) local groundgrad = sconv:backward(input, gradOutput) cutorch.synchronize() - local gconv = cudnn.Tanh(ki,kj,si,sj):cuda() + local gconv = cudnn.Sigmoid():cuda() local rescuda = gconv:forward(input) -- serialize and deserialize @@ -218,6 +218,43 @@ function cudnntest.Sigmoid() mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') end +function cudnntest.SoftMax() + local bs = math.random(1,32) + local from = math.random(1,32) + local ki = math.random(2,4) + local kj = math.random(2,4) + local si = ki + local sj = kj + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = outi + local inj = outj + local input = torch.randn(bs,from,inj,ini):cuda() + local gradOutput = torch.randn(bs,from,outj,outi):cuda() + + local sconv = nn.SoftMax():cuda() + local groundtruth = sconv:forward(input:view(bs,-1)) + local groundgrad = sconv:backward(input, gradOutput) + cutorch.synchronize() + local gconv = cudnn.SoftMax():cuda() + local rescuda = gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local rescuda = gconv:forward(input) + local resgrad = gconv:backward(input, gradOutput) + cutorch.synchronize() + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error on state (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error on state (backward) ') +end + + torch.setdefaulttensortype('torch.FloatTensor') math.randomseed(os.time()) |