diff options
author | Andreas Fidjeland <andreas.fidjeland@gmail.com> | 2013-03-13 22:22:08 +0400 |
---|---|---|
committer | Andreas Fidjeland <andreas.fidjeland@gmail.com> | 2013-03-13 22:25:46 +0400 |
commit | 0d43338c978fd6ae317bc3cb662868b07794bdc8 (patch) | |
tree | bfaac982c0bf64e15acaf727aff2622a0e091e41 | |
parent | edf13338e79ec336b5d5fb0b7f95fc8864f747cc (diff) |
Linear:updateGradInput avoids NaN and inf
The resize in Linear:updateGradInput can introduce NaN and inf into the
gradients. The resize itself leaves garbage in the gradInput tensor.
For normal numbers the subsequent addmm/addmv will clear the garbage.
However, if gradInput contains either nan or inf after the resize, the
multiply will result in nan instead of the desired result.
-rw-r--r-- | Linear.lua | 7 |
1 files changed, 5 insertions, 2 deletions
@@ -52,11 +52,14 @@ end function Linear:updateGradInput(input, gradOutput) if self.gradInput then + local nElement = self.gradInput:nElement() + self.gradInput:resizeAs(input) + if self.gradInput:nElement() ~= nElement then + self.gradInput:zero() + end if input:dim() == 1 then - self.gradInput:resizeAs(input) self.gradInput:addmv(0, 1, self.weight:t(), gradOutput) elseif input:dim() == 2 then - self.gradInput:resizeAs(input) self.gradInput:addmm(0, 1, gradOutput, self.weight) end |