Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-07-09 22:24:59 +0400
committerSoumith Chintala <soumith@gmail.com>2014-07-09 22:24:59 +0400
commit3f35d9bce7e0335b801f3bfe36d8a86cd53ba4ed (patch)
tree7772ab69640404c847636742d3fb196114cf9ba6
parent0413ddd6dc0a35b5281fcaaebc73144b15f285fa (diff)
parentaba9ee56f1678b0c64ec95d566465ab258d852c2 (diff)
Merge pull request #30 from nicholas-leonard/ElementTable
Element table!
-rw-r--r--ElementTable.lua34
-rw-r--r--doc/table.md95
-rw-r--r--init.lua1
-rw-r--r--test/test.lua43
4 files changed, 173 insertions, 0 deletions
diff --git a/ElementTable.lua b/ElementTable.lua
new file mode 100644
index 0000000..cb3ff0f
--- /dev/null
+++ b/ElementTable.lua
@@ -0,0 +1,34 @@
+local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module')
+
+function ElementTable:__init(index)
+ parent.__init(self)
+ self.index = index
+ self.gradInput = {}
+end
+
+function ElementTable:updateOutput(input)
+ self.output = input[self.index]
+ return self.output
+end
+
+function ElementTable:updateGradInput(input, gradOutput)
+ if #self.gradInput == 0 then
+ local function zeroTableCopy(t1, t2)
+ for k, v in pairs(t2) do
+ if (torch.type(v) == "table") then
+ t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
+ else
+ t1[k] = v:clone():zero()
+ end
+ end
+ return t1
+ end
+ zeroTableCopy(self.gradInput, input)
+ end
+ self.gradInput[self.index] = gradOutput
+ return self.gradInput
+end
+
+function ElementTable:type(type)
+ self.gradInput = {}
+end
diff --git a/doc/table.md b/doc/table.md
index 4117117..60b6dea 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -9,6 +9,7 @@ This allows one to build very rich architectures:
* Table Conversion Modules convert between tables and Tensors:
* [SplitTable](#nn.SplitTable) : splits a Tensor into a table of Tensors;
* [JoinTable](#nn.JoinTable) : joins a table of Tensors into a Tensor;
+ * [ElementTable](#nn.ElementTable) : retrieve one element from a table;
* Pair Modules compute a measure like distance or similarity from a pair (table) of input Tensors :
* [PairwiseDistance](#nn.PairwiseDistance) : outputs the `p`-norm. distance between inputs;
* [DotProduct](#nn.DotProduct) : outputs the dot product (similarity) between inputs;
@@ -375,6 +376,100 @@ for i=1,100 do -- A few steps of training such a network..
end
```
+<a name="nn.ElementTable"/>
+## ElementTable ##
+
+`module` = `ElementTable(index)`
+
+Creates a module that takes a Table as input and outputs the element at index `index`.
+This can be either a Table or a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
+
+The gradients of the non-`index` elements are zeroed Tensors of the same size. This is true regardless of the
+dept of the encapsulated Tensor as the function used internally to do so is recursive.
+
+Example 1:
+```lua
+> input = {torch.randn(2,3), torch.randn(2,1)}
+ [0.0002s]
+> =nn.ElementTable(1):forward(input)
+-0.3060 0.1398 0.2707
+ 0.0576 1.5455 0.0610
+[torch.DoubleTensor of dimension 2x3]
+
+ [0.0002s]
+> =nn.ElementTable(2):forward(input)
+ 2.3080
+-0.2955
+[torch.DoubleTensor of dimension 2x1]
+
+> =unpack(nn.ElementTable(1):backward(input, torch.randn(2,3)))
+-0.4891 -0.3495 -0.3182
+-2.0999 0.7381 -0.5312
+[torch.DoubleTensor of dimension 2x3]
+
+0
+0
+[torch.DoubleTensor of dimension 2x1]
+
+```
+
+Example 2:
+```lua
+> input = {torch.randn(2,3), {torch.randn(2,1), {torch.randn(2,2)}}}
+
+> =nn.ElementTable(2):forward(input)
+{
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+}
+
+> =unpack(nn.ElementTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}}))
+0 0 0
+0 0 0
+[torch.DoubleTensor of dimension 2x3]
+
+{
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+}
+
+> gradInput = nn.ElementTable(1):backward(input, torch.randn(2,3))
+
+> =gradInput
+{
+ 1 : DoubleTensor - size: 2x3
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+ }
+}
+
+> =gradInput[1]
+-0.3400 -0.0404 1.1885
+ 1.2865 0.4107 0.6506
+[torch.DoubleTensor of dimension 2x3]
+
+> gradInput[2][1]
+0
+0
+[torch.DoubleTensor of dimension 2x1]
+
+> gradInput[2][2][1]
+0 0
+0 0
+[torch.DoubleTensor of dimension 2x2]
+
+```
<a name="nn.PairwiseDistance"/>
## PairwiseDistance ##
diff --git a/init.lua b/init.lua
index 757c9ec..6424cf5 100644
--- a/init.lua
+++ b/init.lua
@@ -85,6 +85,7 @@ include('ParallelTable.lua')
include('ConcatTable.lua')
include('SplitTable.lua')
include('JoinTable.lua')
+include('ElementTable.lua')
include('CriterionTable.lua')
include('Identity.lua')
diff --git a/test/test.lua b/test/test.lua
index c88c908..0c9a43c 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1869,6 +1869,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)