Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDenis Yarats <denisyarats@gmail.com>2016-11-15 22:55:36 +0300
committerDenis Yarats <denisyarats@gmail.com>2016-11-15 22:55:36 +0300
commitb20b7c9eb98663b7e994d7efda35fce08d046e96 (patch)
tree94588dbfb328ce97530ed556fbc1f00367b10d9d
parent5774690224c64b8cc3f9c2c6f385febb6fb8ad41 (diff)
[cutorch] remove syncing point from baddbmm
This change removes HtoD copies inside baddbmm. These copies introduce a syncing point which causes slow downs in a multi gpu training. Test plan: Run unittests for baddbmm.
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu38
1 files changed, 20 insertions, 18 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index d4bd3c2..18ca290 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -389,6 +389,14 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
#endif
}
+__global__ void createBatchGemmBuffer(const real** buffer, real* data,
+ long stride, long num_batches) {
+ const long idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < num_batches) {
+ buffer[idx] = data + idx * stride;
+ }
+}
+
THC_API void
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
real alpha, THCTensor *batch1, THCTensor *batch2) {
@@ -487,15 +495,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
// Compute pointers to matrices in each batch.
long num_batches = result_->size[0];
size_t matrices_size = num_batches * sizeof(real*);
- const real **matrices1 = (const real **)THAlloc(matrices_size);
- const real **matrices2 = (const real **)THAlloc(matrices_size);
- real **result_matrices = (real **)THAlloc(matrices_size);
- for (int i = 0; i < num_batches; ++i)
- {
- matrices1[i] = THCTensor_(data)(state, batch1_) + i * batch1_->stride[0];
- matrices2[i] = THCTensor_(data)(state, batch2_) + i * batch2_->stride[0];
- result_matrices[i] = THCTensor_(data)(state, result_) + i * result_->stride[0];
- }
// Copy pointers to device.
const real **d_matrices1, **d_matrices2;
@@ -504,12 +503,18 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_result_matrices, matrices_size));
- THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
- THCudaCheck(cudaMemcpyAsync(d_result_matrices, result_matrices, matrices_size,
- cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
+ const long block = 512;
+ const long grid = (num_batches + block - 1) / block;
+
+ createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ d_matrices1, THCTensor_(data)(state, batch1_), batch1_->stride[0],
+ num_batches);
+ createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ d_matrices2, THCTensor_(data)(state, batch2_), batch2_->stride[0],
+ num_batches);
+ createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
+ (const real**)d_result_matrices, THCTensor_(data)(state,result_),
+ result_->stride[0], num_batches);
#ifdef THC_REAL_IS_FLOAT
THCudaBlas_SgemmBatched(
@@ -544,9 +549,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCudaFree(state, d_matrices1);
THCudaFree(state, d_matrices2);
THCudaFree(state, d_result_matrices);
- THFree(matrices1);
- THFree(matrices2);
- THFree(result_matrices);
if (batch1_ != batch1) {
THCTensor_(free)(state, batch1_);