diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 06:04:38 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 06:04:38 +0300 |
commit | 44ba0d3f08755214429eac1eff2c435ab2f14b57 (patch) | |
tree | 50349ec4197fe557296d6507df9c80c3843619d5 | |
parent | df1af9500a45f4deecd0f3f1f5020fe4789248ca (diff) |
nn.CAddTensorTable
-rw-r--r-- | CAddTensorTable.lua | 43 | ||||
-rw-r--r-- | doc/table.md | 16 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 19 |
4 files changed, 79 insertions, 0 deletions
diff --git a/CAddTensorTable.lua b/CAddTensorTable.lua new file mode 100644 index 0000000..16efe44 --- /dev/null +++ b/CAddTensorTable.lua @@ -0,0 +1,43 @@ + +local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module') + +function CAddTensorTable:__init() + parent.__init(self) + self.gradInput = {} +end + +-- input is a table with 2 entries. input[1] is the vector to be added. +-- input[2] is the table to which we add the vector +function CAddTensorTable:updateOutput(input) + local currentOutput = {} + for i=1,#input[2] do + currentOutput[i] = currentOutput[i] or input[1].new() + currentOutput[i]:resizeAs(input[1]) + currentOutput[i]:copy(input[2][i]) + currentOutput[i]:add(input[1]) + end + for i = #input[2]+1, #currentOutput do + currentOutput[i] = nil + end + self.output = currentOutput + return self.output +end + +function CAddTensorTable:updateGradInput(input, gradOutput) + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[1]:resizeAs(input[1]) + self.gradInput[1]:copy(gradOutput[1]) + for i=2, #input[2] do + self.gradInput[1]:add(gradOutput[i]) + end + self.gradInput[2] = self.gradInput[2] or {} + for i=1,#input[2] do + self.gradInput[2][i] = self.gradInput[2][i] or input[1].new() + self.gradInput[2][i]:resizeAs(input[1]) + self.gradInput[2][i]:copy(gradOutput[i]) + end + for i=#input[2]+1, #self.gradInput[2] do + self.gradInput[2][i] = nil + end + return self.gradInput +end
\ No newline at end of file diff --git a/doc/table.md b/doc/table.md index 1924ead..8734bf3 100644 --- a/doc/table.md +++ b/doc/table.md @@ -28,6 +28,7 @@ This allows one to build very rich architectures: * [`CDivTable`](#nn.CDivTable): division of input `Tensor`s; * [`CMaxTable`](#nn.CMaxTable): max of input `Tensor`s; * [`CMinTable`](#nn.CMinTable): min of input `Tensor`s; + * [`CAddTensorTable`](#nn.CAddTensorTable): adds a tensor to a table of tensors of the same size; * `Table` of Criteria: * [`CriterionTable`](#nn.CriterionTable): wraps a [Criterion](criterion.md#nn.Criterion) so that it can accept a `table` of inputs. @@ -1351,3 +1352,18 @@ m = nn.CMinTable() 1 [torch.DoubleTensor of size 3] ``` + +<a name='nn.CAddTensorTable'></a> +## CAddTensorTable ## + +```lua +module = nn.CAddTensorTable() +``` + +Adds the first element `el` of the input table `tab` to each tensor contained in the second element of `tab`, which is itself a table + +Example: +```lua +print(module:forward{ (0,1,1), {(0,0,0),(1,1,1)} }) +{ (0,1,1), (1,2,2) } +``` @@ -71,6 +71,7 @@ require('nn.CMulTable') require('nn.CSubTable') require('nn.CMaxTable') require('nn.CMinTable') +require('nn.CAddTensorTable') require('nn.Euclidean') require('nn.WeightedEuclidean') @@ -8649,6 +8649,25 @@ function nntest.Convert() mytester:assertTensorEq(output, output2:float(), 0.000001, "Convert:type() double->float non-batch") end +function nntest.CAddTensorTable() + -- input : { v, {a,b,c} } + -- output : { v+a, v+b, v+c } + local z = nn.CAddTensorTable() + local input = { torch.randn(3), { torch.randn(3), torch.rand(3), torch.rand(3) } } + local output = z:forward(input) + mytester:assert(#output == 3, "CAddTensorTable #output") + mytester:assertTensorEq(input[1]+input[2][1], output[1], 0.00001, "CAddTensorTable input21 output1") + mytester:assertTensorEq(input[1]+input[2][2], output[2], 0.00001, "CAddTensorTable input22 output2") + mytester:assertTensorEq(input[1]+input[2][3], output[3], 0.00001, "CAddTensorTable input23 output3") + local gradInput = z:backward(input, output) + mytester:assert(#gradInput == 2, "CAddTensorTable #gradInput") + mytester:assert(#(gradInput[2]) == 3, "CAddTensorTable #gradInput[2]") + mytester:assertTensorEq(output[1], gradInput[2][1], 0.000001, "CAddTensorTable gradInput21") + mytester:assertTensorEq(output[2], gradInput[2][2], 0.000001, "CAddTensorTable gradInput22") + mytester:assertTensorEq(output[3], gradInput[2][3], 0.000001, "CAddTensorTable gradInput23") + mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1") +end + mytester:add(nntest) |