Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SpatialBatchNormalization.lua24
-rw-r--r--ffi.lua8
-rw-r--r--test/bntest.lua19
-rw-r--r--test/test.lua28
4 files changed, 66 insertions, 13 deletions
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index e25c5ff..bb3b934 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -11,31 +11,37 @@ function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
end
function SpatialBatchNormalization:createIODescriptors(input)
+ assert(input:dim() == 4)
assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',
'Only CUDA tensors are supported for cudnn.SpatialBatchNormalization!')
- self.iDesc = cudnn.toDescriptor(input)
- self.sDesc = cudnn.toDescriptor(self.bias:view(1, self.nFeature, 1, 1))
+ if not self.iDesc or not self.oDesc or
+ input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
+ or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
+ self.iSize = input:size()
+ self.output:resizeAs(input)
+ self.gradInput:resizeAs(input)
+ self.iDesc = cudnn.toDescriptor(input)
+ self.oDesc = cudnn.toDescriptor(self.output)
+ self.sDesc = cudnn.toDescriptor(self.bias:view(1, self.nFeature, 1, 1))
+ end
end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
function SpatialBatchNormalization:updateOutput(input)
- self:createIODescriptors(input)
-
- self.output:resizeAs(input)
- self.gradInput:resizeAs(input)
+ self:createIODescriptors(input)
if self.train then
errcheck('cudnnBatchNormalizationForwardTraining',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
- self.iDesc[0], input:data(), self.output:data(),
+ self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data());
else
errcheck('cudnnBatchNormalizationForwardInference',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
- self.iDesc[0], input:data(), self.output:data(),
+ self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
self.running_mean:data(), self.running_std:data(), self.eps);
end
@@ -47,7 +53,7 @@ function SpatialBatchNormalization:updateGradInput(input, gradOutput)
self:createIODescriptors(input)
errcheck('cudnnBatchNormalizationBackward',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
- self.iDesc[0], input:data(), gradOutput:data(), self.gradInput:data(),
+ self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
-- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
self.eps, self.save_mean:data(), self.save_std:data());
diff --git a/ffi.lua b/ffi.lua
index c8ee963..61ac9ce 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -1072,7 +1072,7 @@ cudnnStatus_t cudnnBatchNormalizationForwardTraining(
const cudnnTensorDescriptor_t xDesc,
const void *x, /* NxCxHxW */
- /* const cudnnTensorDescriptor_t yDesc, */
+ const cudnnTensorDescriptor_t yDesc,
void *y, /* NxCxHxW */
/* Same shared desc for all the 6 tensors below in the argument list. */
@@ -1139,7 +1139,7 @@ cudnnStatus_t cudnnBatchNormalizationForwardInference(
const cudnnTensorDescriptor_t xDesc,
const void *x, /* NxCxHxW */
- /* const cudnnTensorDescriptor_t yDesc, */
+ const cudnnTensorDescriptor_t yDesc,
void *y, /* NxCxHxW */
/* Same desc for all 4 tensors below */
@@ -1188,9 +1188,9 @@ cudnnStatus_t cudnnBatchNormalizationBackward(
const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */
const void *x,
- /* const cudnnTensorDescriptor_t dyDesc, */
+ const cudnnTensorDescriptor_t dyDesc,
const void *dy,
- /* const cudnnTensorDescriptor_t dxDesc, */
+ const cudnnTensorDescriptor_t dxDesc,
void *dx,
/* this tensor desc is used for all the 4 tensors below */
diff --git a/test/bntest.lua b/test/bntest.lua
new file mode 100644
index 0000000..8ebd1fa
--- /dev/null
+++ b/test/bntest.lua
@@ -0,0 +1,19 @@
+require 'cunn'
+require 'cudnn'
+
+local h=5
+local w=5
+local bsz=4
+local from=4
+local input = torch.randn(bsz,from,h,w):cuda()
+local gradOutput = torch.randn(bsz,from,h,w):cuda()
+local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+local groundtruth = gbn:forward(input)
+local rescuda = cbn:forward(input)
+local resgrad = cbn:backward(input, gradOutput)
+local groundgrad = gbn:backward(input, gradOutput)
+local error = (rescuda:float() - groundtruth:float()):abs():max()
+print("error",error)
+error = (resgrad:float() - groundgrad:float()):abs():max()
+print("error back",error)
diff --git a/test/test.lua b/test/test.lua
index 8c97ffa..9b85499 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -764,6 +764,34 @@ function cudnntest.SpatialLogSoftMax()
mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
end
+
+function cudnntest.SpatialBatchNormalization()
+ -- batch
+ local h = 4 --math.random(5,10)
+ local w = 4 --math.random(5,10)
+ local bsz = 4 --math.random(1, 32)
+ local from = 4 --math.random(1, 32)
+ local input = torch.randn(bsz,from,h,w):cuda()
+ local gradOutput = torch.randn(bsz,from,h,w):cuda()
+ local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+ local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
+
+ local rescuda = cbn:forward(input)
+ local groundtruth = gbn:forward(input)
+ local resgrad = cbn:backward(input, gradOutput)
+ local groundgrad = gbn:backward(input, gradOutput)
+
+
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error in batch normalization (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error in batch normalization (backward) ')
+
+end
+
+
function cudnntest.SpatialCrossEntropyCriterion()
-- batch
local numLabels = math.random(5,10)