Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-09-29 12:41:46 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-09-29 13:19:16 +0300
commit6a1672f186dddc0385afb0b847e743918377010c (patch)
treebb13a8cb2a923cfed04f1d801cdd6e023de1bc57
parent8c112dfe7adb26cad7f10c3f0919234a3ffd7b70 (diff)
add activations to functional
-rw-r--r--functional.lua71
-rw-r--r--test/test.lua37
2 files changed, 108 insertions, 0 deletions
diff --git a/functional.lua b/functional.lua
index 24c6030..39f8589 100644
--- a/functional.lua
+++ b/functional.lua
@@ -380,3 +380,74 @@ cudnn.functional.AveragePooling2D_updateGradInput = function(handle, input, outp
cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_AVERAGE', input, output, gradOutput, gradInput,
kH, kW, dH, dW, padH, padW, ceil_mode);
end
+
+local function createPointwiseDescriptors(mode, input, output)
+ local activDesc = ffi.new('struct cudnnActivationStruct*[1]')
+ errcheck('cudnnCreateActivationDescriptor', activDesc)
+ errcheck('cudnnSetActivationDescriptor', activDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 0.0);
+
+ local function destroyADesc(a)
+ if (a[0]) then
+ errcheck('cudnnDestroyActivationDescriptor', a[0]);
+ a[0] = nil
+ end
+ end
+ ffi.gc(activDesc, destroyADesc)
+
+ local nElem = input:nElement()
+ local iDesc = cudnn.toDescriptor(input:view(1,1,1,nElem))
+ return activDesc, iDesc
+end
+
+local function pointwise_updateOutput(handle, mode, input, output)
+ local activDesc, iDesc = createPointwiseDescriptors(mode, input, output)
+ errcheck('cudnnActivationForward',
+ handle, activDesc[0],
+ cudnn.scalar(input, 1),
+ iDesc[0], input:data(),
+ cudnn.scalar(input, 0),
+ iDesc[0], output:data());
+end
+
+local function pointwise_updateGradInput(handle, mode, input, output, gradOutput, gradInput)
+ local activDesc, iDesc = createPointwiseDescriptors(mode, input, output)
+ errcheck('cudnnActivationBackward',
+ handle, activDesc[0],
+ cudnn.scalar(input, 1),
+ iDesc[0], output:data(),
+ iDesc[0], gradOutput:data(),
+ iDesc[0], input:data(),
+ cudnn.scalar(input, 0),
+ iDesc[0], gradInput:data());
+end
+
+cudnn.functional.ReLU_updateOutput = function(handle, input, output)
+ output:resizeAs(input)
+ pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_RELU', input, output)
+end
+
+cudnn.functional.ReLU_updateGradInput = function(handle, input, output, gradOutput, gradInput)
+ gradInput:resizeAs(input)
+ pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_RELU', input, output, gradOutput, gradInput)
+end
+
+cudnn.functional.Tanh_updateOutput = function(handle, input, output)
+ output:resizeAs(input)
+ pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_TANH', input, output)
+end
+
+cudnn.functional.Tanh_updateGradInput = function(handle, input, output, gradOutput, gradInput)
+ gradInput:resizeAs(input)
+ pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_TANH', input, output, gradOutput, gradInput)
+end
+
+cudnn.functional.Sigmoid_updateOutput = function(handle, input, output)
+ output:resizeAs(input)
+ pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_SIGMOID', input, output)
+end
+
+cudnn.functional.Sigmoid_updateGradInput = function(handle, input, output, gradOutput, gradInput)
+ gradInput:resizeAs(input)
+ pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_SIGMOID', input, output, gradOutput, gradInput)
+end
+
diff --git a/test/test.lua b/test/test.lua
index 7986e9f..e3ed802 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -847,6 +847,43 @@ function cudnntest.functional_maxpooling2d()
testparams.precision_forward, 'error on updateGradInput ')
end
+local function test_functional_activation(mode, module)
+ local a = module:cuda()
+ local input = torch.randn(10,3,10,10):cuda()
+ a:forward(input)
+ local output = a.output:clone():normal()
+ local gradOutput = a.output:clone():normal()
+ local gradInput = a:updateGradInput(input, gradOutput):clone():normal()
+ cudnn.functional[mode.forward](cudnn.getHandle(), input, output)
+ mytester:assertlt((output - a.output):abs():max(),
+ testparams.precision_forward, 'error on forward ')
+ cudnn.functional[mode.backward](cudnn.getHandle(), input, output,
+ gradOutput, gradInput)
+ mytester:assertlt((gradInput - a.gradInput):abs():max(),
+ testparams.precision_forward, 'error on updateGradInput ')
+end
+
+function cudnntest.functional_relu()
+ test_functional_activation({
+ forward = 'ReLU_updateOutput',
+ backward = 'ReLU_updateGradInput',
+ }, cudnn.ReLU())
+end
+
+function cudnntest.functional_tanh()
+ test_functional_activation({
+ forward = 'Tanh_updateOutput',
+ backward = 'Tanh_updateGradInput',
+ }, cudnn.Tanh())
+end
+
+function cudnntest.functional_sigmoid()
+ test_functional_activation({
+ forward = 'Sigmoid_updateOutput',
+ backward = 'Sigmoid_updateGradInput',
+ }, cudnn.Sigmoid())
+end
+
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())
mytester = torch.Tester()