Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 01:43:29 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 01:46:21 +0300
commit78f9a498a6e5444eedc04fc670a2ab108ef1511d (patch)
treee176725122bfab7511190eed3adf720731d178ad
parente40e2816e23cebc85fd5733e716e903a2d02c175 (diff)
nn.ZipTable
-rw-r--r--ZipTable.lua34
-rw-r--r--ZipTableOneToMany.lua37
-rw-r--r--doc/table.md40
-rwxr-xr-xinit.lua2
-rwxr-xr-xtest.lua48
5 files changed, 157 insertions, 4 deletions
diff --git a/ZipTable.lua b/ZipTable.lua
new file mode 100644
index 0000000..7b18619
--- /dev/null
+++ b/ZipTable.lua
@@ -0,0 +1,34 @@
+local ZipTable, parent = torch.class('nn.ZipTable', 'nn.Module')
+
+-- input : { {a1,a2}, {b1,b2}, {c1,c2} }
+-- output : { {a1,b1,c1}, {a2,b2,c2} }
+function ZipTable:__init()
+ parent.__init(self)
+ self.output = {}
+ self.gradInput = {}
+end
+
+function ZipTable:updateOutput(inputTable)
+ self.output = {}
+ for i,inTable in ipairs(inputTable) do
+ for j,input in ipairs(inTable) do
+ local output = self.output[j] or {}
+ output[i] = input
+ self.output[j] = output
+ end
+ end
+ return self.output
+end
+
+function ZipTable:updateGradInput(inputTable, gradOutputTable)
+ self.gradInput = {}
+ for i,gradOutTable in ipairs(gradOutputTable) do
+ for j,gradOutput in ipairs(gradOutTable) do
+ local gradInput = self.gradInput[j] or {}
+ gradInput[i] = gradOutput
+ self.gradInput[j] = gradInput
+ end
+ end
+ return self.gradInput
+end
+
diff --git a/ZipTableOneToMany.lua b/ZipTableOneToMany.lua
new file mode 100644
index 0000000..d4a80fe
--- /dev/null
+++ b/ZipTableOneToMany.lua
@@ -0,0 +1,37 @@
+local ZipTableOneToMany, parent = torch.class('nn.ZipTableOneToMany', 'nn.Module')
+
+-- based on ZipTable in dpnn
+
+-- input : { v, {a, b, c} }
+-- output : { {v,a}, {v,b}, {v,c} }
+function ZipTableOneToMany:__init()
+ parent.__init(self)
+ self.output = {}
+ self.gradInput = {}
+ -- make buffer to update during forward/backward
+ self.gradInputEl = torch.Tensor()
+end
+
+function ZipTableOneToMany:updateOutput(input)
+ assert(#input == 2, "input must be table of element and table")
+ local inputEl, inputTable = input[1], input[2]
+ self.output = {}
+ for i,v in ipairs(inputTable) do
+ self.output[i] = {inputEl, v}
+ end
+ return self.output
+end
+
+function ZipTableOneToMany:updateGradInput(input, gradOutput)
+ assert(#input == 2, "input must be table of element and table")
+ local inputEl, inputTable = input[1], input[2]
+ self.gradInputEl:resizeAs(inputEl):zero()
+ local gradInputTable = {}
+ for i,gradV in ipairs(gradOutput) do
+ self.gradInputEl:add(gradV[1])
+ gradInputTable[i] = gradV[2]
+ end
+ self.gradInput = {self.gradInputEl, gradInputTable}
+ return self.gradInput
+end
+
diff --git a/doc/table.md b/doc/table.md
index b3e2e5f..1924ead 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -15,6 +15,8 @@ This allows one to build very rich architectures:
* [`SelectTable`](#nn.SelectTable): select one element from a `table`;
* [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`;
* [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy;
+ * [`ZipTable`](#nn.ZipTable) : zip a table of tables into a table of tables;
+ * [`ZipTableOneToMany`](#nn.ZipTableOneToMany) : zip a table to a single tensor;
* Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s:
* [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs;
* [`DotProduct`](#nn.DotProduct): outputs the dot product (similarity) between inputs;
@@ -692,7 +694,7 @@ Forwarding a batch of 2 examples gives us something like this:
`module` = `SelectTable(index)`
-Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative).
+Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative).
This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the
@@ -731,7 +733,7 @@ Exmaple 2:
> gradInput = nn.SelectTable("A"):backward(input, torch.randn(2, 3))
-> gradInput
+> gradInput
{
A : DoubleTensor - size: 2x3
B : DoubleTensor - size: 2x1
@@ -811,11 +813,11 @@ Example 3:
`module` = `NarrowTable(offset [, length])`
-Creates a module that takes a `table` as input and outputs the subtable
+Creates a module that takes a `table` as input and outputs the subtable
starting at index `offset` having `length` elements (defaults to 1 element).
The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
-The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
+The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
This is true regardless of the depth of the encapsulated `Tensor` as the function used internally to do so is recursive.
Example:
@@ -883,6 +885,36 @@ gives the output:
}
```
+<a name='nn.ZipTable'></a>
+## ZipTable ##
+
+```lua
+module = nn.ZipTable()
+```
+
+Zips a table of tables into a table of tables.
+
+Example:
+```lua
+print(module:forward{ {'a1','a2'}, {'b1','b2'}, {'c1','c2'} })
+{ {'a1','b1','c1'}, {'a2','b2','c2'} }
+```
+
+<a name='nn.ZipTableOneToMany'></a>
+## ZipTableOneToMany ##
+
+```lua
+module = nn.ZipTableOneToMany()
+```
+
+Zips a table of element `el` and table of elements `tab` into a table of tables, where the i-th table contains the element `el` and the i-th element in table `tab`
+
+Example:
+```lua
+print(module:forward{ 'el', {'a','b','c'} })
+{ {'el','a'}, {'el','b'}, {'el','c'} }
+```
+
<a name="nn.PairwiseDistance"></a>
## PairwiseDistance ##
diff --git a/init.lua b/init.lua
index b397d77..447d357 100755
--- a/init.lua
+++ b/init.lua
@@ -170,6 +170,8 @@ require('nn.CriterionTable')
require('nn.FlattenTable')
require('nn.NarrowTable')
require('nn.MapTable')
+require('nn.ZipTable')
+require('nn.ZipTableOneToMany')
require('nn.Criterion')
require('nn.MSECriterion')
diff --git a/test.lua b/test.lua
index 2dafb09..16bae09 100755
--- a/test.lua
+++ b/test.lua
@@ -8548,6 +8548,54 @@ function nntest.ZeroGrad()
mytester:assertTensorEq(gradInput, gradInput2, 0.0000001)
end
+function nntest.ZipTable()
+ -- input : { {a1,a2}, {b1,b2}, {c1,c2} }
+ -- output : { {a1,b1,c1}, {a2,b2,c2} }
+ local z = nn.ZipTable()
+ local input = {
+ {torch.randn(3,4), torch.randn(3,4)},
+ {torch.randn(3,4), torch.randn(3,4)},
+ {torch.randn(3,4), torch.randn(3,4)}
+ }
+ local output = z:forward(input)
+ mytester:assert(#output == 2, "ZipTable #output")
+ mytester:assert(#(output[1]) == 3, "ZipTable #output[1]")
+ mytester:assertTensorEq(input[1][1], output[1][1], 0.000001, "ZipTable input11")
+ mytester:assertTensorEq(input[1][2], output[2][1], 0.000001, "ZipTable input12")
+ mytester:assertTensorEq(input[3][2], output[2][3], 0.000001, "ZipTable input32")
+ local gradInput = z:backward(input, output)
+ mytester:assert(#gradInput == 3, "ZipTable #gradInput")
+ mytester:assert(#(gradInput[1]) == 2, "ZipTable #gradInput[1]")
+ mytester:assertTensorEq(input[1][1], gradInput[1][1], 0.000001, "ZipTable gradInput11")
+ mytester:assertTensorEq(input[1][2], gradInput[1][2], 0.000001, "ZipTable gradInput12")
+ mytester:assertTensorEq(input[3][2], gradInput[3][2], 0.000001, "ZipTable gradInput32")
+end
+
+function nntest.ZipTableOneToMany()
+ -- input : { v, {a,b,c} }
+ -- output : { {v,a}, {v,b}, {v,c} }
+ local z = nn.ZipTableOneToMany()
+ local input = { torch.randn(3), { torch.randn(4), torch.rand(4), torch.rand(4) } }
+ local output = z:forward(input)
+ mytester:assert(#output == 3, "ZipTableOneToMany #output")
+ mytester:assert(#(output[1]) == 2, "ZipTableOneToMany #output[1]")
+ mytester:assert(#(output[2]) == 2, "ZipTableOneToMany #output[2]")
+ mytester:assert(#(output[3]) == 2, "ZipTableOneToMany #output[3]")
+ mytester:assertTensorEq(input[1], output[1][1], 0.000001, "ZipTableOneToMany input1 output11")
+ mytester:assertTensorEq(input[1], output[2][1], 0.000001, "ZipTableOneToMany input1 output21")
+ mytester:assertTensorEq(input[1], output[3][1], 0.000001, "ZipTableOneToMany input1 output31")
+ mytester:assertTensorEq(input[2][1], output[1][2], 0.000001, "ZipTableOneToMany input21")
+ mytester:assertTensorEq(input[2][2], output[2][2], 0.000001, "ZipTableOneToMany input22")
+ mytester:assertTensorEq(input[2][3], output[3][2], 0.000001, "ZipTableOneToMany input23")
+ local gradInput = z:backward(input, output)
+ mytester:assert(#gradInput == 2, "ZipTableOneToMany #gradInput")
+ mytester:assert(#(gradInput[2]) == 3, "ZipTableOneToMany #gradInput[2]")
+ mytester:assertTensorEq(input[2][1], gradInput[2][1], 0.000001, "ZipTableOneToMany gradInput21")
+ mytester:assertTensorEq(input[2][2], gradInput[2][2], 0.000001, "ZipTableOneToMany gradInput22")
+ mytester:assertTensorEq(input[2][3], gradInput[2][3], 0.000001, "ZipTableOneToMany gradInput32")
+ mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21")
+end
+
mytester:add(nntest)