From 935712907b5599e99283bb3d48901bd866e34f4d Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Mon, 21 Mar 2016 12:26:51 +0100 Subject: full conv tests --- test/test.lua | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index 388deb3..55171c7 100644 --- a/test/test.lua +++ b/test/test.lua @@ -206,6 +206,102 @@ function cudnntest.SpatialConvolution_backward_single() test(sconv, gconv) end +function cudnntest.SpatialFullConvolution_forward_batch() + local bs = math.random(1,32) + local from = math.random(1,32) + local to = math.random(1,64) + local ki = math.random(1,15) + local kj = math.random(1,15) + local si = math.random(1,ki) + local sj = math.random(1,kj) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + + local input = torch.randn(bs,from,inj,ini):cuda() + local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda() + local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + + local function test(sconv, gconv) + local groundtruth = sconv:forward(input) + cutorch.synchronize() + local rescuda = gconv:forward(input) + cutorch.synchronize() + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') + end + + test(sconv, gconv) + local gconv = cudnn.convert(sconv, cudnn) + mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check') + test(sconv, gconv) +end + +function cudnntest.SpatialFullConvolution_backward_batch() + local bs = math.random(1,32) + local from = math.random(1,32) + local to = math.random(1,64) + local ki = math.random(1,15) + local kj = math.random(1,15) + local si = math.random(1,ki) + local sj = math.random(1,kj) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local scale = math.random() + + local input = torch.randn(bs,from,inj,ini):cuda() + local gradOutput = torch.randn(bs,to,outj,outi):cuda() + local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda() + sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local function test(sconv, gconv) + gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + + mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') + end + + test(sconv, gconv) + local gconv = cudnn.convert(sconv, cudnn) + mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check') + test(sconv, gconv) +end + + function cudnntest.TemporalConvolution_batch() local bs = math.random(1,32) local inputFrameSize = math.random(1,64) -- cgit v1.2.3