diff options
author | Denis Yarats <denisyarats@gmail.com> | 2016-11-15 22:55:36 +0300 |
---|---|---|
committer | Denis Yarats <denisyarats@gmail.com> | 2016-11-15 22:55:36 +0300 |
commit | b20b7c9eb98663b7e994d7efda35fce08d046e96 (patch) | |
tree | 94588dbfb328ce97530ed556fbc1f00367b10d9d | |
parent | 5774690224c64b8cc3f9c2c6f385febb6fb8ad41 (diff) |
[cutorch] remove syncing point from baddbmm
This change removes HtoD copies inside baddbmm. These copies
introduce a syncing point which causes slow downs in a multi
gpu training.
Test plan: Run unittests for baddbmm.
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index d4bd3c2..18ca290 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -389,6 +389,14 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, #endif } +__global__ void createBatchGemmBuffer(const real** buffer, real* data, + long stride, long num_batches) { + const long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_batches) { + buffer[idx] = data + idx * stride; + } +} + THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, real alpha, THCTensor *batch1, THCTensor *batch2) { @@ -487,15 +495,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, // Compute pointers to matrices in each batch. long num_batches = result_->size[0]; size_t matrices_size = num_batches * sizeof(real*); - const real **matrices1 = (const real **)THAlloc(matrices_size); - const real **matrices2 = (const real **)THAlloc(matrices_size); - real **result_matrices = (real **)THAlloc(matrices_size); - for (int i = 0; i < num_batches; ++i) - { - matrices1[i] = THCTensor_(data)(state, batch1_) + i * batch1_->stride[0]; - matrices2[i] = THCTensor_(data)(state, batch2_) + i * batch2_->stride[0]; - result_matrices[i] = THCTensor_(data)(state, result_) + i * result_->stride[0]; - } // Copy pointers to device. const real **d_matrices1, **d_matrices2; @@ -504,12 +503,18 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size)); THCudaCheck(THCudaMalloc(state, (void**)&d_result_matrices, matrices_size)); - THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size, - cudaMemcpyHostToDevice, THCState_getCurrentStream(state))); - THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size, - cudaMemcpyHostToDevice, THCState_getCurrentStream(state))); - THCudaCheck(cudaMemcpyAsync(d_result_matrices, result_matrices, matrices_size, - cudaMemcpyHostToDevice, THCState_getCurrentStream(state))); + const long block = 512; + const long grid = (num_batches + block - 1) / block; + + createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>( + d_matrices1, THCTensor_(data)(state, batch1_), batch1_->stride[0], + num_batches); + createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>( + d_matrices2, THCTensor_(data)(state, batch2_), batch2_->stride[0], + num_batches); + createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>( + (const real**)d_result_matrices, THCTensor_(data)(state,result_), + result_->stride[0], num_batches); #ifdef THC_REAL_IS_FLOAT THCudaBlas_SgemmBatched( @@ -544,9 +549,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, THCudaFree(state, d_matrices1); THCudaFree(state, d_matrices2); THCudaFree(state, d_result_matrices); - THFree(matrices1); - THFree(matrices2); - THFree(result_matrices); if (batch1_ != batch1) { THCTensor_(free)(state, batch1_); |