Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2013-03-23 16:06:28 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2013-03-23 16:06:28 +0400
commit7ba7f215064346294277bd6b525fd098418659d2 (patch)
tree9c7478de57d546cb9d1e049da9f04b06e8d7a8a5
parent441aa30ea8d7f710c5776be605d4d7cf5746ddc0 (diff)
parent0d43338c978fd6ae317bc3cb662868b07794bdc8 (diff)
Merge pull request #117 from akfidjeland/linear_nan
Linear:updateGradInput avoids NaN and inf
-rw-r--r--Linear.lua7
1 files changed, 5 insertions, 2 deletions
diff --git a/Linear.lua b/Linear.lua
index cc6da4e..2e6635c 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -52,11 +52,14 @@ end
function Linear:updateGradInput(input, gradOutput)
if self.gradInput then
+ local nElement = self.gradInput:nElement()
+ self.gradInput:resizeAs(input)
+ if self.gradInput:nElement() ~= nElement then
+ self.gradInput:zero()
+ end
if input:dim() == 1 then
- self.gradInput:resizeAs(input)
self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
elseif input:dim() == 2 then
- self.gradInput:resizeAs(input)
self.gradInput:addmm(0, 1, gradOutput, self.weight)
end