Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SelectTable.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: be3ceac48ef42da50ac53434912f72e19a06ffab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
local SelectTable, parent = torch.class('nn.SelectTable', 'nn.Module')

function SelectTable:__init(index)
   parent.__init(self)
   self.index = index
   self.gradInput = {}
end

function SelectTable:updateOutput(input)
   self.output = input[self.index]
   return self.output
end

local function zeroTableCopy(t1, t2)
   for k, v in pairs(t2) do
      if (torch.type(v) == "table") then
         t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
      else
         if not t1[k] then
            t1[k] = v:clone():zero()
         else
            local tensor = t1[k]
            if not tensor:isSameSizeAs(v) then
               t1[k]:resizeAs(v)
               t1[k]:zero()
            end
         end
      end
   end
   return t1
end

function SelectTable:updateGradInput(input, gradOutput)
   self.gradInput[self.index] = gradOutput
   zeroTableCopy(self.gradInput, input)
   return self.gradInput
end

function SelectTable:type(type)
   self.gradInput = {}
end