diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-06-12 22:42:37 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-15 19:32:28 +0300 |
commit | 9cffa0ef9c7775093896606e7f86c206e8099ce8 (patch) | |
tree | b3cf2f1552421bee3adffd5fbb37eeb17499ed88 | |
parent | 42c92bfe456981de02393df7836daeb23998f497 (diff) |
nn.EmbeddingBag to compute a bag of word embeddings (Embedding + Sum/Mean)
-rw-r--r-- | lib/THCUNN/LookupTableBag.cu | 141 | ||||
-rw-r--r-- | lib/THCUNN/generic/LookupTableBag.cu | 200 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 24 |
3 files changed, 365 insertions, 0 deletions
diff --git a/lib/THCUNN/LookupTableBag.cu b/lib/THCUNN/LookupTableBag.cu new file mode 100644 index 0000000..1042214 --- /dev/null +++ b/lib/THCUNN/LookupTableBag.cu @@ -0,0 +1,141 @@ +#include "THCUNN.h" +#include "common.h" + +#include "THCThrustAllocator.cuh" +#include <thrust/device_ptr.h> +#include <thrust/execution_policy.h> +#include <thrust/iterator/constant_iterator.h> +#include <thrust/transform_reduce.h> +#if CUDA_VERSION >= 7000 +#include <thrust/system/cuda/execution_policy.h> +#endif +#include <thrust/unique.h> +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "THCTensorSort.cuh" + +const int WARP_SIZE = 32; +const int MODE_SUM = 0; +const int MODE_MEAN = 1; + +template <typename Dtype, typename Acctype> +__global__ void cunn_LookupTableBag_updateOutputKernel( + long *input, long *offsets, Dtype *weight, Dtype *output, + long *offset2bag, long numIndices, long numBags, long stride, int mode, + long *bag_size) { + + // the strategy here is that each bag x feature is handled by a single thread + + long chunksPerBag = THCCeilDiv(stride, (long) blockDim.x); + long numChunks = numBags * chunksPerBag; + long chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; + long chunkStride = gridDim.x * blockDim.y; + + for (long chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { + long featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; + if (featureDim < stride) { + long bag = chunk / chunksPerBag; + Dtype* weightFeat = weight + featureDim; + long begin = offsets[bag] - TH_INDEX_BASE; + long end = (bag < numBags - 1) ? (offsets[bag + 1] - TH_INDEX_BASE) : numIndices; + assert(end >= begin); + Acctype weightFeatSum = ScalarConvert<float, Acctype>::to(0); + long bag_size_ = 0; + for (long emb = begin; emb < end; emb++) { + const int weightRow = ((int) input[emb] - TH_INDEX_BASE) * stride; + weightFeatSum += ScalarConvert<Dtype, Acctype>::to(weightFeat[weightRow]); + bag_size_ ++; + if (featureDim == 0) { + offset2bag[emb] = bag + TH_INDEX_BASE; + } + } + if (mode == MODE_MEAN) { + weightFeatSum = weightFeatSum / ScalarConvert<long, Acctype>::to(bag_size_); + bag_size[bag] = bag_size_; + } + output[bag * stride + featureDim] = ScalarConvert<Acctype, Dtype>::to(weightFeatSum); + } + } +} + +// FIXME: removed the accGradParametersKernelByFeature case present in +// LookupTable. That kernel is faster at small sizes (<768 indices), which +// does not need LookupTableBag (LookupTable + Sum works fine), but would +// still be nice to not be slow in that case. + +template <typename Dtype, typename Acctype> +__global__ void cunn_LookupTableBag_accGradParametersKernel( + long *input, long *indices, Dtype *gradOutput, Dtype *gradWeight, long *offset2bag, + long *count, Dtype defaultScale, ptrdiff_t numel, long stride, + int mode, long *bag_size) { + + int idx = blockIdx.x * 4 + threadIdx.y; + + // Each warp is responsible for an input into the LookupTable. + // If the preceding input has the same as this input, then the warp + // exits immediately. The warp also processes subsequent inputs with the + // same value. + // + // Input Warp + // 1 <warp 1> + // 1 <warp 1> (<warp 2> exits without doing any work) + // 5 <warp 3> + // 8 <warp 4> + + // Number of values proceessed by each thread (grain size) + const int SZ = 4; + + if (idx < numel + && (idx == 0 || input[idx] != input[idx - 1])) { + do { + const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride; + + // Note: only this line changes from LookupTable_accgradParametersKernel + const int origRow = ((int) indices[idx] - TH_INDEX_BASE); + const int seq_number = offset2bag[origRow] - TH_INDEX_BASE; + const int gradOutputRow = ((int) seq_number) * stride; + + const Acctype scale = count ? ScalarConvert<Dtype, Acctype>::to(defaultScale) / count[idx] : ScalarConvert<Dtype, Acctype>::to(defaultScale); + + Acctype gradient[SZ]; + Acctype weight[SZ]; + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) + { + int featureDim = startFeature + ii * WARP_SIZE; + if (featureDim < stride) + { + gradient[ii] = ScalarConvert<Dtype, Acctype>::to(gradOutput[gradOutputRow + featureDim]); + if (mode == MODE_MEAN) { + gradient[ii] /= bag_size[seq_number]; + } + weight[ii] = ScalarConvert<Dtype, Acctype>::to(gradWeight[weightRow + featureDim]); + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) + { + weight[ii] += gradient[ii] * scale; + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) + { + int featureDim = startFeature + ii * WARP_SIZE; + if (featureDim < stride) + { + gradWeight[weightRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(weight[ii]); + } + } + + idx++; + } while (idx < numel && input[idx] == input[idx - 1]); + } +} + + +#include "generic/LookupTableBag.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/LookupTableBag.cu b/lib/THCUNN/generic/LookupTableBag.cu new file mode 100644 index 0000000..ceec120 --- /dev/null +++ b/lib/THCUNN/generic/LookupTableBag.cu @@ -0,0 +1,200 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/LookupTableBag.cu" +#else + + +void THNN_(LookupTableBag_updateOutput)( + THCState *state, + THCIndexTensor *input, + THCIndexTensor *offsets, + THCTensor *weight, + THCTensor *output, + THCIndexTensor *offset2bag, + int mode, + THCIndexTensor *bag_size) +{ + THCUNN_assertSameGPU(state, 5, input, offsets, weight, output, offset2bag); + + if (!(THCIndexTensor_(isContiguous)(state, input) && + THCIndexTensor_(isContiguous)(state, offsets) && + THCTensor_(isContiguous)(state, weight))) { + THError("Tensors must be contiguous"); + } + + ptrdiff_t numIndices = THCIndexTensor_(size)(state, input, 0); + ptrdiff_t numBags = THCIndexTensor_(size)(state, offsets, 0); + ptrdiff_t stride = THCTensor_(size)(state, weight, 1); + long *bag_size_data = NULL; + if (bag_size != NULL) { + bag_size_data = THCIndexTensor_(data)(state, bag_size); + } + + cudaStream_t stream = THCState_getCurrentStream(state); + + THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input); + THLongStorage *outputSize = THLongStorage_newWithSize(2); + outputSize->data[0] = numBags; + outputSize->data[1] = stride; + THCTensor_(resize)(state, output, outputSize, NULL); + THCTensor_(zero)(state, output); + THCIndexTensor_(resize)(state, offset2bag, inputSize, NULL); + THLongStorage_free(inputSize); + THLongStorage_free(outputSize); + + dim3 block = dim3(32, 8); + int grid = 1024; + cunn_LookupTableBag_updateOutputKernel<real, accreal><<<grid, block, 0, stream>>>( + THCIndexTensor_(data)(state, input), + THCIndexTensor_(data)(state, offsets), + THCTensor_(data)(state, weight), + THCTensor_(data)(state, output), + THCIndexTensor_(data)(state, offset2bag), + numIndices, + numBags, + stride, + mode, + bag_size_data + ); + + THCudaCheck(cudaGetLastError()); +} + + +void THNN_(LookupTableBag_accGradParameters)( + THCState *state, + THCIndexTensor *input, + THCTensor *gradOutput, + THCTensor *gradWeight, + THCIndexTensor *offset2bag, + THCIndexTensor *count, + THCIndexTensor *sortedIndices, + THCIndexTensor *origIndices, + bool scaleGradByFreq, + int mode, + THCIndexTensor *bag_size, + accreal scale_) +{ + real scale = ScalarConvert<accreal, real>::to(scale_); + THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, offset2bag, sortedIndices, origIndices); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + if (!(THCIndexTensor_(isContiguous)(state, input) && + THCTensor_(isContiguous)(state, gradWeight) && + THCIndexTensor_(isContiguous)(state, offset2bag))) { + THError("Tensors must be contiguous"); + } + + long *bag_size_data = NULL; + if (bag_size != NULL) { + bag_size_data = THCIndexTensor_(data)(state, bag_size); + } + + int nDim = THCIndexTensor_(nDimension)(state, input); + if (THCIndexTensor_(nDimension)(state, input) != 1 && THCIndexTensor_(nDimension)(state, input) != 2) { + THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, input); + THError("input must be a vector or matrix, but is of shape: %s", s1.str); + } + + ptrdiff_t numel = THCIndexTensor_(nElement)(state, input); + long stride = THCTensor_(stride)(state, gradWeight, 0); + + cudaStream_t stream = THCState_getCurrentStream(state); + + THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input); + THCIndexTensor_(resize)(state, sortedIndices, inputSize, NULL); + THCIndexTensor_(resize)(state, origIndices, inputSize, NULL); + THLongStorage_free(inputSize); + + // Sort the inputs into sorted with the corresponding indices; we + // don't need a stable or multidimensional sort, so just use Thrust + // directly + { + THCIndexTensor_(copy)(state, sortedIndices, input); + + THCThrustAllocator thrustAlloc(state); + + thrust::device_ptr<THCIndex_t> + sortedIndicesIter(THCIndexTensor_(data)(state, sortedIndices)); + thrust::device_ptr<THCIndex_t> + origIndicesIter(THCIndexTensor_(data)(state, origIndices)); + + // Fill sortedOrigIndices with sequential indices + thrust::counting_iterator<THCIndex_t> countIter(TH_INDEX_BASE); + + thrust::copy( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + countIter, countIter + numel, origIndicesIter); + + // Sort; a stable sort is not required + thrust::sort_by_key( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + sortedIndicesIter, sortedIndicesIter + numel, + origIndicesIter, ThrustLTOp<long>()); + } + + THCIndex_t *sortedIndices_data = THCIndexTensor_(data)(state, sortedIndices); + THCIndex_t *origIndices_data = THCIndexTensor_(data)(state, origIndices); + THCIndex_t *offset2bag_data = THCIndexTensor_(data)(state, offset2bag); + THCIndex_t *count_data = NULL; + + if (scaleGradByFreq) { + THCIndexTensor_(resizeAs)(state, count, input); + count_data = THCIndexTensor_(data)(state, count); + + THCThrustAllocator thrustAlloc(state); + thrust::device_ptr<THCIndex_t> sortedIndices_ptr(sortedIndices_data); + thrust::device_ptr<THCIndex_t> count_ptr(count_data); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + thrust::inclusive_scan_by_key( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + sortedIndices_ptr, + sortedIndices_ptr + numel, + thrust::make_constant_iterator(1), + count_ptr + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + thrust::inclusive_scan_by_key( +#if CUDA_VERSION >= 7000 + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + thrust::make_reverse_iterator(sortedIndices_ptr + numel), + thrust::make_reverse_iterator(sortedIndices_ptr), + thrust::make_reverse_iterator(count_ptr + numel), + thrust::make_reverse_iterator(count_ptr + numel), + thrust::equal_to<long>(), + thrust::maximum<long>() + ); + } + + dim3 grid(THCCeilDiv(numel, (ptrdiff_t) 4), THCCeilDiv(stride, (long) 128)); + dim3 block(32, 4); + cunn_LookupTableBag_accGradParametersKernel<real, accreal><<<grid, block, 0, stream>>>( + sortedIndices_data, + origIndices_data, + THCTensor_(data)(state, gradOutput), + THCTensor_(data)(state, gradWeight), + offset2bag_data, + count_data, + scale, + numel, + stride, + mode, + bag_size_data + ); + + THCTensor_(free)(state, gradOutput); + THCudaCheck(cudaGetLastError()); +} + +#endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index f51759f..e770dff 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -250,6 +250,30 @@ TH_API void THNN_(LookupTable_renorm)( accreal maxNorm, accreal normType); +TH_API void THNN_(LookupTableBag_updateOutput)( + THCState *state, + THCIndexTensor *input, + THCIndexTensor *offsets, + THCTensor *weight, + THCTensor *output, + THCIndexTensor *offset2bag, + int mode, + THCIndexTensor *seq_length); // [OPTIONAL] + +TH_API void THNN_(LookupTableBag_accGradParameters)( + THCState *state, + THCIndexTensor *input, + THCTensor *gradOutput, + THCTensor *gradWeight, + THCIndexTensor *offset2bag, + THCIndexTensor *count, + THCIndexTensor *sortedIndices, + THCIndexTensor *origIndices, + bool scaleGradByFreq, + int mode, + THCIndexTensor *seq_length, // [OPTIONAL] + accreal scale_); + TH_API void THNN_(L1Cost_updateOutput)( THCState *state, THCTensor *input, |