From d175269377941de0d4846f977121205109c48970 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 3 Jul 2017 12:33:30 -0700 Subject: Avoid two unnecessary copies in addmm backward The `r_` and `t` tensors become different objects, even though they point to the same data. Avoid the copy whenever beta=0. --- lib/THC/generic/THCTensorMathBlas.cu | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index a6aa074..61c255a 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -121,7 +121,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix); - // r_ will have answer as matrix, need to return a vecotr + // r_ will have answer as matrix, need to return a vector THCTensor_(resize1d)(state, r_, r_->size[0]); THCTensor_(free)(state, vecAsMatrix); THCTensor_(free)(state, tAsMatrix); @@ -245,7 +245,9 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real if(t != r_) { THCTensor_(resizeAs)(state, r_, t); - THCTensor_(copy)(state, r_, t); + if (ScalarConvert::to(beta) != 0.0) { + THCTensor_(copy)(state, r_, t); + } } /* r_ */ @@ -402,7 +404,9 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, if (t != result) { THCTensor_(resizeAs)(state, result, t); - THCTensor_(copy)(state, result, t); + if (ScalarConvert::to(beta) != 0.0) { + THCTensor_(copy)(state, result, t); + } } THCTensor *slice1 = THCTensor_(new)(state); @@ -450,7 +454,9 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, if (t != result) { THCTensor_(resizeAs)(state, result, t); - THCTensor_(copy)(state, result, t); + if (ScalarConvert::to(beta) != 0.0) { + THCTensor_(copy)(state, result, t); + } } bool transpose_result; -- cgit v1.2.3