From 71d092dc0474321803efd54473377bf8920e76f9 Mon Sep 17 00:00:00 2001 From: soumith Date: Thu, 26 Feb 2015 16:09:00 -0800 Subject: cudnn.SoftMax now supports 1D and 2D inputs --- SpatialSoftMax.lua | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'SpatialSoftMax.lua') diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 87af4d5..97c1e38 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -14,7 +14,15 @@ end function SpatialSoftMax:createIODescriptors(input) local batch = true - if input:dim() == 3 then + local singleDim = false + if input:dim() == 1 then + singleDim = true + batch = false + input = input:view(1, input:size(1), 1, 1) + elseif input:dim() == 2 then + singleDim = true + input = input:view(input:size(1), input:size(2), 1, 1) + elseif input:dim() == 3 then input = input:view(1, input:size(1), input:size(2), input:size(3)) batch = false end @@ -27,13 +35,19 @@ function SpatialSoftMax:createIODescriptors(input) self.output:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) - if not batch then + if not singleDim and not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) + elseif singleDim and not batch then + self.gradInput = self.gradInput:view(self.gradInput:size(2)) + self.output = self.output:view(self.output:size(2)) + elseif singleDim and batch then + self.gradInput = self.gradInput:view(self.gradInput:size(1), self.gradInput:size(2)) + self.output = self.output:view(self.output:size(1), self.output:size(2)) end end end @@ -54,8 +68,7 @@ function SpatialSoftMax:updateOutput(input) end function SpatialSoftMax:updateGradInput(input, gradOutput) - assert((gradOutput:dim() == 4 or gradOutput:dim() == 3) - and gradOutput:isContiguous()); + assert(gradOutput:isContiguous()); self:createIODescriptors(input) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], -- cgit v1.2.3