Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCTensorMathReduce.cu')
-rw-r--r--lib/THC/THCTensorMathReduce.cu358
1 files changed, 358 insertions, 0 deletions
diff --git a/lib/THC/THCTensorMathReduce.cu b/lib/THC/THCTensorMathReduce.cu
new file mode 100644
index 0000000..27949ad
--- /dev/null
+++ b/lib/THC/THCTensorMathReduce.cu
@@ -0,0 +1,358 @@
+#include "THCTensorMath.h"
+#include "THCGeneral.h"
+#include "THCNumerics.cuh"
+#include "THCReduce.cuh"
+#include "THCReduceAll.cuh"
+#include <thrust/functional.h>
+
+// Reduction operators that support `half`, unlike Thrust
+template <typename InT, typename AccT>
+struct ReduceAdd {
+ inline __device__ AccT operator()(AccT a, InT b) const {
+ return a + (AccT) b;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct ReduceAdd<half, half> {
+ inline __device__ half operator()(half a, half b) const {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ return __hadd(a, b);
+#else
+ float fa = __half2float(a);
+ float fb = __half2float(b);
+ return __float2half(fa + fb);
+#endif
+ }
+};
+
+template <>
+struct ReduceAdd<half, float> {
+ inline __device__ float operator()(float a, half b) const {
+ return a + __half2float(b);
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename InT, typename AccT>
+struct ReduceMultiply {
+ inline __device__ AccT operator()(AccT a, InT b) const {
+ return a * (AccT) b;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct ReduceMultiply<half, half> {
+ inline __device__ half operator()(half a, half b) const {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ return __hmul(a, b);
+#else
+ float fa = __half2float(a);
+ float fb = __half2float(b);
+ return __float2half(fa * fb);
+#endif
+ }
+};
+
+template <>
+struct ReduceMultiply<half, float> {
+ inline __device__ float operator()(float a, half b) const {
+ return a * __half2float(b);
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct ReduceMin {
+ inline __device__ T operator()(T a, T b) const {
+ return THCNumerics<T>::lt(a, b) ? a : b;
+ }
+};
+
+template <typename T>
+struct ReduceMax {
+ inline __device__ T operator()(T a, T b) const {
+ return THCNumerics<T>::gt(a, b) ? a : b;
+ }
+};
+
+struct LogicalAll {
+ inline __device__ unsigned char operator()(unsigned char x,
+ unsigned char y) const {
+ return (x && y);
+ }
+};
+
+struct LogicalAny {
+ inline __device__ unsigned char operator()(unsigned char x,
+ unsigned char y) const {
+ return (x || y);
+ }
+};
+
+
+THC_API int
+THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self) {
+ THAssert(THCudaByteTensor_checkGPU(state, 1, self));
+ unsigned char result;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<unsigned char>(),
+ LogicalAll(),
+ LogicalAll(),
+ (unsigned char) 1, &result, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ return (int) result;
+}
+
+THC_API int
+THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self) {
+ THAssert(THCudaByteTensor_checkGPU(state, 1, self));
+ unsigned char result;
+ if (!THC_reduceAll(state, self,
+ thrust::identity<unsigned char>(),
+ LogicalAny(),
+ LogicalAny(),
+ (unsigned char) 0, &result, 0)) {
+ THArgCheck(false, 1, CUTORCH_DIM_WARNING);
+ }
+
+ return (int) result;
+}
+
+
+#include <thrust/functional.h>
+
+/* A set of reduction kernels that take in binary ops on thrust pairs (of value, index).
+ These are useful when you not only have to do a reduction, but you might have
+ to preserve the location of contention (for example min/max operations).
+ The structure of the kernels follows the structure of the reduction kernels.
+*/
+template <typename K, typename Index, class BinaryFunction>
+__global__ void
+kernelTransformReduceOuterDimIndex(K *tgt1,
+ Index *tgt2,
+ K *src_,
+ unsigned num_orows,
+ unsigned num_irows,
+ unsigned row_size,
+ thrust::pair<K, Index> init,
+ BinaryFunction binary_op) {
+ for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
+ for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x;
+ irow < num_irows;
+ irow += gridDim.y * blockDim.x) {
+ K *src = src_ + orow * row_size * num_irows + irow;
+ thrust::pair<K, Index> acc = init;
+
+ for (unsigned col = 0; col < row_size; ++col) {
+ // +1 for Lua index
+ acc = binary_op(thrust::make_pair<K, Index>(*src, col+1),
+ acc);
+ src += num_irows;
+ }
+
+ tgt1[orow * num_irows + irow] = acc.first;
+ tgt2[orow * num_irows + irow] = acc.second;
+ }
+ }
+}
+
+template <typename TensorTypeK,
+ typename TensorTypeIndex,
+ typename BinaryFunction>
+__host__ void
+THC_transformReduceOuterDimIndex(THCState *state,
+ TensorTypeK *tgt1,
+ TensorTypeIndex *tgt2,
+ TensorTypeK *src,
+ long rdim,
+ const thrust::pair<
+ typename TensorUtils<TensorTypeK>::DataType,
+ typename TensorUtils<TensorTypeIndex>::DataType>& init,
+ BinaryFunction binary_op) {
+ unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src);
+ unsigned num_orows = 1;
+ for (unsigned dim = 0; dim < rdim; dim++) {
+ num_orows *= TensorUtils<TensorTypeK>::getSize(state, src, dim);
+ }
+ unsigned row_size = TensorUtils<TensorTypeK>::getSize(state, src, rdim);
+ unsigned num_irows = 1;
+ for (unsigned dim = rdim + 1; dim < ndim; dim++) {
+ num_irows *= TensorUtils<TensorTypeK>::getSize(state, src, dim);
+ }
+
+ dim3 threads(min(512, num_irows));
+ unsigned maxGridDim = 1024;
+ dim3 grid(min(maxGridDim, num_orows),
+ min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
+
+ kernelTransformReduceOuterDimIndex
+ <<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ TensorUtils<TensorTypeK>::getData(state, tgt1),
+ TensorUtils<TensorTypeIndex>::getData(state, tgt2),
+ TensorUtils<TensorTypeK>::getData(state, src),
+ num_orows, num_irows, row_size, init, binary_op);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+/* Reduce the innermost dimension of a tensor (on thrust::pair functors which are (value, index))
+ *
+ * For an n-d tensor (n <= 4) where the reduction is along the innermost dimension:
+ *
+ * - block.x is the innermost dimension, i.e. dimension 0;
+ * - block.y and grid.y make up dimension 1; and
+ * - grid.x and grid z are the remaining two outer dimensions (if any)
+ *
+ * Reduction along other dimensions is handled in a separate kernel.
+ */
+template <typename K, typename Index, class BinaryFunction>
+__global__ void
+kernelTransformReduceInnermostDimIndex(K *tgt1,
+ Index* tgt2,
+ K *src_,
+ unsigned num_rows,
+ unsigned row_size,
+ thrust::pair<K, Index> init,
+ BinaryFunction binary_op) {
+ __shared__ K sbuf[32][16 + 1]; // avoid bank conflict
+ __shared__ Index ibuf[32][16 + 1]; // avoid bank conflict
+
+ for (unsigned block_row = blockIdx.x * blockDim.y;
+ block_row < num_rows;
+ block_row += blockDim.y * gridDim.x) {
+ unsigned row = block_row + threadIdx.y;
+ thrust::pair<K, Index> acc = init;
+ if (row < num_rows) {
+ K *src = src_ + row * row_size;
+ // Sequential reduction within a thread.
+ for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) {
+ acc = binary_op(thrust::make_pair<K, Index>(src[col], col + 1), acc);
+ }
+ }
+
+ sbuf[threadIdx.y][threadIdx.x] = acc.first;
+ ibuf[threadIdx.y][threadIdx.x] = acc.second;
+
+ __syncthreads();
+
+ // Reduce intermediate values to single value.
+ K* sline = &sbuf[threadIdx.y][0];
+ Index* iline = &ibuf[threadIdx.y][0];
+ for (unsigned s = 8; s > 0; s >>= 1) {
+ if (row < num_rows && threadIdx.x < s) {
+ thrust::pair<K, Index> arg1 =
+ thrust::make_pair<K, Index>(sline[threadIdx.x], iline[threadIdx.x]);
+ thrust::pair<K, Index> arg2 =
+ thrust::make_pair<K, Index>(sline[threadIdx.x + s], iline[threadIdx.x + s]);
+ thrust::pair<K, Index> res = binary_op(arg1, arg2);
+
+ sline[threadIdx.x] = res.first;
+ iline[threadIdx.x] = res.second;
+ }
+ __syncthreads();
+ }
+
+ if (row < num_rows && threadIdx.x == 0) {
+ tgt1[row] = sline[0];
+ tgt2[row] = iline[0];
+ }
+ __syncthreads();
+ }
+}
+
+template <typename TensorTypeK,
+ typename TensorTypeIndex,
+ typename BinaryFunction>
+__host__ void
+THC_transformReduceInnermostDimIndex(THCState *state,
+ TensorTypeK *tgt1,
+ TensorTypeIndex *tgt2,
+ TensorTypeK *src,
+ const thrust::pair<
+ typename TensorUtils<TensorTypeK>::DataType,
+ typename TensorUtils<TensorTypeIndex>::DataType>& init,
+ BinaryFunction binary_op) {
+ unsigned ndim = TensorUtils<TensorTypeK>::getDims(state, src);
+ unsigned num_rows = 1;
+ for (unsigned dim = 0; dim < ndim - 1; dim++) {
+ num_rows *= TensorUtils<TensorTypeK>::getSize(state, src, dim);
+ }
+ unsigned row_size = TensorUtils<TensorTypeK>::getSize(state, src, ndim - 1);
+
+ dim3 threads(16, 32);
+ dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y)));
+
+ kernelTransformReduceInnermostDimIndex
+ <<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ TensorUtils<TensorTypeK>::getData(state, tgt1),
+ TensorUtils<TensorTypeIndex>::getData(state, tgt2),
+ TensorUtils<TensorTypeK>::getData(state, src),
+ num_rows, row_size, init, binary_op);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+template <typename TensorTypeK,
+ typename TensorTypeIndex,
+ typename BinaryFunction>
+void
+THC_reduceDimIndex(THCState *state,
+ TensorTypeK *tgt1_,
+ TensorTypeIndex *tgt2_,
+ TensorTypeK *src,
+ long dimension,
+ const thrust::pair<
+ typename TensorUtils<TensorTypeK>::DataType,
+ typename TensorUtils<TensorTypeIndex>::DataType>& init,
+ BinaryFunction binary_op)
+{
+ THArgCheck(dimension >= 0 &&
+ dimension < TensorUtils<TensorTypeK>::getDims(state, src),
+ 3, "dimension out of range");
+
+ THLongStorage *dim = TensorUtils<TensorTypeK>::newSizeOf(state, src);
+ THLongStorage_set(dim, dimension, 1);
+ TensorUtils<TensorTypeK>::resize(state, tgt1_, dim, NULL);
+ TensorUtils<TensorTypeIndex>::resize(state, tgt2_, dim, NULL);
+ THLongStorage_free(dim);
+
+ TensorTypeK *tgt1 = TensorUtils<TensorTypeK>::newContiguous(state, tgt1_);
+ TensorTypeIndex *tgt2 = TensorUtils<TensorTypeIndex>::newContiguous(state, tgt2_);
+ src = TensorUtils<TensorTypeK>::newContiguous(state, src);
+
+ if (dimension == TensorUtils<TensorTypeK>::getDims(state, src) - 1) {
+ THC_transformReduceInnermostDimIndex(state, tgt1, tgt2, src, init, binary_op);
+ } else {
+ THC_transformReduceOuterDimIndex(state, tgt1, tgt2, src, dimension, init, binary_op);
+ }
+
+ TensorUtils<TensorTypeK>::free(state, src);
+ TensorUtils<TensorTypeK>::freeCopyTo(state, tgt1, tgt1_);
+ TensorUtils<TensorTypeIndex>::freeCopyTo(state, tgt2, tgt2_);
+}
+
+template <typename T, typename Index>
+struct MaxValuePair {
+ __host__ __device__
+ thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
+ const thrust::pair<T, Index>& b) {
+ return THCNumerics<T>::ge(a.first, b.first) ? a : b;
+ }
+};
+
+template <typename T, typename Index>
+struct MinValuePair {
+ __host__ __device__
+ thrust::pair<T, Index> operator()(const thrust::pair<T, Index>& a,
+ const thrust::pair<T, Index>& b) {
+ return THCNumerics<T>::le(a.first, b.first) ? a : b;
+ }
+};
+
+#include "generic/THCTensorMathReduce.cu"
+#include "THCGenerateAllTypes.h"