diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-09-29 12:41:46 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-09-29 13:19:16 +0300 |
commit | 6a1672f186dddc0385afb0b847e743918377010c (patch) | |
tree | bb13a8cb2a923cfed04f1d801cdd6e023de1bc57 | |
parent | 8c112dfe7adb26cad7f10c3f0919234a3ffd7b70 (diff) |
add activations to functional
-rw-r--r-- | functional.lua | 71 | ||||
-rw-r--r-- | test/test.lua | 37 |
2 files changed, 108 insertions, 0 deletions
diff --git a/functional.lua b/functional.lua index 24c6030..39f8589 100644 --- a/functional.lua +++ b/functional.lua @@ -380,3 +380,74 @@ cudnn.functional.AveragePooling2D_updateGradInput = function(handle, input, outp cudnn.functional.Pooling_updateGradInput(handle, 'CUDNN_POOLING_AVERAGE', input, output, gradOutput, gradInput, kH, kW, dH, dW, padH, padW, ceil_mode); end + +local function createPointwiseDescriptors(mode, input, output) + local activDesc = ffi.new('struct cudnnActivationStruct*[1]') + errcheck('cudnnCreateActivationDescriptor', activDesc) + errcheck('cudnnSetActivationDescriptor', activDesc[0], mode, 'CUDNN_PROPAGATE_NAN', 0.0); + + local function destroyADesc(a) + if (a[0]) then + errcheck('cudnnDestroyActivationDescriptor', a[0]); + a[0] = nil + end + end + ffi.gc(activDesc, destroyADesc) + + local nElem = input:nElement() + local iDesc = cudnn.toDescriptor(input:view(1,1,1,nElem)) + return activDesc, iDesc +end + +local function pointwise_updateOutput(handle, mode, input, output) + local activDesc, iDesc = createPointwiseDescriptors(mode, input, output) + errcheck('cudnnActivationForward', + handle, activDesc[0], + cudnn.scalar(input, 1), + iDesc[0], input:data(), + cudnn.scalar(input, 0), + iDesc[0], output:data()); +end + +local function pointwise_updateGradInput(handle, mode, input, output, gradOutput, gradInput) + local activDesc, iDesc = createPointwiseDescriptors(mode, input, output) + errcheck('cudnnActivationBackward', + handle, activDesc[0], + cudnn.scalar(input, 1), + iDesc[0], output:data(), + iDesc[0], gradOutput:data(), + iDesc[0], input:data(), + cudnn.scalar(input, 0), + iDesc[0], gradInput:data()); +end + +cudnn.functional.ReLU_updateOutput = function(handle, input, output) + output:resizeAs(input) + pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_RELU', input, output) +end + +cudnn.functional.ReLU_updateGradInput = function(handle, input, output, gradOutput, gradInput) + gradInput:resizeAs(input) + pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_RELU', input, output, gradOutput, gradInput) +end + +cudnn.functional.Tanh_updateOutput = function(handle, input, output) + output:resizeAs(input) + pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_TANH', input, output) +end + +cudnn.functional.Tanh_updateGradInput = function(handle, input, output, gradOutput, gradInput) + gradInput:resizeAs(input) + pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_TANH', input, output, gradOutput, gradInput) +end + +cudnn.functional.Sigmoid_updateOutput = function(handle, input, output) + output:resizeAs(input) + pointwise_updateOutput(handle, 'CUDNN_ACTIVATION_SIGMOID', input, output) +end + +cudnn.functional.Sigmoid_updateGradInput = function(handle, input, output, gradOutput, gradInput) + gradInput:resizeAs(input) + pointwise_updateGradInput(handle, 'CUDNN_ACTIVATION_SIGMOID', input, output, gradOutput, gradInput) +end + diff --git a/test/test.lua b/test/test.lua index 7986e9f..e3ed802 100644 --- a/test/test.lua +++ b/test/test.lua @@ -847,6 +847,43 @@ function cudnntest.functional_maxpooling2d() testparams.precision_forward, 'error on updateGradInput ') end +local function test_functional_activation(mode, module) + local a = module:cuda() + local input = torch.randn(10,3,10,10):cuda() + a:forward(input) + local output = a.output:clone():normal() + local gradOutput = a.output:clone():normal() + local gradInput = a:updateGradInput(input, gradOutput):clone():normal() + cudnn.functional[mode.forward](cudnn.getHandle(), input, output) + mytester:assertlt((output - a.output):abs():max(), + testparams.precision_forward, 'error on forward ') + cudnn.functional[mode.backward](cudnn.getHandle(), input, output, + gradOutput, gradInput) + mytester:assertlt((gradInput - a.gradInput):abs():max(), + testparams.precision_forward, 'error on updateGradInput ') +end + +function cudnntest.functional_relu() + test_functional_activation({ + forward = 'ReLU_updateOutput', + backward = 'ReLU_updateGradInput', + }, cudnn.ReLU()) +end + +function cudnntest.functional_tanh() + test_functional_activation({ + forward = 'Tanh_updateOutput', + backward = 'Tanh_updateGradInput', + }, cudnn.Tanh()) +end + +function cudnntest.functional_sigmoid() + test_functional_activation({ + forward = 'Sigmoid_updateOutput', + backward = 'Sigmoid_updateGradInput', + }, cudnn.Sigmoid()) +end + torch.setdefaulttensortype('torch.FloatTensor') math.randomseed(os.time()) mytester = torch.Tester() |