diff options
-rw-r--r-- | SpatialSoftMax.lua | 21 | ||||
-rw-r--r-- | test/test.lua | 14 |
2 files changed, 21 insertions, 14 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua index 87af4d5..97c1e38 100644 --- a/SpatialSoftMax.lua +++ b/SpatialSoftMax.lua @@ -14,7 +14,15 @@ end function SpatialSoftMax:createIODescriptors(input) local batch = true - if input:dim() == 3 then + local singleDim = false + if input:dim() == 1 then + singleDim = true + batch = false + input = input:view(1, input:size(1), 1, 1) + elseif input:dim() == 2 then + singleDim = true + input = input:view(input:size(1), input:size(2), 1, 1) + elseif input:dim() == 3 then input = input:view(1, input:size(1), input:size(2), input:size(3)) batch = false end @@ -27,13 +35,19 @@ function SpatialSoftMax:createIODescriptors(input) self.output:resizeAs(input) self.iDesc = cudnn.toDescriptor(input) self.oDesc = cudnn.toDescriptor(self.output) - if not batch then + if not singleDim and not batch then self.gradInput = self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3), self.gradInput:size(4)) self.output = self.output:view(self.output:size(2), self.output:size(3), self.output:size(4)) + elseif singleDim and not batch then + self.gradInput = self.gradInput:view(self.gradInput:size(2)) + self.output = self.output:view(self.output:size(2)) + elseif singleDim and batch then + self.gradInput = self.gradInput:view(self.gradInput:size(1), self.gradInput:size(2)) + self.output = self.output:view(self.output:size(1), self.output:size(2)) end end end @@ -54,8 +68,7 @@ function SpatialSoftMax:updateOutput(input) end function SpatialSoftMax:updateGradInput(input, gradOutput) - assert((gradOutput:dim() == 4 or gradOutput:dim() == 3) - and gradOutput:isContiguous()); + assert(gradOutput:isContiguous()); self:createIODescriptors(input) errcheck('cudnnSoftmaxBackward', cudnn.handle[cutorch.getDevice()-1], diff --git a/test/test.lua b/test/test.lua index 3ca3712..cc634c7 100644 --- a/test/test.lua +++ b/test/test.lua @@ -499,16 +499,12 @@ function cudnntest.Sigmoid_batch() end function cudnntest.SoftMax_single() - local from = math.random(1,32) - local outi = math.random(1,64) - local outj = math.random(1,64) - local ini = outi - local inj = outj - local input = torch.randn(from,inj,ini):cuda() - local gradOutput = torch.randn(from,outj,outi):cuda() + local sz = math.random(1,64) + local input = torch.randn(sz):cuda() + local gradOutput = torch.randn(sz):cuda() local sconv = nn.SoftMax():cuda() - local groundtruth = sconv:forward(input:view(-1)) + local groundtruth = sconv:forward(input) local groundgrad = sconv:backward(input, gradOutput) cutorch.synchronize() local gconv = cudnn.SoftMax():cuda() @@ -521,8 +517,6 @@ function cudnntest.SoftMax_single() local rescuda = gconv:forward(input) local resgrad = gconv:backward(input, gradOutput) cutorch.synchronize() - mytester:asserteq(rescuda:dim(), 3, 'error in dimension') - mytester:asserteq(resgrad:dim(), 3, 'error in dimension') local error = rescuda:float() - groundtruth:float() local errmax = error:abs():max() if (errmax ~= errmax) then |