diff options
author | Jonathan Uesato <juesato@mit.edu> | 2016-08-25 03:17:56 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-08-25 03:17:56 +0300 |
commit | d31a725921a9e3bf8916863b0ccfb139fd36940b (patch) | |
tree | 8587d08a9f7991b8501be2dcc1a3fd799b193f0c | |
parent | fc70b6b2904159d3a5b46b6fc0d3e8f880964b74 (diff) |
add CMaxTable class (#885)
* Add CMaxTable and CMinTable
-rw-r--r-- | CMaxTable.lua | 33 | ||||
-rw-r--r-- | CMinTable.lua | 33 | ||||
-rw-r--r-- | doc/table.md | 29 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | test.lua | 30 |
5 files changed, 127 insertions, 0 deletions
diff --git a/CMaxTable.lua b/CMaxTable.lua new file mode 100644 index 0000000..3907faf --- /dev/null +++ b/CMaxTable.lua @@ -0,0 +1,33 @@ +local CMaxTable, parent = torch.class('nn.CMaxTable', 'nn.Module') + +function CMaxTable:__init() + parent.__init(self) + self.gradInput = {} + self.maxIdx = torch.Tensor() +end + +function CMaxTable:updateOutput(input) + self.output:resizeAs(input[1]):copy(input[1]) + self.maxIdx:resizeAs(input[1]):fill(1) + for i=2,#input do + local mask = torch.gt(input[i], self.output) + self.maxIdx:maskedFill(mask, i) + self.output:maskedCopy(mask, input[i][mask]) + end + return self.output +end + +function CMaxTable:updateGradInput(input, gradOutput) + for i=1,#input do + self.gradInput[i] = torch.Tensor() + self.gradInput[i]:resizeAs(input[i]):fill(0.0) + local mask = torch.eq(self.maxIdx, i) + self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) + end + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + + return self.gradInput +end diff --git a/CMinTable.lua b/CMinTable.lua new file mode 100644 index 0000000..a8385e8 --- /dev/null +++ b/CMinTable.lua @@ -0,0 +1,33 @@ +local CMinTable, parent = torch.class('nn.CMinTable', 'nn.Module') + +function CMinTable:__init() + parent.__init(self) + self.gradInput = {} + self.minIdx = torch.Tensor() +end + +function CMinTable:updateOutput(input) + self.output:resizeAs(input[1]):copy(input[1]) + self.minIdx:resizeAs(input[1]):fill(1) + for i=2,#input do + local mask = torch.lt(input[i], self.output) + self.minIdx:maskedFill(mask, i) + self.output:maskedCopy(mask, input[i][mask]) + end + return self.output +end + +function CMinTable:updateGradInput(input, gradOutput) + for i=1,#input do + self.gradInput[i] = torch.Tensor() + self.gradInput[i]:resizeAs(input[i]):fill(0.0) + local mask = torch.eq(self.minIdx, i) + self.gradInput[i]:maskedCopy(mask, gradOutput[mask]) + end + + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + + return self.gradInput +end diff --git a/doc/table.md b/doc/table.md index 5e75173..ee61719 100644 --- a/doc/table.md +++ b/doc/table.md @@ -24,6 +24,8 @@ This allows one to build very rich architectures: * [`CSubTable`](#nn.CSubTable): substraction of input `Tensor`s; * [`CMulTable`](#nn.CMulTable): multiplication of input `Tensor`s; * [`CDivTable`](#nn.CDivTable): division of input `Tensor`s; + * [`CMaxTable`](#nn.CMaxTable): max of input `Tensor`s; + * [`CMinTable`](#nn.CMinTable): min of input `Tensor`s; * `Table` of Criteria: * [`CriterionTable`](#nn.CriterionTable): wraps a [Criterion](criterion.md#nn.Criterion) so that it can accept a `table` of inputs. @@ -1264,3 +1266,30 @@ m = nn.CDivTable() [torch.DoubleTensor of dimension 5] ``` +<a name="nn.CMaxTable"></a> +## CMaxTable ## + +Takes a `table` of `Tensor`s and outputs the max of all of them. + +```lua +m = nn.CMaxTable() +=m:forward({{torch.Tensor{1,2,3}, torch.Tensor{3,2,1}}) + 3 + 2 + 3 +[torch.DoubleTensor of size 3] +``` + +<a name="nn.CMinTable"></a> +## CMinTable ## + +Takes a `table` of `Tensor`s and outputs the min of all of them. + +```lua +m = nn.CMinTable() +=m:forward({{torch.Tensor{1,2,3}, torch.Tensor{3,2,1}}) + 1 + 2 + 1 +[torch.DoubleTensor of size 3] +``` @@ -54,6 +54,8 @@ require('nn.CAddTable') require('nn.CDivTable') require('nn.CMulTable') require('nn.CSubTable') +require('nn.CMaxTable') +require('nn.CMinTable') require('nn.Euclidean') require('nn.WeightedEuclidean') @@ -5003,6 +5003,36 @@ function nntest.Copy() mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err') end +function nntest.CMaxTable() + local input1 = torch.Tensor{{1,3},{2,4}} + local input2 = torch.Tensor{{4,2},{3,1}} + local input = {input1, input2} + local module = nn.CMaxTable() + local err1 = torch.add(module:forward(input), -1, torch.Tensor{{4,3},{3,4}}) + mytester:assertalmosteq(err1:abs():max(), 0, 1e-15, "CMaxTable forward call") + local gradOutputs = torch.Tensor{5,6,7,8} + local gradInputs = module:backward(input, gradOutputs) + local err2 = torch.add(gradInputs[1], -1, torch.Tensor{{0,6},{0,8}}) + local err3 = torch.add(gradInputs[2], -1, torch.Tensor{{5,0},{7,0}}) + mytester:assertalmosteq(err2:abs():max(), 0, 1e-15, "CMaxTable backward call") + mytester:assertalmosteq(err3:abs():max(), 0, 1e-15, "CMaxTable backward call") +end + +function nntest.CMinTable() + local input1 = torch.Tensor{{1,3},{2,4}} + local input2 = torch.Tensor{{4,2},{3,1}} + local input = {input1, input2} + local module = nn.CMinTable() + local err1 = torch.add(module:forward(input), -1, torch.Tensor{{1,2},{2,1}}) + mytester:assertalmosteq(err1:abs():max(), 0, 1e-15, "CMinTable forward call") + local gradOutputs = torch.Tensor{5,6,7,8} + local gradInputs = module:backward(input, gradOutputs) + local err2 = torch.add(gradInputs[1], -1, torch.Tensor{{5,0},{7,0}}) + local err3 = torch.add(gradInputs[2], -1, torch.Tensor{{0,6},{0,8}}) + mytester:assertalmosteq(err2:abs():max(), 0, 1e-15, "CMinTable backward call") + mytester:assertalmosteq(err3:abs():max(), 0, 1e-15, "CMinTable backward call") +end + function nntest.JoinTable() local tensor = torch.rand(3,4,5) local input = {tensor, tensor} |