diff options
-rw-r--r-- | SelectTable.lua (renamed from ElementTable.lua) | 10 | ||||
-rw-r--r-- | doc/table.md | 20 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | test/test.lua | 4 |
4 files changed, 18 insertions, 18 deletions
diff --git a/ElementTable.lua b/SelectTable.lua index cb3ff0f..217be42 100644 --- a/ElementTable.lua +++ b/SelectTable.lua @@ -1,17 +1,17 @@ -local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module') +local SelectTable, parent = torch.class('nn.SelectTable', 'nn.Module') -function ElementTable:__init(index) +function SelectTable:__init(index) parent.__init(self) self.index = index self.gradInput = {} end -function ElementTable:updateOutput(input) +function SelectTable:updateOutput(input) self.output = input[self.index] return self.output end -function ElementTable:updateGradInput(input, gradOutput) +function SelectTable:updateGradInput(input, gradOutput) if #self.gradInput == 0 then local function zeroTableCopy(t1, t2) for k, v in pairs(t2) do @@ -29,6 +29,6 @@ function ElementTable:updateGradInput(input, gradOutput) return self.gradInput end -function ElementTable:type(type) +function SelectTable:type(type) self.gradInput = {} end diff --git a/doc/table.md b/doc/table.md index 60b6dea..5a34a6a 100644 --- a/doc/table.md +++ b/doc/table.md @@ -9,7 +9,7 @@ This allows one to build very rich architectures: * Table Conversion Modules convert between tables and Tensors: * [SplitTable](#nn.SplitTable) : splits a Tensor into a table of Tensors; * [JoinTable](#nn.JoinTable) : joins a table of Tensors into a Tensor; - * [ElementTable](#nn.ElementTable) : retrieve one element from a table; + * [SelectTable](#nn.SelectTable) : retrieve one element from a table; * Pair Modules compute a measure like distance or similarity from a pair (table) of input Tensors : * [PairwiseDistance](#nn.PairwiseDistance) : outputs the `p`-norm. distance between inputs; * [DotProduct](#nn.DotProduct) : outputs the dot product (similarity) between inputs; @@ -376,10 +376,10 @@ for i=1,100 do -- A few steps of training such a network.. end ``` -<a name="nn.ElementTable"/> -## ElementTable ## +<a name="nn.SelectTable"/> +## SelectTable ## -`module` = `ElementTable(index)` +`module` = `SelectTable(index)` Creates a module that takes a Table as input and outputs the element at index `index`. This can be either a Table or a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). @@ -391,18 +391,18 @@ Example 1: ```lua > input = {torch.randn(2,3), torch.randn(2,1)} [0.0002s] -> =nn.ElementTable(1):forward(input) +> =nn.SelectTable(1):forward(input) -0.3060 0.1398 0.2707 0.0576 1.5455 0.0610 [torch.DoubleTensor of dimension 2x3] [0.0002s] -> =nn.ElementTable(2):forward(input) +> =nn.SelectTable(2):forward(input) 2.3080 -0.2955 [torch.DoubleTensor of dimension 2x1] -> =unpack(nn.ElementTable(1):backward(input, torch.randn(2,3))) +> =unpack(nn.SelectTable(1):backward(input, torch.randn(2,3))) -0.4891 -0.3495 -0.3182 -2.0999 0.7381 -0.5312 [torch.DoubleTensor of dimension 2x3] @@ -417,7 +417,7 @@ Example 2: ```lua > input = {torch.randn(2,3), {torch.randn(2,1), {torch.randn(2,2)}}} -> =nn.ElementTable(2):forward(input) +> =nn.SelectTable(2):forward(input) { 1 : DoubleTensor - size: 2x1 2 : @@ -426,7 +426,7 @@ Example 2: } } -> =unpack(nn.ElementTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}})) +> =unpack(nn.SelectTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}})) 0 0 0 0 0 0 [torch.DoubleTensor of dimension 2x3] @@ -439,7 +439,7 @@ Example 2: } } -> gradInput = nn.ElementTable(1):backward(input, torch.randn(2,3)) +> gradInput = nn.SelectTable(1):backward(input, torch.randn(2,3)) > =gradInput { @@ -87,7 +87,7 @@ include('ParallelTable.lua') include('ConcatTable.lua') include('SplitTable.lua') include('JoinTable.lua') -include('ElementTable.lua') +include('SelectTable.lua') include('CriterionTable.lua') include('Identity.lua') diff --git a/test/test.lua b/test/test.lua index 73426fb..135624d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1899,7 +1899,7 @@ function nntest.SplitTable() end end -function nntest.ElementTable() +function nntest.SelectTable() local input = { torch.rand(3,4,5), torch.rand(3,4,5), {torch.rand(3,4,5)}, @@ -1927,7 +1927,7 @@ function nntest.ElementTable() local nonIdx = {2,3,4,1} local module for idx = 1,#input do - module = nn.ElementTable(idx) + module = nn.SelectTable(idx) local output = module:forward(input) equal(output, input[idx], "output dimension " .. idx) local gradInput = module:backward(input, gradOutputs[idx]) |