Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-09-26 20:14:36 +0400
committerSoumith Chintala <soumith@gmail.com>2014-09-26 20:47:31 +0400
commit2da6353dc28457c5b8dd758d552026c9aebdbcba (patch)
tree05e4b1ea5985e3f91c48d716004400fb44957231 /Tanh.lua
parent440ea8adeda929efc637316585446eb9996bd6fe (diff)
adding serialization to unit tests and fixing descriptor checks. Fixes #4
Diffstat (limited to 'Tanh.lua')
-rw-r--r--Tanh.lua4
1 files changed, 3 insertions, 1 deletions
diff --git a/Tanh.lua b/Tanh.lua
index 52c2572..ee1e264 100644
--- a/Tanh.lua
+++ b/Tanh.lua
@@ -9,7 +9,8 @@ function Tanh:__init()
end
function Tanh:createIODescriptors(input)
- if input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
+ if not self.iDesc or not self.oDesc or
+ input:size(1) ~= self.iSize[1] or input:size(2) ~= self.iSize[2]
or input:size(3) ~= self.iSize[3] or input:size(4) ~= self.iSize[4] then
self.iSize = input:size()
self.gradInput:resizeAs(input)
@@ -31,6 +32,7 @@ end
function Tanh:updateGradInput(input, gradOutput)
assert(input:dim() == 4 and input:isContiguous());
assert(gradOutput:dim() == 4 and gradOutput:isContiguous());
+ self:createIODescriptors(input)
errcheck('cudnnActivationBackward', cudnn.handle[cutorch.getDevice()-1], 'CUDNN_ACTIVATION_TANH',
self.oDesc[0], self.output:data(),
self.oDesc[0], gradOutput:data(),