From 084bd1806aa9f0a2d5f5f79375b0f67087dc4c17 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Mon, 21 Mar 2016 12:03:04 +0100 Subject: removed double conversion tests --- test/test.lua | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index d812541..388deb3 100644 --- a/test/test.lua +++ b/test/test.lua @@ -25,10 +25,11 @@ function cudnntest.SpatialConvolution_forward_batch() local inj = (outj-1)*sj+kj local input = torch.randn(bs,from,inj,ini):cuda() - local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() + local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) + local function test(sconv, gconv) local groundtruth = sconv:forward(input) cutorch.synchronize() @@ -37,13 +38,6 @@ function cudnntest.SpatialConvolution_forward_batch() local error = rescuda:float() - groundtruth:float() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') - local gconv = cudnn.convert(sconv, cudnn) - mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialConvolution', 'conversion type check') - local rescuda = gconv:forward(input) - cutorch.synchronize() - local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ') - -- IO local ferr,berr = jac.testIO(gconv, input) mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') @@ -73,7 +67,7 @@ function cudnntest.SpatialConvolution_backward_batch() local input = torch.randn(bs,from,inj,ini):cuda() local gradOutput = torch.randn(bs,to,outj,outi):cuda() - local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() + local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() sconv:forward(input) sconv:zeroGradParameters() local groundgrad = sconv:backward(input, gradOutput, scale) @@ -126,7 +120,7 @@ function cudnntest.SpatialConvolution_forward_single() local inj = (outj-1)*sj+kj local input = torch.randn(from,inj,ini):cuda() - local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() + local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() gconv.weight:copy(sconv.weight) gconv.bias:copy(sconv.bias) @@ -141,13 +135,6 @@ function cudnntest.SpatialConvolution_forward_single() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') - local gconv = cudnn.convert(sconv, cudnn) - mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialConvolution', 'conversion type check') - local rescuda = gconv:forward(input) - cutorch.synchronize() - local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ') - -- IO local ferr,berr = jac.testIO(gconv, input) mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') @@ -175,7 +162,7 @@ function cudnntest.SpatialConvolution_backward_single() local input = torch.randn(from,inj,ini):cuda() local gradOutput = torch.randn(to,outj,outi):cuda() - local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda() + local sconv = nn.SpatialConvolution(from,to,ki,kj,si,sj):cuda() sconv:forward(input) sconv:zeroGradParameters() local groundgrad = sconv:backward(input, gradOutput) @@ -396,12 +383,6 @@ function cudnntest.VolumetricConvolution_forward_single() mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') - local gconv = cudnn.convert(sconv, cudnn):cuda() - local rescuda = gconv:forward(input) - cutorch.synchronize() - local error = rescuda:float() - groundtruth:float() - mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward conversion) ') - -- IO local ferr,berr = jac.testIO(gconv, input) mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') -- cgit v1.2.3