diff options
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r-- | SpatialSoftMax.lua | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 3a4106d..87af4d5 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -38,12 +38,17 @@ function SpatialSoftMax:createIODescriptors(input) end end +local one = torch.FloatTensor({1}); +local zero = torch.FloatTensor({0}); + function SpatialSoftMax:updateOutput(input) self:createIODescriptors(input) errcheck('cudnnSoftmaxForward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.iDesc[0], input:data(), + zero:data(), self.oDesc[0], self.output:data()); return self.output end @@ -55,8 +60,10 @@ function SpatialSoftMax:updateGradInput(input, gradOutput) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], self.algorithm, self.mode, + one:data(), self.oDesc[0], self.output:data(), self.oDesc[0], gradOutput:data(), + zero:data(), self.iDesc[0], self.gradInput:data()); return self.gradInput end |