diff options
-rw-r--r-- | NarrowTable.lua | 41 | ||||
-rwxr-xr-x | doc/table.md | 41 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 39 | ||||
-rw-r--r-- | utils.lua | 14 |
5 files changed, 136 insertions, 0 deletions
diff --git a/NarrowTable.lua b/NarrowTable.lua new file mode 100644 index 0000000..14c90ca --- /dev/null +++ b/NarrowTable.lua @@ -0,0 +1,41 @@ +local NarrowTable, parent = torch.class('nn.NarrowTable', 'nn.Module') + +function NarrowTable:__init(offset, length) + parent.__init(self) + self.offset = offset + self.length = length or 1 + if not offset then + error('nn.NarrowTable(offset, length)') + end + + self.output = {} + self.gradInput = {} +end + +function NarrowTable:updateOutput(input) + for k,v in ipairs(self.output) do self.output[k] = nil end + for i=1,self.length do + self.output[i] = input[self.offset+i-1] + end + return self.output +end + +function NarrowTable:updateGradInput(input, gradOutput) + for i=1,#gradOutput do + self.gradInput[self.offset+i-1] = gradOutput[i] + end + for i=1,#input do + if (i < self.offset) or (i >= self.offset + self.length) then + self.gradInput[i] = nn.utils.recursiveResizeAs(self.gradInput[i], input[i]) + nn.utils.recursiveFill(self.gradInput[i], 0) + end + end + for i=#input+1,#self.gradInput do self.gradInput[i] = nil end + return self.gradInput +end + +function NarrowTable:type(type, tensorCache) + self.output = {} + self.gradInput = {} + return parent.type(self, type, tensorCache) +end diff --git a/doc/table.md b/doc/table.md index c2aeb83..95ac2b6 100755 --- a/doc/table.md +++ b/doc/table.md @@ -11,6 +11,7 @@ This allows one to build very rich architectures: * [`JoinTable`](#nn.JoinTable): joins a `table` of `Tensor`s into a `Tensor`; * [`MixtureTable`](#nn.MixtureTable): mixture of experts weighted by a gater; * [`SelectTable`](#nn.SelectTable): select one element from a `table`; + * [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`; * [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy; * Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s: * [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs; @@ -724,6 +725,46 @@ Example 2: ``` +<a name="nn.NarrowTable"/> +## NarrowTable ## + +`module` = `NarrowTable(offset [, length])` + +Creates a module that takes a `table` as input and outputs the subtable +starting at index `offset` having `length` elements (defaults to 1 element). +The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). + +The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size. +This is true regardless of the dept of the encapsulated `Tensor` as the function used internally to do so is recursive. + +Example: +```lua +> input = {torch.randn(2, 3), torch.randn(2, 1), torch.randn(1, 2)} +> =nn.NarrowTable(2,2):forward(input) +{ + 1 : DoubleTensor - size: 2x1 + 2 : DoubleTensor - size: 1x2 +} + +> =nn.NarrowTable(1):forward(input) +{ + 1 : DoubleTensor - size: 2x3 +} + +> =table.unpack(nn.NarrowTable(1,2):backward(input, {torch.randn(2, 3), torch.randn(2, 1)})) + 1.9528 -0.1381 0.2023 + 0.2297 -1.5169 -1.1871 +[torch.DoubleTensor of size 2x3] + +-1.2023 +-0.4165 +[torch.DoubleTensor of size 2x1] + + 0 0 +[torch.DoubleTensor of size 1x2] + +``` + <a name="nn.FlattenTable"/> ## FlattenTable ## @@ -102,6 +102,7 @@ include('SelectTable.lua') include('MixtureTable.lua') include('CriterionTable.lua') include('FlattenTable.lua') +include('NarrowTable.lua') include('Identity.lua') include('Criterion.lua') @@ -3116,6 +3116,45 @@ function nntest.MixtureTable() end end +function nntest.NarrowTable() + local input = torch.randn(3,10,4) + local gradOutput = torch.randn(3,3,4) + local nt = nn.NarrowTable(5,3) + local seq = nn.Sequential() + seq:add(nn.SplitTable(1,2)) + seq:add(nt) + seq:add(nn.JoinTable(1,1)) + seq:add(nn.Reshape(3,3,4)) + local seq2 = nn.Narrow(2,5,3) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err") + + -- now try it with a smaller input + local input = input:narrow(2, 1, 8) + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err") + + -- test type-cast + local input = input:float() + local gradOutput = gradOutput:float() + seq:float() + seq2:float() + local output = seq:forward(input) + local gradInput = seq:backward(input, gradOutput) + local output2 = seq2:forward(input) + local gradInput2 = seq2:backward(input, gradOutput) + mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err") + mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err") +end + function nntest.View() local input = torch.rand(10) local template = torch.rand(5,2) @@ -30,5 +30,19 @@ function nn.utils.recursiveResizeAs(t1,t2) return t1, t2 end +function nn.utils.recursiveFill(t2, val) + if torch.type(t2) == 'table' then + for key,_ in pairs(t2) do + t2[key] = nn.utils.recursiveFill(t2[key], val) + end + elseif torch.isTensor(t2) then + t2:fill(val) + else + error("expecting tensor or table thereof. Got " + ..torch.type(t2).." instead") + end + return t2 +end + table.unpack = table.unpack or unpack |