Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SplitTable.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d2c690ed07e2cd1d1d2ac9dcc0668eea8378cf90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')

function SplitTable:__init(dimension)
   parent.__init(self)
   self.modules = {} 
   self.dimension = dimension
end

function SplitTable:updateOutput(input)
   local currentOutput= {};
   local slices = input:size(self.dimension)
   for i=1,slices do
      currentOutput[#currentOutput+1] = input:select(self.dimension,i)
   end
   self.output = currentOutput
   return self.output
end 


function SplitTable:updateGradInput(input, gradOutput)
   local slices = input:size(self.dimension)
   self.gradInput:resizeAs(input)

   local offset = 1
   for i=1,slices do 
      local currentGradInput = gradOutput[i];        
      self.gradInput:select(self.dimension,i):copy(currentGradInput)
   end
   return self.gradInput
end