diff options
author | Cédric Deltheil <cedric@moodstocks.com> | 2015-10-11 02:38:23 +0300 |
---|---|---|
committer | Cédric Deltheil <cedric@moodstocks.com> | 2015-10-11 02:38:23 +0300 |
commit | 9d39190ea1cdf39253cddc50b57b9fda3766f813 (patch) | |
tree | aca4d7cf93c6f65975732072d7b035eac081b280 /TensorMath.lua | |
parent | 28de02639e8e970532f2635d47ba3eabdf50e04f (diff) |
TensorMath.lua: zero init result tensor
A function like `torch.mm` stores its output into a result tensor (res)
and uses the `addmm` function behind the scenes. `addmm` performs multi-
plications like res = beta x res + alpha x sum with beta=0 and alpha=1
here. If res is not initialized NaN values could occur resulting on NaN
output values.
The same applies for `torch.bmm` and `torch.mv`.
Diffstat (limited to 'TensorMath.lua')
-rw-r--r-- | TensorMath.lua | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index ac52dc6..6acb6b8 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -314,6 +314,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", string.format("TH%s_resize1d(%s, %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg()) }, '\n') end, + precall=function(arg) + return table.concat( + { + string.format("TH%s_zero(%s);", Tensor, arg:carg()), + arg.__metatable.precall(arg) + }, '\n') + end, }, {name=real, default=0, invisible=true}, {name=Tensor, default=1, invisible=true}, @@ -332,6 +339,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", string.format("TH%s_resize2d(%s, %s->size[0], %s->size[1]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg()) }, '\n') end, + precall=function(arg) + return table.concat( + { + string.format("TH%s_zero(%s);", Tensor, arg:carg()), + arg.__metatable.precall(arg) + }, '\n') + end, }, {name=real, default=0, invisible=true}, {name=Tensor, default=1, invisible=true}, @@ -351,6 +365,13 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[5]:carg(), arg.args[6]:carg()) }, '\n') end, + precall=function(arg) + return table.concat( + { + string.format("TH%s_zero(%s);", Tensor, arg:carg()), + arg.__metatable.precall(arg) + }, '\n') + end, }, {name=real, default=0, invisible=true}, {name=Tensor, default=1, invisible=true}, |