Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavan Yalamanchili <pyalamanchili@twitter.com>2017-02-16 02:42:13 +0300
committerPavan Yalamanchili <pyalamanchili@twitter.com>2017-03-25 02:34:31 +0300
commit97940f0a81b689657234ee456ac60e35fe72d043 (patch)
tree81975c1419c2dddf4e5cf37f1eb4d659bb3890dc
parente7783b3ab3153139f64bc72de4519e08c90354a8 (diff)
Improving the performance of IndexLinear:updateOutput
- Removes separate kernel for updateOutputTrain
-rw-r--r--lib/THCUNN/IndexLinear.cu185
-rw-r--r--lib/THCUNN/generic/IndexLinear.cu38
2 files changed, 91 insertions, 132 deletions
diff --git a/lib/THCUNN/IndexLinear.cu b/lib/THCUNN/IndexLinear.cu
index b52f54b..fb2dc93 100644
--- a/lib/THCUNN/IndexLinear.cu
+++ b/lib/THCUNN/IndexLinear.cu
@@ -8,16 +8,45 @@ const int THREADS_PER_BLOCK = 256;
const int THREADS_X = 32;
const int THREADS_Y = THREADS_PER_BLOCK / THREADS_X;
const int REPEAT = 32;
+const long NNZ_PER_BLOCK_MAX = 1024;
/* sign MACRO */
#ifndef clamp
#define clamp(a, low, high) max(min((a), (high)), (low))
#endif
-template<typename Ty>
+#ifndef ATOMIC_REAL_MINMAX
+#define ATOMIC_REAL_MINMAX(func) \
+ __device__ void atomic_##func(double *address, double val) { \
+ unsigned long long int* address_as_ull = (unsigned long long int*)address; \
+ unsigned long long int old = *address_as_ull; \
+ unsigned long long int assumed; \
+ do { \
+ assumed = old; \
+ old = atomicCAS(address_as_ull, assumed, \
+ __double_as_longlong(func(val, __longlong_as_double(assumed)))); \
+ } while (assumed != old); \
+ } \
+ __device__ void atomic_##func(float *address, float val) { \
+ int* address_as_int = (int*)address; \
+ int old = *address_as_int; \
+ int assumed; \
+ do { \
+ assumed = old; \
+ old = atomicCAS(address_as_int, assumed, \
+ __float_as_int(func(val, __int_as_float(assumed)))); \
+ } while (assumed != old); \
+ } \
+
+ATOMIC_REAL_MINMAX(max)
+ATOMIC_REAL_MINMAX(min)
+#endif
+
+template<typename Ty, bool train>
__global__ static
void updateOutput(
Ty *output,
+ Ty *normalizedValues,
const Ty *values,
const long *cumSumSizes,
const long *keys,
@@ -27,11 +56,12 @@ void updateOutput(
const Ty *bias,
const long weightStride,
const long keysOffset,
- int maxNormalize)
+ const int maxNormalize,
+ const int nnzPerBlock)
{
/*******************************************************
- * Adopted from the following file in arrayfire
- * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmv.cl
+ * Adapted from the following file in arrayfire
+ * https://github.com/arrayfire/arrayfire/blob/v3.4.1/src/backend/opencl/kernel/csrmm.cl
*
*******************************************************
* Original copyright notice can be seen below:
@@ -44,16 +74,17 @@ void updateOutput(
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
- const long goff = blockIdx.x * blockDim.x;
- const long gidx = goff + threadIdx.x;
- const long tid = threadIdx.x;
+ const long tidx = threadIdx.x;
+ const long tidy = threadIdx.y;
+ const long tid = tidy * blockDim.x + tidx;
+ const long gidx = blockIdx.x * blockDim.x + tidx;
+
Ty *nWeight = weight;
// Offset the number of elements specified by maxNormalize
weight += gidx + maxNormalize;
output += gidx;
-
bool within_N = (gidx < outDim);
__shared__ Ty s_values[THREADS_PER_BLOCK];
@@ -63,23 +94,36 @@ void updateOutput(
// if (rowId >= batchSize) return;
// Load the nonzero column offsets for current row
- const long batchStart = rowId == 0 ? 0 : cumSumSizes[rowId - 1];
- const long batchEnd = cumSumSizes[rowId];
+ const long batchStart = (rowId == 0 ? 0 : cumSumSizes[rowId - 1]) + blockIdx.z * nnzPerBlock;
+ const long batchEnd = min(batchStart + nnzPerBlock, cumSumSizes[rowId]);
+ const long batchStride = blockDim.x * blockDim.y;
- Ty outval = 0;
+ Ty outVal = 0;
// Since the number of nonzero elements might be greater than local memory available,
// Load only part of the row into local memory, perform partial dot, repeat until done.
- for (long id = batchStart; id < batchEnd; id += blockDim.x) {
+ for (long id = batchStart; id < batchEnd; id += batchStride) {
// Load the current chunk of the row into local memory
- long lim = min(batchEnd - id, (long)blockDim.x);
+ long lim = min(batchEnd - id, (long)batchStride);
- // Subtract 1 from keys[id + tid] to convert base 1 to base 0
long key = tid < lim ? keys[id + tid] + keysOffset : -1;
Ty val = tid < lim ? values[id + tid] : 0;
-
- if (maxNormalize && tid < lim) {
- Ty *nWeightCurr = nWeight + key * weightStride;
- val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
+ long nWeightOffset = key * weightStride;
+
+ if (tid < lim && maxNormalize) {
+ Ty *nWeightCurr = nWeight + nWeightOffset;
+ if (train) {
+ Ty absVal = fabs(val);
+ Ty maxVal = nWeight[key * weightStride + 0];
+ if (absVal > maxVal) {
+ // Updating maxVal and invMaxVal. Go hogwild!
+ atomic_max(nWeightCurr + 0, absVal);
+ atomic_min(nWeightCurr + 1, 1.0/absVal);
+ }
+ val = val * nWeightCurr[1] + nWeightCurr[3];
+ normalizedValues[id + tid] = val;
+ } else {
+ val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3];
+ }
}
s_keys[tid] = key;
@@ -87,99 +131,8 @@ void updateOutput(
__syncthreads();
// Perform a single "dot" operation for each thread
- for (long idy = 0; within_N && idy < lim; idy++) {
- outval += s_values[idy] * weight[weightStride * s_keys[idy]];
- }
- __syncthreads();
- }
-
- if (within_N) {
- output[rowId * outDim] = outval + bias[gidx];
- }
-}
-
-// This kernel is launched with [M x 1] blocks of size [X x Y].
-// Each block writes X entries to the output for the given batchId.
-template<typename Ty>
-__global__ static
-void updateOutputTrain(
- Ty *output,
- Ty *normalizedValues,
- const Ty *values,
- const long *cumSumSizes,
- const long *keys,
- const long batchSize,
- const long outDim,
- Ty *weight,
- const Ty *bias,
- const long weightStride,
- const long keysOffset,
- int maxNormalize,
- long batchId)
-{
- const long tidx = threadIdx.x;
- const long tidy = threadIdx.y;
- const long tid = tidy * blockDim.x + tidx;
- const long gidx = blockIdx.x * blockDim.x + tidx;
-
- const long batchStart = batchId == 0 ? 0 : cumSumSizes[batchId - 1];
- const long batchEnd = cumSumSizes[batchId];
- const long batchLimit = batchEnd - batchStart;
-
- // A dot operation is performed by a single block.
- // Calculate the number of iterations required to load all elements from current batch.
- const long iters = divup(batchLimit, blockDim.x * blockDim.y);
-
- Ty *nWeight = weight;
- // Offset to the current output id
- weight += maxNormalize + gidx;
- output += batchId * outDim + gidx;
-
- // Offset to the current batch
- keys += batchStart;
- values += batchStart;
- normalizedValues += batchStart;
-
- __shared__ Ty s_values[THREADS_PER_BLOCK];
- __shared__ long s_keys[THREADS_PER_BLOCK];
-
- Ty outVal = 0;
- // Not bailing early because we need __syncthreads later
- for (long n = 0; n < iters; n++) {
- long off = n * blockDim.y * blockDim.x;
- long lim = min((long)blockDim.y * blockDim.x, batchLimit - off);
-
-
- // Each block uses all of its threads to load data.
- // This ensures coalesced reads from global memory.
- // Each variable in shared memory is then used by all the threads in a warp.
- if (tid < lim) {
- Ty val = values[off + tid];
- long key = keys[off + tid] + keysOffset;
- long nWeightOffset = key * weightStride;
-
- Ty absVal = fabs(val);
- Ty maxVal = nWeight[key * weightStride + 0];
- if (absVal > maxVal) {
- // Updating maxVal and invMaxVal
- nWeight[nWeightOffset + 0] = absVal;
- nWeight[nWeightOffset + 1] = 1.0/absVal;
- maxVal = absVal;
- }
-
- // TODO: implement a smarter update scale following the CPU implementation.
- nWeight[nWeightOffset + 2] = 1.0;
- s_values[tid] = val / maxVal + nWeight[nWeightOffset + 3];
- s_keys[tid] = key;
- normalizedValues[off + tid] = s_values[tid];
- }
- __syncthreads();
-
- if (gidx < outDim) {
- // Performing the partial dot operation for each thread
- for (long id = tidy; id < lim; id += blockDim.y) {
- outVal += s_values[id] * weight[weightStride * s_keys[id]];
- }
+ for (long idy = tidy; within_N && idy < lim; idy += blockDim.y) {
+ outVal += s_values[idy] * weight[weightStride * s_keys[idy]];
}
__syncthreads();
}
@@ -192,9 +145,13 @@ void updateOutputTrain(
if (tidy < y) s_values[tid] = s_values[tid] + s_values[tid + y * blockDim.x];
}
- // Writing the final value from the first lane into the output.
- if (gidx < outDim && tidy == 0) {
- *output = s_values[tid] + bias[gidx];
+ if (within_N && tidy == 0) {
+ Ty val = s_values[tid] + (blockIdx.z == 0 ? bias[gidx] : 0);
+ if (gridDim.z == 1) {
+ output[rowId * outDim] = val;
+ } else {
+ atomicAdd(output + rowId * outDim, val);
+ }
}
}
diff --git a/lib/THCUNN/generic/IndexLinear.cu b/lib/THCUNN/generic/IndexLinear.cu
index e1f8642..ae96148 100644
--- a/lib/THCUNN/generic/IndexLinear.cu
+++ b/lib/THCUNN/generic/IndexLinear.cu
@@ -47,6 +47,7 @@ void THNN_(IndexLinear_updateOutput)(
long weightStride = weight->stride[0];
int maxNormalize = wDim - outDim;
long keysSize = keys->size[0];
+ long nnzPerRow = divup(keysSize, batchSize);
THCTensor_(resize2d)(state, output, batchSize, outDim);
long *keysData = THCudaLongTensor_data (state, keys);
@@ -57,28 +58,29 @@ void THNN_(IndexLinear_updateOutput)(
real *outData = THCTensor_(data) (state, output);
cudaStream_t stream = THCState_getCurrentStream(state);
+ dim3 threads(THREADS_X, THREADS_Y);
+ int blocks_x = divup(outDim, threads.x);
+ int blocks_y = batchSize;
+ int nnzPerBlock = ((outDim == 1 || batchSize == 1) ? THREADS_X : NNZ_PER_BLOCK_MAX);
+ int blocks_z = divup(nnzPerRow, nnzPerBlock);
+
+ dim3 blocks(blocks_x, blocks_y, blocks_z);
+
+ if (blocks_z > 1) {
+ THCudaCheck(cudaMemsetAsync(outData, 0, outDim * batchSize * sizeof(real), stream));
+ }
+ real *normalizedValuesData = NULL;
if (maxNormalize && train) {
THCTensor_(resize1d)(state, normalizedValues, keysSize);
- real *normalizedValuesData = THCTensor_(data)(state, normalizedValues);
- dim3 threads(THREADS_X, THREADS_Y);
- int blocks_x = divup(outDim, threads.x);
- int blocks_y = 1;
- dim3 blocks(blocks_x, blocks_y);
- for (long batchId = 0; batchId < batchSize; batchId++) {
- updateOutputTrain<real><<<blocks, threads, 0, stream>>>
- (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
- batchSize, outDim, weightData, biasData, weightStride, keysOffset,
- maxNormalize, batchId);
- }
+ normalizedValuesData = THCTensor_(data)(state, normalizedValues);
+ updateOutput<real, true><<<blocks, threads, 0, stream>>>
+ (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
+ batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
} else {
- int threads = THREADS_PER_BLOCK;
- int blocks_x = divup(outDim, threads);
- int blocks_y = batchSize;
- dim3 blocks(blocks_x, blocks_y);
- updateOutput<real><<<blocks, threads, 0, stream>>>
- (outData, valuesData, cumSumSizesData, keysData, batchSize, outDim,
- weightData, biasData, weightStride, keysOffset, maxNormalize);
+ updateOutput<real, false><<<blocks, threads, 0, stream>>>
+ (outData, normalizedValuesData, valuesData, cumSumSizesData, keysData,
+ batchSize, outDim, weightData, biasData, weightStride, keysOffset, maxNormalize, nnzPerBlock);
}
}