Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCTensorTypeUtils.cuh « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 16a0cdee3805635adc2608489438aeb36238c438 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#ifndef THC_TENSOR_TYPE_UTILS_INC
#define THC_TENSOR_TYPE_UTILS_INC

#include <cuda.h>
#include <assert.h>
#include "THCGeneral.h"
#include "THCHalf.h"
#include "THCTensor.h"
#include "THCTensorInfo.cuh"

/// A utility for accessing THCuda*Tensor types in a generic manner

/// Equivalent to C++11's type_traits std::is_same; used for comparing
/// equality of types. Don't assume the existence of C++11
template <typename T, typename U>
struct SameType {
  static const bool same = false;
};

template <typename T>
struct SameType<T, T> {
  static const bool same = true;
};

template <typename T, typename U>
bool isSameType() {
  return SameType<T, U>::same;
}

template <typename TensorType>
struct TensorUtils {
};

#define TENSOR_UTILS(TENSOR_TYPE, DATA_TYPE, ACC_DATA_TYPE)             \
  template <>                                                           \
  struct THC_CLASS TensorUtils<TENSOR_TYPE> {                                     \
    typedef DATA_TYPE DataType;                                         \
    typedef ACC_DATA_TYPE AccDataType;                                  \
                                                                        \
    static TENSOR_TYPE* newTensor(THCState* state);                     \
    static TENSOR_TYPE* newContiguous(THCState* state, TENSOR_TYPE* t); \
    static THLongStorage* newSizeOf(THCState* state, TENSOR_TYPE* t);   \
    static void retain(THCState* state, TENSOR_TYPE* t);                \
    static void free(THCState* state, TENSOR_TYPE* t);                  \
    static void freeCopyTo(THCState* state, TENSOR_TYPE* src,           \
                           TENSOR_TYPE* dst);                           \
    static void resize(THCState* state, TENSOR_TYPE* out,               \
                       THLongStorage* sizes,                            \
                       THLongStorage* strides);                         \
    static void resizeAs(THCState* state, TENSOR_TYPE* dst,             \
                         TENSOR_TYPE* src);                             \
    static void squeeze1d(THCState *state, TENSOR_TYPE *dst,            \
                          TENSOR_TYPE *src, int dimension);             \
    static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t);         \
    static ptrdiff_t getNumElements(THCState* state, TENSOR_TYPE* t);        \
    static long getSize(THCState* state, TENSOR_TYPE* t, int dim);      \
    static long getStride(THCState* state, TENSOR_TYPE* t, int dim);    \
    static int getDims(THCState* state, TENSOR_TYPE* t);                \
    static bool isContiguous(THCState* state, TENSOR_TYPE* t);          \
    static bool allContiguous(THCState* state, TENSOR_TYPE** inputs, int numInputs); \
    static int getDevice(THCState* state, TENSOR_TYPE* t);              \
    static bool allSameDevice(THCState* state, TENSOR_TYPE** inputs, int numInputs); \
    static void copyIgnoringOverlaps(THCState* state,                   \
                                     TENSOR_TYPE* dst, TENSOR_TYPE* src); \
    /* Determines if the given tensor has overlapping data points (i.e., */ \
    /* is there more than one index into the tensor that references */  \
    /* the same piece of data)? */                                      \
    static bool overlappingIndices(THCState* state, TENSOR_TYPE* t);    \
    /* Can we use 32 bit math for indexing? */                          \
    static bool canUse32BitIndexMath(THCState* state, TENSOR_TYPE* t);  \
    /* Are all tensors 32-bit indexable? */                             \
    static bool all32BitIndexable(THCState* state, TENSOR_TYPE** inputs, int numInputs); \
  }

TENSOR_UTILS(THCudaByteTensor, unsigned char, long);
TENSOR_UTILS(THCudaCharTensor, char, long);
TENSOR_UTILS(THCudaShortTensor, short, long);
TENSOR_UTILS(THCudaIntTensor, int, long);
TENSOR_UTILS(THCudaLongTensor, long, long);
TENSOR_UTILS(THCudaTensor, float, float);
TENSOR_UTILS(THCudaDoubleTensor, double, double);

#ifdef CUDA_HALF_TENSOR
TENSOR_UTILS(THCudaHalfTensor, half, float);
#endif

#undef TENSOR_UTILS

// Utility function for constructing TensorInfo structs. In this case, the
// two template parameters are:
//
// 1. The TensorType, e.g. THCTensor in generic functions, or THCudaTensor,
// THCudaLongTensor etc.
//
// 2. The IndexType. This is always going to be an unsigned integral value,
// but depending on the size of the Tensor you may select unsigned int,
// unsigned long, unsigned long long etc.
//
// Internally we use the TensorUtils static functions to get the necessary
// dims, sizes, stride etc.
//
// For example, suppose we have a THCudaTensor t, with dim = 2, size = [3, 4],
// stride = [4, 1], offset = 8, and we set our index type to be unsigned int.
// Then we yield a TensorInfo struct templatized with float, unsigned int and
// the following fields:
//
// data is a float* to the underlying storage at position 8
// dims is 2
// sizes is a MAX_CUTORCH_DIMS element array with [3, 4] in its first two positions
// strides is a MAX_CUTORCH_DIMS element array with [4, 1] in its first two positions
//
// TensorInfos can then be passed to CUDA kernels, but we can use the static functions
// defined above to perform Tensor Operations that are appropriate for each
// TensorType.
template <typename TensorType, typename IndexType>
TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>
getTensorInfo(THCState* state, TensorType* t) {
  IndexType sz[MAX_CUTORCH_DIMS];
  IndexType st[MAX_CUTORCH_DIMS];

  int dims = TensorUtils<TensorType>::getDims(state, t);
  for (int i = 0; i < dims; ++i) {
    sz[i] = TensorUtils<TensorType>::getSize(state, t, i);
    st[i] = TensorUtils<TensorType>::getStride(state, t, i);
  }

  return TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>(
    TensorUtils<TensorType>::getData(state, t), dims, sz, st);
}

template <typename T>
struct ScalarNegate {
  static __host__ __device__ T to(const T v) { return -v; }
};

template <typename T>
struct ScalarInv {
  static __host__ __device__ T to(const T v) { return ((T) 1) / v; }
};

#ifdef CUDA_HALF_TENSOR
template <>
struct ScalarNegate<half> {
  static __host__ __device__ half to(const half v) {
#ifdef __CUDA_ARCH__
#ifdef CUDA_HALF_INSTRUCTIONS
    return __hneg(v);
#else
    return __float2half(-__half2float(v));
#endif
#else
#if CUDA_VERSION < 9000
    half out = v;
#else
    __half_raw out = __half_raw(v);
#endif
    out.x ^= 0x8000; // toggle sign bit
    return out;
#endif
  }
};

template <>
struct ScalarInv<half> {
  static __host__ __device__ half to(const half v) {
#ifdef __CUDA_ARCH__
    return __float2half(1.0f / __half2float(v));
#else
    float fv = THC_half2float(v);
    fv = 1.0f / fv;
    return THC_float2half(fv);
#endif
  }
};

inline bool operator==(half a, half b) {
#if CUDA_VERSION < 9000
  return a.x == b.x;
#else
  __half_raw araw, braw;
  araw = __half_raw(a);
  braw = __half_raw(b);
  return araw.x == braw.x;
#endif
}

inline bool operator!=(half a, half b) {
#if CUDA_VERSION < 9000
    return a.x != b.x;
#else
  __half_raw araw, braw;
  araw = __half_raw(a);
  braw = __half_raw(b);
  return araw.x != braw.x;
#endif
}

#endif // CUDA_HALF_TENSOR

#endif // THC_TENSOR_TYPE_UTILS_INC