Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-11 01:03:02 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-11 01:03:02 +0400
commita8bf53f8738fd319568593718280343c3ebc93e6 (patch)
tree58f632560c26f31c9d87e81d6f14c546c3020bcd /Tanh.lua
first commit
Diffstat (limited to 'Tanh.lua')
-rw-r--r--Tanh.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/Tanh.lua b/Tanh.lua
new file mode 100644
index 0000000..d1fbcf8
--- /dev/null
+++ b/Tanh.lua
@@ -0,0 +1,39 @@
+local Tanh, parent = torch.class('cudnn.Tanh','nn.Module')
+local ffi = require 'ffi'
+local C = cudnn.C
+local errcheck = cudnn.errcheck
+
+function Tanh:__init()
+ parent.__init(self)
+ self.iSize = torch.LongStorage(4):fill(0)
+end
+
+function Tanh:createIODescriptors(input)
+ if input:size(1) ~= self.iSize:size(1) or input:size(2) ~= self.iSize:size(2)
+ or input:size(3) ~= self.iSize:size(3) or input:size(4) ~= self.iSize:size(4) then
+ self.gradInput:resizeAs(input)
+ self.output:resizeAs(input)
+ self.iDesc = cudnn.toDescriptor(input)
+ self.oDesc = cudnn.toDescriptor(self.output)
+ end
+end
+
+function Tanh:updateOutput(input)
+ assert(input:dim() == 4 and input:isContiguous());
+ self:createIODescriptors(input)
+ errcheck('cudnnActivationForward', cudnn.handle[0], 'CUDNN_ACTIVATION_TANH',
+ self.iDesc[0], input:data(),
+ self.oDesc[0], self.output:data());
+ return self.output
+end
+
+function Tanh:updateGradInput(input, gradOutput)
+ assert(input:dim() == 4 and input:isContiguous());
+ assert(gradOutput:dim() == 4 and gradOutput:isContiguous());
+ errcheck('cudnnActivationBackward', cudnn.handle[0], 'CUDNN_ACTIVATION_TANH',
+ self.oDesc[0], self.output:data(),
+ self.oDesc[0], gradOutput:data(),
+ self.iDesc[0], input:data(),
+ self.iDesc[0], self.gradInput:data());
+ return self.gradInput
+end