diff options
-rw-r--r-- | SpatialBatchNormalization.lua | 24 | ||||
-rw-r--r-- | ffi.lua | 8 | ||||
-rw-r--r-- | test/bntest.lua | 19 | ||||
-rw-r--r-- | test/test.lua | 28 |
4 files changed, 66 insertions, 13 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua index e25c5ff..bb3b934 100644 --- a/SpatialBatchNormalization.lua +++ b/SpatialBatchNormalization.lua @@ -11,31 +11,37 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine) end function SpatialBatchNormalization:createIODescriptors(input) + assert(input:dim() == 4) assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor', 'Only CUDA tensors are supported for cudnn.SpatialBatchNormalization!') - self.iDesc = cudnn.toDescriptor(input) - self.sDesc = cudnn.toDescriptor(self.bias:view(1, self.nFeature, 1, 1)) + if not self.iDesc or not self.oDesc or + input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2] + or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then + self.iSize = input:size() + self.output:resizeAs(input) + self.gradInput:resizeAs(input) + self.iDesc = cudnn.toDescriptor(input) + self.oDesc = cudnn.toDescriptor(self.output) + self.sDesc = cudnn.toDescriptor(self.bias:view(1, self.nFeature, 1, 1)) + end end local one = torch.FloatTensor({1}); local zero = torch.FloatTensor({0}); function SpatialBatchNormalization:updateOutput(input) - self:createIODescriptors(input) - - self.output:resizeAs(input) - self.gradInput:resizeAs(input) + self:createIODescriptors(input) if self.train then errcheck('cudnnBatchNormalizationForwardTraining', cudnn.getHandle(), self.mode, one:data(), zero:data(), - self.iDesc[0], input:data(), self.output:data(), + self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data()); else errcheck('cudnnBatchNormalizationForwardInference', cudnn.getHandle(), self.mode, one:data(), zero:data(), - self.iDesc[0], input:data(), self.output:data(), + self.iDesc[0], input:data(), self.oDesc[0], self.output:data(), self.sDesc[0], self.weight:data(), self.bias:data(), self.running_mean:data(), self.running_std:data(), self.eps); end @@ -47,7 +53,7 @@ function SpatialBatchNormalization:updateGradInput(input, gradOutput) self:createIODescriptors(input) errcheck('cudnnBatchNormalizationBackward', cudnn.getHandle(), self.mode, one:data(), zero:data(), - self.iDesc[0], input:data(), gradOutput:data(), self.gradInput:data(), + self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(), -- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(), self.eps, self.save_mean:data(), self.save_std:data()); @@ -1072,7 +1072,7 @@ cudnnStatus_t cudnnBatchNormalizationForwardTraining( const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ - /* const cudnnTensorDescriptor_t yDesc, */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ /* Same shared desc for all the 6 tensors below in the argument list. */ @@ -1139,7 +1139,7 @@ cudnnStatus_t cudnnBatchNormalizationForwardInference( const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ - /* const cudnnTensorDescriptor_t yDesc, */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ /* Same desc for all 4 tensors below */ @@ -1188,9 +1188,9 @@ cudnnStatus_t cudnnBatchNormalizationBackward( const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ const void *x, - /* const cudnnTensorDescriptor_t dyDesc, */ + const cudnnTensorDescriptor_t dyDesc, const void *dy, - /* const cudnnTensorDescriptor_t dxDesc, */ + const cudnnTensorDescriptor_t dxDesc, void *dx, /* this tensor desc is used for all the 4 tensors below */ diff --git a/test/bntest.lua b/test/bntest.lua new file mode 100644 index 0000000..8ebd1fa --- /dev/null +++ b/test/bntest.lua @@ -0,0 +1,19 @@ +require 'cunn' +require 'cudnn' + +local h=5 +local w=5 +local bsz=4 +local from=4 +local input = torch.randn(bsz,from,h,w):cuda() +local gradOutput = torch.randn(bsz,from,h,w):cuda() +local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda() +local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda() +local groundtruth = gbn:forward(input) +local rescuda = cbn:forward(input) +local resgrad = cbn:backward(input, gradOutput) +local groundgrad = gbn:backward(input, gradOutput) +local error = (rescuda:float() - groundtruth:float()):abs():max() +print("error",error) +error = (resgrad:float() - groundgrad:float()):abs():max() +print("error back",error) diff --git a/test/test.lua b/test/test.lua index 8c97ffa..9b85499 100644 --- a/test/test.lua +++ b/test/test.lua @@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') end + +function cudnntest.SpatialBatchNormalization() + -- batch + local h = 4 --math.random(5,10) + local w = 4 --math.random(5,10) + local bsz = 4 --math.random(1, 32) + local from = 4 --math.random(1, 32) + local input = torch.randn(bsz,from,h,w):cuda() + local gradOutput = torch.randn(bsz,from,h,w):cuda() + local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda() + local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda() + + local rescuda = cbn:forward(input) + local groundtruth = gbn:forward(input) + local resgrad = cbn:backward(input, gradOutput) + local groundgrad = gbn:backward(input, gradOutput) + + + local error = rescuda:float() - groundtruth:float() + mytester:assertlt(error:abs():max(), + precision_forward, 'error in batch normalization (forward) ') + error = resgrad:float() - groundgrad:float() + mytester:assertlt(error:abs():max(), + precision_backward, 'error in batch normalization (backward) ') + +end + + function cudnntest.SpatialCrossEntropyCriterion() -- batch local numLabels = math.random(5,10) |