diff options
author | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2017-02-16 02:42:13 +0300 |
---|---|---|
committer | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2017-03-25 02:34:31 +0300 |
commit | 97940f0a81b689657234ee456ac60e35fe72d043 (patch) | |
tree | 81975c1419c2dddf4e5cf37f1eb4d659bb3890dc | |
parent | e7783b3ab3153139f64bc72de4519e08c90354a8 (diff) |
Improving the performance of IndexLinear:updateOutput
- Removes separate kernel for updateOutputTrain
-rw-r--r-- | lib/THCUNN/IndexLinear.cu | 185 | ||||
-rw-r--r-- | lib/THCUNN/generic/IndexLinear.cu | 38 |
2 files changed, 91 insertions, 132 deletions
diff --git a/lib/THCUNN/IndexLinear.cu b/lib/THCUNN/IndexLinear.cu index b52f54b..fb2dc93 100644 --- a/lib/THCUNN/IndexLinear.cu +++ b/lib/THCUNN/IndexLinear.cu @@ -8,16 +8,45 @@ const int THREADS_PER_BLOCK = 256; const int THREADS_X = 32; const int THREADS_Y = THREADS_PER_BLOCK / THREADS_X; const int REPEAT = 32; +const long NNZ_PER_BLOCK_MAX = 1024; /* sign MACRO */ #ifndef clamp #define clamp(a, low, high) max(min((a), (high)), (low)) #endif -template<typename Ty> +#ifndef ATOMIC_REAL_MINMAX +#define ATOMIC_REAL_MINMAX(func) \ + __device__ void atomic_##func(double *address, double val) { \ + unsigned long long int* address_as_ull = (unsigned long long int*)address; \ + unsigned long long int old = *address_as_ull; \ + unsigned long long int assumed; \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_ull, assumed, \ + __double_as_longlong(func(val, __longlong_as_double(assumed)))); \ + } while (assumed != old); \ + } \ + __device__ void atomic_##func(float *address, float val) { \ + int* address_as_int = (int*)address; \ + int old = *address_as_int; \ + int assumed; \ + do { \ + assumed = old; \ + old = atomicCAS(address_as_int, assumed, \ + __float_as_int(func(val, __int_as_float(assumed)))); \ + } while (assumed != old); \ + } \ + +ATOMIC_REAL_MINMAX(max) +ATOMIC_REAL_MINMAX(min) +#endif + +template<typename Ty, bool train> __global__ static void updateOutput( Ty *output, + Ty *normalizedValues, const Ty *values, const long *cumSumSizes, const long *keys, @@ -27,11 +56,12 @@ void updateOutput( const Ty *bias, const long weightStride, const long keysOffset, - int maxNormalize) + const int maxNormalize, + const int nnzPerBlock) { /******************************************************* - * Adopted from the following file in arrayfire - * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmv.cl + * Adapted from the following file in arrayfire + * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmm.cl * ******************************************************* * Original copyright notice can be seen below: @@ -44,16 +74,17 @@ void updateOutput( * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ - const long goff = blockIdx.x * blockDim.x; - const long gidx = goff + threadIdx.x; - const long tid = threadIdx.x; + const long tidx = threadIdx.x; + const long tidy = threadIdx.y; + const long tid = tidy * blockDim.x + tidx; + const long gidx = blockIdx.x * blockDim.x + tidx; + Ty *nWeight = weight; // Offset the number of elements specified by maxNormalize weight += gidx + maxNormalize; output += gidx; - bool within_N = (gidx < outDim); __shared__ Ty s_values[THREADS_PER_BLOCK]; @@ -63,23 +94,36 @@ void updateOutput( // if (rowId >= batchSize) return; // Load the nonzero column offsets for current row - const long batchStart = rowId == 0 ? 0 : cumSumSizes[rowId - 1]; - const long batchEnd = cumSumSizes[rowId]; + const long batchStart = (rowId == 0 ? 0 : cumSumSizes[rowId - 1]) + blockIdx.z * nnzPerBlock; + const long batchEnd = min(batchStart + nnzPerBlock, cumSumSizes[rowId]); + const long batchStride = blockDim.x * blockDim.y; - Ty outval = 0; + Ty outVal = 0; // Since the number of nonzero elements might be greater than local memory available, // Load only part of the row into local memory, perform partial dot, repeat until done. - for (long id = batchStart; id < batchEnd; id += blockDim.x) { + for (long id = batchStart; id < batchEnd; id += batchStride) { // Load the current chunk of the row into local memory - long lim = min(batchEnd - id, (long)blockDim.x); + long lim = min(batchEnd - id, (long)batchStride); - // Subtract 1 from keys[id + tid] to convert base 1 to base 0 long key = tid < lim ? keys[id + tid] + keysOffset : -1; Ty val = tid < lim ? values[id + tid] : 0; - - if (maxNormalize && tid < lim) { - Ty *nWeightCurr = nWeight + key * weightStride; - val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3]; + long nWeightOffset = key * weightStride; + + if (tid < lim && maxNormalize) { + Ty *nWeightCurr = nWeight + nWeightOffset; + if (train) { + Ty absVal = fabs(val); + Ty maxVal = nWeight[key * weightStride + 0]; + if (absVal > maxVal) { + // Updating maxVal and invMaxVal. Go hogwild! + atomic_max(nWeightCurr + 0, absVal); + atomic_min(nWeightCurr + 1, 1.0/absVal); + } + val = val * nWeightCurr[1] + nWeightCurr[3]; + normalizedValues[id + tid] = val; + } else { + val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3]; + } } s_keys[tid] = key; @@ -87,99 +131,8 @@ void updateOutput( __syncthreads(); // Perform a single "dot" operation for each thread - for (long idy = 0; within_N && idy < lim; idy++) { - outval += s_values[idy] * weight[weightStride * s_keys[idy]]; - } - __syncthreads(); - } - - if (within_N) { - output[rowId * outDim] = outval + bias[gidx]; - } -} - -// This kernel is launched with [M x 1] blocks of size [X x Y]. -// Each block writes X entries to the output for the given batchId. -template<typename Ty> -__global__ static -void updateOutputTrain( - Ty *output, - Ty *normalizedValues, - const Ty *values, - const long *cumSumSizes, - const long *keys, - const long batchSize, - const long outDim, - Ty *weight, - const Ty *bias, - const long weightStride, - const long keysOffset, - int maxNormalize, - long batchId) -{ - const long tidx = threadIdx.x; - const long tidy = threadIdx.y; - const long tid = tidy * blockDim.x + tidx; - const long gidx = blockIdx.x * blockDim.x + tidx; - - const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1]; - const long batchEnd = cumSumSizes[batchId]; - const long batchLimit = batchEnd - batchStart; - - // A dot operation is performed by a single block. - // Calculate the number of iterations required to load all elements from current batch. - const long iters = divup(batchLimit, blockDim.x * blockDim.y); - - Ty *nWeight = weight; - // Offset to the current output id - weight += maxNormalize + gidx; - output += batchId * outDim + gidx; - - // Offset to the current batch - keys += batchStart; - values += batchStart; - normalizedValues += batchStart; - - __shared__ Ty s_values[THREADS_PER_BLOCK]; - __shared__ long s_keys[THREADS_PER_BLOCK]; - - Ty outVal = 0; - // Not bailing early because we need __syncthreads later - for (long n = 0; n < iters; n++) { - long off = n * blockDim.y * blockDim.x; - long lim = min((long)blockDim.y * blockDim.x, batchLimit - off); - - - // Each block uses all of its threads to load data. - // This ensures coalesced reads from global memory. - // Each variable in shared memory is then used by all the threads in a warp. - if (tid < lim) { - Ty val = values[off + tid]; - long key = keys[off + tid] + keysOffset; - long nWeightOffset = key * weightStride; - - Ty absVal = fabs(val); - Ty maxVal = nWeight[key * weightStride + 0]; - if (absVal > maxVal) { - // Updating maxVal and invMaxVal - nWeight[nWeightOffset + 0] = absVal; - nWeight[nWeightOffset + 1] = 1.0/absVal; - maxVal = absVal; - } - - // TODO: implement a smarter update scale following the CPU implementation. - nWeight[nWeightOffset + 2] = 1.0; - s_values[tid] = val / maxVal + nWeight[nWeightOffset + 3]; - s_keys[tid] = key; - normalizedValues[off + tid] = s_values[tid]; - } - __syncthreads(); - - if (gidx < outDim) { - // Performing the partial dot operation for each thread - for (long id = tidy; id < lim; id += blockDim.y) { - outVal += s_values[id] * weight[weightStride * s_keys[id]]; - } + for (long idy = tidy; within_N && idy < lim; idy += blockDim.y) { + outVal += s_values[idy] * weight[weightStride * s_keys[idy]]; } __syncthreads(); } @@ -192,9 +145,13 @@ void updateOutputTrain( if (tidy < y) s_values[tid] = s_values[tid] + s_values[tid + y * blockDim.x]; } - // Writing the final value from the first lane into the output. - if (gidx < outDim && tidy == 0) { - *output = s_values[tid] + bias[gidx]; + if (within_N && tidy == 0) { + Ty val = s_values[tid] + (blockIdx.z == 0 ? bias[gidx] : 0); + if (gridDim.z == 1) { + output[rowId * outDim] = val; + } else { + atomicAdd(output + rowId * outDim, val); + } } } diff --git a/lib/THCUNN/generic/IndexLinear.cu b/lib/THCUNN/generic/IndexLinear.cu index e1f8642..ae96148 100644 --- a/lib/THCUNN/generic/IndexLinear.cu +++ b/lib/THCUNN/generic/IndexLinear.cu @@ -47,6 +47,7 @@ void THNN_(IndexLinear_updateOutput)( long weightStride = weight->stride[0]; int maxNormalize = wDim - outDim; long keysSize = keys->size[0]; + long nnzPerRow = divup(keysSize, batchSize); THCTensor_(resize2d)(state, output, batchSize, outDim); long *keysData = THCudaLongTensor_data (state, keys); @@ -57,28 +58,29 @@ void THNN_(IndexLinear_updateOutput)( real *outData = THCTensor_(data) (state, output); cudaStream_t stream = THCState_getCurrentStream(state); + dim3 threads(THREADS_X, THREADS_Y); + int blocks_x = divup(outDim, threads.x); + int blocks_y = batchSize; + int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX); + int blocks_z = divup(nnzPerRow, nnzPerBlock); + + dim3 blocks(blocks_x, blocks_y, blocks_z); + + if (blocks_z > 1) { + THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream)); + } + real *normalizedValuesData = NULL; if (maxNormalize && train) { THCTensor_(resize1d)(state, normalizedValues, keysSize); - real *normalizedValuesData = THCTensor_(data)(state, normalizedValues); - dim3 threads(THREADS_X, THREADS_Y); - int blocks_x = divup(outDim, threads.x); - int blocks_y = 1; - dim3 blocks(blocks_x, blocks_y); - for (long batchId = 0; batchId < batchSize; batchId++) { - updateOutputTrain<real><<<blocks, threads, 0, stream>>> - (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData, - batchSize, outDim, weightData, biasData, weightStride, keysOffset, - maxNormalize, batchId); - } + normalizedValuesData = THCTensor_(data)(state, normalizedValues); + updateOutput<real, true><<<blocks, threads, 0, stream>>> + (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData, + batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock); } else { - int threads = THREADS_PER_BLOCK; - int blocks_x = divup(outDim, threads); - int blocks_y = batchSize; - dim3 blocks(blocks_x, blocks_y); - updateOutput<real><<<blocks, threads, 0, stream>>> - (outData, valuesData, cumSumSizesData, keysData, batchSize, outDim, - weightData, biasData, weightStride, keysOffset, maxNormalize); + updateOutput<real, false><<<blocks, threads, 0, stream>>> + (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData, + batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock); } } |