Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-12-19 02:43:13 +0300
committerSoumith Chintala <soumith@gmail.com>2015-12-19 02:43:13 +0300
commit6cd17c40f4cbc426d20894a58d81360724253333 (patch)
tree7facc0cc09a782d378dfe5357c9a502cbb4b2380
parenta4cc6c9fbad6fb322fcd69937d67c5183f6fcc37 (diff)
parentf290385da6c6e553b9c5d6966fe454e5b45a88ff (diff)
Merge pull request #96 from fidlej/topic_split_one
Allow to do node:split(1).
-rw-r--r--JustElement.lua18
-rw-r--r--init.lua1
-rw-r--r--node.lua4
-rw-r--r--test/test_JustElement.lua28
-rw-r--r--test/test_JustTable.lua19
5 files changed, 69 insertions, 1 deletions
diff --git a/JustElement.lua b/JustElement.lua
new file mode 100644
index 0000000..0f18972
--- /dev/null
+++ b/JustElement.lua
@@ -0,0 +1,18 @@
+
+local JustElement, parent = torch.class('nngraph.JustElement', 'nn.Module')
+function JustElement:__init()
+ self.gradInput = {}
+end
+
+-- The input is a table with one element.
+-- The output the element from the table.
+function JustElement:updateOutput(input)
+ assert(#input == 1, "expecting one element")
+ self.output = input[1]
+ return self.output
+end
+
+function JustElement:updateGradInput(input, gradOutput)
+ self.gradInput[1] = gradOutput
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index 1eae7cb..ac38c51 100644
--- a/init.lua
+++ b/init.lua
@@ -7,6 +7,7 @@ torch.include('nngraph','nest.lua')
torch.include('nngraph','node.lua')
torch.include('nngraph','gmodule.lua')
torch.include('nngraph','graphinspecting.lua')
+torch.include('nngraph','JustElement.lua')
torch.include('nngraph','JustTable.lua')
torch.include('nngraph','ModuleFromCriterion.lua')
diff --git a/node.lua b/node.lua
index 3842605..a55aa48 100644
--- a/node.lua
+++ b/node.lua
@@ -46,7 +46,9 @@ end
-- that each take a single component of the output of this
-- node in the order they are returned.
function nnNode:split(noutput)
- assert(noutput >= 2, "splitting to one output is not supported")
+ if noutput == 1 then
+ return nngraph.JustElement()(self)
+ end
local debugLabel = self.data.annotations._debugLabel
-- Specify the source location where :split is called.
local dinfo = debug.getinfo(2, 'Sl')
diff --git a/test/test_JustElement.lua b/test/test_JustElement.lua
new file mode 100644
index 0000000..d6c49a8
--- /dev/null
+++ b/test/test_JustElement.lua
@@ -0,0 +1,28 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+ local input = {torch.randn(7, 11)}
+ local module = nngraph.JustElement()
+ tester:eq(module:forward(input), input[1], "output")
+end
+
+function test.test_grad()
+ local input = {torch.randn(7, 11)}
+ local module = nngraph.JustElement()
+ totem.nn.checkGradients(tester, module, input)
+end
+
+function test.test_split()
+ local in1 = nn.Identity()()
+ local output = in1:split(1)
+ local net = nn.gModule({in1}, {output})
+
+ local input = {torch.randn(7, 11)}
+ tester:eq(net:forward(input), input[1], "output of split(1)")
+end
+
+tester:add(test):run()
diff --git a/test/test_JustTable.lua b/test/test_JustTable.lua
new file mode 100644
index 0000000..d24d739
--- /dev/null
+++ b/test/test_JustTable.lua
@@ -0,0 +1,19 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+ local input = torch.randn(7, 11)
+ local module = nngraph.JustTable()
+ tester:eq(module:forward(input), {input}, "output")
+end
+
+function test.test_grad()
+ local input = torch.randn(7, 11)
+ local module = nngraph.JustTable()
+ totem.nn.checkGradients(tester, module, input)
+end
+
+tester:add(test):run()