diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-08-11 12:13:27 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-08-11 12:13:27 +0300 |
commit | 0f5c1cc069519817aee8c22fb16d2af875236fb0 (patch) | |
tree | 85c9e9692d66a513e4a1455b793cc1d2008ac927 | |
parent | d72e7a6949f55694a62fb490726ef9f5758ea059 (diff) | |
parent | a1762b989ac6a650307adc0f7620bfe1609301b7 (diff) |
Merge pull request #345 from kosklain/master
Fix contiguous gradoutput bug in nn.Mean
-rw-r--r-- | Mean.lua | 18 |
1 files changed, 6 insertions, 12 deletions
@@ -4,6 +4,7 @@ function Mean:__init(dimension) parent.__init(self) dimension = dimension or 1 self.dimension = dimension + self._gradInput = torch.Tensor() end function Mean:updateOutput(input) @@ -15,20 +16,13 @@ function Mean:updateOutput(input) end function Mean:updateGradInput(input, gradOutput) - local size = gradOutput:size():totable() - local stride = gradOutput:stride():totable() + self._gradInput:resizeAs(gradOutput):copy(gradOutput) + self._gradInput:mul(1/input:size(self.dimension)) if input:nDimension() > 1 then - table.insert(size, self.dimension, input:size(self.dimension)) - table.insert(stride, self.dimension, 0) - else - size[1] = input:size(1) - stride[1] = 0 + self._gradInput = nn.utils.addSingletonDimension(self._gradInput, + self.dimension) end - - self.gradInput:resizeAs(gradOutput):copy(gradOutput) - self.gradInput:mul(1/input:size(self.dimension)) - self.gradInput:resize(torch.LongStorage(size), torch.LongStorage(stride)) - + self.gradInput = self._gradInput:expandAs(input) return self.gradInput end |