diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-12 06:47:49 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-12 06:47:50 +0300 |
commit | 3040562b80277c4aa0de1b4caf8588349046ddb4 (patch) | |
tree | 3bdabff0efc116a94965f90e54e675fc70bc3345 | |
parent | 65568cfcf6fbfadb9c7aca5323d1cc407be91510 (diff) | |
parent | 2618f796bd1db29344d2c862928bbc1510a05814 (diff) |
Merge pull request #319 from torch/typesfix
fixes for multiple cuda types
-rw-r--r-- | THCUNN.lua | 5 | ||||
-rw-r--r-- | lib/THCUNN/LookupTable.cu | 56 | ||||
-rw-r--r-- | lib/THCUNN/SparseLinear.cu | 17 | ||||
-rw-r--r-- | lib/THCUNN/THCUNN.h | 13 |
4 files changed, 44 insertions, 47 deletions
@@ -25,9 +25,8 @@ local replacements = { { ['THTensor'] = 'THCudaTensor', - ['THIndexTensor'] = 'THCudaTensor', - ['THIntegerTensor'] = 'THCudaTensor', - ['THIndex_t'] = 'float', + ['THIndexTensor'] = 'THCudaLongTensor', + ['THIndex_t'] = 'long', ['THInteger_t'] = 'float' } } diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu index 749ce15..2b2040e 100644 --- a/lib/THCUNN/LookupTable.cu +++ b/lib/THCUNN/LookupTable.cu @@ -50,7 +50,7 @@ __device__ __forceinline__ bool warpHasCollision(int val) } __global__ void cunn_LookupTable_accGradParametersKernelByFeature( - float *input, float *gradOutput, float *gradWeight, float scale, long numel, + long *input, float *gradOutput, float *gradWeight, float scale, long numel, long stride, int paddingValue) { const int featureDim = blockIdx.x * 4 + threadIdx.x / 32; @@ -96,8 +96,8 @@ __global__ void cunn_LookupTable_accGradParametersKernelByFeature( } __global__ void cunn_LookupTable_accGradParametersKernel( - float *input, float *indices, float *gradOutput, float *gradWeight, - float *count, float defaultScale, long numel, long stride, int paddingValue) { + long *input, long *indices, float *gradOutput, float *gradWeight, + long *count, float defaultScale, long numel, long stride, int paddingValue) { int idx = blockIdx.x * 4 + threadIdx.y; @@ -164,33 +164,33 @@ void THNN_CudaLookupTable_accGradParameters( THIndexTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradWeight, - THIntegerTensor *count, - THCudaTensor *sorted, - THCudaTensor *indices, + THIndexTensor *count, + THIndexTensor *sorted, + THIndexTensor *indices, bool scaleGradByFreq, int paddingValue, float scale) { THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices); - if (!(THCudaTensor_isContiguous(state, input) && + if (!(THIndexTensor_(isContiguous)(state, input) && THCudaTensor_isContiguous(state, gradOutput) && THCudaTensor_isContiguous(state, gradWeight))) { THError("Tensors must be contiguous"); } - int nDim = THCudaTensor_nDimension(state, input); + int nDim = THIndexTensor_(nDimension)(state, input); if (nDim != 1 && nDim != 2) THError("input must be a vector or matrix"); - long numel = THCudaTensor_nElement(state, input); + long numel = THIndexTensor_(nElement)(state, input); long stride = gradWeight->stride[0]; cudaStream_t stream = THCState_getCurrentStream(state); if (numel <= 768 && !scaleGradByFreq) { cunn_LookupTable_accGradParametersKernelByFeature<<<DIVUP(stride,4), 128, 0, stream>>>( - THCudaTensor_data(state, input), + THIndexTensor_(data)(state, input), THCudaTensor_data(state, gradOutput), THCudaTensor_data(state, gradWeight), scale, @@ -201,23 +201,25 @@ void THNN_CudaLookupTable_accGradParameters( return; } - THCudaTensor_resizeAs(state, sorted, input); - THCudaTensor_resizeAs(state, indices, input); + THLongStorage *inputSize = THIndexTensor_(newSizeOf)(state, input); + THIndexTensor_(resize)(state, sorted, inputSize, NULL); + THIndexTensor_(resize)(state, indices, inputSize, NULL); + THLongStorage_free(inputSize); // Sort the inputs into sorted with the corresponding indices - THCudaTensor_sort(state, sorted, indices, input, 0, 0); + THIndexTensor_(sort)(state, sorted, indices, input, 0, 0); - float *sorted_data = THCudaTensor_data(state, sorted); - float *indices_data = THCudaTensor_data(state, indices); - float *count_data = NULL; + long *sorted_data = THIndexTensor_(data)(state, sorted); + long *indices_data = THIndexTensor_(data)(state, indices); + long *count_data = NULL; if (scaleGradByFreq) { - THIntegerTensor_(resizeAs)(state, count, input); - count_data = THIntegerTensor_(data)(state, count); + THIndexTensor_(resizeAs)(state, count, input); + count_data = THIndexTensor_(data)(state, count); - thrust::device_ptr<float> sorted_ptr(sorted_data); - thrust::device_ptr<float> count_ptr(count_data); + thrust::device_ptr<long> sorted_ptr(sorted_data); + thrust::device_ptr<long> count_ptr(count_data); // Compute an increasing sequence per unique item in sorted: // sorted: 2 5 5 5 7 7 8 9 9 @@ -243,8 +245,8 @@ void THNN_CudaLookupTable_accGradParameters( thrust::make_reverse_iterator(sorted_ptr), thrust::make_reverse_iterator(count_ptr + numel), thrust::make_reverse_iterator(count_ptr + numel), - thrust::equal_to<float>(), - thrust::maximum<float>() + thrust::equal_to<long>(), + thrust::maximum<long>() ); } @@ -302,23 +304,23 @@ void THNN_CudaLookupTable_renorm( float normType) { THCUNN_assertSameGPU(state, 2, idx, weight); - if (!(THCudaTensor_isContiguous(state, idx) && + if (!(THIndexTensor_(isContiguous)(state, idx) && THCudaTensor_isContiguous(state, weight))) { THError("Tensors must be contiguous"); } - if (THCudaTensor_nDimension(state, idx) != 1) + if (THIndexTensor_(nDimension)(state, idx) != 1) THError("idx must be a vector"); if (normType <= 0) THError("non-positive-norm not supported"); - long numel = THCudaTensor_nElement(state, idx); + long numel = THIndexTensor_(nElement)(state, idx); long stride = weight->stride[0]; // get the unique indices thrust::device_ptr<float> weight_ptr(THCudaTensor_data(state, weight)); - thrust::device_ptr<float> idx_ptr(THCudaTensor_data(state, idx)); - thrust::device_ptr<float> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel); + thrust::device_ptr<long> idx_ptr(THIndexTensor_(data)(state, idx)); + thrust::device_ptr<long> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel); numel = end_ptr - idx_ptr; pow_v<float> unary_pow(normType); diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu index 706e53b..577eec5 100644 --- a/lib/THCUNN/SparseLinear.cu +++ b/lib/THCUNN/SparseLinear.cu @@ -72,7 +72,7 @@ void THNN_CudaSparseLinear_updateOutput(THCState *state, init_cusparse(); cusparseXcoo2csr(cusparse_handle, - THCudaIntTensor_data(state, rowbuf), nnz, batchnum, + THCudaIntTensor_data(state, rowbuf), nnz, batchnum, THCudaIntTensor_data(state, csrPtrs), CUSPARSE_INDEX_BASE_ONE); // output = bias @@ -82,13 +82,13 @@ void THNN_CudaSparseLinear_updateOutput(THCState *state, THCudaTensor_select(state, sel, buffer, 1, h); THCudaTensor_copy(state, sel, bias); } - + // output = W * x float one = 1; cusparseMatDescr_t descr = 0; cusparseCreateMatDescr(&descr); cusparseSetMatType(descr,CUSPARSE_MATRIX_TYPE_GENERAL); - cusparseSetMatIndexBase(descr,CUSPARSE_INDEX_BASE_ONE); + cusparseSetMatIndexBase(descr,CUSPARSE_INDEX_BASE_ONE); cusparseScsrmm(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, batchnum, outDim, inDim, nnz, @@ -105,7 +105,7 @@ void THNN_CudaSparseLinear_updateOutput(THCState *state, // We do work in the buffer to keep the output contiguous THCudaTensor_copy(state, output, buffer); - cusparseDestroyMatDescr(descr); + cusparseDestroyMatDescr(descr); descr = 0; THCudaTensor_free(state, buffer); THCudaTensor_free(state, sel); @@ -141,7 +141,7 @@ void THNN_CudaSparseLinear_accGradParameters( THCudaTensor *buf = THCudaTensor_new(state); THCudaTensor *cols = THCudaTensor_new(state); THCudaTensor *sel = THCudaTensor_new(state); - THCudaTensor *inds = THCudaTensor_new(state); + THCudaLongTensor *inds = THCudaLongTensor_new(state); THCudaTensor *values = THCudaTensor_new(state); THCudaIntTensor *colbuf = THCudaIntTensor_new(state); THCudaIntTensor *colPtrs = THCudaIntTensor_new(state); @@ -169,7 +169,7 @@ void THNN_CudaSparseLinear_accGradParameters( init_cusparse(); // Secretly coo2csc cusparseXcoo2csr(cusparse_handle, - THCudaIntTensor_data(state, colbuf), nnz, inDim, + THCudaIntTensor_data(state, colbuf), nnz, inDim, THCudaIntTensor_data(state, colPtrs), CUSPARSE_INDEX_BASE_ONE); // FORTRAN expects contiguous col-major matricies @@ -182,7 +182,7 @@ void THNN_CudaSparseLinear_accGradParameters( cusparseMatDescr_t descr = 0; cusparseCreateMatDescr(&descr); cusparseSetMatType(descr,CUSPARSE_MATRIX_TYPE_GENERAL); - cusparseSetMatIndexBase(descr,CUSPARSE_INDEX_BASE_ONE); + cusparseSetMatIndexBase(descr,CUSPARSE_INDEX_BASE_ONE); cusparseScsrmm(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, inDim, outDim, batchnum, nnz, @@ -208,7 +208,7 @@ void THNN_CudaSparseLinear_accGradParameters( THCudaTensor_free(state, buf); THCudaTensor_free(state, sel); THCudaTensor_free(state, cols); - THCudaTensor_free(state, inds); + THCudaLongTensor_free(state, inds); THCudaTensor_free(state, values); THCudaIntTensor_free(state, colbuf); THCudaIntTensor_free(state, rowInds); @@ -260,4 +260,3 @@ TH_API void THNN_CudaSparseLinear_updateParameters( void THNN_CudaSparseLinear_cudaClearState(THCState *state) { } - diff --git a/lib/THCUNN/THCUNN.h b/lib/THCUNN/THCUNN.h index b176a6d..858ddab 100644 --- a/lib/THCUNN/THCUNN.h +++ b/lib/THCUNN/THCUNN.h @@ -1,11 +1,8 @@ #include <THC/THC.h> #include <THC/THCApply.cuh> -#define THIndexTensor THCudaTensor -#define THIndexTensor_(NAME) THCudaTensor_ ## NAME - -#define THIntegerTensor THCudaTensor -#define THIntegerTensor_(NAME) THCudaTensor_ ## NAME +#define THIndexTensor THCudaLongTensor +#define THIndexTensor_(NAME) THCudaLongTensor_ ## NAME TH_API void THNN_CudaAbs_updateOutput( THCState *state, @@ -159,9 +156,9 @@ TH_API void THNN_CudaLookupTable_accGradParameters( THIndexTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradWeight, - THIntegerTensor *count, - THCudaTensor *sorted, // [OPTIONAL] - THCudaTensor *indices, // [OPTIONAL] + THIndexTensor *count, + THIndexTensor *sorted, // [OPTIONAL] + THIndexTensor *indices, // [OPTIONAL] bool scaleGradByFreq, int paddingValue, float scale); |