From 44ba0d3f08755214429eac1eff2c435ab2f14b57 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 24 May 2017 23:04:38 -0400 Subject: nn.CAddTensorTable --- CAddTensorTable.lua | 43 +++++++++++++++++++++++++++++++++++++++++++ doc/table.md | 16 ++++++++++++++++ init.lua | 1 + test.lua | 19 +++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 CAddTensorTable.lua diff --git a/CAddTensorTable.lua b/CAddTensorTable.lua new file mode 100644 index 0000000..16efe44 --- /dev/null +++ b/CAddTensorTable.lua @@ -0,0 +1,43 @@ + +local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module') + +function CAddTensorTable:__init() + parent.__init(self) + self.gradInput = {} +end + +-- input is a table with 2 entries. input[1] is the vector to be added. +-- input[2] is the table to which we add the vector +function CAddTensorTable:updateOutput(input) + local currentOutput = {} + for i=1,#input[2] do + currentOutput[i] = currentOutput[i] or input[1].new() + currentOutput[i]:resizeAs(input[1]) + currentOutput[i]:copy(input[2][i]) + currentOutput[i]:add(input[1]) + end + for i = #input[2]+1, #currentOutput do + currentOutput[i] = nil + end + self.output = currentOutput + return self.output +end + +function CAddTensorTable:updateGradInput(input, gradOutput) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[1]:resizeAs(input[1]) + self.gradInput[1]:copy(gradOutput[1]) + for i=2, #input[2] do + self.gradInput[1]:add(gradOutput[i]) + end + self.gradInput[2] = self.gradInput[2] or {} + for i=1,#input[2] do + self.gradInput[2][i] = self.gradInput[2][i] or input[1].new() + self.gradInput[2][i]:resizeAs(input[1]) + self.gradInput[2][i]:copy(gradOutput[i]) + end + for i=#input[2]+1, #self.gradInput[2] do + self.gradInput[2][i] = nil + end + return self.gradInput +end \ No newline at end of file diff --git a/doc/table.md b/doc/table.md index 1924ead..8734bf3 100644 --- a/doc/table.md +++ b/doc/table.md @@ -28,6 +28,7 @@ This allows one to build very rich architectures: * [`CDivTable`](#nn.CDivTable): division of input `Tensor`s; * [`CMaxTable`](#nn.CMaxTable): max of input `Tensor`s; * [`CMinTable`](#nn.CMinTable): min of input `Tensor`s; + * [`CAddTensorTable`](#nn.CAddTensorTable): adds a tensor to a table of tensors of the same size; * `Table` of Criteria: * [`CriterionTable`](#nn.CriterionTable): wraps a [Criterion](criterion.md#nn.Criterion) so that it can accept a `table` of inputs. @@ -1351,3 +1352,18 @@ m = nn.CMinTable() 1 [torch.DoubleTensor of size 3] ``` + + +## CAddTensorTable ## + +```lua +module = nn.CAddTensorTable() +``` + +Adds the first element `el` of the input table `tab` to each tensor contained in the second element of `tab`, which is itself a table + +Example: +```lua +print(module:forward{ (0,1,1), {(0,0,0),(1,1,1)} }) +{ (0,1,1), (1,2,2) } +``` diff --git a/init.lua b/init.lua index 503d2c2..97485f0 100755 --- a/init.lua +++ b/init.lua @@ -71,6 +71,7 @@ require('nn.CMulTable') require('nn.CSubTable') require('nn.CMaxTable') require('nn.CMinTable') +require('nn.CAddTensorTable') require('nn.Euclidean') require('nn.WeightedEuclidean') diff --git a/test.lua b/test.lua index 67b9fd9..44390ae 100755 --- a/test.lua +++ b/test.lua @@ -8649,6 +8649,25 @@ function nntest.Convert() mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch") end +function nntest.CAddTensorTable() + -- input : { v, {a,b,c} } + -- output : { v+a, v+b, v+c } + local z = nn.CAddTensorTable() + local input = { torch.randn(3), { torch.randn(3), torch.rand(3), torch.rand(3) } } + local output = z:forward(input) + mytester:assert(#output == 3, "CAddTensorTable #output") + mytester:assertTensorEq(input[1]+input[2][1], output[1], 0.00001, "CAddTensorTable input21 output1") + mytester:assertTensorEq(input[1]+input[2][2], output[2], 0.00001, "CAddTensorTable input22 output2") + mytester:assertTensorEq(input[1]+input[2][3], output[3], 0.00001, "CAddTensorTable input23 output3") + local gradInput = z:backward(input, output) + mytester:assert(#gradInput == 2, "CAddTensorTable #gradInput") + mytester:assert(#(gradInput[2]) == 3, "CAddTensorTable #gradInput[2]") + mytester:assertTensorEq(output[1], gradInput[2][1], 0.000001, "CAddTensorTable gradInput21") + mytester:assertTensorEq(output[2], gradInput[2][2], 0.000001, "CAddTensorTable gradInput22") + mytester:assertTensorEq(output[3], gradInput[2][3], 0.000001, "CAddTensorTable gradInput23") + mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1") +end + mytester:add(nntest) -- cgit v1.2.3