diff options
author | SeanNaren <taz838@hotmail.co.uk> | 2016-06-11 14:43:04 +0300 |
---|---|---|
committer | SeanNaren <taz838@hotmail.co.uk> | 2016-06-11 14:43:04 +0300 |
commit | ae1c144739c18b50169717d8517b3b180dbc3c93 (patch) | |
tree | c63c6b8c583c7411e529ee82df154a14a1be9e07 | |
parent | dbb1a44bbe46e394fee423647101db72a0310f14 (diff) |
Added clipped ReLU
-rw-r--r-- | ClippedReLU.lua | 13 | ||||
-rw-r--r-- | Pointwise.lua | 2 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 10 |
4 files changed, 25 insertions, 1 deletions
diff --git a/ClippedReLU.lua b/ClippedReLU.lua new file mode 100644 index 0000000..866e62a --- /dev/null +++ b/ClippedReLU.lua @@ -0,0 +1,13 @@ +local ClippedReLU, parent = torch.class('cudnn.ClippedReLU','cudnn._Pointwise') + +function ClippedReLU:__init(inplace, ceiling) + parent.__init(self) + self.inplace = inplace + assert(ceiling, "No ceiling was given to ClippedReLU") + self.ceiling = ceiling +end + +function ClippedReLU:updateOutput(input) + if not self.mode then self.mode = 'CUDNN_ACTIVATION_CLIPPED_RELU' end + return parent.updateOutput(self, input) +end
\ No newline at end of file diff --git a/Pointwise.lua b/Pointwise.lua index 93298ad..9cfe0f2 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -18,7 +18,7 @@ function Pointwise:createIODescriptors(input) if not self.activDesc then self.activDesc = ffi.new('struct cudnnActivationStruct*[1]') errcheck('cudnnCreateActivationDescriptor', self.activDesc) - errcheck('cudnnSetActivationDescriptor', self.activDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', 0.0); + errcheck('cudnnSetActivationDescriptor', self.activDesc[0], self.mode, 'CUDNN_PROPAGATE_NAN', self.ceiling or 0.0); local function destroyADesc(a) if (a[0]) then @@ -110,6 +110,7 @@ require('cudnn.VolumetricMaxPooling') require('cudnn.VolumetricAveragePooling') require('cudnn.Pointwise') require('cudnn.ReLU') +require('cudnn.ClippedReLU') require('cudnn.Tanh') require('cudnn.Sigmoid') require('cudnn.SpatialSoftMax') diff --git a/test/test.lua b/test/test.lua index a448317..aac45b4 100644 --- a/test/test.lua +++ b/test/test.lua @@ -962,6 +962,16 @@ function cudnntest.ReLU_batch() nonlinBatch('ReLU') end +function cudnntest.ClippedReLU_single() + local input = torch.randn(1, 32):cuda() + local ceiling = 0.1 + local module = cudnn.ClippedReLU(true, ceiling):cuda() + local output = module:forward(input) + local expectedOutput = input:clone() + expectedOutput[expectedOutput:ge(ceiling)] = ceiling + mytester:assertTensorEq(output, expectedOutput) +end + function cudnntest.Tanh_single() nonlinSingle('Tanh') end |