Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-11-02 04:54:43 +0300
committernicholas-leonard <nick@nikopia.org>2014-11-04 01:12:09 +0300
commita597418786d95b2c8e04c8147ae4afb05c5a9d6c (patch)
treef64a02e1d07c0f8fb1a4c9294a6c0f8897b8595a /Add.lua
parent84317af444b2fb5c7d917e123b8950bf6b633562 (diff)
nn.Add works with batches
Diffstat (limited to 'Add.lua')
-rw-r--r--Add.lua40
1 files changed, 26 insertions, 14 deletions
diff --git a/Add.lua b/Add.lua
index fadcd21..e5a7c8b 100644
--- a/Add.lua
+++ b/Add.lua
@@ -5,12 +5,11 @@ function Add:__init(inputSize,scalar)
local size = inputSize
if scalar then size=1 end
+ self.scalar = scalar
self.bias = torch.Tensor(size)
self.gradBias = torch.Tensor(size)
-
- -- state
- self.gradInput:resize(inputSize)
- self.output:resize(inputSize)
+
+ self._ones = torch.Tensor{1}
self:reset()
end
@@ -22,24 +21,32 @@ function Add:reset(stdv)
stdv = 1./math.sqrt(self.bias:size(1))
end
- for i=1,self.bias:size(1) do
- self.bias[i] = torch.uniform(-stdv, stdv)
- end
+ self.bias:uniform(-stdv, stdv)
end
function Add:updateOutput(input)
- self.output:copy(input);
- if self.gradBias:size(1)==1 then
- self.output:add(self.bias[1]);
+ self.output:resizeAs(input):copy(input)
+ if self.scalar then
+ self.output:add(self.bias[1]);
else
- self.output:add(self.bias);
+ if input:isSameSizeAs(self.bias) then
+ self.output:add(self.bias)
+ else
+ local batchSize = input:size(1)
+ if self._ones:size(1) ~= batchSize then
+ self._ones:resize(batchSize):fill(1)
+ end
+ local bias = self.bias:view(-1)
+ local output = self.output:view(batchSize, -1)
+ output:addr(1, self._ones, bias)
+ end
end
return self.output
-end
+end
function Add:updateGradInput(input, gradOutput)
if self.gradInput then
- self.gradInput:copy(gradOutput)
+ self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end
end
@@ -49,6 +56,11 @@ function Add:accGradParameters(input, gradOutput, scale)
if self.gradBias:size(1) == 1 then
self.gradBias[1] = self.gradBias[1] + scale*gradOutput:sum();
else
- self.gradBias:add(scale, gradOutput)
+ if input:isSameSizeAs(self.bias) then
+ self.gradBias:add(scale, gradOutput)
+ else
+ local gradOutput = gradOutput:view(input:size(1), -1)
+ self.gradBias:addmv(scale, gradOutput:t(), self._ones)
+ end
end
end