diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-11-02 04:54:43 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-11-04 01:12:09 +0300 |
commit | a597418786d95b2c8e04c8147ae4afb05c5a9d6c (patch) | |
tree | f64a02e1d07c0f8fb1a4c9294a6c0f8897b8595a /Add.lua | |
parent | 84317af444b2fb5c7d917e123b8950bf6b633562 (diff) |
nn.Add works with batches
Diffstat (limited to 'Add.lua')
-rw-r--r-- | Add.lua | 40 |
1 files changed, 26 insertions, 14 deletions
@@ -5,12 +5,11 @@ function Add:__init(inputSize,scalar) local size = inputSize if scalar then size=1 end + self.scalar = scalar self.bias = torch.Tensor(size) self.gradBias = torch.Tensor(size) - - -- state - self.gradInput:resize(inputSize) - self.output:resize(inputSize) + + self._ones = torch.Tensor{1} self:reset() end @@ -22,24 +21,32 @@ function Add:reset(stdv) stdv = 1./math.sqrt(self.bias:size(1)) end - for i=1,self.bias:size(1) do - self.bias[i] = torch.uniform(-stdv, stdv) - end + self.bias:uniform(-stdv, stdv) end function Add:updateOutput(input) - self.output:copy(input); - if self.gradBias:size(1)==1 then - self.output:add(self.bias[1]); + self.output:resizeAs(input):copy(input) + if self.scalar then + self.output:add(self.bias[1]); else - self.output:add(self.bias); + if input:isSameSizeAs(self.bias) then + self.output:add(self.bias) + else + local batchSize = input:size(1) + if self._ones:size(1) ~= batchSize then + self._ones:resize(batchSize):fill(1) + end + local bias = self.bias:view(-1) + local output = self.output:view(batchSize, -1) + output:addr(1, self._ones, bias) + end end return self.output -end +end function Add:updateGradInput(input, gradOutput) if self.gradInput then - self.gradInput:copy(gradOutput) + self.gradInput:resizeAs(gradOutput):copy(gradOutput) return self.gradInput end end @@ -49,6 +56,11 @@ function Add:accGradParameters(input, gradOutput, scale) if self.gradBias:size(1) == 1 then self.gradBias[1] = self.gradBias[1] + scale*gradOutput:sum(); else - self.gradBias:add(scale, gradOutput) + if input:isSameSizeAs(self.bias) then + self.gradBias:add(scale, gradOutput) + else + local gradOutput = gradOutput:view(input:size(1), -1) + self.gradBias:addmv(scale, gradOutput:t(), self._ones) + end end end |