Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdria Puigdomenech <adriap@google.com>2015-08-04 17:52:10 +0300
committerAdria Puigdomenech <adriap@google.com>2015-08-04 17:52:10 +0300
commita1762b989ac6a650307adc0f7620bfe1609301b7 (patch)
tree85c9e9692d66a513e4a1455b793cc1d2008ac927 /Mean.lua
parentd72e7a6949f55694a62fb490726ef9f5758ea059 (diff)
Fix contiguous output bug in nn.Mean
Diffstat (limited to 'Mean.lua')
-rw-r--r--Mean.lua18
1 files changed, 6 insertions, 12 deletions
diff --git a/Mean.lua b/Mean.lua
index 541025f..c8d2b30 100644
--- a/Mean.lua
+++ b/Mean.lua
@@ -4,6 +4,7 @@ function Mean:__init(dimension)
parent.__init(self)
dimension = dimension or 1
self.dimension = dimension
+ self._gradInput = torch.Tensor()
end
function Mean:updateOutput(input)
@@ -15,20 +16,13 @@ function Mean:updateOutput(input)
end
function Mean:updateGradInput(input, gradOutput)
- local size = gradOutput:size():totable()
- local stride = gradOutput:stride():totable()
+ self._gradInput:resizeAs(gradOutput):copy(gradOutput)
+ self._gradInput:mul(1/input:size(self.dimension))
if input:nDimension() > 1 then
- table.insert(size, self.dimension, input:size(self.dimension))
- table.insert(stride, self.dimension, 0)
- else
- size[1] = input:size(1)
- stride[1] = 0
+ self._gradInput = nn.utils.addSingletonDimension(self._gradInput,
+ self.dimension)
end
-
- self.gradInput:resizeAs(gradOutput):copy(gradOutput)
- self.gradInput:mul(1/input:size(self.dimension))
- self.gradInput:resize(torch.LongStorage(size), torch.LongStorage(stride))
-
+ self.gradInput = self._gradInput:expandAs(input)
return self.gradInput
end