Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-12-19 02:36:12 +0300
committerSoumith Chintala <soumith@gmail.com>2014-12-19 23:28:13 +0300
commitad958c0e268d876ee4d713510b8c3ef83b37bca0 (patch)
tree0defbe1196f778c9fb3f79f5f6e7a1da9ae92cda /SpatialSoftMax.lua
parentd290c4cb9d632120d3fba97caefb3afb961081bf (diff)
everything works with R2. all unit tests pass. Maxpooling has free zero-padding
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r--SpatialSoftMax.lua7
1 files changed, 7 insertions, 0 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua
index 3a4106d..87af4d5 100644
--- a/SpatialSoftMax.lua
+++ b/SpatialSoftMax.lua
@@ -38,12 +38,17 @@ function SpatialSoftMax:createIODescriptors(input)
end
end
+local one = torch.FloatTensor({1});
+local zero = torch.FloatTensor({0});
+
function SpatialSoftMax:updateOutput(input)
self:createIODescriptors(input)
errcheck('cudnnSoftmaxForward',
cudnn.handle[cutorch.getDevice()-1],
self.algorithm, self.mode,
+ one:data(),
self.iDesc[0], input:data(),
+ zero:data(),
self.oDesc[0], self.output:data());
return self.output
end
@@ -55,8 +60,10 @@ function SpatialSoftMax:updateGradInput(input, gradOutput)
errcheck('cudnnSoftmaxBackward',
cudnn.handle[cutorch.getDevice()-1],
self.algorithm, self.mode,
+ one:data(),
self.oDesc[0], self.output:data(),
self.oDesc[0], gradOutput:data(),
+ zero:data(),
self.iDesc[0], self.gradInput:data());
return self.gradInput
end