diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-12 21:23:34 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-12 21:23:34 +0300 |
commit | e00f7d4c0f70e3583ff0a5359095ad7afcaa7009 (patch) | |
tree | 83419daaa8a0bb50ddf7284cc056869ca72fcd07 | |
parent | bcbb427c4d7322a4e88f867a19193940677dfabc (diff) | |
parent | 01d7a63ab68d6993730104739cdae17bbdc0255e (diff) |
Merge pull request #628 from killeent/more-documentation
TensorInfo related code documentation
-rw-r--r-- | lib/THC/THCTensorInfo.cuh | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorTypeUtils.cuh | 26 |
2 files changed, 27 insertions, 0 deletions
diff --git a/lib/THC/THCTensorInfo.cuh b/lib/THC/THCTensorInfo.cuh index 5347116..3389e61 100644 --- a/lib/THC/THCTensorInfo.cuh +++ b/lib/THC/THCTensorInfo.cuh @@ -247,6 +247,7 @@ struct IndexToOffset { } }; +// For contiguous tensors, the offset = index template <typename T, typename IndexType> struct IndexToOffset<T, IndexType, -2> { static inline __host__ __device__ IndexType diff --git a/lib/THC/THCTensorTypeUtils.cuh b/lib/THC/THCTensorTypeUtils.cuh index 81051f7..6889faf 100644 --- a/lib/THC/THCTensorTypeUtils.cuh +++ b/lib/THC/THCTensorTypeUtils.cuh @@ -80,6 +80,32 @@ TENSOR_UTILS(THCudaHalfTensor, half, float); #undef TENSOR_UTILS +// Utility function for constructing TensorInfo structs. In this case, the +// two template parameters are: +// +// 1. The TensorType, e.g. THCTensor in generic functions, or THCudaTensor, +// THCudaLongTensor etc. +// +// 2. The IndexType. This is always going to be an unsigned integral value, +// but depending on the size of the Tensor you may select unsigned int, +// unsigned long, unsigned long long etc. +// +// Internally we use the TensorUtils static functions to get the necessary +// dims, sizes, stride etc. +// +// For example, suppose we have a THCudaTensor t, with dim = 2, size = [3, 4], +// stride = [4, 1], offset = 8, and we set our index type to be unsigned int. +// Then we yield a TensorInfo struct templatized with float, unsigned int and +// the following fields: +// +// data is a float* to the underlying storage at position 8 +// dims is 2 +// sizes is a MAX_CUTORCH_DIMS element array with [3, 4] in its first two positions +// strides is a MAX_CUTORCH_DIMS element array with [4, 1] in its first two positions +// +// TensorInfos can then be passed to CUDA kernels, but we can use the static functions +// defined above to perform Tensor Operations that are appropriate for each +// TensorType. template <typename TensorType, typename IndexType> TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType> getTensorInfo(THCState* state, TensorType* t) { |