Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2017-07-03 22:33:30 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-12 19:56:17 +0300
commitd175269377941de0d4846f977121205109c48970 (patch)
treea27a3532759fd57da155e4bea06d7e8a62f73eb5
parent4e085cd2dfb5aba7bb959fd1be9616ee1e35ddf5 (diff)
Avoid two unnecessary copies in addmm backward
The `r_` and `t` tensors become different objects, even though they point to the same data. Avoid the copy whenever beta=0.
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu14
1 files changed, 10 insertions, 4 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index a6aa074..61c255a 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -121,7 +121,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix);
- // r_ will have answer as matrix, need to return a vecotr
+ // r_ will have answer as matrix, need to return a vector
THCTensor_(resize1d)(state, r_, r_->size[0]);
THCTensor_(free)(state, vecAsMatrix);
THCTensor_(free)(state, tAsMatrix);
@@ -245,7 +245,9 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
if(t != r_)
{
THCTensor_(resizeAs)(state, r_, t);
- THCTensor_(copy)(state, r_, t);
+ if (ScalarConvert<real, double>::to(beta) != 0.0) {
+ THCTensor_(copy)(state, r_, t);
+ }
}
/* r_ */
@@ -402,7 +404,9 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
if (t != result) {
THCTensor_(resizeAs)(state, result, t);
- THCTensor_(copy)(state, result, t);
+ if (ScalarConvert<real, double>::to(beta) != 0.0) {
+ THCTensor_(copy)(state, result, t);
+ }
}
THCTensor *slice1 = THCTensor_(new)(state);
@@ -450,7 +454,9 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
if (t != result) {
THCTensor_(resizeAs)(state, result, t);
- THCTensor_(copy)(state, result, t);
+ if (ScalarConvert<real, double>::to(beta) != 0.0) {
+ THCTensor_(copy)(state, result, t);
+ }
}
bool transpose_result;