diff options
-rw-r--r-- | ElementTable.lua | 34 | ||||
-rw-r--r-- | doc/table.md | 95 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test/test.lua | 43 |
4 files changed, 173 insertions, 0 deletions
diff --git a/ElementTable.lua b/ElementTable.lua new file mode 100644 index 0000000..cb3ff0f --- /dev/null +++ b/ElementTable.lua @@ -0,0 +1,34 @@ +local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module') + +function ElementTable:__init(index) + parent.__init(self) + self.index = index + self.gradInput = {} +end + +function ElementTable:updateOutput(input) + self.output = input[self.index] + return self.output +end + +function ElementTable:updateGradInput(input, gradOutput) + if #self.gradInput == 0 then + local function zeroTableCopy(t1, t2) + for k, v in pairs(t2) do + if (torch.type(v) == "table") then + t1[k] = zeroTableCopy(t1[k] or {}, t2[k]) + else + t1[k] = v:clone():zero() + end + end + return t1 + end + zeroTableCopy(self.gradInput, input) + end + self.gradInput[self.index] = gradOutput + return self.gradInput +end + +function ElementTable:type(type) + self.gradInput = {} +end diff --git a/doc/table.md b/doc/table.md index 4117117..60b6dea 100644 --- a/doc/table.md +++ b/doc/table.md @@ -9,6 +9,7 @@ This allows one to build very rich architectures: * Table Conversion Modules convert between tables and Tensors: * [SplitTable](#nn.SplitTable) : splits a Tensor into a table of Tensors; * [JoinTable](#nn.JoinTable) : joins a table of Tensors into a Tensor; + * [ElementTable](#nn.ElementTable) : retrieve one element from a table; * Pair Modules compute a measure like distance or similarity from a pair (table) of input Tensors : * [PairwiseDistance](#nn.PairwiseDistance) : outputs the `p`-norm. distance between inputs; * [DotProduct](#nn.DotProduct) : outputs the dot product (similarity) between inputs; @@ -375,6 +376,100 @@ for i=1,100 do -- A few steps of training such a network.. end ``` +<a name="nn.ElementTable"/> +## ElementTable ## + +`module` = `ElementTable(index)` + +Creates a module that takes a Table as input and outputs the element at index `index`. +This can be either a Table or a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). + +The gradients of the non-`index` elements are zeroed Tensors of the same size. This is true regardless of the +dept of the encapsulated Tensor as the function used internally to do so is recursive. + +Example 1: +```lua +> input = {torch.randn(2,3), torch.randn(2,1)} + [0.0002s] +> =nn.ElementTable(1):forward(input) +-0.3060 0.1398 0.2707 + 0.0576 1.5455 0.0610 +[torch.DoubleTensor of dimension 2x3] + + [0.0002s] +> =nn.ElementTable(2):forward(input) + 2.3080 +-0.2955 +[torch.DoubleTensor of dimension 2x1] + +> =unpack(nn.ElementTable(1):backward(input, torch.randn(2,3))) +-0.4891 -0.3495 -0.3182 +-2.0999 0.7381 -0.5312 +[torch.DoubleTensor of dimension 2x3] + +0 +0 +[torch.DoubleTensor of dimension 2x1] + +``` + +Example 2: +```lua +> input = {torch.randn(2,3), {torch.randn(2,1), {torch.randn(2,2)}}} + +> =nn.ElementTable(2):forward(input) +{ + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } +} + +> =unpack(nn.ElementTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}})) +0 0 0 +0 0 0 +[torch.DoubleTensor of dimension 2x3] + +{ + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } +} + +> gradInput = nn.ElementTable(1):backward(input, torch.randn(2,3)) + +> =gradInput +{ + 1 : DoubleTensor - size: 2x3 + 2 : + { + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } + } +} + +> =gradInput[1] +-0.3400 -0.0404 1.1885 + 1.2865 0.4107 0.6506 +[torch.DoubleTensor of dimension 2x3] + +> gradInput[2][1] +0 +0 +[torch.DoubleTensor of dimension 2x1] + +> gradInput[2][2][1] +0 0 +0 0 +[torch.DoubleTensor of dimension 2x2] + +``` <a name="nn.PairwiseDistance"/> ## PairwiseDistance ## @@ -86,6 +86,7 @@ include('ParallelTable.lua') include('ConcatTable.lua') include('SplitTable.lua') include('JoinTable.lua') +include('ElementTable.lua') include('CriterionTable.lua') include('Identity.lua') diff --git a/test/test.lua b/test/test.lua index 8b4cfb9..a3c816f 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1879,6 +1879,49 @@ function nntest.SplitTable() end end +function nntest.ElementTable() + local input = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local gradOutputs = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local zeros = { + torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), + {torch.Tensor(3,4,5):zero()}, + {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} + } + local function equal(t1, t2, msg) + if (torch.type(t1) == "table") then + for k, v in pairs(t2) do + equal(t1[k], t2[k]) + end + else + mytester:assertTensorEq(t1, t2, 0.00001, msg) + end + end + local nonIdx = {2,3,4,1} + local module + for idx = 1,#input do + module = nn.ElementTable(idx) + local output = module:forward(input) + equal(output, input[idx], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) + end + module:float() + local idx = #input + local output = module:forward(input) + equal(output, input[idx], "type output") + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) +end function nntest.View() local input = torch.rand(10) |