Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCTensorTypeUtils.cuh')
-rw-r--r--lib/THC/THCTensorTypeUtils.cuh184
1 files changed, 184 insertions, 0 deletions
diff --git a/lib/THC/THCTensorTypeUtils.cuh b/lib/THC/THCTensorTypeUtils.cuh
new file mode 100644
index 0000000..4456f47
--- /dev/null
+++ b/lib/THC/THCTensorTypeUtils.cuh
@@ -0,0 +1,184 @@
+#ifndef THC_TENSOR_TYPE_UTILS_INC
+#define THC_TENSOR_TYPE_UTILS_INC
+
+#include <cuda.h>
+#include <assert.h>
+#include "THCGeneral.h"
+#include "THCHalf.h"
+#include "THCTensor.h"
+#include "THCTensorInfo.cuh"
+
+/// A utility for accessing THCuda*Tensor types in a generic manner
+
+/// Equivalent to C++11's type_traits std::is_same; used for comparing
+/// equality of types. Don't assume the existence of C++11
+template <typename T, typename U>
+struct SameType {
+ static const bool same = false;
+};
+
+template <typename T>
+struct SameType<T, T> {
+ static const bool same = true;
+};
+
+template <typename T, typename U>
+bool isSameType() {
+ return SameType<T, U>::same;
+}
+
+template <typename TensorType>
+struct TensorUtils {
+};
+
+#define TENSOR_UTILS(TENSOR_TYPE, DATA_TYPE) \
+ template <> \
+ struct TensorUtils<TENSOR_TYPE> { \
+ typedef DATA_TYPE DataType; \
+ \
+ static TENSOR_TYPE* newTensor(THCState* state); \
+ static TENSOR_TYPE* newContiguous(THCState* state, TENSOR_TYPE* t); \
+ static void retain(THCState* state, TENSOR_TYPE* t); \
+ static void free(THCState* state, TENSOR_TYPE* t); \
+ static void freeCopyTo(THCState* state, TENSOR_TYPE* src, \
+ TENSOR_TYPE* dst); \
+ static void resizeAs(THCState* state, TENSOR_TYPE* dst, \
+ TENSOR_TYPE* src); \
+ static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t); \
+ static long getNumElements(THCState* state, TENSOR_TYPE* t); \
+ static long getSize(THCState* state, TENSOR_TYPE* t, int dim); \
+ static long getStride(THCState* state, TENSOR_TYPE* t, int dim); \
+ static int getDims(THCState* state, TENSOR_TYPE* t); \
+ static bool isContiguous(THCState* state, TENSOR_TYPE* t); \
+ static int getDevice(THCState* state, TENSOR_TYPE* t); \
+ static void copyIgnoringOverlaps(THCState* state, \
+ TENSOR_TYPE* dst, TENSOR_TYPE* src); \
+ /* Determines if the given tensor has overlapping data points (i.e., */ \
+ /* is there more than one index into the tensor that references */ \
+ /* the same piece of data)? */ \
+ static bool overlappingIndices(THCState* state, TENSOR_TYPE* t); \
+ /* Can we use 32 bit math for indexing? */ \
+ static bool canUse32BitIndexMath(THCState* state, TENSOR_TYPE* t); \
+ }
+
+TENSOR_UTILS(THCudaByteTensor, unsigned char);
+TENSOR_UTILS(THCudaCharTensor, char);
+TENSOR_UTILS(THCudaShortTensor, short);
+TENSOR_UTILS(THCudaIntTensor, int);
+TENSOR_UTILS(THCudaLongTensor, long);
+TENSOR_UTILS(THCudaTensor, float);
+TENSOR_UTILS(THCudaDoubleTensor, double);
+
+#ifdef CUDA_HALF_TENSOR
+TENSOR_UTILS(THCudaHalfTensor, half);
+#endif
+
+#undef TENSOR_UTILS
+
+template <typename TensorType, typename IndexType>
+TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>
+getTensorInfo(THCState* state, TensorType* t) {
+ IndexType sz[MAX_CUTORCH_DIMS];
+ IndexType st[MAX_CUTORCH_DIMS];
+
+ int dims = TensorUtils<TensorType>::getDims(state, t);
+ for (int i = 0; i < dims; ++i) {
+ sz[i] = TensorUtils<TensorType>::getSize(state, t, i);
+ st[i] = TensorUtils<TensorType>::getStride(state, t, i);
+ }
+
+ return TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>(
+ TensorUtils<TensorType>::getData(state, t), dims, sz, st);
+}
+
+/// `half` has some type conversion issues associated with it, since it
+/// is a struct without a constructor/implicit conversion constructor.
+/// We use this to convert scalar values to the given type that the
+/// tensor expects.
+template <typename In, typename Out>
+struct ScalarConvert {
+ static __host__ __device__ Out to(const In v) { return (Out) v; }
+};
+
+template <typename T>
+struct ScalarNegate {
+ static __host__ __device__ T to(const T v) { return -v; }
+};
+
+template <typename T>
+struct ScalarInv {
+ static __host__ __device__ T to(const T v) { return ((T) 1) / v; }
+};
+
+#ifdef CUDA_HALF_TENSOR
+
+template <typename Out>
+struct ScalarConvert<half, Out> {
+ static __host__ __device__ Out to(const half v) {
+#ifdef __CUDA_ARCH__
+ return (Out) __half2float(v);
+#else
+ return (Out) THC_half2float(v);
+#endif
+ }
+};
+
+template <typename In>
+struct ScalarConvert<In, half> {
+ static __host__ __device__ half to(const In v) {
+#ifdef __CUDA_ARCH__
+ return __float2half((float) v);
+#else
+ return THC_float2half((float) v);
+#endif
+ }
+};
+
+template <>
+struct ScalarConvert<half, half> {
+ static __host__ __device__ half to(const half v) {
+ return v;
+ }
+};
+
+template <>
+struct ScalarNegate<half> {
+ static __host__ __device__ half to(const half v) {
+#ifdef __CUDA_ARCH__
+#ifdef CUDA_HALF_INSTRUCTIONS
+ return __hneg(v);
+#else
+ return __float2half(-__half2float(v));
+#endif
+#else
+ half out = v;
+ out.x ^= 0x8000; // toggle sign bit
+ return out;
+#endif
+ }
+};
+
+template <>
+struct ScalarInv<half> {
+ static __host__ __device__ half to(const half v) {
+#ifdef __CUDA_ARCH__
+ return __float2half(1.0f / __half2float(v));
+#else
+ float fv = THC_half2float(v);
+ fv = 1.0f / fv;
+ return THC_float2half(fv);
+#endif
+ }
+};
+
+inline bool operator==(half a, half b) {
+ return a.x == b.x;
+}
+
+inline bool operator!=(half a, half b) {
+ return a.x != b.x;
+}
+
+#endif // CUDA_HALF_TENSOR
+
+#endif // THC_TENSOR_TYPE_UTILS_INC