diff options
author | Sam Gross <sgross@fb.com> | 2017-07-03 22:33:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-12 19:56:08 +0300 |
commit | eda1c8892fdb054bba0869bc81355eaad5260c2c (patch) | |
tree | 12dfa411dfc34832c5d22aaf5f8904ec709b50f3 | |
parent | ada30dad41f5a72c128514af0a4a67e8f565e40a (diff) |
Avoid two unnecessary copies in addmm backward
The `r_` and `t` tensors become different objects, even though they
point to the same data. Avoid the copy whenever beta=0.
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index db7a0cb..e85d607 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1306,7 +1306,9 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor if(t != r_) { THTensor_(resizeAs)(r_, t); - THTensor_(copy)(r_, t); + if (beta != 0.0) { + THTensor_(copy)(r_, t); + } } /* r_ */ @@ -1476,7 +1478,9 @@ void THTensor_(addbmm)(THTensor *result, real beta, THTensor *t, real alpha, THT if (t != result) { THTensor_(resizeAs)(result, t); - THTensor_(copy)(result, t); + if (beta != 0.0) { + THTensor_(copy)(result, t); + } } THTensor *matrix1 = THTensor_(new)(); @@ -1517,7 +1521,9 @@ void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, TH if (t != result) { THTensor_(resizeAs)(result, t); - THTensor_(copy)(result, t); + if (beta != 0.0) { + THTensor_(copy)(result, t); + } } THTensor *matrix1 = THTensor_(new)(); |