diff options
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r-- | SpatialSoftMax.lua | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 87af4d5..97c1e38 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -14,7 +14,15 @@ end function SpatialSoftMax:createIODescriptors(input) local batch = true - if input:dim() == 3 then + local singleDim = false + if input:dim() == 1 then + singleDim = true + batch = false + input = input:view(1, input:size(1), 1, 1) + elseif input:dim() == 2 then + singleDim = true + input = input:view(input:size(1), input:size(2), 1, 1) + elseif input:dim() == 3 then input = input:view(1, input:size(1), input:size(2), input:size(3)) batch = false end @@ -27,13 +35,19 @@ function SpatialSoftMax:createIODescriptors(input) self.output:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) - if not batch then + if not singleDim and not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) + elseif singleDim and not batch then + self.gradInput = self.gradInput:view(self.gradInput:size(2)) + self.output = self.output:view(self.output:size(2)) + elseif singleDim and batch then + self.gradInput = self.gradInput:view(self.gradInput:size(1), self.gradInput:size(2)) + self.output = self.output:view(self.output:size(1), self.output:size(2)) end end end @@ -54,8 +68,7 @@ function SpatialSoftMax:updateOutput(input) end function SpatialSoftMax:updateGradInput(input, gradOutput) - assert((gradOutput:dim() == 4 or gradOutput:dim() == 3) - and gradOutput:isContiguous()); + assert(gradOutput:isContiguous()); self:createIODescriptors(input) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], |