Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-03-15 21:35:59 +0300
committerGitHub <noreply@github.com>2017-03-15 21:35:59 +0300
commit7293be03f00e902179fcb6518b7c43bca2f5790c (patch)
tree827b404a4cec6b278f8f3c8b1995de566920ba0f
parent12d9add9301ced59ae9c910f8b511d0fd7f285a1 (diff)
parent8d74ea923b8ea17ac17bb0853809c0456107f204 (diff)
Merge pull request #451 from wickedfoo/faster-lookup-table
Improve cunn LookupTable performance for large batch sizes
-rw-r--r--lib/THCUNN/LookupTable.cu5
-rw-r--r--lib/THCUNN/generic/LookupTable.cu92
2 files changed, 63 insertions, 34 deletions
diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu
index 29cb3a1..e626632 100644
--- a/lib/THCUNN/LookupTable.cu
+++ b/lib/THCUNN/LookupTable.cu
@@ -12,10 +12,7 @@
#include <thrust/unique.h>
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
-
-#ifndef DIVUP
-#define DIVUP(x, y) (((x) + (y) - 1) / (y))
-#endif
+#include "THCTensorSort.cuh"
const int WARP_SIZE = 32;
diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu
index fa7c5ac..955b29b 100644
--- a/lib/THCUNN/generic/LookupTable.cu
+++ b/lib/THCUNN/generic/LookupTable.cu
@@ -8,18 +8,17 @@ void THNN_(LookupTable_accGradParameters)(
THCTensor *gradOutput,
THCTensor *gradWeight,
THCIndexTensor *count,
- THCIndexTensor *sorted,
- THCIndexTensor *indices,
+ THCIndexTensor *sortedIndices,
+ THCIndexTensor *origIndices,
bool scaleGradByFreq,
int paddingValue,
accreal scale_)
{
real scale = ScalarConvert<accreal, real>::to(scale_);
- THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices);
+ THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sortedIndices, origIndices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (!(THCIndexTensor_(isContiguous)(state, input) &&
- THCTensor_(isContiguous)(state, gradWeight)))
- {
+ THCTensor_(isContiguous)(state, gradWeight))) {
THError("Tensors must be contiguous");
}
@@ -30,12 +29,15 @@ void THNN_(LookupTable_accGradParameters)(
}
ptrdiff_t numel = THCIndexTensor_(nElement)(state, input);
- long stride = gradWeight->stride[0];
+ long stride = THCTensor_(stride)(state, gradWeight, 0);
cudaStream_t stream = THCState_getCurrentStream(state);
if (numel <= 768 && !scaleGradByFreq) {
- cunn_LookupTable_accGradParametersKernelByFeature<<<DIVUP(stride,4), 128, 0, stream>>>(
+ dim3 grid(THCCeilDiv(stride, (long) 4));
+ dim3 block(128);
+
+ cunn_LookupTable_accGradParametersKernelByFeature<<<grid, block, 0, stream>>>(
THCIndexTensor_(data)(state, input),
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, gradWeight),
@@ -49,35 +51,62 @@ void THNN_(LookupTable_accGradParameters)(
}
THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
- THCIndexTensor_(resize)(state, sorted, inputSize, NULL);
- THCIndexTensor_(resize)(state, indices, inputSize, NULL);
+ THCIndexTensor_(resize)(state, sortedIndices, inputSize, NULL);
+ THCIndexTensor_(resize)(state, origIndices, inputSize, NULL);
THLongStorage_free(inputSize);
- // Sort the inputs into sorted with the corresponding indices
- THCIndexTensor_(sort)(state, sorted, indices, input, 0, 0);
+ // Sort the inputs into sorted with the corresponding indices; we
+ // don't need a stable or multidimensional sort, so just use Thrust
+ // directly
+ {
+ THCIndexTensor_(copy)(state, sortedIndices, input);
+
+ THCThrustAllocator thrustAlloc(state);
+
+ thrust::device_ptr<THCIndex_t>
+ sortedIndicesIter(THCIndexTensor_(data)(state, sortedIndices));
+ thrust::device_ptr<THCIndex_t>
+ origIndicesIter(THCIndexTensor_(data)(state, origIndices));
+
+ // Fill sortedOrigIndices with sequential indices
+ thrust::counting_iterator<THCIndex_t> countIter(TH_INDEX_BASE);
+
+ thrust::copy(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+#endif
+ countIter, countIter + numel, origIndicesIter);
+
+ // Sort; a stable sort is not required
+ thrust::sort_by_key(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+#endif
+ sortedIndicesIter, sortedIndicesIter + numel,
+ origIndicesIter, ThrustLTOp<long>());
+ }
- THCIndex_t *sorted_data = THCIndexTensor_(data)(state, sorted);
- THCIndex_t *indices_data = THCIndexTensor_(data)(state, indices);
+ THCIndex_t *sortedIndices_data = THCIndexTensor_(data)(state, sortedIndices);
+ THCIndex_t *origIndices_data = THCIndexTensor_(data)(state, origIndices);
THCIndex_t *count_data = NULL;
- if (scaleGradByFreq)
- {
+ if (scaleGradByFreq) {
THCIndexTensor_(resizeAs)(state, count, input);
count_data = THCIndexTensor_(data)(state, count);
THCThrustAllocator thrustAlloc(state);
- thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data);
+ thrust::device_ptr<THCIndex_t> sortedIndices_ptr(sortedIndices_data);
thrust::device_ptr<THCIndex_t> count_ptr(count_data);
- // Compute an increasing sequence per unique item in sorted:
+ // Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
thrust::inclusive_scan_by_key(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- sorted_ptr,
- sorted_ptr + numel,
+ sortedIndices_ptr,
+ sortedIndices_ptr + numel,
thrust::make_constant_iterator(1),
count_ptr
);
@@ -89,8 +118,8 @@ void THNN_(LookupTable_accGradParameters)(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- thrust::make_reverse_iterator(sorted_ptr + numel),
- thrust::make_reverse_iterator(sorted_ptr),
+ thrust::make_reverse_iterator(sortedIndices_ptr + numel),
+ thrust::make_reverse_iterator(sortedIndices_ptr),
thrust::make_reverse_iterator(count_ptr + numel),
thrust::make_reverse_iterator(count_ptr + numel),
thrust::equal_to<long>(),
@@ -98,11 +127,11 @@ void THNN_(LookupTable_accGradParameters)(
);
}
- dim3 grid(DIVUP(numel,4), DIVUP(stride,128));
+ dim3 grid(THCCeilDiv(numel, (ptrdiff_t) 4), THCCeilDiv(stride, (long) 128));
dim3 block(32, 4);
cunn_LookupTable_accGradParametersKernel<real, accreal><<<grid, block, 0, stream>>>(
- sorted_data,
- indices_data,
+ sortedIndices_data,
+ origIndices_data,
THCTensor_(data)(state, gradOutput),
THCTensor_(data)(state, gradWeight),
count_data,
@@ -127,22 +156,25 @@ void THNN_(LookupTable_renorm)(
real normType = ScalarConvert<accreal, real>::to(normType_);
THCUNN_assertSameGPU(state, 2, idx, weight);
if (!(THCIndexTensor_(isContiguous)(state, idx) &&
- THCTensor_(isContiguous)(state, weight)))
- {
+ THCTensor_(isContiguous)(state, weight))) {
THError("Tensors must be contiguous");
}
- if (THCIndexTensor_(nDimension)(state, idx) != 1)
+
+ if (THCIndexTensor_(nDimension)(state, idx) != 1) {
THError("idx must be a vector");
- if (normType <= 0)
+ }
+
+ if (normType <= 0) {
THError("non-positive-norm not supported");
+ }
THCIndex_t numel = THCIndexTensor_(nElement)(state, idx);
- long stride = weight->stride[0];
+ long stride = THCTensor_(stride)(state, weight, 0);
// get the unique indices
thrust::device_ptr<real> weight_ptr(THCTensor_(data)(state, weight));
thrust::device_ptr<THCIndex_t> idx_ptr(THCIndexTensor_(data)(state, idx));
- thrust::device_ptr<THCIndex_t> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel);
+ thrust::device_ptr<THCIndex_t> end_ptr(thrust::unique(idx_ptr, idx_ptr+numel));
numel = end_ptr - idx_ptr;
pow_v<real, accreal> unary_pow(normType);