diff options
author | Ivo Danihelka <danihelka@google.com> | 2015-12-18 19:56:04 +0300 |
---|---|---|
committer | Ivo Danihelka <danihelka@google.com> | 2015-12-18 19:56:04 +0300 |
commit | f290385da6c6e553b9c5d6966fe454e5b45a88ff (patch) | |
tree | 7facc0cc09a782d378dfe5357c9a502cbb4b2380 | |
parent | a4cc6c9fbad6fb322fcd69937d67c5183f6fcc37 (diff) |
Allow to do node:split(1).
-rw-r--r-- | JustElement.lua | 18 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | node.lua | 4 | ||||
-rw-r--r-- | test/test_JustElement.lua | 28 | ||||
-rw-r--r-- | test/test_JustTable.lua | 19 |
5 files changed, 69 insertions, 1 deletions
diff --git a/JustElement.lua b/JustElement.lua new file mode 100644 index 0000000..0f18972 --- /dev/null +++ b/JustElement.lua @@ -0,0 +1,18 @@ + +local JustElement, parent = torch.class('nngraph.JustElement', 'nn.Module') +function JustElement:__init() + self.gradInput = {} +end + +-- The input is a table with one element. +-- The output the element from the table. +function JustElement:updateOutput(input) + assert(#input == 1, "expecting one element") + self.output = input[1] + return self.output +end + +function JustElement:updateGradInput(input, gradOutput) + self.gradInput[1] = gradOutput + return self.gradInput +end @@ -7,6 +7,7 @@ torch.include('nngraph','nest.lua') torch.include('nngraph','node.lua') torch.include('nngraph','gmodule.lua') torch.include('nngraph','graphinspecting.lua') +torch.include('nngraph','JustElement.lua') torch.include('nngraph','JustTable.lua') torch.include('nngraph','ModuleFromCriterion.lua') @@ -46,7 +46,9 @@ end -- that each take a single component of the output of this -- node in the order they are returned. function nnNode:split(noutput) - assert(noutput >= 2, "splitting to one output is not supported") + if noutput == 1 then + return nngraph.JustElement()(self) + end local debugLabel = self.data.annotations._debugLabel -- Specify the source location where :split is called. local dinfo = debug.getinfo(2, 'Sl') diff --git a/test/test_JustElement.lua b/test/test_JustElement.lua new file mode 100644 index 0000000..d6c49a8 --- /dev/null +++ b/test/test_JustElement.lua @@ -0,0 +1,28 @@ + +require 'totem' +require 'nngraph' +local test = {} +local tester = totem.Tester() + +function test.test_output() + local input = {torch.randn(7, 11)} + local module = nngraph.JustElement() + tester:eq(module:forward(input), input[1], "output") +end + +function test.test_grad() + local input = {torch.randn(7, 11)} + local module = nngraph.JustElement() + totem.nn.checkGradients(tester, module, input) +end + +function test.test_split() + local in1 = nn.Identity()() + local output = in1:split(1) + local net = nn.gModule({in1}, {output}) + + local input = {torch.randn(7, 11)} + tester:eq(net:forward(input), input[1], "output of split(1)") +end + +tester:add(test):run() diff --git a/test/test_JustTable.lua b/test/test_JustTable.lua new file mode 100644 index 0000000..d24d739 --- /dev/null +++ b/test/test_JustTable.lua @@ -0,0 +1,19 @@ + +require 'totem' +require 'nngraph' +local test = {} +local tester = totem.Tester() + +function test.test_output() + local input = torch.randn(7, 11) + local module = nngraph.JustTable() + tester:eq(module:forward(input), {input}, "output") +end + +function test.test_grad() + local input = torch.randn(7, 11) + local module = nngraph.JustTable() + totem.nn.checkGradients(tester, module, input) +end + +tester:add(test):run() |