Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-03-21 14:26:51 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-03-21 14:26:51 +0300
commit935712907b5599e99283bb3d48901bd866e34f4d (patch)
tree17a7ddf3c85bb03d66850fccee7172e344d902db
parent084bd1806aa9f0a2d5f5f79375b0f67087dc4c17 (diff)
full conv tests
-rw-r--r--SpatialFullConvolution.lua8
-rw-r--r--convert.lua1
-rw-r--r--test/test.lua96
3 files changed, 100 insertions, 5 deletions
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua
index d736e26..d00a8a2 100644
--- a/SpatialFullConvolution.lua
+++ b/SpatialFullConvolution.lua
@@ -8,11 +8,6 @@ autotunerCache[1] = {} -- forward
autotunerCache[2] = {} -- backwardFilter
autotunerCache[3] = {} -- backwardData
-function SpatialFullConvolution:__init(nInputPlane, nOutputPlane,
- kW, kH, dW, dH, padW, padH, adjW, adjH)
- parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
-end
-
-- if you change the configuration of the module manually, call this
function SpatialFullConvolution:resetWeightDescriptors()
assert(torch.typename(self.weight) == 'torch.CudaTensor',
@@ -40,6 +35,7 @@ end
function SpatialFullConvolution:fastest(mode)
if mode == nil then mode = true end
self.fastest_mode = mode
+ self.iSize = self.iSize or torch.LongStorage(4)
self.iSize:fill(0)
return self
end
@@ -54,6 +50,7 @@ function SpatialFullConvolution:setMode(fmode, bdmode, bwmode)
if bwmode ~= nil then
self.bwmode = bwmode
end
+ self.iSize = self.iSize or torch.LongStorage(4)
self.iSize:fill(0)
return self
end
@@ -72,6 +69,7 @@ function SpatialFullConvolution:createIODescriptors(input)
batch = false
end
assert(input:dim() == 4 and input:isContiguous());
+ self.iSize = self.iSize or torch.LongStorage(4):fill(0)
if not self.iDesc or not self.oDesc or
input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
diff --git a/convert.lua b/convert.lua
index 7b51ae5..aa8c65f 100644
--- a/convert.lua
+++ b/convert.lua
@@ -2,6 +2,7 @@
local layer_list = {
'SpatialConvolution',
'SpatialCrossMapLRN',
+ 'SpatialFullConvolution',
'SpatialMaxPooling',
'SpatialAveragePooling',
'ReLU',
diff --git a/test/test.lua b/test/test.lua
index 388deb3..55171c7 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -206,6 +206,102 @@ function cudnntest.SpatialConvolution_backward_single()
test(sconv, gconv)
end
+function cudnntest.SpatialFullConvolution_forward_batch()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local to = math.random(1,64)
+ local ki = math.random(1,15)
+ local kj = math.random(1,15)
+ local si = math.random(1,ki)
+ local sj = math.random(1,kj)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+
+ local input = torch.randn(bs,from,inj,ini):cuda()
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda()
+ local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest()
+ gconv.weight:copy(sconv.weight)
+ gconv.bias:copy(sconv.bias)
+
+ local function test(sconv, gconv)
+ local groundtruth = sconv:forward(input)
+ cutorch.synchronize()
+ local rescuda = gconv:forward(input)
+ cutorch.synchronize()
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
+
+ -- IO
+ local ferr,berr = jac.testIO(gconv, input)
+ mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ')
+ mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ')
+ end
+
+ test(sconv, gconv)
+ local gconv = cudnn.convert(sconv, cudnn)
+ mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check')
+ test(sconv, gconv)
+end
+
+function cudnntest.SpatialFullConvolution_backward_batch()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local to = math.random(1,64)
+ local ki = math.random(1,15)
+ local kj = math.random(1,15)
+ local si = math.random(1,ki)
+ local sj = math.random(1,kj)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local scale = math.random()
+
+ local input = torch.randn(bs,from,inj,ini):cuda()
+ local gradOutput = torch.randn(bs,to,outj,outi):cuda()
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda()
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput, scale)
+ cutorch.synchronize()
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+
+ local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest()
+ gconv.weight:copy(sconv.weight)
+ gconv.bias:copy(sconv.bias)
+ gconv:forward(input)
+
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+
+ local function test(sconv, gconv)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput, scale)
+ cutorch.synchronize()
+ local weightcuda = gconv.gradWeight
+ local biascuda = gconv.gradBias
+
+ local error = rescuda:float() - groundgrad:float()
+ local werror = weightcuda:float() - groundweight:float()
+ local berror = biascuda:float() - groundbias:float()
+
+ mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
+ mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
+ mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ end
+
+ test(sconv, gconv)
+ local gconv = cudnn.convert(sconv, cudnn)
+ mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check')
+ test(sconv, gconv)
+end
+
+
function cudnntest.TemporalConvolution_batch()
local bs = math.random(1,32)
local inputFrameSize = math.random(1,64)