diff options
author | soumith <soumith@gmail.com> | 2016-08-06 21:48:57 +0300 |
---|---|---|
committer | soumith <soumith@gmail.com> | 2016-08-06 21:50:31 +0300 |
commit | 7afb2414753b9f34ceee6c5cca022d9eb2652a83 (patch) | |
tree | ebe261c7042d02b03b3f9607d515f3609edd2d55 /Pointwise.lua | |
parent | 327e6af4bcfbcbe1c4221b6dd9190602f411a2c3 (diff) |
working double precision
Diffstat (limited to 'Pointwise.lua')
-rw-r--r-- | Pointwise.lua | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/Pointwise.lua b/Pointwise.lua index 9cfe0f2..76e6499 100644 --- a/Pointwise.lua +++ b/Pointwise.lua @@ -37,17 +37,14 @@ function Pointwise:createIODescriptors(input) end -local one = torch.FloatTensor({1}); -local zero = torch.FloatTensor({0}); - function Pointwise:updateOutput(input) self:createIODescriptors(input) if self.inplace then self.output:set(input) end errcheck('cudnnActivationForward', cudnn.getHandle(), self.activDesc[0], - one:data(), + cudnn.scalar(input, 1), self.iDesc[0], input:data(), - zero:data(), + cudnn.scalar(input, 0), self.iDesc[0], self.output:data()); return self.output end @@ -67,11 +64,11 @@ function Pointwise:updateGradInput(input, gradOutput) end errcheck('cudnnActivationBackward', cudnn.getHandle(), self.activDesc[0], - one:data(), + cudnn.scalar(input, 1), self.iDesc[0], self.output:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], input:data(), - zero:data(), + cudnn.scalar(input, 0), self.iDesc[0], self.gradInput:data()); return self.gradInput end |