diff options
author | Peng Sun <pengsun000@gamil.com> | 2016-03-06 20:32:40 +0300 |
---|---|---|
committer | Peng Sun <pengsun000@gamil.com> | 2016-03-06 20:32:40 +0300 |
commit | 744166d883b346a6a1f8fcf7f1464363b9a8fc93 (patch) | |
tree | e65e2ce61ffb86124b42bd32bf7fbbb900516716 /Unsqueeze.lua | |
parent | 04cea9106ac2635b7d75c3d95cb88b10d955b8d0 (diff) |
Modify the implementation by using nn.utils.addSingletonDimension
Diffstat (limited to 'Unsqueeze.lua')
-rw-r--r-- | Unsqueeze.lua | 62 |
1 files changed, 23 insertions, 39 deletions
diff --git a/Unsqueeze.lua b/Unsqueeze.lua index f4b1a11..ab98023 100644 --- a/Unsqueeze.lua +++ b/Unsqueeze.lua @@ -1,44 +1,12 @@ local Unsqueeze, parent = torch.class('nn.Unsqueeze', 'nn.Module') -local function _checkPos(pos) - local pos = pos or error('the position to insert singleton dim not specified') - assert(type(pos) == 'number') - assert(pos > 0) - return pos -end - local function _assertTensor(t) - assert(t or torch.isTensor(t), "Unsqueeze only works on tensor") -end - -local function _unsqueezeSize(inputSize, numInputDims, pos) - -- is batchMode? - local offsetDim = #inputSize - numInputDims - assert(offsetDim >= 0, "numInputDims must be <= input:dim()") - -- pos overflow? - assert(offsetDim + pos <= 1 + #inputSize, - ("offsetDim + pos (=%d + %d) exceeds input dim (=%d) by more than 1"):format( - offsetDim, pos, #inputSize) - ) - - local outputSize = {} - -- left to pos - for i = 1, offsetDim + pos-1 do - table.insert(outputSize, inputSize[i]) - end - -- this pos: the singleton dim - table.insert(outputSize, 1) - -- right to pose - for i = offsetDim + pos, #inputSize do - table.insert(outputSize, inputSize[i]) - end - - return outputSize + assert(torch.isTensor(t), "This module only works on tensor") end function Unsqueeze:__init(pos, numInputDims) parent.__init(self) - self.pos = _checkPos(pos) + self.pos = pos or error('the position to insert singleton dim not specified') self:setNumInputDims(numInputDims) end @@ -49,20 +17,36 @@ end function Unsqueeze:updateOutput(input) _assertTensor(input) - - local inputSize = input:size():totable() - local numInputDims = self.numInputDims or input:dim() - local outputSize = _unsqueezeSize(inputSize, numInputDims, self.pos) - self.output = input:view( table.unpack(outputSize) ) + local actualPos = self:_getActualPosition(input) + self.output = nn.utils.addSingletonDimension(input, actualPos) return self.output end function Unsqueeze:updateGradInput(input, gradOutput) + _assertTensor(input) + _assertTensor(gradOutput) assert(input:nElement() == gradOutput:nElement()) + self.gradInput = gradOutput:view(input:size()) return self.gradInput end function Unsqueeze:__tostring__() return torch.type(self)..'(dim ' .. self.pos .. ')' +end + +function Unsqueeze:_getActualPosition(input) + -- get valid dimesion offset for batchMode (if any) + local inputDim = input:dim() -- data batch dim + self.numInputDims = self.numInputDims or inputDim -- feature map dim + local offsetDim = inputDim - self.numInputDims + assert(offsetDim >= 0, "input feature map dim (numInputDims) must be <= input:dim()") + + -- the actual position; clearer error message for batchMode (if any) + local actualPos = self.pos + offsetDim + assert(actualPos >= 1 and actualPos <= (inputDim + 1), + ("Invalid position: %d. input:dim() is %d, input feature map dim (numInputDims) is %d.") + :format(self.pos, inputDim, self.numInputDims) + ) + return actualPos end
\ No newline at end of file |