diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-15 21:35:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-15 21:35:59 +0300 |
commit | 7293be03f00e902179fcb6518b7c43bca2f5790c (patch) | |
tree | 827b404a4cec6b278f8f3c8b1995de566920ba0f | |
parent | 12d9add9301ced59ae9c910f8b511d0fd7f285a1 (diff) | |
parent | 8d74ea923b8ea17ac17bb0853809c0456107f204 (diff) |
Merge pull request #451 from wickedfoo/faster-lookup-table
Improve cunn LookupTable performance for large batch sizes
-rw-r--r-- | lib/THCUNN/LookupTable.cu | 5 | ||||
-rw-r--r-- | lib/THCUNN/generic/LookupTable.cu | 92 |
2 files changed, 63 insertions, 34 deletions
diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu index 29cb3a1..e626632 100644 --- a/lib/THCUNN/LookupTable.cu +++ b/lib/THCUNN/LookupTable.cu @@ -12,10 +12,7 @@ #include <thrust/unique.h> #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" - -#ifndef DIVUP -#define DIVUP(x, y) (((x) + (y) - 1) / (y)) -#endif +#include "THCTensorSort.cuh" const int WARP_SIZE = 32; diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu index fa7c5ac..955b29b 100644 --- a/lib/THCUNN/generic/LookupTable.cu +++ b/lib/THCUNN/generic/LookupTable.cu @@ -8,18 +8,17 @@ void THNN_(LookupTable_accGradParameters)( THCTensor *gradOutput, THCTensor *gradWeight, THCIndexTensor *count, - THCIndexTensor *sorted, - THCIndexTensor *indices, + THCIndexTensor *sortedIndices, + THCIndexTensor *origIndices, bool scaleGradByFreq, int paddingValue, accreal scale_) { real scale = ScalarConvert<accreal, real>::to(scale_); - THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices); + THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sortedIndices, origIndices); gradOutput = THCTensor_(newContiguous)(state, gradOutput); if (!(THCIndexTensor_(isContiguous)(state, input) && - THCTensor_(isContiguous)(state, gradWeight))) - { + THCTensor_(isContiguous)(state, gradWeight))) { THError("Tensors must be contiguous"); } @@ -30,12 +29,15 @@ void THNN_(LookupTable_accGradParameters)( } ptrdiff_t numel = THCIndexTensor_(nElement)(state, input); - long stride = gradWeight->stride[0]; + long stride = THCTensor_(stride)(state, gradWeight, 0); cudaStream_t stream = THCState_getCurrentStream(state); if (numel <= 768 && !scaleGradByFreq) { - cunn_LookupTable_accGradParametersKernelByFeature<<<DIVUP(stride,4), 128, 0, stream>>>( + dim3 grid(THCCeilDiv(stride, (long) 4)); + dim3 block(128); + + cunn_LookupTable_accGradParametersKernelByFeature<<<grid, block, 0, stream>>>( THCIndexTensor_(data)(state, input), THCTensor_(data)(state, gradOutput), THCTensor_(data)(state, gradWeight), @@ -49,35 +51,62 @@ void THNN_(LookupTable_accGradParameters)( } THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input); - THCIndexTensor_(resize)(state, sorted, inputSize, NULL); - THCIndexTensor_(resize)(state, indices, inputSize, NULL); + THCIndexTensor_(resize)(state, sortedIndices, inputSize, NULL); + THCIndexTensor_(resize)(state, origIndices, inputSize, NULL); THLongStorage_free(inputSize); - // Sort the inputs into sorted with the corresponding indices - THCIndexTensor_(sort)(state, sorted, indices, input, 0, 0); + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + THCIndexTensor_(copy)(state, sortedIndices, input); + + THCThrustAllocator thrustAlloc(state); + + thrust::device_ptr<THCIndex_t> + sortedIndicesIter(THCIndexTensor_(data)(state, sortedIndices)); + thrust::device_ptr<THCIndex_t> + origIndicesIter(THCIndexTensor_(data)(state, origIndices)); + + // Fill sortedOrigIndices with sequential indices + thrust::counting_iterator<THCIndex_t> countIter(TH_INDEX_BASE); + + thrust::copy( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + countIter, countIter + numel, origIndicesIter); + + // Sort; a stable sort is not required + thrust::sort_by_key( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + sortedIndicesIter, sortedIndicesIter + numel, + origIndicesIter, ThrustLTOp<long>()); + } - THCIndex_t *sorted_data = THCIndexTensor_(data)(state, sorted); - THCIndex_t *indices_data = THCIndexTensor_(data)(state, indices); + THCIndex_t *sortedIndices_data = THCIndexTensor_(data)(state, sortedIndices); + THCIndex_t *origIndices_data = THCIndexTensor_(data)(state, origIndices); THCIndex_t *count_data = NULL; - if (scaleGradByFreq) - { + if (scaleGradByFreq) { THCIndexTensor_(resizeAs)(state, count, input); count_data = THCIndexTensor_(data)(state, count); THCThrustAllocator thrustAlloc(state); - thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data); + thrust::device_ptr<THCIndex_t> sortedIndices_ptr(sortedIndices_data); thrust::device_ptr<THCIndex_t> count_ptr(count_data); - // Compute an increasing sequence per unique item in sorted: + // Compute an increasing sequence per unique item in sortedIndices: // sorted: 2 5 5 5 7 7 8 9 9 // count: 1 1 2 3 1 2 1 1 2 thrust::inclusive_scan_by_key( #if CUDA_VERSION >= 7000 thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - sorted_ptr, - sorted_ptr + numel, + sortedIndices_ptr, + sortedIndices_ptr + numel, thrust::make_constant_iterator(1), count_ptr ); @@ -89,8 +118,8 @@ void THNN_(LookupTable_accGradParameters)( #if CUDA_VERSION >= 7000 thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - thrust::make_reverse_iterator(sorted_ptr + numel), - thrust::make_reverse_iterator(sorted_ptr), + thrust::make_reverse_iterator(sortedIndices_ptr + numel), + thrust::make_reverse_iterator(sortedIndices_ptr), thrust::make_reverse_iterator(count_ptr + numel), thrust::make_reverse_iterator(count_ptr + numel), thrust::equal_to<long>(), @@ -98,11 +127,11 @@ void THNN_(LookupTable_accGradParameters)( ); } - dim3 grid(DIVUP(numel,4), DIVUP(stride,128)); + dim3 grid(THCCeilDiv(numel, (ptrdiff_t) 4), THCCeilDiv(stride, (long) 128)); dim3 block(32, 4); cunn_LookupTable_accGradParametersKernel<real, accreal><<<grid, block, 0, stream>>>( - sorted_data, - indices_data, + sortedIndices_data, + origIndices_data, THCTensor_(data)(state, gradOutput), THCTensor_(data)(state, gradWeight), count_data, @@ -127,22 +156,25 @@ void THNN_(LookupTable_renorm)( real normType = ScalarConvert<accreal, real>::to(normType_); THCUNN_assertSameGPU(state, 2, idx, weight); if (!(THCIndexTensor_(isContiguous)(state, idx) && - THCTensor_(isContiguous)(state, weight))) - { + THCTensor_(isContiguous)(state, weight))) { THError("Tensors must be contiguous"); } - if (THCIndexTensor_(nDimension)(state, idx) != 1) + + if (THCIndexTensor_(nDimension)(state, idx) != 1) { THError("idx must be a vector"); - if (normType <= 0) + } + + if (normType <= 0) { THError("non-positive-norm not supported"); + } THCIndex_t numel = THCIndexTensor_(nElement)(state, idx); - long stride = weight->stride[0]; + long stride = THCTensor_(stride)(state, weight, 0); // get the unique indices thrust::device_ptr<real> weight_ptr(THCTensor_(data)(state, weight)); thrust::device_ptr<THCIndex_t> idx_ptr(THCIndexTensor_(data)(state, idx)); - thrust::device_ptr<THCIndex_t> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel); + thrust::device_ptr<THCIndex_t> end_ptr(thrust::unique(idx_ptr, idx_ptr+numel)); numel = end_ptr - idx_ptr; pow_v<real, accreal> unary_pow(normType); |