Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Fidjeland <andreas.fidjeland@gmail.com>2013-03-13 22:22:08 +0400
committerAndreas Fidjeland <andreas.fidjeland@gmail.com>2013-03-13 22:25:46 +0400
commit0d43338c978fd6ae317bc3cb662868b07794bdc8 (patch)
treebfaac982c0bf64e15acaf727aff2622a0e091e41
parentedf13338e79ec336b5d5fb0b7f95fc8864f747cc (diff)
Linear:updateGradInput avoids NaN and inf
The resize in Linear:updateGradInput can introduce NaN and inf into the gradients. The resize itself leaves garbage in the gradInput tensor. For normal numbers the subsequent addmm/addmv will clear the garbage. However, if gradInput contains either nan or inf after the resize, the multiply will result in nan instead of the desired result.
-rw-r--r--Linear.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/Linear.lua b/Linear.lua
index cc6da4e..2e6635c 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -52,11 +52,14 @@ end
function Linear:updateGradInput(input, gradOutput)
if self.gradInput then
+ local nElement = self.gradInput:nElement()
+ self.gradInput:resizeAs(input)
+ if self.gradInput:nElement() ~= nElement then
+ self.gradInput:zero()
+ end
if input:dim() == 1 then
- self.gradInput:resizeAs(input)
self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
elseif input:dim() == 2 then
- self.gradInput:resizeAs(input)
self.gradInput:addmm(0, 1, gradOutput, self.weight)
end