Welcome to mirror list, hosted at ThFree Co, Russian Federation.

nest.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a9da62e39b541d82252c98cb59addc35437fd8b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

local function isNode(input)
   local typename = torch.typename(input)
   return typename and typename == 'nngraph.Node'
end

local function isNonEmptyList(input)
   return type(input) == "table" and #input > 0
end

local function _nest(input)
   if not isNode(input) and not isNonEmptyList(input) then
      error('what is this in the nest input? ' .. tostring(input))
   end

   if isNode(input) then
      return input
   end

   if #input == 1 then
      return nngraph.JustTable()(input)
   end

   local wrappedChildren = {}
   for i, child in ipairs(input) do
      wrappedChildren[i] = _nest(child)
   end
   return nn.Identity()(wrappedChildren)
end

-- Returns a nngraph node to represent a nested structure.
-- Usage example:
--    local in1 = nn.Identity()()
--    local in2 = nn.Identity()()
--    local in3 = nn.Identity()()
--    local ok = nn.CAddTable()(nngraph.nest({in1}))
--    local in1Again = nngraph.nest(in1)
--    local state = nngraph.nest({in1, {in2}, in3})
function nngraph.nest(...)
   local nArgs = select("#", ...)
   assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')

   local input = ...
   assert(nArgs > 0 and input ~= nil, 'Pass an input.')
   return _nest(input)
end