Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ZipTableOneToMany.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d4a80fe0dd72aa7b10d157c0635874cbb085a2eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
local ZipTableOneToMany, parent = torch.class('nn.ZipTableOneToMany', 'nn.Module')

-- based on ZipTable in dpnn

-- input : { v, {a, b, c} }
-- output : { {v,a}, {v,b}, {v,c} }
function ZipTableOneToMany:__init()
   parent.__init(self)
   self.output = {}
   self.gradInput = {}
   -- make buffer to update during forward/backward
   self.gradInputEl = torch.Tensor()
end

function ZipTableOneToMany:updateOutput(input)
   assert(#input == 2, "input must be table of element and table")
   local inputEl, inputTable = input[1], input[2]
   self.output = {}
   for i,v in ipairs(inputTable) do
      self.output[i] = {inputEl, v}
   end
   return self.output
end

function ZipTableOneToMany:updateGradInput(input, gradOutput)
   assert(#input == 2, "input must be table of element and table")
   local inputEl, inputTable = input[1], input[2]
   self.gradInputEl:resizeAs(inputEl):zero()
   local gradInputTable = {}
   for i,gradV in ipairs(gradOutput) do
      self.gradInputEl:add(gradV[1])
      gradInputTable[i] = gradV[2]
   end
   self.gradInput = {self.gradInputEl, gradInputTable}
   return self.gradInput
end