Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSeanNaren <taz838@hotmail.co.uk>2016-06-11 14:43:04 +0300
committerSeanNaren <taz838@hotmail.co.uk>2016-06-11 14:43:04 +0300
commitae1c144739c18b50169717d8517b3b180dbc3c93 (patch)
treec63c6b8c583c7411e529ee82df154a14a1be9e07
parentdbb1a44bbe46e394fee423647101db72a0310f14 (diff)
Added clipped ReLU
-rw-r--r--ClippedReLU.lua13
-rw-r--r--Pointwise.lua2
-rw-r--r--init.lua1
-rw-r--r--test/test.lua10
4 files changed, 25 insertions, 1 deletions
diff --git a/ClippedReLU.lua b/ClippedReLU.lua
new file mode 100644
index 0000000..866e62a
--- /dev/null
+++ b/ClippedReLU.lua
@@ -0,0 +1,13 @@
+local ClippedReLU, parent = torch.class('cudnn.ClippedReLU','cudnn._Pointwise')
+
+function ClippedReLU:__init(inplace, ceiling)
+ parent.__init(self)
+ self.inplace = inplace
+ assert(ceiling, "No ceiling was given to ClippedReLU")
+ self.ceiling = ceiling
+end
+
+function ClippedReLU:updateOutput(input)
+ if not self.mode then self.mode = 'CUDNN_ACTIVATION_CLIPPED_RELU' end
+ return parent.updateOutput(self, input)
+end \ No newline at end of file
diff --git a/Pointwise.lua b/Pointwise.lua
index 93298ad..9cfe0f2 100644
--- a/Pointwise.lua
+++ b/Pointwise.lua
@@ -18,7 +18,7 @@ function Pointwise:createIODescriptors(input)
if not self.activDesc then
self.activDesc = ffi.new('struct cudnnActivationStruct*[1]')
errcheck('cudnnCreateActivationDescriptor', self.activDesc)
- errcheck('cudnnSetActivationDescriptor', self.activDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', 0.0);
+ errcheck('cudnnSetActivationDescriptor', self.activDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', self.ceiling or 0.0);
local function destroyADesc(a)
if (a[0]) then
diff --git a/init.lua b/init.lua
index 318570b..f200b95 100644
--- a/init.lua
+++ b/init.lua
@@ -110,6 +110,7 @@ require('cudnn.VolumetricMaxPooling')
require('cudnn.VolumetricAveragePooling')
require('cudnn.Pointwise')
require('cudnn.ReLU')
+require('cudnn.ClippedReLU')
require('cudnn.Tanh')
require('cudnn.Sigmoid')
require('cudnn.SpatialSoftMax')
diff --git a/test/test.lua b/test/test.lua
index a448317..aac45b4 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -962,6 +962,16 @@ function cudnntest.ReLU_batch()
nonlinBatch('ReLU')
end
+function cudnntest.ClippedReLU_single()
+ local input = torch.randn(1, 32):cuda()
+ local ceiling = 0.1
+ local module = cudnn.ClippedReLU(true, ceiling):cuda()
+ local output = module:forward(input)
+ local expectedOutput = input:clone()
+ expectedOutput[expectedOutput:ge(ceiling)] = ceiling
+ mytester:assertTensorEq(output, expectedOutput)
+end
+
function cudnntest.Tanh_single()
nonlinSingle('Tanh')
end