diff options
author | Sam Gross <sgross@fb.com> | 2017-07-03 22:33:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-12 19:56:17 +0300 |
commit | d175269377941de0d4846f977121205109c48970 (patch) | |
tree | a27a3532759fd57da155e4bea06d7e8a62f73eb5 | |
parent | 4e085cd2dfb5aba7bb959fd1be9616ee1e35ddf5 (diff) |
Avoid two unnecessary copies in addmm backward
The `r_` and `t` tensors become different objects, even though they
point to the same data. Avoid the copy whenever beta=0.
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index a6aa074..61c255a 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -121,7 +121,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix); - // r_ will have answer as matrix, need to return a vecotr + // r_ will have answer as matrix, need to return a vector THCTensor_(resize1d)(state, r_, r_->size[0]); THCTensor_(free)(state, vecAsMatrix); THCTensor_(free)(state, tAsMatrix); @@ -245,7 +245,9 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real if(t != r_) { THCTensor_(resizeAs)(state, r_, t); - THCTensor_(copy)(state, r_, t); + if (ScalarConvert<real, double>::to(beta) != 0.0) { + THCTensor_(copy)(state, r_, t); + } } /* r_ */ @@ -402,7 +404,9 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, if (t != result) { THCTensor_(resizeAs)(state, result, t); - THCTensor_(copy)(state, result, t); + if (ScalarConvert<real, double>::to(beta) != 0.0) { + THCTensor_(copy)(state, result, t); + } } THCTensor *slice1 = THCTensor_(new)(state); @@ -450,7 +454,9 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, if (t != result) { THCTensor_(resizeAs)(state, result, t); - THCTensor_(copy)(state, result, t); + if (ScalarConvert<real, double>::to(beta) != 0.0) { + THCTensor_(copy)(state, result, t); + } } bool transpose_result; |