Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeng Sun <pengsun000@gamil.com>2016-03-06 20:32:40 +0300
committerPeng Sun <pengsun000@gamil.com>2016-03-06 20:32:40 +0300
commit744166d883b346a6a1f8fcf7f1464363b9a8fc93 (patch)
treee65e2ce61ffb86124b42bd32bf7fbbb900516716 /Unsqueeze.lua
parent04cea9106ac2635b7d75c3d95cb88b10d955b8d0 (diff)
Modify the implementation by using nn.utils.addSingletonDimension
Diffstat (limited to 'Unsqueeze.lua')
-rw-r--r--Unsqueeze.lua62
1 files changed, 23 insertions, 39 deletions
diff --git a/Unsqueeze.lua b/Unsqueeze.lua
index f4b1a11..ab98023 100644
--- a/Unsqueeze.lua
+++ b/Unsqueeze.lua
@@ -1,44 +1,12 @@
local Unsqueeze, parent = torch.class('nn.Unsqueeze', 'nn.Module')
-local function _checkPos(pos)
- local pos = pos or error('the position to insert singleton dim not specified')
- assert(type(pos) == 'number')
- assert(pos > 0)
- return pos
-end
-
local function _assertTensor(t)
- assert(t or torch.isTensor(t), "Unsqueeze only works on tensor")
-end
-
-local function _unsqueezeSize(inputSize, numInputDims, pos)
- -- is batchMode?
- local offsetDim = #inputSize - numInputDims
- assert(offsetDim >= 0, "numInputDims must be <= input:dim()")
- -- pos overflow?
- assert(offsetDim + pos <= 1 + #inputSize,
- ("offsetDim + pos (=%d + %d) exceeds input dim (=%d) by more than 1"):format(
- offsetDim, pos, #inputSize)
- )
-
- local outputSize = {}
- -- left to pos
- for i = 1, offsetDim + pos-1 do
- table.insert(outputSize, inputSize[i])
- end
- -- this pos: the singleton dim
- table.insert(outputSize, 1)
- -- right to pose
- for i = offsetDim + pos, #inputSize do
- table.insert(outputSize, inputSize[i])
- end
-
- return outputSize
+ assert(torch.isTensor(t), "This module only works on tensor")
end
function Unsqueeze:__init(pos, numInputDims)
parent.__init(self)
- self.pos = _checkPos(pos)
+ self.pos = pos or error('the position to insert singleton dim not specified')
self:setNumInputDims(numInputDims)
end
@@ -49,20 +17,36 @@ end
function Unsqueeze:updateOutput(input)
_assertTensor(input)
-
- local inputSize = input:size():totable()
- local numInputDims = self.numInputDims or input:dim()
- local outputSize = _unsqueezeSize(inputSize, numInputDims, self.pos)
- self.output = input:view( table.unpack(outputSize) )
+ local actualPos = self:_getActualPosition(input)
+ self.output = nn.utils.addSingletonDimension(input, actualPos)
return self.output
end
function Unsqueeze:updateGradInput(input, gradOutput)
+ _assertTensor(input)
+ _assertTensor(gradOutput)
assert(input:nElement() == gradOutput:nElement())
+
self.gradInput = gradOutput:view(input:size())
return self.gradInput
end
function Unsqueeze:__tostring__()
return torch.type(self)..'(dim ' .. self.pos .. ')'
+end
+
+function Unsqueeze:_getActualPosition(input)
+ -- get valid dimesion offset for batchMode (if any)
+ local inputDim = input:dim() -- data batch dim
+ self.numInputDims = self.numInputDims or inputDim -- feature map dim
+ local offsetDim = inputDim - self.numInputDims
+ assert(offsetDim >= 0, "input feature map dim (numInputDims) must be <= input:dim()")
+
+ -- the actual position; clearer error message for batchMode (if any)
+ local actualPos = self.pos + offsetDim
+ assert(actualPos >= 1 and actualPos <= (inputDim + 1),
+ ("Invalid position: %d. input:dim() is %d, input feature map dim (numInputDims) is %d.")
+ :format(self.pos, inputDim, self.numInputDims)
+ )
+ return actualPos
end \ No newline at end of file