Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2016-08-06 21:48:57 +0300
committersoumith <soumith@gmail.com>2016-08-06 21:50:31 +0300
commit7afb2414753b9f34ceee6c5cca022d9eb2652a83 (patch)
treeebe261c7042d02b03b3f9607d515f3609edd2d55 /Pointwise.lua
parent327e6af4bcfbcbe1c4221b6dd9190602f411a2c3 (diff)
working double precision
Diffstat (limited to 'Pointwise.lua')
-rw-r--r--Pointwise.lua11
1 files changed, 4 insertions, 7 deletions
diff --git a/Pointwise.lua b/Pointwise.lua
index 9cfe0f2..76e6499 100644
--- a/Pointwise.lua
+++ b/Pointwise.lua
@@ -37,17 +37,14 @@ function Pointwise:createIODescriptors(input)
end
-local one = torch.FloatTensor({1});
-local zero = torch.FloatTensor({0});
-
function Pointwise:updateOutput(input)
self:createIODescriptors(input)
if self.inplace then self.output:set(input) end
errcheck('cudnnActivationForward',
cudnn.getHandle(), self.activDesc[0],
- one:data(),
+ cudnn.scalar(input, 1),
self.iDesc[0], input:data(),
- zero:data(),
+ cudnn.scalar(input, 0),
self.iDesc[0], self.output:data());
return self.output
end
@@ -67,11 +64,11 @@ function Pointwise:updateGradInput(input, gradOutput)
end
errcheck('cudnnActivationBackward',
cudnn.getHandle(), self.activDesc[0],
- one:data(),
+ cudnn.scalar(input, 1),
self.iDesc[0], self.output:data(),
self.iDesc[0], gradOutput:data(),
self.iDesc[0], input:data(),
- zero:data(),
+ cudnn.scalar(input, 0),
self.iDesc[0], self.gradInput:data());
return self.gradInput
end