Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/THCUNN/LookupTable.cu193
-rw-r--r--lib/THCUNN/THCUNN.h19
-rw-r--r--lib/THCUNN/generic/LookupTable.cu157
-rw-r--r--lib/THCUNN/generic/THCUNN.h19
4 files changed, 201 insertions, 187 deletions
diff --git a/lib/THCUNN/LookupTable.cu b/lib/THCUNN/LookupTable.cu
index 1098a87..bb91c7e 100644
--- a/lib/THCUNN/LookupTable.cu
+++ b/lib/THCUNN/LookupTable.cu
@@ -9,6 +9,8 @@
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <thrust/unique.h>
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y) - 1) / (y))
@@ -49,8 +51,9 @@ __device__ __forceinline__ bool warpHasCollision(int val)
return __any(dup) != 0;
}
+template <typename Dtype>
__global__ void cunn_LookupTable_accGradParametersKernelByFeature(
- long *input, float *gradOutput, float *gradWeight, float scale, long numel,
+ long *input, Dtype *gradOutput, Dtype *gradWeight, Dtype scale, long numel,
long stride, int paddingValue) {
const int featureDim = blockIdx.x * 4 + threadIdx.x / 32;
@@ -78,8 +81,9 @@ __global__ void cunn_LookupTable_accGradParametersKernelByFeature(
continue;
}
- float update = gradOutput[i*stride + featureDim] * scale;
+ Dtype update = gradOutput[i*stride + featureDim] * scale;
+ // FIXME: should we accumulate as accreal?
// Check for collision
if (warpHasCollision(weightIndex)) {
// Run all lanes sequentially; warp divergence
@@ -95,9 +99,10 @@ __global__ void cunn_LookupTable_accGradParametersKernelByFeature(
}
}
+template <typename Dtype, typename Acctype>
__global__ void cunn_LookupTable_accGradParametersKernel(
- long *input, long *indices, float *gradOutput, float *gradWeight,
- long *count, float defaultScale, long numel, long stride, int paddingValue) {
+ long *input, long *indices, Dtype *gradOutput, Dtype *gradWeight,
+ long *count, Dtype defaultScale, long numel, long stride, int paddingValue) {
int idx = blockIdx.x * 4 + threadIdx.y;
@@ -122,10 +127,10 @@ __global__ void cunn_LookupTable_accGradParametersKernel(
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
const int weightRow = ((int) input[idx] - TH_INDEX_BASE) * stride;
const int gradOutputRow = ((int) indices[idx] - TH_INDEX_BASE) * stride;
- const float scale = count ? defaultScale / count[idx] : defaultScale;
+ const Acctype scale = count ? ScalarConvert<Dtype, Acctype>::to(defaultScale) / count[idx] : ScalarConvert<Dtype, Acctype>::to(defaultScale);
- float gradient[SZ];
- float weight[SZ];
+ Acctype gradient[SZ];
+ Acctype weight[SZ];
#pragma unroll
for (int ii = 0; ii < SZ; ii++)
@@ -133,8 +138,8 @@ __global__ void cunn_LookupTable_accGradParametersKernel(
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
- gradient[ii] = gradOutput[gradOutputRow + featureDim];
- weight[ii] = gradWeight[weightRow + featureDim];
+ gradient[ii] = ScalarConvert<Dtype, Acctype>::to(gradOutput[gradOutputRow + featureDim]);
+ weight[ii] = ScalarConvert<Dtype, Acctype>::to(gradWeight[weightRow + featureDim]);
}
}
@@ -150,7 +155,7 @@ __global__ void cunn_LookupTable_accGradParametersKernel(
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
- gradWeight[weightRow + featureDim] = weight[ii];
+ gradWeight[weightRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(weight[ii]);
}
}
@@ -159,129 +164,23 @@ __global__ void cunn_LookupTable_accGradParametersKernel(
}
}
-void THNN_CudaLookupTable_accGradParameters(
- THCState *state,
- THCIndexTensor *input,
- THCudaTensor *gradOutput,
- THCudaTensor *gradWeight,
- THCIndexTensor *count,
- THCIndexTensor *sorted,
- THCIndexTensor *indices,
- bool scaleGradByFreq,
- int paddingValue,
- float scale)
-{
- THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices);
- if (!(THCIndexTensor_(isContiguous)(state, input) &&
- THCudaTensor_isContiguous(state, gradOutput) &&
- THCudaTensor_isContiguous(state, gradWeight)))
- {
- THError("Tensors must be contiguous");
- }
-
- int nDim = THCIndexTensor_(nDimension)(state, input);
- if (nDim != 1 && nDim != 2)
- THError("input must be a vector or matrix");
-
- long numel = THCIndexTensor_(nElement)(state, input);
- long stride = gradWeight->stride[0];
-
- cudaStream_t stream = THCState_getCurrentStream(state);
-
- if (numel <= 768 && !scaleGradByFreq) {
- cunn_LookupTable_accGradParametersKernelByFeature<<<DIVUP(stride,4), 128, 0, stream>>>(
- THCIndexTensor_(data)(state, input),
- THCudaTensor_data(state, gradOutput),
- THCudaTensor_data(state, gradWeight),
- scale,
- numel,
- stride,
- paddingValue);
- THCudaCheck(cudaGetLastError());
- return;
- }
-
- THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
- THCIndexTensor_(resize)(state, sorted, inputSize, NULL);
- THCIndexTensor_(resize)(state, indices, inputSize, NULL);
- THLongStorage_free(inputSize);
-
- // Sort the inputs into sorted with the corresponding indices
- THCIndexTensor_(sort)(state, sorted, indices, input, 0, 0);
-
- THCIndex_t *sorted_data = THCIndexTensor_(data)(state, sorted);
- THCIndex_t *indices_data = THCIndexTensor_(data)(state, indices);
- THCIndex_t *count_data = NULL;
-
- if (scaleGradByFreq)
- {
- THCIndexTensor_(resizeAs)(state, count, input);
- count_data = THCIndexTensor_(data)(state, count);
-
- thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data);
- thrust::device_ptr<THCIndex_t> count_ptr(count_data);
-
- // Compute an increasing sequence per unique item in sorted:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 1 2 3 1 2 1 1 2
- thrust::inclusive_scan_by_key(
-#if CUDA_VERSION >= 7000
- thrust::cuda::par.on(THCState_getCurrentStream(state)),
-#endif
- sorted_ptr,
- sorted_ptr + numel,
- thrust::make_constant_iterator(1),
- count_ptr
- );
-
- // Take the maximum of each count per unique key in reverse:
- // sorted: 2 5 5 5 7 7 8 9 9
- // count: 1 3 3 3 2 2 1 2 2
- thrust::inclusive_scan_by_key(
-#if CUDA_VERSION >= 7000
- thrust::cuda::par.on(THCState_getCurrentStream(state)),
-#endif
- thrust::make_reverse_iterator(sorted_ptr + numel),
- thrust::make_reverse_iterator(sorted_ptr),
- thrust::make_reverse_iterator(count_ptr + numel),
- thrust::make_reverse_iterator(count_ptr + numel),
- thrust::equal_to<long>(),
- thrust::maximum<long>()
- );
- }
-
- dim3 grid(DIVUP(numel,4), DIVUP(stride,128));
- dim3 block(32, 4);
- cunn_LookupTable_accGradParametersKernel<<<grid, block, 0, stream>>>(
- sorted_data,
- indices_data,
- THCudaTensor_data(state, gradOutput),
- THCudaTensor_data(state, gradWeight),
- count_data,
- scale,
- numel,
- stride,
- paddingValue
- );
- THCudaCheck(cudaGetLastError());
-}
-
/*
* Keep the norm of weight smaller than maxNorm
*/
-template <typename T>
+template <typename Dtype, typename Acctype>
struct pow_v
{
- T normType;
- pow_v(T v) : normType(v) {}
+ Acctype normType;
+ pow_v(Dtype v) : normType(ScalarConvert<Dtype, Acctype>::to(v)) {}
__host__ __device__
- T operator()(const T& x) const {
+ Acctype operator()(const Dtype& x) const {
+ Acctype xA = ScalarConvert<Dtype, Acctype>::to(x);
if (normType == 1)
- return std::abs(x);
+ return std::abs(xA);
else if (normType == 2)
- return x * x;
+ return xA * xA;
else
- return std::pow(std::abs(x), normType);
+ return std::pow(std::abs(xA), normType);
}
};
@@ -296,47 +195,5 @@ struct multiply_s
}
};
-void THNN_CudaLookupTable_renorm(
- THCState *state,
- THCIndexTensor *idx,
- THCudaTensor *weight,
- float maxNorm,
- float normType)
-{
- THCUNN_assertSameGPU(state, 2, idx, weight);
- if (!(THCIndexTensor_(isContiguous)(state, idx) &&
- THCudaTensor_isContiguous(state, weight)))
- {
- THError("Tensors must be contiguous");
- }
- if (THCIndexTensor_(nDimension)(state, idx) != 1)
- THError("idx must be a vector");
- if (normType <= 0)
- THError("non-positive-norm not supported");
-
- THCIndex_t numel = THCIndexTensor_(nElement)(state, idx);
- long stride = weight->stride[0];
-
- // get the unique indices
- thrust::device_ptr<float> weight_ptr(THCudaTensor_data(state, weight));
- thrust::device_ptr<THCIndex_t> idx_ptr(THCIndexTensor_(data)(state, idx));
- thrust::device_ptr<THCIndex_t> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel);
- numel = end_ptr - idx_ptr;
-
- pow_v<float> unary_pow(normType);
- thrust::plus<float> binary_plus;
- // numel << stride, since idx usually contains sparse row indices
- for (long i = 0; i < numel; i++)
- {
- THCIndex_t k = idx_ptr[i] - TH_INDEX_BASE;
- thrust::device_ptr<float> row_ptr = weight_ptr + k * stride;
- float norm = thrust::transform_reduce(row_ptr, row_ptr + stride,
- unary_pow, 0, binary_plus);
- norm = std::pow(norm, (float) (1.0 / normType));
- if (norm > maxNorm)
- {
- multiply_s<float> unary_mul(maxNorm / (norm + 1e-7));
- thrust::transform(row_ptr, row_ptr + stride, row_ptr, unary_mul);
- }
- }
-}
+#include "generic/LookupTable.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/THCUNN.h b/lib/THCUNN/THCUNN.h
index c387d0c..4e95b06 100644
--- a/lib/THCUNN/THCUNN.h
+++ b/lib/THCUNN/THCUNN.h
@@ -7,24 +7,5 @@ typedef long THCIndex_t;
#define THNN_(NAME) TH_CONCAT_3(THNN_, CReal, NAME)
-TH_API void THNN_CudaLookupTable_accGradParameters(
- THCState *state,
- THCIndexTensor *input,
- THCudaTensor *gradOutput,
- THCudaTensor *gradWeight,
- THCIndexTensor *count,
- THCIndexTensor *sorted, // [OPTIONAL]
- THCIndexTensor *indices, // [OPTIONAL]
- bool scaleGradByFreq,
- int paddingValue,
- float scale);
-
-TH_API void THNN_CudaLookupTable_renorm(
- THCState *state,
- THCIndexTensor *idx,
- THCudaTensor *weight,
- float maxNorm,
- float normType);
-
#include "generic/THCUNN.h"
#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu
new file mode 100644
index 0000000..2027425
--- /dev/null
+++ b/lib/THCUNN/generic/LookupTable.cu
@@ -0,0 +1,157 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/LookupTable.cu"
+#else
+
+void THNN_(LookupTable_accGradParameters)(
+ THCState *state,
+ THCIndexTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCIndexTensor *count,
+ THCIndexTensor *sorted,
+ THCIndexTensor *indices,
+ bool scaleGradByFreq,
+ int paddingValue,
+ real scale)
+{
+ THCUNN_assertSameGPU_generic(state, 5, input, gradOutput, gradWeight, sorted, indices);
+ if (!(THCIndexTensor_(isContiguous)(state, input) &&
+ THCTensor_(isContiguous)(state, gradOutput) &&
+ THCTensor_(isContiguous)(state, gradWeight)))
+ {
+ THError("Tensors must be contiguous");
+ }
+
+ int nDim = THCIndexTensor_(nDimension)(state, input);
+ if (nDim != 1 && nDim != 2)
+ THError("input must be a vector or matrix");
+
+ long numel = THCIndexTensor_(nElement)(state, input);
+ long stride = gradWeight->stride[0];
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+
+ if (numel <= 768 && !scaleGradByFreq) {
+ cunn_LookupTable_accGradParametersKernelByFeature<<<DIVUP(stride,4), 128, 0, stream>>>(
+ THCIndexTensor_(data)(state, input),
+ THCTensor_(data)(state, gradOutput),
+ THCTensor_(data)(state, gradWeight),
+ scale,
+ numel,
+ stride,
+ paddingValue);
+ THCudaCheck(cudaGetLastError());
+ return;
+ }
+
+ THLongStorage *inputSize = THCIndexTensor_(newSizeOf)(state, input);
+ THCIndexTensor_(resize)(state, sorted, inputSize, NULL);
+ THCIndexTensor_(resize)(state, indices, inputSize, NULL);
+ THLongStorage_free(inputSize);
+
+ // Sort the inputs into sorted with the corresponding indices
+ THCIndexTensor_(sort)(state, sorted, indices, input, 0, 0);
+
+ THCIndex_t *sorted_data = THCIndexTensor_(data)(state, sorted);
+ THCIndex_t *indices_data = THCIndexTensor_(data)(state, indices);
+ THCIndex_t *count_data = NULL;
+
+ if (scaleGradByFreq)
+ {
+ THCIndexTensor_(resizeAs)(state, count, input);
+ count_data = THCIndexTensor_(data)(state, count);
+
+ thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data);
+ thrust::device_ptr<THCIndex_t> count_ptr(count_data);
+
+ // Compute an increasing sequence per unique item in sorted:
+ // sorted: 2 5 5 5 7 7 8 9 9
+ // count: 1 1 2 3 1 2 1 1 2
+ thrust::inclusive_scan_by_key(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ sorted_ptr,
+ sorted_ptr + numel,
+ thrust::make_constant_iterator(1),
+ count_ptr
+ );
+
+ // Take the maximum of each count per unique key in reverse:
+ // sorted: 2 5 5 5 7 7 8 9 9
+ // count: 1 3 3 3 2 2 1 2 2
+ thrust::inclusive_scan_by_key(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ thrust::make_reverse_iterator(sorted_ptr + numel),
+ thrust::make_reverse_iterator(sorted_ptr),
+ thrust::make_reverse_iterator(count_ptr + numel),
+ thrust::make_reverse_iterator(count_ptr + numel),
+ thrust::equal_to<long>(),
+ thrust::maximum<long>()
+ );
+ }
+
+ dim3 grid(DIVUP(numel,4), DIVUP(stride,128));
+ dim3 block(32, 4);
+ cunn_LookupTable_accGradParametersKernel<real, accreal><<<grid, block, 0, stream>>>(
+ sorted_data,
+ indices_data,
+ THCTensor_(data)(state, gradOutput),
+ THCTensor_(data)(state, gradWeight),
+ count_data,
+ scale,
+ numel,
+ stride,
+ paddingValue
+ );
+ THCudaCheck(cudaGetLastError());
+}
+
+void THNN_(LookupTable_renorm)(
+ THCState *state,
+ THCIndexTensor *idx,
+ THCTensor *weight,
+ real maxNorm,
+ real normType)
+{
+ THCUNN_assertSameGPU_generic(state, 2, idx, weight);
+ if (!(THCIndexTensor_(isContiguous)(state, idx) &&
+ THCTensor_(isContiguous)(state, weight)))
+ {
+ THError("Tensors must be contiguous");
+ }
+ if (THCIndexTensor_(nDimension)(state, idx) != 1)
+ THError("idx must be a vector");
+ if (normType <= 0)
+ THError("non-positive-norm not supported");
+
+ THCIndex_t numel = THCIndexTensor_(nElement)(state, idx);
+ long stride = weight->stride[0];
+
+ // get the unique indices
+ thrust::device_ptr<real> weight_ptr(THCTensor_(data)(state, weight));
+ thrust::device_ptr<THCIndex_t> idx_ptr(THCIndexTensor_(data)(state, idx));
+ thrust::device_ptr<THCIndex_t> end_ptr = thrust::unique(idx_ptr, idx_ptr+numel);
+ numel = end_ptr - idx_ptr;
+
+ pow_v<real, accreal> unary_pow(normType);
+ thrust::plus<accreal> binary_plus;
+ // numel << stride, since idx usually contains sparse row indices
+ for (long i = 0; i < numel; i++)
+ {
+ THCIndex_t k = idx_ptr[i] - TH_INDEX_BASE;
+ thrust::device_ptr<real> row_ptr = weight_ptr + k * stride;
+ accreal norm = thrust::transform_reduce(row_ptr, row_ptr + stride,
+ unary_pow, 0, binary_plus);
+ norm = std::pow(norm, (accreal) (1.0 / normType));
+ if (norm > ScalarConvert<real, accreal>::to(maxNorm))
+ {
+ multiply_s<real> unary_mul(ScalarConvert<accreal, real>::to(maxNorm / (norm + 1e-7)));
+ thrust::transform(row_ptr, row_ptr + stride, row_ptr, unary_mul);
+ }
+ }
+}
+
+#endif
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index 15316fe..91d68ae 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -178,6 +178,25 @@ TH_API void THNN_(LogSoftMax_updateGradInput)(
THCTensor *gradInput,
THCTensor *output);
+TH_API void THNN_(LookupTable_accGradParameters)(
+ THCState *state,
+ THCIndexTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCIndexTensor *count,
+ THCIndexTensor *sorted, // [OPTIONAL]
+ THCIndexTensor *indices, // [OPTIONAL]
+ bool scaleGradByFreq,
+ int paddingValue,
+ real scale);
+
+TH_API void THNN_(LookupTable_renorm)(
+ THCState *state,
+ THCIndexTensor *idx,
+ THCTensor *weight,
+ real maxNorm,
+ real normType);
+
TH_API void THNN_(L1Cost_updateOutput)(
THCState *state,
THCTensor *input,