Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 04:24:48 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 04:24:48 +0300
commiteb6548a0c30db70465de4779d866bfac781ec0b1 (patch)
tree5f5a5248a6af77163c3bdf28daafcb1114d080e3
parentd1f66cbfafcf16dc1ed5bc62872aab2f0fe1f457 (diff)
nn.Collapse
-rw-r--r--Collapse.lua30
-rwxr-xr-xdoc/simple.md24
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua16
4 files changed, 69 insertions, 2 deletions
diff --git a/Collapse.lua b/Collapse.lua
new file mode 100644
index 0000000..a088608
--- /dev/null
+++ b/Collapse.lua
@@ -0,0 +1,30 @@
+local Collapse, parent = torch.class('nn.Collapse', 'nn.Module')
+
+-- collapses non-batch dims
+function Collapse:__init(nInputDim)
+ parent.__init(self)
+ self.nInputDim = nInputDim
+end
+
+function Collapse:updateOutput(input)
+ if not input:isContiguous() then
+ self._input = self._input or input.new()
+ self._input:resize(input:size()):copy(input)
+ input = self._input
+ end
+ if input:dim() > self.nInputDim then
+ self.output:view(input,input:size(1),-1)
+ else
+ self.output:view(input,-1)
+ end
+ return self.output
+end
+
+function Collapse:updateGradInput(input, gradOutput)
+ self.gradInput:view(gradOutput, input:size())
+ return self.gradInput
+end
+
+function Collapse:clearState()
+ self._input = nil
+end
diff --git a/doc/simple.md b/doc/simple.md
index 5e31080..3d08167 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -61,7 +61,8 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode;
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
* [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
- * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`.
+ * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`;
+ * [Collapse](#nn.Collapse) : just like `nn.View(-1)`.
<a name="nn.Linear"></a>
## Linear ##
@@ -1029,6 +1030,8 @@ Example 2:
[torch.LongStorage of size 2]
```
+For collapsing non-batch dims, check out [nn.Collapse](#nn.Collapse).
+
<a name="nn.Contiguous"></a>
## Contiguous ##
@@ -1783,4 +1786,21 @@ print(module:backward(input, gradOutput))
[torch.DoubleTensor of size 2]
```
-The module zeros the `gradInput` but forwards the `input` as-is. \ No newline at end of file
+The module zeros the `gradInput` but forwards the `input` as-is.
+
+<a name='nn.Collapse'></a>
+## Collapse ##
+
+```lua
+module = nn.Collapse(nInputDim)
+```
+
+This module is the equivalent of:
+```
+view = nn.View(-1)
+view:setNumInputDim(nInputDim)
+```
+
+It collapses all non-batch dimensions. This is useful for converting
+a spatial feature map to the single dimension required by a dense
+hidden layer like Linear. \ No newline at end of file
diff --git a/init.lua b/init.lua
index 447d357..bd4a5b0 100755
--- a/init.lua
+++ b/init.lua
@@ -172,6 +172,7 @@ require('nn.NarrowTable')
require('nn.MapTable')
require('nn.ZipTable')
require('nn.ZipTableOneToMany')
+require('nn.Collapse')
require('nn.Criterion')
require('nn.MSECriterion')
diff --git a/test.lua b/test.lua
index 16bae09..e776f26 100755
--- a/test.lua
+++ b/test.lua
@@ -8596,6 +8596,22 @@ function nntest.ZipTableOneToMany()
mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21")
end
+function nntest.Collapse()
+ local c = nn.Collapse(3)
+ local input = torch.randn(8,3,4,5)
+ local output = c:forward(input)
+ mytester:assertTensorEq(input:view(8,-1), output, 0.000001, "Collapse:forward")
+ local gradInput = c:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.000001, "Collapse:backward")
+ mytester:assertTableEq(gradInput:size():totable(), input:size():totable(), 0.000001, "Collapse:backward size")
+ local input2 = input:transpose(1,4)
+ local output2 = c:forward(input2)
+ mytester:assertTensorEq(input2:contiguous():view(5,-1), output2, 0.000001, "Collapse:forward non-contiguous")
+ local gradInput2 = c:backward(input2, output2)
+ mytester:assertTensorEq(gradInput2, input2, 0.000001, "Collapse:backward non-contiguous")
+ mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous")
+end
+
mytester:add(nntest)