diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-19 02:36:12 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-19 23:28:13 +0300 |
commit | ad958c0e268d876ee4d713510b8c3ef83b37bca0 (patch) | |
tree | 0defbe1196f778c9fb3f79f5f6e7a1da9ae92cda /SpatialSoftMax.lua | |
parent | d290c4cb9d632120d3fba97caefb3afb961081bf (diff) |
everything works with R2. all unit tests pass. Maxpooling has free zero-padding
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r-- | SpatialSoftMax.lua | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 3a4106d..87af4d5 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -38,12 +38,17 @@ function SpatialSoftMax:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function SpatialSoftMax:updateOutput(input) self:createIODescriptors(input) errcheck('cudnnSoftmaxForward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -55,8 +60,10 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end |