diff options
author | Ivo Danihelka <danihelka@google.com> | 2015-10-19 15:45:18 +0300 |
---|---|---|
committer | Ivo Danihelka <danihelka@google.com> | 2015-10-19 15:45:18 +0300 |
commit | aebe940ed7bd9f8638f73839d836b90038095310 (patch) | |
tree | 5cfb5f1b67e34b2826f34b414f9f887cd0d72bc9 | |
parent | 6fd3e3d1e732d3d71ac38e525d129555e7882a63 (diff) |
Add nngraph.nest() utility function.
-rw-r--r-- | JustTable.lua | 17 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | nest.lua | 46 | ||||
-rw-r--r-- | test/test_nest.lua | 33 |
4 files changed, 98 insertions, 0 deletions
diff --git a/JustTable.lua b/JustTable.lua new file mode 100644 index 0000000..1fc8434 --- /dev/null +++ b/JustTable.lua @@ -0,0 +1,17 @@ + +local JustTable, parent = torch.class('nngraph.JustTable', 'nn.Module') +function JustTable:__init() + self.output = {} +end + +-- The input is one element. +-- The output is a table with one element: {element} +function JustTable:updateOutput(input) + self.output[1] = input + return self.output +end + +function JustTable:updateGradInput(input, gradOutput) + self.gradInput = gradOutput[1] + return self.gradInput +end @@ -3,9 +3,11 @@ require 'graph' nngraph = {} +torch.include('nngraph','nest.lua') torch.include('nngraph','node.lua') torch.include('nngraph','gmodule.lua') torch.include('nngraph','graphinspecting.lua') +torch.include('nngraph','JustTable.lua') torch.include('nngraph','ModuleFromCriterion.lua') -- handy functions diff --git a/nest.lua b/nest.lua new file mode 100644 index 0000000..a9da62e --- /dev/null +++ b/nest.lua @@ -0,0 +1,46 @@ + +local function isNode(input) + local typename = torch.typename(input) + return typename and typename == 'nngraph.Node' +end + +local function isNonEmptyList(input) + return type(input) == "table" and #input > 0 +end + +local function _nest(input) + if not isNode(input) and not isNonEmptyList(input) then + error('what is this in the nest input? ' .. tostring(input)) + end + + if isNode(input) then + return input + end + + if #input == 1 then + return nngraph.JustTable()(input) + end + + local wrappedChildren = {} + for i, child in ipairs(input) do + wrappedChildren[i] = _nest(child) + end + return nn.Identity()(wrappedChildren) +end + +-- Returns a nngraph node to represent a nested structure. +-- Usage example: +-- local in1 = nn.Identity()() +-- local in2 = nn.Identity()() +-- local in3 = nn.Identity()() +-- local ok = nn.CAddTable()(nngraph.nest({in1})) +-- local in1Again = nngraph.nest(in1) +-- local state = nngraph.nest({in1, {in2}, in3}) +function nngraph.nest(...) + local nArgs = select("#", ...) + assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.') + + local input = ... + assert(nArgs > 0 and input ~= nil, 'Pass an input.') + return _nest(input) +end diff --git a/test/test_nest.lua b/test/test_nest.lua new file mode 100644 index 0000000..2e9b410 --- /dev/null +++ b/test/test_nest.lua @@ -0,0 +1,33 @@ + +require 'totem' +require 'nngraph' + +local test = {} +local tester = totem.Tester() + +function test.test_output() + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local in3 = nn.Identity()() + local ok = nn.CAddTable()(nngraph.nest({in1})) + local in1Again = nngraph.nest(in1) + local state = nngraph.nest({in1, {in2}, in3}) + + local net = nn.gModule( + {in1, in2, in3}, + {ok, in1Again, state, nngraph.nest({in3}), nngraph.nest({in1, in2})}) + + local val1 = torch.randn(7, 3) + local val2 = torch.randn(2) + local val3 = torch.randn(3) + local expectedOutput = { + val1, val1, {val1, {val2}, val3}, {val3}, {val1, val2}, + } + local output = net:forward({val1, val2, val3}) + tester:eq(output, expectedOutput, "output") +end + + +return tester:add(test):run() + + |