Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-09-29 19:22:46 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:27 +0300
commitc1dcfb6a54febeb6a45b9066553b621070b55c85 (patch)
tree52dd4de5c788cef12f21d902e2737f2d74104051
parent2d31dbf0074fe16c6612a7a1ee144096f48a3917 (diff)
[cutorch refactor] make _renorm(...)'s ops generic
-rw-r--r--lib/THC/THCNumerics.cuh12
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMath2.cu59
-rw-r--r--lib/THC/THCTensorMathReduce.cuh48
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu25
-rw-r--r--lib/THC/generic/THCTensorMathReduce.h6
6 files changed, 83 insertions, 68 deletions
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh
index 36ed0c8..543a544 100644
--- a/lib/THC/THCNumerics.cuh
+++ b/lib/THC/THCNumerics.cuh
@@ -474,6 +474,16 @@ struct THCNumerics<half> {
#endif
}
+ static inline __host__ __device__ half pow(half a, half b) {
+#ifdef __CUDA_ARCH__
+ float fa = __half2float(a);
+ float fb = __half2float(b);
+ return __float2half(powf(fa, fb));
+#else // __CUDA_ARCH__
+ return THC_float2half(powf(THC_half2float(a), THC_half2float(b)));
+#endif
+ }
+
};
#endif
@@ -517,6 +527,7 @@ struct THCNumerics<float> {
static inline __host__ __device__ float div (float a, float b) { return a / b; }
static inline __host__ __device__ float mul (float a, float b) { return a * b; }
static inline __host__ __device__ float sub (float a, float b) { return a - b; }
+ static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); }
};
template <>
@@ -559,6 +570,7 @@ struct THCNumerics<double> {
static inline __host__ __device__ double div (double a, double b) { return a / b; }
static inline __host__ __device__ double mul (double a, double b) { return a * b; }
static inline __host__ __device__ double sub (double a, double b) { return a - b; }
+ static inline __host__ __device__ double pow (double a, double b) { return pow(a, b); }
};
/// `half` has some type conversion issues associated with it, since it
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 9b70b01..439e2e1 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -75,6 +75,7 @@ THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCuda
THC_API float THCudaTensor_varall(THCState *state, THCudaTensor *self);
THC_API void THCudaTensor_var(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag);
THC_API float THCudaTensor_stdall(THCState *state, THCudaTensor *self);
+THC_API void THCudaTensor_std(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim, int flag);
THC_API float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value);
THC_API void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension);
THC_API void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float max_norm);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 7fd11ff..d0d47ad 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -216,6 +216,27 @@ void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, l
THCudaTensor_freeCopyTo(state, self, self_);
}
+void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
+{
+ THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
+ THLongStorage *dim = THCudaTensor_newSizeOf(state, src);
+ THLongStorage_set(dim, dimension, 1);
+ THCudaTensor_resize(state, self_, dim, NULL);
+ THLongStorage_free(dim);
+
+ THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
+ src = THCudaTensor_newContiguous(state, src);
+
+ if (dimension == THCudaTensor_nDimension(state, src) - 1) {
+ THCTensor_varInnermostDim<THCudaTensor, float, true>(state, self, src, flag);
+ } else {
+ THCTensor_varOuterDim<THCudaTensor, float, true>(state, self, src, dimension, flag);
+ }
+
+ THCudaTensor_free(state, src);
+ THCudaTensor_freeCopyTo(state, self, self_);
+}
+
template <int StaticExp>
struct TensorNormOp
{
@@ -315,42 +336,6 @@ void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, f
THCudaCheck(cudaGetLastError());
}
-__global__ void THCudaTensor_kernel_renorm(float *data, const float value, const long size, const float maxnorm)
-{
- __shared__ float buffer[32];
- long tx = threadIdx.x;
- long bx = blockIdx.x;
- long step = blockDim.x;
- float *row = data + size*bx;
-
- buffer[tx] = 0;
-
- // get norm of axis
- for (long i=tx; i<size; i+=step)
- {
- buffer[tx] += pow(fabs(row[i]), value);
- }
- // add (reduce)
- for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
- {
- __syncthreads();
- if (tx < stride)
- buffer[tx] += buffer[tx+stride];
- }
- // clip norms
- __syncthreads();
- float norm = pow(buffer[0], 1/value);
- if (norm > maxnorm)
- {
- norm = maxnorm / (norm + 1e-7);
- // renormalize
- for (long i=tx; i<size; i+=step)
- {
- row[i] *= norm;
- }
- }
-}
-
void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float maxnorm)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
@@ -366,7 +351,7 @@ void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src,
dim3 grid(data->size[0]);
dim3 threads(32);
- THCudaTensor_kernel_renorm<<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm);
+ THCTensor_kernel_renorm<float><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm);
cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh
index 3bc0837..15cb314 100644
--- a/lib/THC/THCTensorMathReduce.cuh
+++ b/lib/THC/THCTensorMathReduce.cuh
@@ -95,6 +95,54 @@ struct LogicalAny {
}
};
+template<typename Real>
+__global__ void THCTensor_kernel_renorm(Real *data, const Real value, const long size, const Real maxnorm)
+{
+ __shared__ Real buffer[32];
+ long tx = threadIdx.x;
+ long bx = blockIdx.x;
+ long step = blockDim.x;
+ Real *row = data + size*bx;
+
+ buffer[tx] = ScalarConvert<int, Real>::to(0);
+
+ // get norm of axis
+ for (long i=tx; i<size; i+=step)
+ {
+ buffer[tx] = THCNumerics<Real>::add(
+ buffer[tx],
+ THCNumerics<Real>::pow(
+ THCNumerics<Real>::abs(row[i]),
+ value)
+ );
+ }
+ // add (reduce)
+ for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
+ {
+ __syncthreads();
+ if (tx < stride)
+ buffer[tx] = THCNumerics<Real>::add(buffer[tx], buffer[tx+stride]);
+ }
+ // clip norms
+ __syncthreads();
+ Real norm = THCNumerics<Real>::pow(buffer[0], THCNumerics<Real>::cinv(value));
+ if (THCNumerics<Real>::gt(norm, maxnorm))
+ {
+ norm = THCNumerics<Real>::div(
+ maxnorm,
+ THCNumerics<Real>::add(
+ norm,
+ ScalarConvert<float, Real>::to(1e-7)
+ )
+ );
+ // renormalize
+ for (long i=tx; i<size; i+=step)
+ {
+ row[i] = THCNumerics<Real>::mul(row[i], norm);
+ }
+ }
+}
+
#include <thrust/functional.h>
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index 8f6b9c4..4ae3c0c 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -38,31 +38,6 @@ THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim)
THCTensor_(div)(state, self, self, ScalarConvert<long, real>::to(THCTensor_(size)(state, src, dim)));
}
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
-
-void THCTensor_std(THCState *state, THCTensor *self_, THCTensor *src, long dimension, int flag)
-{
- THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
- THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
- THLongStorage_set(dim, dimension, 1);
- THCTensor_(resize)(state, self_, dim, NULL);
- THLongStorage_free(dim);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- src = THCTensor_(newContiguous)(state, src);
-
- if (dimension == THCTensor_(nDimension)(state, src) - 1) {
- THCTensor_varInnermostDim<THCTensor, real, true>(state, self, src, flag);
- } else {
- THCTensor_varOuterDim<THCTensor, real, true>(state, self, src, dimension, flag);
- }
-
- THCTensor_(free)(state, src);
- THCTensor_(freeCopyTo)(state, self, self_);
-}
-
-#endif
-
THC_API accreal
THCTensor_(sumall)(THCState *state, THCTensor *self) {
THAssert(THCTensor_(checkGPU)(state, 1, self));
diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h
index bc37f85..500003f 100644
--- a/lib/THC/generic/THCTensorMathReduce.h
+++ b/lib/THC/generic/THCTensorMathReduce.h
@@ -2,12 +2,6 @@
#define THC_GENERIC_FILE "generic/THCTensorMathReduce.h"
#else
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
-
-THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag);
-
-#endif
-
THC_API void THCTensor_(sum)(THCState *state, THCTensor *self, THCTensor *src, long dim);
THC_API void THCTensor_(prod)(THCState *state, THCTensor *self, THCTensor *src, long dim);
THC_API void THCTensor_(mean)(THCState *state, THCTensor *self, THCTensor *src, long dim);