Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-02-26 07:07:27 +0300
committerSoumith Chintala <soumith@gmail.com>2016-02-26 07:07:27 +0300
commit756111325aa68030240bc0267e7c2864876bde6f (patch)
tree0cadfd57fc5a2e7c564261ee0fb2d2a46b768701
parentf1cfa7ca2d379ac9c336147a59b45fbf2039ffbf (diff)
parent83b5b6c6bbed0fb1a9457fc285d187341a040e90 (diff)
Merge pull request #124 from colesbury/bn
Add cudnn.BatchNormalization and cudnn.VolumetricBatchNormalization
-rw-r--r--BatchNormalization.lua133
-rw-r--r--README.md18
-rw-r--r--SpatialBatchNormalization.lua138
-rw-r--r--VolumetricBatchNormalization.lua5
-rw-r--r--init.lua8
-rw-r--r--test/test.lua48
6 files changed, 190 insertions, 160 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua
new file mode 100644
index 0000000..c353dc3
--- /dev/null
+++ b/BatchNormalization.lua
@@ -0,0 +1,133 @@
+local BatchNormalization, parent = torch.class('cudnn.BatchNormalization', 'nn.Module')
+local ffi = require 'ffi'
+local errcheck = cudnn.errcheck
+
+BatchNormalization.mode = 'CUDNN_BATCHNORM_PER_ACTIVATION'
+BatchNormalization.nDim = 2
+
+function BatchNormalization:__init(nFeature, eps, momentum, affine)
+ parent.__init(self)
+ assert(nFeature and type(nFeature) == 'number',
+ 'Missing argument #1: Number of feature planes. ')
+ assert(nFeature ~= 0, 'To set affine=false call BatchNormalization'
+ .. '(nFeature, eps, momentum, false) ')
+ assert(affine == true or affine == nil, 'only affine supported')
+ self.affine = true
+ self.eps = eps or 1e-5
+ self.train = true
+ self.momentum = momentum or 0.1
+
+ self.running_mean = torch.zeros(nFeature)
+ self.running_std = torch.ones(nFeature)
+ if self.affine then
+ self.weight = torch.Tensor(nFeature)
+ self.bias = torch.Tensor(nFeature)
+ self.gradWeight = torch.Tensor(nFeature)
+ self.gradBias = torch.Tensor(nFeature)
+ self:reset()
+ end
+end
+
+function BatchNormalization:reset()
+ if self.weight then
+ self.weight:uniform()
+ end
+ if self.bias then
+ self.bias:zero()
+ end
+ self.running_mean:zero()
+ self.running_std:fill(1)
+end
+
+function BatchNormalization:createIODescriptors(input)
+ assert(input:dim() == self.nDim)
+ assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',
+ 'Only CUDA tensors are supported for cudnn.BatchNormalization!')
+ if not self.iDesc or not self.oDesc or not input:isSize(self.iSize) then
+ local nFeature = self.running_mean:numel()
+ self.iSize = input:size()
+ self.output:resizeAs(input)
+ self.gradInput:resizeAs(input)
+ self.iDesc = cudnn.toDescriptor(input)
+ self.oDesc = cudnn.toDescriptor(self.output)
+ local biasSize = torch.ones(self.nDim):totable()
+ biasSize[2] = nFeature
+ self.sDesc = cudnn.toDescriptor(self.bias:view(table.unpack(biasSize)))
+ end
+end
+
+local one = torch.FloatTensor({1});
+local zero = torch.FloatTensor({0});
+local scaleTens = torch.FloatTensor(1);
+
+function BatchNormalization:updateOutput(input)
+ self:createIODescriptors(input)
+
+ self.save_mean = self.save_mean or input.new()
+ self.save_mean:resizeAs(self.running_mean)
+ self.save_std = self.save_std or input.new()
+ self.save_std:resizeAs(self.running_std)
+
+ if self.train then
+ errcheck('cudnnBatchNormalizationForwardTraining',
+ cudnn.getHandle(), self.mode, one:data(), zero:data(),
+ self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
+ self.sDesc[0], self.weight:data(), self.bias:data(),
+ self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data());
+ else
+ errcheck('cudnnBatchNormalizationForwardInference',
+ cudnn.getHandle(), self.mode, one:data(), zero:data(),
+ self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
+ self.sDesc[0], self.weight:data(), self.bias:data(),
+ self.running_mean:data(), self.running_std:data(), self.eps);
+ end
+ return self.output
+end
+
+local function backward(self,input,gradOutput, scale)
+ assert(gradOutput:isContiguous())
+ self:createIODescriptors(input)
+ scale = scale or 1
+ scaleTens:fill(scale)
+ errcheck('cudnnBatchNormalizationBackward',
+ cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(),
+ self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
+ -- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
+ self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
+ self.eps, self.save_mean:data(), self.save_std:data());
+ return self.gradInput
+end
+
+function BatchNormalization:updateGradInput(input, gradOutput, scale)
+ -- will in fact update gradWeight and gradBias too, accGradParameters call is empty
+ return backward(self, input, gradOutput, scale)
+end
+
+
+function BatchNormalization:backward(input, gradOutput, scale)
+ return backward(self, input, gradOutput, scale)
+end
+
+function BatchNormalization:accGradParameters(input, gradOutput, scale)
+end
+
+function BatchNormalization:clearDesc()
+ self.iDesc = nil
+ self.oDesc = nil
+ self.sDesc = nil
+end
+
+function BatchNormalization:write(f)
+ self:clearDesc()
+ local var = {}
+ for k,v in pairs(self) do
+ var[k] = v
+ end
+ f:writeObject(var)
+end
+
+function BatchNormalization:clearState()
+ self:clearDesc()
+ nn.utils.clear(self, 'save_mean', 'save_std')
+ return parent.clearState(self)
+end
diff --git a/README.md b/README.md
index 708d3a8..6f8c1c4 100644
--- a/README.md
+++ b/README.md
@@ -1,21 +1,21 @@
cudnn.torch
===========
-Torch7 FFI bindings for NVidia CuDNN (R4) kernels!
+Torch7 FFI bindings for NVIDIA cuDNN (R4) kernels!
Modules are API compatible their [`nn`](https://github.com/torch/nn) equivalents. Fully unit-tested against `nn` implementations.
Conversion between `nn` and `cudnn` is available through `cudnn.convert` function.
#### Installation
-* Install CuDNN (version R4 EA)
-* Have at least Cuda 7.0
+* Install cuDNN (version R4 EA)
+* Have at least CUDA 7.0
* Have `libcudnn.so` in your library path (Install it from https://developer.nvidia.com/cuDNN )
#### Modules
```lua
--- All inputs have to be 3D or 4D(batch-mode), except ReLU, Tanh and Sigmoid
+-- All inputs have to be 3D or 4D(batch-mode), except ReLU, Tanh, Sigmoid, and BatchNormalization
cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, [dW = 1], [dH = 1], [padW = 0], [padH = 0], [groups = 1])
cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
@@ -91,13 +91,3 @@ For version CuDNN R3, checkout the branch **R3**
R4 Release Notes:
- Rather than resolving v3-v4 diffs, I have imported new cudnn.h with its entirety and converted comments and defines. This should be less error-prone.
- addTensor_v2 uses changed to new AddTensor API.
-
-R4 TODO:
-per-activation BN code needs to be added (new .lua similar to SpatialBN.lua, as per Andrei:
-I believe we have at least one thing missing - per-activation BN (Torch implementation in nn.BatchNormalization.lua).
-What I believe we have now is an integration of implementation for nn.SpatialBatchNormalization.lua
-
-This is very similar to SpatialBatchNormalizaiton.lua but should use a different cudnnBatchNormalizationMode_t and tensor dimensions need to be adjusted accordingly.
-For Spatial BN normalization is performed over N with 1CHW result and for per-activation it's done over NHW with 1C11 result.
-
-Per-activation BN is only used after non-convolutional layers where spatially-invariant behavior is not expected.
diff --git a/SpatialBatchNormalization.lua b/SpatialBatchNormalization.lua
index 53e8f7e..2c80fc7 100644
--- a/SpatialBatchNormalization.lua
+++ b/SpatialBatchNormalization.lua
@@ -1,135 +1,5 @@
-local SpatialBatchNormalization, parent = torch.class('cudnn.SpatialBatchNormalization', 'nn.Module')
-local ffi = require 'ffi'
-local errcheck = cudnn.errcheck
+local SpatialBatchNormalization, parent =
+ torch.class('cudnn.SpatialBatchNormalization', 'cudnn.BatchNormalization')
-function SpatialBatchNormalization:__init(nFeature, eps, momentum, affine)
- parent.__init(self)
- assert(nFeature and type(nFeature) == 'number',
- 'Missing argument #1: Number of feature planes. ')
- assert(nFeature ~= 0, 'To set affine=false call SpatialBatchNormalization'
- .. '(nFeature, eps, momentum, false) ')
- if affine ~= nil then
- assert(type(affine) == 'boolean', 'affine has to be true/false')
- self.affine = affine
- else
- self.affine = true
- end
- self.eps = eps or 1e-5
- self.train = true
- self.momentum = momentum or 0.1
-
- self.running_mean = torch.zeros(nFeature)
- self.running_std = torch.ones(nFeature)
- if self.affine then
- self.weight = torch.Tensor(nFeature)
- self.bias = torch.Tensor(nFeature)
- self.gradWeight = torch.Tensor(nFeature)
- self.gradBias = torch.Tensor(nFeature)
- self:reset()
- end
- self.mode = 'CUDNN_BATCHNORM_SPATIAL'
-end
-
-function SpatialBatchNormalization:reset()
- if self.weight then
- self.weight:uniform()
- end
- if self.bias then
- self.bias:zero()
- end
- self.running_mean:zero()
- self.running_std:fill(1)
-end
-
-function SpatialBatchNormalization:createIODescriptors(input)
- assert(input:dim() == 4)
- assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',
- 'Only CUDA tensors are supported for cudnn.SpatialBatchNormalization!')
- if not self.iDesc or not self.oDesc or
- input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
- or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
- local nFeature = self.running_mean:numel()
- self.iSize = input:size()
- self.output:resizeAs(input)
- self.gradInput:resizeAs(input)
- self.iDesc = cudnn.toDescriptor(input)
- self.oDesc = cudnn.toDescriptor(self.output)
- self.sDesc = cudnn.toDescriptor(self.bias:view(1, nFeature, 1, 1))
- end
-end
-
-local one = torch.FloatTensor({1});
-local zero = torch.FloatTensor({0});
-local scaleTens = torch.FloatTensor(1);
-
-function SpatialBatchNormalization:updateOutput(input)
- self:createIODescriptors(input)
-
- self.save_mean = self.save_mean or input.new()
- self.save_mean:resizeAs(self.running_mean)
- self.save_std = self.save_std or input.new()
- self.save_std:resizeAs(self.running_std)
-
- if self.train then
- errcheck('cudnnBatchNormalizationForwardTraining',
- cudnn.getHandle(), self.mode, one:data(), zero:data(),
- self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
- self.sDesc[0], self.weight:data(), self.bias:data(),
- self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data());
- else
- errcheck('cudnnBatchNormalizationForwardInference',
- cudnn.getHandle(), self.mode, one:data(), zero:data(),
- self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
- self.sDesc[0], self.weight:data(), self.bias:data(),
- self.running_mean:data(), self.running_std:data(), self.eps);
- end
- return self.output
-end
-
-local function backward(self,input,gradOutput, scale)
- assert(gradOutput:isContiguous())
- self:createIODescriptors(input)
- scale = scale or 1
- scaleTens:fill(scale)
- errcheck('cudnnBatchNormalizationBackward',
- cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(),
- self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
- -- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
- self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
- self.eps, self.save_mean:data(), self.save_std:data());
- return self.gradInput
-end
-
-function SpatialBatchNormalization:updateGradInput(input, gradOutput, scale)
--- will in fact update gradWeight and gradBias too, accGradParameters call is empty
- return backward(self, input,gradOutput, scale)
-end
-
-
-function SpatialBatchNormalization:backward(input, gradOutput, scale)
- return backward(self, input,gradOutput, scale)
-end
-
-function SpatialBatchNormalization:accGradParameters(input, gradOutput, scale)
-end
-
-function SpatialBatchNormalization:clearDesc()
- self.iDesc = nil
- self.oDesc = nil
- self.sDesc = nil
-end
-
-function SpatialBatchNormalization:write(f)
- self:clearDesc()
- local var = {}
- for k,v in pairs(self) do
- var[k] = v
- end
- f:writeObject(var)
-end
-
-function SpatialBatchNormalization:clearState()
- self:clearDesc()
- nn.utils.clear(self, 'save_mean', 'save_std')
- return parent.clearState(self)
-end
+SpatialBatchNormalization.mode = 'CUDNN_BATCHNORM_SPATIAL'
+SpatialBatchNormalization.nDim = 4
diff --git a/VolumetricBatchNormalization.lua b/VolumetricBatchNormalization.lua
new file mode 100644
index 0000000..1ae87f8
--- /dev/null
+++ b/VolumetricBatchNormalization.lua
@@ -0,0 +1,5 @@
+local VolumetricBatchNormalization =
+ torch.class('cudnn.VolumetricBatchNormalization', 'cudnn.BatchNormalization')
+
+VolumetricBatchNormalization.mode = 'CUDNN_BATCHNORM_SPATIAL'
+VolumetricBatchNormalization.nDim = 5
diff --git a/init.lua b/init.lua
index 580c177..53cb7ea 100644
--- a/init.lua
+++ b/init.lua
@@ -70,6 +70,12 @@ function cudnn.toDescriptor(t)
errcheck('cudnnDestroyTensorDescriptor', d[0]);
end
ffi.gc(descriptor, destroy)
+ -- view 2D and 3D as 4D
+ if t:dim() == 2 then
+ t = t:view(t:size(1), t:size(2), 1, 1)
+ elseif t:dim() == 3 then
+ t = t:view(t:size(1), t:size(2), t:size(3), 1)
+ end
-- set descriptor
local size = torch.LongTensor(t:size()):int()
local stride = torch.LongTensor(t:stride()):int()
@@ -110,7 +116,9 @@ require('cudnn.SpatialLogSoftMax')
require('cudnn.SoftMax')
require('cudnn.LogSoftMax')
require('cudnn.SpatialCrossMapLRN')
+require('cudnn.BatchNormalization')
require('cudnn.SpatialBatchNormalization')
+require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
require('cudnn.functional')
diff --git a/test/test.lua b/test/test.lua
index 10ceabb..9449b88 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1134,16 +1134,11 @@ function cudnntest.SpatialLogSoftMax()
mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')
end
-function cudnntest.SpatialBatchNormalization()
- -- batch
- local h = math.random(5,10)
- local w = math.random(5,10)
- local bsz = math.random(1, 32)
- local from = math.random(1, 32)
- local input = torch.randn(bsz,from,h,w):cuda()
- local gradOutput = torch.randn(bsz,from,h,w):cuda()
- local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda()
- local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda()
+local function testBatchNormalization(moduleName, inputSize)
+ local input = torch.randn(table.unpack(inputSize)):cuda()
+ local gradOutput = torch.randn(table.unpack(inputSize)):cuda()
+ local cbn = cudnn[moduleName](inputSize[2], 1e-3):cuda()
+ local gbn = nn[moduleName](inputSize[2], 1e-3):cuda()
cbn.weight:copy(gbn.weight)
cbn.bias:copy(gbn.bias)
mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init')
@@ -1161,6 +1156,35 @@ function cudnntest.SpatialBatchNormalization()
precision_backward, 'error in batch normalization (backward) ')
end
+function cudnntest.BatchNormalization()
+ local size = {
+ math.random(1, 32),
+ math.random(16, 256),
+ }
+ testBatchNormalization('BatchNormalization', size)
+end
+
+function cudnntest.SpatialBatchNormalization()
+ local size = {
+ math.random(1, 32),
+ math.random(1, 32),
+ math.random(5, 10),
+ math.random(5, 10),
+ }
+ testBatchNormalization('SpatialBatchNormalization', size)
+end
+
+function cudnntest.SpatialBatchNormalization()
+ local size = {
+ math.random(1, 32),
+ math.random(1, 32),
+ math.random(2, 6),
+ math.random(2, 6),
+ math.random(2, 6),
+ }
+ testBatchNormalization('VolumetricBatchNormalization', size)
+end
+
function cudnntest.SpatialCrossEntropyCriterion()
-- batch
local numLabels = math.random(5,10)
@@ -1188,11 +1212,11 @@ function cudnntest.SpatialCrossEntropyCriterion()
ggi[{{}, {}, {i}, {j}}]:copy(ggi1)
end
end
-
+
-- nn.CrossEntropy in contrast to cudnn.SpatialCrossEntropyCriterion cannot
-- average over the last spatial dimensions because it is run in a loop
ggi:div(h * w)
-
+
local err = (gi - ggi):abs():max()
mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward')