Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Uesato <juesato@mit.edu>2016-08-25 03:17:56 +0300
committerSoumith Chintala <soumith@gmail.com>2016-08-25 03:17:56 +0300
commitd31a725921a9e3bf8916863b0ccfb139fd36940b (patch)
tree8587d08a9f7991b8501be2dcc1a3fd799b193f0c
parentfc70b6b2904159d3a5b46b6fc0d3e8f880964b74 (diff)
add CMaxTable class (#885)
* Add CMaxTable and CMinTable
-rw-r--r--CMaxTable.lua33
-rw-r--r--CMinTable.lua33
-rw-r--r--doc/table.md29
-rw-r--r--init.lua2
-rw-r--r--test.lua30
5 files changed, 127 insertions, 0 deletions
diff --git a/CMaxTable.lua b/CMaxTable.lua
new file mode 100644
index 0000000..3907faf
--- /dev/null
+++ b/CMaxTable.lua
@@ -0,0 +1,33 @@
+local CMaxTable, parent = torch.class('nn.CMaxTable', 'nn.Module')
+
+function CMaxTable:__init()
+ parent.__init(self)
+ self.gradInput = {}
+ self.maxIdx = torch.Tensor()
+end
+
+function CMaxTable:updateOutput(input)
+ self.output:resizeAs(input[1]):copy(input[1])
+ self.maxIdx:resizeAs(input[1]):fill(1)
+ for i=2,#input do
+ local mask = torch.gt(input[i], self.output)
+ self.maxIdx:maskedFill(mask, i)
+ self.output:maskedCopy(mask, input[i][mask])
+ end
+ return self.output
+end
+
+function CMaxTable:updateGradInput(input, gradOutput)
+ for i=1,#input do
+ self.gradInput[i] = torch.Tensor()
+ self.gradInput[i]:resizeAs(input[i]):fill(0.0)
+ local mask = torch.eq(self.maxIdx, i)
+ self.gradInput[i]:maskedCopy(mask, gradOutput[mask])
+ end
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
+ return self.gradInput
+end
diff --git a/CMinTable.lua b/CMinTable.lua
new file mode 100644
index 0000000..a8385e8
--- /dev/null
+++ b/CMinTable.lua
@@ -0,0 +1,33 @@
+local CMinTable, parent = torch.class('nn.CMinTable', 'nn.Module')
+
+function CMinTable:__init()
+ parent.__init(self)
+ self.gradInput = {}
+ self.minIdx = torch.Tensor()
+end
+
+function CMinTable:updateOutput(input)
+ self.output:resizeAs(input[1]):copy(input[1])
+ self.minIdx:resizeAs(input[1]):fill(1)
+ for i=2,#input do
+ local mask = torch.lt(input[i], self.output)
+ self.minIdx:maskedFill(mask, i)
+ self.output:maskedCopy(mask, input[i][mask])
+ end
+ return self.output
+end
+
+function CMinTable:updateGradInput(input, gradOutput)
+ for i=1,#input do
+ self.gradInput[i] = torch.Tensor()
+ self.gradInput[i]:resizeAs(input[i]):fill(0.0)
+ local mask = torch.eq(self.minIdx, i)
+ self.gradInput[i]:maskedCopy(mask, gradOutput[mask])
+ end
+
+ for i=#input+1, #self.gradInput do
+ self.gradInput[i] = nil
+ end
+
+ return self.gradInput
+end
diff --git a/doc/table.md b/doc/table.md
index 5e75173..ee61719 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -24,6 +24,8 @@ This allows one to build very rich architectures:
* [`CSubTable`](#nn.CSubTable): substraction of input `Tensor`s;
* [`CMulTable`](#nn.CMulTable): multiplication of input `Tensor`s;
* [`CDivTable`](#nn.CDivTable): division of input `Tensor`s;
+ * [`CMaxTable`](#nn.CMaxTable): max of input `Tensor`s;
+ * [`CMinTable`](#nn.CMinTable): min of input `Tensor`s;
* `Table` of Criteria:
* [`CriterionTable`](#nn.CriterionTable): wraps a [Criterion](criterion.md#nn.Criterion) so that it can accept a `table` of inputs.
@@ -1264,3 +1266,30 @@ m = nn.CDivTable()
[torch.DoubleTensor of dimension 5]
```
+<a name="nn.CMaxTable"></a>
+## CMaxTable ##
+
+Takes a `table` of `Tensor`s and outputs the max of all of them.
+
+```lua
+m = nn.CMaxTable()
+=m:forward({{torch.Tensor{1,2,3}, torch.Tensor{3,2,1}})
+ 3
+ 2
+ 3
+[torch.DoubleTensor of size 3]
+```
+
+<a name="nn.CMinTable"></a>
+## CMinTable ##
+
+Takes a `table` of `Tensor`s and outputs the min of all of them.
+
+```lua
+m = nn.CMinTable()
+=m:forward({{torch.Tensor{1,2,3}, torch.Tensor{3,2,1}})
+ 1
+ 2
+ 1
+[torch.DoubleTensor of size 3]
+```
diff --git a/init.lua b/init.lua
index 74bb90e..98edfc5 100644
--- a/init.lua
+++ b/init.lua
@@ -54,6 +54,8 @@ require('nn.CAddTable')
require('nn.CDivTable')
require('nn.CMulTable')
require('nn.CSubTable')
+require('nn.CMaxTable')
+require('nn.CMinTable')
require('nn.Euclidean')
require('nn.WeightedEuclidean')
diff --git a/test.lua b/test.lua
index f914691..6673fed 100644
--- a/test.lua
+++ b/test.lua
@@ -5003,6 +5003,36 @@ function nntest.Copy()
mytester:assert(torch.type(output) == 'torch.FloatTensor', 'copy forward type err')
end
+function nntest.CMaxTable()
+ local input1 = torch.Tensor{{1,3},{2,4}}
+ local input2 = torch.Tensor{{4,2},{3,1}}
+ local input = {input1, input2}
+ local module = nn.CMaxTable()
+ local err1 = torch.add(module:forward(input), -1, torch.Tensor{{4,3},{3,4}})
+ mytester:assertalmosteq(err1:abs():max(), 0, 1e-15, "CMaxTable forward call")
+ local gradOutputs = torch.Tensor{5,6,7,8}
+ local gradInputs = module:backward(input, gradOutputs)
+ local err2 = torch.add(gradInputs[1], -1, torch.Tensor{{0,6},{0,8}})
+ local err3 = torch.add(gradInputs[2], -1, torch.Tensor{{5,0},{7,0}})
+ mytester:assertalmosteq(err2:abs():max(), 0, 1e-15, "CMaxTable backward call")
+ mytester:assertalmosteq(err3:abs():max(), 0, 1e-15, "CMaxTable backward call")
+end
+
+function nntest.CMinTable()
+ local input1 = torch.Tensor{{1,3},{2,4}}
+ local input2 = torch.Tensor{{4,2},{3,1}}
+ local input = {input1, input2}
+ local module = nn.CMinTable()
+ local err1 = torch.add(module:forward(input), -1, torch.Tensor{{1,2},{2,1}})
+ mytester:assertalmosteq(err1:abs():max(), 0, 1e-15, "CMinTable forward call")
+ local gradOutputs = torch.Tensor{5,6,7,8}
+ local gradInputs = module:backward(input, gradOutputs)
+ local err2 = torch.add(gradInputs[1], -1, torch.Tensor{{5,0},{7,0}})
+ local err3 = torch.add(gradInputs[2], -1, torch.Tensor{{0,6},{0,8}})
+ mytester:assertalmosteq(err2:abs():max(), 0, 1e-15, "CMinTable backward call")
+ mytester:assertalmosteq(err3:abs():max(), 0, 1e-15, "CMinTable backward call")
+end
+
function nntest.JoinTable()
local tensor = torch.rand(3,4,5)
local input = {tensor, tensor}