Welcome to mirror list, hosted at ThFree Co, Russian Federation.

JoinTable.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6ab68e189dd6f524ddcb608886b42dda42a6cd35 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
local JoinTable, parent = torch.class('nn.JoinTable', 'nn.Module')

function JoinTable:__init(dimension, nInputDims)
   parent.__init(self)
   self.size = torch.LongStorage()
   self.dimension = dimension
   self.gradInput = {}
   self.nInputDims = nInputDims
end

function JoinTable:_getPositiveDimension(input)
   local dimension = self.dimension
   if dimension < 0 then
      dimension = input[1]:dim() + dimension + 1
   elseif self.nInputDims and input[1]:dim()==(self.nInputDims+1) then
      dimension = dimension + 1
   end
   return dimension
end

function JoinTable:updateOutput(input)
   local dimension = self:_getPositiveDimension(input)

   for i=1,#input do
      local currentOutput = input[i]
      if i == 1 then
         self.size:resize(currentOutput:dim()):copy(currentOutput:size())
      else
         self.size[dimension] = self.size[dimension]
            + currentOutput:size(dimension)
      end
   end
   self.output:resize(self.size)

   local offset = 1
   for i=1,#input do
      local currentOutput = input[i]
      self.output:narrow(dimension, offset,
         currentOutput:size(dimension)):copy(currentOutput)
      offset = offset + currentOutput:size(dimension)
   end
   return self.output
end

function JoinTable:updateGradInput(input, gradOutput)
   local dimension = self:_getPositiveDimension(input)

   for i=1,#input do
      if self.gradInput[i] == nil then
         self.gradInput[i] = input[i].new()
      end
      self.gradInput[i]:resizeAs(input[i])
   end

   -- clear out invalid gradInputs
   for i=#input+1, #self.gradInput do
      self.gradInput[i] = nil
   end

   local offset = 1
   for i=1,#input do
      local currentOutput = input[i]
      local currentGradInput = gradOutput:narrow(dimension, offset,
                      currentOutput:size(dimension))
      self.gradInput[i]:copy(currentGradInput)
      offset = offset + currentOutput:size(dimension)
   end
   return self.gradInput
end

function JoinTable:type(type, tensorCache)
   self.gradInput = {}
   return parent.type(self, type, tensorCache)
end