Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-03-21 14:03:04 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-03-21 14:03:04 +0300
commit084bd1806aa9f0a2d5f5f79375b0f67087dc4c17 (patch)
treef93bbb481ae8b67c0f4a0e71ebb07aa22dfb5077 /test
parent51d583f3a67ac2ad5ebe7d6096a8c2ee3b4f3e25 (diff)
removed double conversion tests
Diffstat (limited to 'test')
-rw-r--r--test/test.lua29
1 files changed, 5 insertions, 24 deletions
diff --git a/test/test.lua b/test/test.lua
index d812541..388deb3 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -25,10 +25,11 @@ function cudnntest.SpatialConvolution_forward_batch()
local inj = (outj-1)*sj+kj
local input = torch.randn(bs,from,inj,ini):cuda()
- local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda()
+ local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest()
gconv.weight:copy(sconv.weight)
gconv.bias:copy(sconv.bias)
+
local function test(sconv, gconv)
local groundtruth = sconv:forward(input)
cutorch.synchronize()
@@ -37,13 +38,6 @@ function cudnntest.SpatialConvolution_forward_batch()
local error = rescuda:float() - groundtruth:float()
mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
- local gconv = cudnn.convert(sconv, cudnn)
- mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialConvolution', 'conversion type check')
- local rescuda = gconv:forward(input)
- cutorch.synchronize()
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ')
-
-- IO
local ferr,berr = jac.testIO(gconv, input)
mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ')
@@ -73,7 +67,7 @@ function cudnntest.SpatialConvolution_backward_batch()
local input = torch.randn(bs,from,inj,ini):cuda()
local gradOutput = torch.randn(bs,to,outj,outi):cuda()
- local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda()
+ local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
sconv:forward(input)
sconv:zeroGradParameters()
local groundgrad = sconv:backward(input, gradOutput, scale)
@@ -126,7 +120,7 @@ function cudnntest.SpatialConvolution_forward_single()
local inj = (outj-1)*sj+kj
local input = torch.randn(from,inj,ini):cuda()
- local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda()
+ local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
gconv.weight:copy(sconv.weight)
gconv.bias:copy(sconv.bias)
@@ -141,13 +135,6 @@ function cudnntest.SpatialConvolution_forward_single()
mytester:assertlt(error:abs():max(), precision_forward,
'error on state (forward) ')
- local gconv = cudnn.convert(sconv, cudnn)
- mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialConvolution', 'conversion type check')
- local rescuda = gconv:forward(input)
- cutorch.synchronize()
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ')
-
-- IO
local ferr,berr = jac.testIO(gconv, input)
mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ')
@@ -175,7 +162,7 @@ function cudnntest.SpatialConvolution_backward_single()
local input = torch.randn(from,inj,ini):cuda()
local gradOutput = torch.randn(to,outj,outi):cuda()
- local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda()
+ local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
sconv:forward(input)
sconv:zeroGradParameters()
local groundgrad = sconv:backward(input, gradOutput)
@@ -396,12 +383,6 @@ function cudnntest.VolumetricConvolution_forward_single()
mytester:assertlt(error:abs():max(), precision_forward,
'error on state (forward) ')
- local gconv = cudnn.convert(sconv, cudnn):cuda()
- local rescuda = gconv:forward(input)
- cutorch.synchronize()
- local error = rescuda:float() - groundtruth:float()
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ')
-
-- IO
local ferr,berr = jac.testIO(gconv, input)
mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ')