Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Sum.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: eef72081b3820dc1d0f5bf5ef3aa4f92d242dab4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
local Sum, parent = torch.class('nn.Sum', 'nn.Module')

function Sum:__init(dimension)
   parent.__init(self)
   dimension = dimension or 1
   self.dimension = dimension
end

function Sum:updateOutput(input)
   if type(self.output) == 'number' then
      self.output = input.new()
   end
   self.output:sum(input, self.dimension)
   if self.output:nDimension() > 1 then
      self.output = self.output:select(self.dimension, 1)
   end
   return self.output
end

function Sum:updateGradInput(input, gradOutput)
    -- zero-strides dont work with MKL/BLAS, so
    -- dont set self.gradInput to zero-stride tensor.
    -- Instead, do a deepcopy
    local size = input:size()
    size[self.dimension] = 1
    gradOutput = gradOutput:view(size)
    self.gradInput:resizeAs(input)
    self.gradInput:copy(gradOutput:expandAs(input))

    return self.gradInput
end