Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvo Danihelka <danihelka@google.com>2015-10-19 15:45:18 +0300
committerIvo Danihelka <danihelka@google.com>2015-10-19 15:45:18 +0300
commitaebe940ed7bd9f8638f73839d836b90038095310 (patch)
tree5cfb5f1b67e34b2826f34b414f9f887cd0d72bc9
parent6fd3e3d1e732d3d71ac38e525d129555e7882a63 (diff)
Add nngraph.nest() utility function.
-rw-r--r--JustTable.lua17
-rw-r--r--init.lua2
-rw-r--r--nest.lua46
-rw-r--r--test/test_nest.lua33
4 files changed, 98 insertions, 0 deletions
diff --git a/JustTable.lua b/JustTable.lua
new file mode 100644
index 0000000..1fc8434
--- /dev/null
+++ b/JustTable.lua
@@ -0,0 +1,17 @@
+
+local JustTable, parent = torch.class('nngraph.JustTable', 'nn.Module')
+function JustTable:__init()
+ self.output = {}
+end
+
+-- The input is one element.
+-- The output is a table with one element: {element}
+function JustTable:updateOutput(input)
+ self.output[1] = input
+ return self.output
+end
+
+function JustTable:updateGradInput(input, gradOutput)
+ self.gradInput = gradOutput[1]
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index a76b80d..9b117a6 100644
--- a/init.lua
+++ b/init.lua
@@ -3,9 +3,11 @@ require 'graph'
nngraph = {}
+torch.include('nngraph','nest.lua')
torch.include('nngraph','node.lua')
torch.include('nngraph','gmodule.lua')
torch.include('nngraph','graphinspecting.lua')
+torch.include('nngraph','JustTable.lua')
torch.include('nngraph','ModuleFromCriterion.lua')
-- handy functions
diff --git a/nest.lua b/nest.lua
new file mode 100644
index 0000000..a9da62e
--- /dev/null
+++ b/nest.lua
@@ -0,0 +1,46 @@
+
+local function isNode(input)
+ local typename = torch.typename(input)
+ return typename and typename == 'nngraph.Node'
+end
+
+local function isNonEmptyList(input)
+ return type(input) == "table" and #input > 0
+end
+
+local function _nest(input)
+ if not isNode(input) and not isNonEmptyList(input) then
+ error('what is this in the nest input? ' .. tostring(input))
+ end
+
+ if isNode(input) then
+ return input
+ end
+
+ if #input == 1 then
+ return nngraph.JustTable()(input)
+ end
+
+ local wrappedChildren = {}
+ for i, child in ipairs(input) do
+ wrappedChildren[i] = _nest(child)
+ end
+ return nn.Identity()(wrappedChildren)
+end
+
+-- Returns a nngraph node to represent a nested structure.
+-- Usage example:
+-- local in1 = nn.Identity()()
+-- local in2 = nn.Identity()()
+-- local in3 = nn.Identity()()
+-- local ok = nn.CAddTable()(nngraph.nest({in1}))
+-- local in1Again = nngraph.nest(in1)
+-- local state = nngraph.nest({in1, {in2}, in3})
+function nngraph.nest(...)
+ local nArgs = select("#", ...)
+ assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
+
+ local input = ...
+ assert(nArgs > 0 and input ~= nil, 'Pass an input.')
+ return _nest(input)
+end
diff --git a/test/test_nest.lua b/test/test_nest.lua
new file mode 100644
index 0000000..2e9b410
--- /dev/null
+++ b/test/test_nest.lua
@@ -0,0 +1,33 @@
+
+require 'totem'
+require 'nngraph'
+
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local in3 = nn.Identity()()
+ local ok = nn.CAddTable()(nngraph.nest({in1}))
+ local in1Again = nngraph.nest(in1)
+ local state = nngraph.nest({in1, {in2}, in3})
+
+ local net = nn.gModule(
+ {in1, in2, in3},
+ {ok, in1Again, state, nngraph.nest({in3}), nngraph.nest({in1, in2})})
+
+ local val1 = torch.randn(7, 3)
+ local val2 = torch.randn(2)
+ local val3 = torch.randn(3)
+ local expectedOutput = {
+ val1, val1, {val1, {val2}, val3}, {val3}, {val1, val2},
+ }
+ local output = net:forward({val1, val2, val3})
+ tester:eq(output, expectedOutput, "output")
+end
+
+
+return tester:add(test):run()
+
+