Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-15 19:28:48 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-15 21:24:04 +0300
commit59e3f8cb4b6571fd46b0ba66e57626f725fbe81f (patch)
treee0478aa9067338664b40a734f08c13bc23c49f7c
parent514a093f5e9a76a04f3252d58259673ef4ff71bb (diff)
NarrowTable
-rw-r--r--NarrowTable.lua41
-rwxr-xr-xdoc/table.md41
-rw-r--r--init.lua1
-rw-r--r--test.lua39
-rw-r--r--utils.lua14
5 files changed, 136 insertions, 0 deletions
diff --git a/NarrowTable.lua b/NarrowTable.lua
new file mode 100644
index 0000000..14c90ca
--- /dev/null
+++ b/NarrowTable.lua
@@ -0,0 +1,41 @@
+local NarrowTable, parent = torch.class('nn.NarrowTable', 'nn.Module')
+
+function NarrowTable:__init(offset, length)
+ parent.__init(self)
+ self.offset = offset
+ self.length = length or 1
+ if not offset then
+ error('nn.NarrowTable(offset, length)')
+ end
+
+ self.output = {}
+ self.gradInput = {}
+end
+
+function NarrowTable:updateOutput(input)
+ for k,v in ipairs(self.output) do self.output[k] = nil end
+ for i=1,self.length do
+ self.output[i] = input[self.offset+i-1]
+ end
+ return self.output
+end
+
+function NarrowTable:updateGradInput(input, gradOutput)
+ for i=1,#gradOutput do
+ self.gradInput[self.offset+i-1] = gradOutput[i]
+ end
+ for i=1,#input do
+ if (i < self.offset) or (i >= self.offset + self.length) then
+ self.gradInput[i] = nn.utils.recursiveResizeAs(self.gradInput[i], input[i])
+ nn.utils.recursiveFill(self.gradInput[i], 0)
+ end
+ end
+ for i=#input+1,#self.gradInput do self.gradInput[i] = nil end
+ return self.gradInput
+end
+
+function NarrowTable:type(type, tensorCache)
+ self.output = {}
+ self.gradInput = {}
+ return parent.type(self, type, tensorCache)
+end
diff --git a/doc/table.md b/doc/table.md
index c2aeb83..95ac2b6 100755
--- a/doc/table.md
+++ b/doc/table.md
@@ -11,6 +11,7 @@ This allows one to build very rich architectures:
* [`JoinTable`](#nn.JoinTable): joins a `table` of `Tensor`s into a `Tensor`;
* [`MixtureTable`](#nn.MixtureTable): mixture of experts weighted by a gater;
* [`SelectTable`](#nn.SelectTable): select one element from a `table`;
+ * [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`;
* [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy;
* Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s:
* [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs;
@@ -724,6 +725,46 @@ Example 2:
```
+<a name="nn.NarrowTable"/>
+## NarrowTable ##
+
+`module` = `NarrowTable(offset [, length])`
+
+Creates a module that takes a `table` as input and outputs the subtable
+starting at index `offset` having `length` elements (defaults to 1 element).
+The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
+
+The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
+This is true regardless of the dept of the encapsulated `Tensor` as the function used internally to do so is recursive.
+
+Example:
+```lua
+> input = {torch.randn(2, 3), torch.randn(2, 1), torch.randn(1, 2)}
+> =nn.NarrowTable(2,2):forward(input)
+{
+ 1 : DoubleTensor - size: 2x1
+ 2 : DoubleTensor - size: 1x2
+}
+
+> =nn.NarrowTable(1):forward(input)
+{
+ 1 : DoubleTensor - size: 2x3
+}
+
+> =table.unpack(nn.NarrowTable(1,2):backward(input, {torch.randn(2, 3), torch.randn(2, 1)}))
+ 1.9528 -0.1381 0.2023
+ 0.2297 -1.5169 -1.1871
+[torch.DoubleTensor of size 2x3]
+
+-1.2023
+-0.4165
+[torch.DoubleTensor of size 2x1]
+
+ 0 0
+[torch.DoubleTensor of size 1x2]
+
+```
+
<a name="nn.FlattenTable"/>
## FlattenTable ##
diff --git a/init.lua b/init.lua
index e6c5827..3659d4d 100644
--- a/init.lua
+++ b/init.lua
@@ -102,6 +102,7 @@ include('SelectTable.lua')
include('MixtureTable.lua')
include('CriterionTable.lua')
include('FlattenTable.lua')
+include('NarrowTable.lua')
include('Identity.lua')
include('Criterion.lua')
diff --git a/test.lua b/test.lua
index bfece0e..92b686f 100644
--- a/test.lua
+++ b/test.lua
@@ -3116,6 +3116,45 @@ function nntest.MixtureTable()
end
end
+function nntest.NarrowTable()
+ local input = torch.randn(3,10,4)
+ local gradOutput = torch.randn(3,3,4)
+ local nt = nn.NarrowTable(5,3)
+ local seq = nn.Sequential()
+ seq:add(nn.SplitTable(1,2))
+ seq:add(nt)
+ seq:add(nn.JoinTable(1,1))
+ seq:add(nn.Reshape(3,3,4))
+ local seq2 = nn.Narrow(2,5,3)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput err")
+
+ -- now try it with a smaller input
+ local input = input:narrow(2, 1, 8)
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable small output err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable small gradInput err")
+
+ -- test type-cast
+ local input = input:float()
+ local gradOutput = gradOutput:float()
+ seq:float()
+ seq2:float()
+ local output = seq:forward(input)
+ local gradInput = seq:backward(input, gradOutput)
+ local output2 = seq2:forward(input)
+ local gradInput2 = seq2:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output2, 0.0000001, "NarrowTable output float err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.00001, "NarrowTable gradInput float err")
+end
+
function nntest.View()
local input = torch.rand(10)
local template = torch.rand(5,2)
diff --git a/utils.lua b/utils.lua
index a2bb46b..90ec50d 100644
--- a/utils.lua
+++ b/utils.lua
@@ -30,5 +30,19 @@ function nn.utils.recursiveResizeAs(t1,t2)
return t1, t2
end
+function nn.utils.recursiveFill(t2, val)
+ if torch.type(t2) == 'table' then
+ for key,_ in pairs(t2) do
+ t2[key] = nn.utils.recursiveFill(t2[key], val)
+ end
+ elseif torch.isTensor(t2) then
+ t2:fill(val)
+ else
+ error("expecting tensor or table thereof. Got "
+ ..torch.type(t2).." instead")
+ end
+ return t2
+end
+
table.unpack = table.unpack or unpack