Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Linear.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/Linear.lua b/Linear.lua
index cc6da4e..2e6635c 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -52,11 +52,14 @@ end
function Linear:updateGradInput(input, gradOutput)
if self.gradInput then
+ local nElement = self.gradInput:nElement()
+ self.gradInput:resizeAs(input)
+ if self.gradInput:nElement() ~= nElement then
+ self.gradInput:zero()
+ end
if input:dim() == 1 then
- self.gradInput:resizeAs(input)
self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
elseif input:dim() == 2 then
- self.gradInput:resizeAs(input)
self.gradInput:addmm(0, 1, gradOutput, self.weight)
end