Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-25 06:04:38 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-25 06:04:38 +0300
commit44ba0d3f08755214429eac1eff2c435ab2f14b57 (patch)
tree50349ec4197fe557296d6507df9c80c3843619d5
parentdf1af9500a45f4deecd0f3f1f5020fe4789248ca (diff)
nn.CAddTensorTable
-rw-r--r--CAddTensorTable.lua43
-rw-r--r--doc/table.md16
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua19
4 files changed, 79 insertions, 0 deletions
diff --git a/CAddTensorTable.lua b/CAddTensorTable.lua
new file mode 100644
index 0000000..16efe44
--- /dev/null
+++ b/CAddTensorTable.lua
@@ -0,0 +1,43 @@
+
+local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module')
+
+function CAddTensorTable:__init()
+ parent.__init(self)
+ self.gradInput = {}
+end
+
+-- input is a table with 2 entries. input[1] is the vector to be added.
+-- input[2] is the table to which we add the vector
+function CAddTensorTable:updateOutput(input)
+ local currentOutput = {}
+ for i=1,#input[2] do
+ currentOutput[i] = currentOutput[i] or input[1].new()
+ currentOutput[i]:resizeAs(input[1])
+ currentOutput[i]:copy(input[2][i])
+ currentOutput[i]:add(input[1])
+ end
+ for i = #input[2]+1, #currentOutput do
+ currentOutput[i] = nil
+ end
+ self.output = currentOutput
+ return self.output
+end
+
+function CAddTensorTable:updateGradInput(input, gradOutput)
+ self.gradInput[1] = self.gradInput[1] or input[1].new()
+ self.gradInput[1]:resizeAs(input[1])
+ self.gradInput[1]:copy(gradOutput[1])
+ for i=2, #input[2] do
+ self.gradInput[1]:add(gradOutput[i])
+ end
+ self.gradInput[2] = self.gradInput[2] or {}
+ for i=1,#input[2] do
+ self.gradInput[2][i] = self.gradInput[2][i] or input[1].new()
+ self.gradInput[2][i]:resizeAs(input[1])
+ self.gradInput[2][i]:copy(gradOutput[i])
+ end
+ for i=#input[2]+1, #self.gradInput[2] do
+ self.gradInput[2][i] = nil
+ end
+ return self.gradInput
+end \ No newline at end of file
diff --git a/doc/table.md b/doc/table.md
index 1924ead..8734bf3 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -28,6 +28,7 @@ This allows one to build very rich architectures:
* [`CDivTable`](#nn.CDivTable): division of input `Tensor`s;
* [`CMaxTable`](#nn.CMaxTable): max of input `Tensor`s;
* [`CMinTable`](#nn.CMinTable): min of input `Tensor`s;
+ * [`CAddTensorTable`](#nn.CAddTensorTable): adds a tensor to a table of tensors of the same size;
* `Table` of Criteria:
* [`CriterionTable`](#nn.CriterionTable): wraps a [Criterion](criterion.md#nn.Criterion) so that it can accept a `table` of inputs.
@@ -1351,3 +1352,18 @@ m = nn.CMinTable()
1
[torch.DoubleTensor of size 3]
```
+
+<a name='nn.CAddTensorTable'></a>
+## CAddTensorTable ##
+
+```lua
+module = nn.CAddTensorTable()
+```
+
+Adds the first element `el` of the input table `tab` to each tensor contained in the second element of `tab`, which is itself a table
+
+Example:
+```lua
+print(module:forward{ (0,1,1), {(0,0,0),(1,1,1)} })
+{ (0,1,1), (1,2,2) }
+```
diff --git a/init.lua b/init.lua
index 503d2c2..97485f0 100755
--- a/init.lua
+++ b/init.lua
@@ -71,6 +71,7 @@ require('nn.CMulTable')
require('nn.CSubTable')
require('nn.CMaxTable')
require('nn.CMinTable')
+require('nn.CAddTensorTable')
require('nn.Euclidean')
require('nn.WeightedEuclidean')
diff --git a/test.lua b/test.lua
index 67b9fd9..44390ae 100755
--- a/test.lua
+++ b/test.lua
@@ -8649,6 +8649,25 @@ function nntest.Convert()
mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch")
end
+function nntest.CAddTensorTable()
+ -- input : { v, {a,b,c} }
+ -- output : { v+a, v+b, v+c }
+ local z = nn.CAddTensorTable()
+ local input = { torch.randn(3), { torch.randn(3), torch.rand(3), torch.rand(3) } }
+ local output = z:forward(input)
+ mytester:assert(#output == 3, "CAddTensorTable #output")
+ mytester:assertTensorEq(input[1]+input[2][1], output[1], 0.00001, "CAddTensorTable input21 output1")
+ mytester:assertTensorEq(input[1]+input[2][2], output[2], 0.00001, "CAddTensorTable input22 output2")
+ mytester:assertTensorEq(input[1]+input[2][3], output[3], 0.00001, "CAddTensorTable input23 output3")
+ local gradInput = z:backward(input, output)
+ mytester:assert(#gradInput == 2, "CAddTensorTable #gradInput")
+ mytester:assert(#(gradInput[2]) == 3, "CAddTensorTable #gradInput[2]")
+ mytester:assertTensorEq(output[1], gradInput[2][1], 0.000001, "CAddTensorTable gradInput21")
+ mytester:assertTensorEq(output[2], gradInput[2][2], 0.000001, "CAddTensorTable gradInput22")
+ mytester:assertTensorEq(output[3], gradInput[2][3], 0.000001, "CAddTensorTable gradInput23")
+ mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1")
+end
+
mytester:add(nntest)