diff options
-rw-r--r-- | SpatialFullConvolution.lua | 8 | ||||
-rw-r--r-- | convert.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 96 |
3 files changed, 100 insertions, 5 deletions
diff --git a/SpatialFullConvolution.lua b/SpatialFullConvolution.lua index d736e26..d00a8a2 100644 --- a/SpatialFullConvolution.lua +++ b/SpatialFullConvolution.lua @@ -8,11 +8,6 @@ autotunerCache[1] = {} -- forward autotunerCache[2] = {} -- backwardFilter autotunerCache[3] = {} -- backwardData -function SpatialFullConvolution:__init(nInputPlane, nOutputPlane, - kW, kH, dW, dH, padW, padH, adjW, adjH) - parent.__init(self, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH) -end - -- if you change the configuration of the module manually, call this function SpatialFullConvolution:resetWeightDescriptors() assert(torch.typename(self.weight) == 'torch.CudaTensor', @@ -40,6 +35,7 @@ end function SpatialFullConvolution:fastest(mode) if mode == nil then mode = true end self.fastest_mode = mode + self.iSize = self.iSize or torch.LongStorage(4) self.iSize:fill(0) return self end @@ -54,6 +50,7 @@ function SpatialFullConvolution:setMode(fmode, bdmode, bwmode) if bwmode ~= nil then self.bwmode = bwmode end + self.iSize = self.iSize or torch.LongStorage(4) self.iSize:fill(0) return self end @@ -72,6 +69,7 @@ function SpatialFullConvolution:createIODescriptors(input) batch = false end assert(input:dim() == 4 and input:isContiguous()); + self.iSize = self.iSize or torch.LongStorage(4):fill(0) if not self.iDesc or not self.oDesc or input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then diff --git a/convert.lua b/convert.lua index 7b51ae5..aa8c65f 100644 --- a/convert.lua +++ b/convert.lua @@ -2,6 +2,7 @@ local layer_list = { 'SpatialConvolution', 'SpatialCrossMapLRN', + 'SpatialFullConvolution', 'SpatialMaxPooling', 'SpatialAveragePooling', 'ReLU', diff --git a/test/test.lua b/test/test.lua index 388deb3..55171c7 100644 --- a/test/test.lua +++ b/test/test.lua @@ -206,6 +206,102 @@ function cudnntest.SpatialConvolution_backward_single() test(sconv, gconv) end +function cudnntest.SpatialFullConvolution_forward_batch() + local bs = math.random(1,32) + local from = math.random(1,32) + local to = math.random(1,64) + local ki = math.random(1,15) + local kj = math.random(1,15) + local si = math.random(1,ki) + local sj = math.random(1,kj) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + + local input = torch.randn(bs,from,inj,ini):cuda() + local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda() + local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + + local function test(sconv, gconv) + local groundtruth = sconv:forward(input) + cutorch.synchronize() + local rescuda = gconv:forward(input) + cutorch.synchronize() + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') + + -- IO + local ferr,berr = jac.testIO(gconv, input) + mytester:assertlt(ferr, precision_io, torch.typename(gconv) .. ' - i/o forward err ') + mytester:assertlt(berr, precision_io, torch.typename(gconv) .. ' - i/o backward err ') + end + + test(sconv, gconv) + local gconv = cudnn.convert(sconv, cudnn) + mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check') + test(sconv, gconv) +end + +function cudnntest.SpatialFullConvolution_backward_batch() + local bs = math.random(1,32) + local from = math.random(1,32) + local to = math.random(1,64) + local ki = math.random(1,15) + local kj = math.random(1,15) + local si = math.random(1,ki) + local sj = math.random(1,kj) + local outi = math.random(1,64) + local outj = math.random(1,64) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + local scale = math.random() + + local input = torch.randn(bs,from,inj,ini):cuda() + local gradOutput = torch.randn(bs,to,outj,outi):cuda() + local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda() + sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local groundweight = sconv.gradWeight + local groundbias = sconv.gradBias + + local gconv = cudnn.SpatialFullConvolution(from,to,ki,kj,si,sj):cuda():fastest() + gconv.weight:copy(sconv.weight) + gconv.bias:copy(sconv.bias) + gconv:forward(input) + + -- serialize and deserialize + torch.save('modelTemp.t7', gconv) + gconv = torch.load('modelTemp.t7') + + local function test(sconv, gconv) + gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput, scale) + cutorch.synchronize() + local weightcuda = gconv.gradWeight + local biascuda = gconv.gradBias + + local error = rescuda:float() - groundgrad:float() + local werror = weightcuda:float() - groundweight:float() + local berror = biascuda:float() - groundbias:float() + + mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') + mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') + mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') + end + + test(sconv, gconv) + local gconv = cudnn.convert(sconv, cudnn) + mytester:asserteq(torch.typename(gconv), 'cudnn.SpatialFullConvolution', 'conversion type check') + test(sconv, gconv) +end + + function cudnntest.TemporalConvolution_batch() local bs = math.random(1,32) local inputFrameSize = math.random(1,64) |