Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCTensorTypeUtils.cuh « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 22a2f92adebc14bd6c2d4e151af2e2940d494082 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#ifndef THC_TENSOR_TYPE_UTILS_INC
#define THC_TENSOR_TYPE_UTILS_INC

#include <cuda.h>
#include <assert.h>
#include "THCGeneral.h"
#include "THCHalf.h"
#include "THCTensor.h"
#include "THCTensorInfo.cuh"

/// A utility for accessing THCuda*Tensor types in a generic manner

/// Equivalent to C++11's type_traits std::is_same; used for comparing
/// equality of types. Don't assume the existence of C++11
template <typename T, typename U>
struct SameType {
  static const bool same = false;
};

template <typename T>
struct SameType<T, T> {
  static const bool same = true;
};

template <typename T, typename U>
bool isSameType() {
  return SameType<T, U>::same;
}

template <typename TensorType>
struct TensorUtils {
};

#define TENSOR_UTILS(TENSOR_TYPE, DATA_TYPE, ACC_DATA_TYPE)             \
  template <>                                                           \
  struct TensorUtils<TENSOR_TYPE> {                                     \
    typedef DATA_TYPE DataType;                                         \
    typedef ACC_DATA_TYPE AccDataType;                                  \
                                                                        \
    static TENSOR_TYPE* newTensor(THCState* state);                     \
    static TENSOR_TYPE* newContiguous(THCState* state, TENSOR_TYPE* t); \
    static THLongStorage* newSizeOf(THCState* state, TENSOR_TYPE* t);   \
    static void retain(THCState* state, TENSOR_TYPE* t);                \
    static void free(THCState* state, TENSOR_TYPE* t);                  \
    static void freeCopyTo(THCState* state, TENSOR_TYPE* src,           \
                           TENSOR_TYPE* dst);                           \
    static void resize(THCState* state, TENSOR_TYPE* out,               \
                       THLongStorage* sizes,                            \
                       THLongStorage* strides);                         \
    static void resizeAs(THCState* state, TENSOR_TYPE* dst,             \
                         TENSOR_TYPE* src);                             \
    static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t);         \
    static long getNumElements(THCState* state, TENSOR_TYPE* t);        \
    static long getSize(THCState* state, TENSOR_TYPE* t, int dim);      \
    static long getStride(THCState* state, TENSOR_TYPE* t, int dim);    \
    static int getDims(THCState* state, TENSOR_TYPE* t);                \
    static bool isContiguous(THCState* state, TENSOR_TYPE* t);          \
    static int getDevice(THCState* state, TENSOR_TYPE* t);              \
    static void copyIgnoringOverlaps(THCState* state,                   \
                                     TENSOR_TYPE* dst, TENSOR_TYPE* src); \
    /* Determines if the given tensor has overlapping data points (i.e., */ \
    /* is there more than one index into the tensor that references */  \
    /* the same piece of data)? */                                      \
    static bool overlappingIndices(THCState* state, TENSOR_TYPE* t);    \
    /* Can we use 32 bit math for indexing? */                          \
    static bool canUse32BitIndexMath(THCState* state, TENSOR_TYPE* t);  \
  }

TENSOR_UTILS(THCudaByteTensor, unsigned char, long);
TENSOR_UTILS(THCudaCharTensor, char, long);
TENSOR_UTILS(THCudaShortTensor, short, long);
TENSOR_UTILS(THCudaIntTensor, int, long);
TENSOR_UTILS(THCudaLongTensor, long, long);
TENSOR_UTILS(THCudaTensor, float, float);
TENSOR_UTILS(THCudaDoubleTensor, double, double);

#ifdef CUDA_HALF_TENSOR
TENSOR_UTILS(THCudaHalfTensor, half, float);
#endif

#undef TENSOR_UTILS

template <typename TensorType, typename IndexType>
TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>
getTensorInfo(THCState* state, TensorType* t) {
  IndexType sz[MAX_CUTORCH_DIMS];
  IndexType st[MAX_CUTORCH_DIMS];

  int dims = TensorUtils<TensorType>::getDims(state, t);
  for (int i = 0; i < dims; ++i) {
    sz[i] = TensorUtils<TensorType>::getSize(state, t, i);
    st[i] = TensorUtils<TensorType>::getStride(state, t, i);
  }

  return TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>(
    TensorUtils<TensorType>::getData(state, t), dims, sz, st);
}

template <typename T>
struct ScalarNegate {
  static __host__ __device__ T to(const T v) { return -v; }
};

template <typename T>
struct ScalarInv {
  static __host__ __device__ T to(const T v) { return ((T) 1) / v; }
};

#ifdef CUDA_HALF_TENSOR
template <>
struct ScalarNegate<half> {
  static __host__ __device__ half to(const half v) {
#ifdef __CUDA_ARCH__
#ifdef CUDA_HALF_INSTRUCTIONS
    return __hneg(v);
#else
    return __float2half(-__half2float(v));
#endif
#else
    half out = v;
    out.x ^= 0x8000; // toggle sign bit
    return out;
#endif
  }
};

template <>
struct ScalarInv<half> {
  static __host__ __device__ half to(const half v) {
#ifdef __CUDA_ARCH__
    return __float2half(1.0f / __half2float(v));
#else
    float fv = THC_half2float(v);
    fv = 1.0f / fv;
    return THC_float2half(fv);
#endif
  }
};

inline bool operator==(half a, half b) {
  return a.x == b.x;
}

inline bool operator!=(half a, half b) {
  return a.x != b.x;
}

#endif // CUDA_HALF_TENSOR

#endif // THC_TENSOR_TYPE_UTILS_INC