Welcome to mirror list, hosted at ThFree Co, Russian Federation.

FlattenTable.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3fe2fd5e5be3a31d9de2611d9547ebeb83a9e9b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
local FlattenTable, parent = torch.class('nn.FlattenTable', 'nn.Module')

function FlattenTable:__init()
  parent.__init(self)

  self.output = {}
  self.input_map = {}
  self.gradInput = {}
end

-- Recursive function to flatten a table (output is a table)
local function flatten(output, input)
  local input_map  -- has the same structure as input, but stores the
                   -- indices to the corresponding output
  if torch.type(input) == 'table' then
    input_map = {}
    -- forward DFS order
    for i = 1, #input do
      input_map[#input_map+1] = flatten(output, input[i])
    end
  else
    input_map = #output + 1
    output[input_map] = input  -- append the tensor
  end
  return input_map
end

-- Recursive function to check if we need to rebuild the output table
local function checkMapping(output, input, input_map)
  if input_map == nil or output == nil or input == nil then
    return false
  end
  if torch.type(input) == 'table' then
    if torch.type(input_map) ~= 'table' then
      return false
    end
    if #input ~= #input_map then
      return false
    end
    -- forward DFS order
    for i = 1, #input do
       local ok = checkMapping(output, input[i], input_map[i])
       if not ok then
          return false
       end
    end
    return true
  else
    if torch.type(input_map) ~= 'number' then
      return false
    end
    return output[input_map] == input
  end
end

-- During BPROP we have to build a gradInput with the same shape as the
-- input.  This is a recursive function to build up a gradInput
local function inverseFlatten(gradOutput, input_map)
  if torch.type(input_map) == 'table' then
    local gradInput = {}
    for i = 1, #input_map do
      gradInput[#gradInput + 1] = inverseFlatten(gradOutput, input_map[i])
    end
    return gradInput
  else
    return gradOutput[input_map]
  end
end

function FlattenTable:updateOutput(input)
  assert(torch.type(input) == 'table', 'input must be a table')
  -- to avoid updating rebuilding the flattened table every updateOutput call
  -- we will do a DFS pass over the existing output table and the inputs to
  -- see if it needs to be rebuilt.
  if not checkMapping(self.output, input, self.input_map) then
    self.output = {}
    self.input_map = flatten(self.output, input)
  end
  return self.output
end

function FlattenTable:updateGradInput(input, gradOutput)
  assert(torch.type(input) == 'table', 'input must be a table')
  assert(torch.type(input) == 'table', 'gradOutput must be a table')
  -- If the input changes between the updateOutput and updateGradInput call,
  -- then we may have to rebuild the input_map!  However, let's assume that
  -- the input_map is valid and that forward has already been called.

  -- However, we should check that the gradInput is valid:
  if not checkMapping(gradOutput, self.gradInput, self.input_map) then
    self.gradInput = inverseFlatten(gradOutput, self.input_map)
  end

  return self.gradInput
end

function FlattenTable:type(type, tensorCache)
  -- This function just stores references so we don't need to do any type
  -- conversions.  Just force the tables to be empty.
  self:clearState()
end

function FlattenTable:clearState()
  self.input_map = {}
  return parent.clearState(self)
end