Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Deltheil <cedric@moodstocks.com>2015-10-11 02:38:23 +0300
committerCédric Deltheil <cedric@moodstocks.com>2015-10-11 02:38:23 +0300
commit9d39190ea1cdf39253cddc50b57b9fda3766f813 (patch)
treeaca4d7cf93c6f65975732072d7b035eac081b280 /TensorMath.lua
parent28de02639e8e970532f2635d47ba3eabdf50e04f (diff)
TensorMath.lua: zero init result tensor
A function like `torch.mm` stores its output into a result tensor (res) and uses the `addmm` function behind the scenes. `addmm` performs multi- plications like res = beta x res + alpha x sum with beta=0 and alpha=1 here. If res is not initialized NaN values could occur resulting on NaN output values. The same applies for `torch.bmm` and `torch.mv`.
Diffstat (limited to 'TensorMath.lua')
-rw-r--r--TensorMath.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index ac52dc6..6acb6b8 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -314,6 +314,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
string.format("TH%s_resize1d(%s, %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg())
}, '\n')
end,
+ precall=function(arg)
+ return table.concat(
+ {
+ string.format("TH%s_zero(%s);", Tensor, arg:carg()),
+ arg.__metatable.precall(arg)
+ }, '\n')
+ end,
},
{name=real, default=0, invisible=true},
{name=Tensor, default=1, invisible=true},
@@ -332,6 +339,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
string.format("TH%s_resize2d(%s, %s->size[0], %s->size[1]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
+ precall=function(arg)
+ return table.concat(
+ {
+ string.format("TH%s_zero(%s);", Tensor, arg:carg()),
+ arg.__metatable.precall(arg)
+ }, '\n')
+ end,
},
{name=real, default=0, invisible=true},
{name=Tensor, default=1, invisible=true},
@@ -351,6 +365,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
Tensor, arg:carg(), arg.args[5]:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
+ precall=function(arg)
+ return table.concat(
+ {
+ string.format("TH%s_zero(%s);", Tensor, arg:carg()),
+ arg.__metatable.precall(arg)
+ }, '\n')
+ end,
},
{name=real, default=0, invisible=true},
{name=Tensor, default=1, invisible=true},