diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 04:24:48 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 04:24:48 +0300 |
commit | eb6548a0c30db70465de4779d866bfac781ec0b1 (patch) | |
tree | 5f5a5248a6af77163c3bdf28daafcb1114d080e3 | |
parent | d1f66cbfafcf16dc1ed5bc62872aab2f0fe1f457 (diff) |
nn.Collapse
-rw-r--r-- | Collapse.lua | 30 | ||||
-rwxr-xr-x | doc/simple.md | 24 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 16 |
4 files changed, 69 insertions, 2 deletions
diff --git a/Collapse.lua b/Collapse.lua new file mode 100644 index 0000000..a088608 --- /dev/null +++ b/Collapse.lua @@ -0,0 +1,30 @@ +local Collapse, parent = torch.class('nn.Collapse', 'nn.Module') + +-- collapses non-batch dims +function Collapse:__init(nInputDim) + parent.__init(self) + self.nInputDim = nInputDim +end + +function Collapse:updateOutput(input) + if not input:isContiguous() then + self._input = self._input or input.new() + self._input:resize(input:size()):copy(input) + input = self._input + end + if input:dim() > self.nInputDim then + self.output:view(input,input:size(1),-1) + else + self.output:view(input,-1) + end + return self.output +end + +function Collapse:updateGradInput(input, gradOutput) + self.gradInput:view(gradOutput, input:size()) + return self.gradInput +end + +function Collapse:clearState() + self._input = nil +end diff --git a/doc/simple.md b/doc/simple.md index 5e31080..3d08167 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -61,7 +61,8 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode; * [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding; * [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging); - * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`. + * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`; + * [Collapse](#nn.Collapse) : just like `nn.View(-1)`. <a name="nn.Linear"></a> ## Linear ## @@ -1029,6 +1030,8 @@ Example 2: [torch.LongStorage of size 2] ``` +For collapsing non-batch dims, check out [nn.Collapse](#nn.Collapse). + <a name="nn.Contiguous"></a> ## Contiguous ## @@ -1783,4 +1786,21 @@ print(module:backward(input, gradOutput)) [torch.DoubleTensor of size 2] ``` -The module zeros the `gradInput` but forwards the `input` as-is.
\ No newline at end of file +The module zeros the `gradInput` but forwards the `input` as-is. + +<a name='nn.Collapse'></a> +## Collapse ## + +```lua +module = nn.Collapse(nInputDim) +``` + +This module is the equivalent of: +``` +view = nn.View(-1) +view:setNumInputDim(nInputDim) +``` + +It collapses all non-batch dimensions. This is useful for converting +a spatial feature map to the single dimension required by a dense +hidden layer like Linear.
\ No newline at end of file @@ -172,6 +172,7 @@ require('nn.NarrowTable') require('nn.MapTable') require('nn.ZipTable') require('nn.ZipTableOneToMany') +require('nn.Collapse') require('nn.Criterion') require('nn.MSECriterion') @@ -8596,6 +8596,22 @@ function nntest.ZipTableOneToMany() mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21") end +function nntest.Collapse() + local c = nn.Collapse(3) + local input = torch.randn(8,3,4,5) + local output = c:forward(input) + mytester:assertTensorEq(input:view(8,-1), output, 0.000001, "Collapse:forward") + local gradInput = c:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.000001, "Collapse:backward") + mytester:assertTableEq(gradInput:size():totable(), input:size():totable(), 0.000001, "Collapse:backward size") + local input2 = input:transpose(1,4) + local output2 = c:forward(input2) + mytester:assertTensorEq(input2:contiguous():view(5,-1), output2, 0.000001, "Collapse:forward non-contiguous") + local gradInput2 = c:backward(input2, output2) + mytester:assertTensorEq(gradInput2, input2, 0.000001, "Collapse:backward non-contiguous") + mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous") +end + mytester:add(nntest) |