From 01d7a63ab68d6993730104739cdae17bbdc0255e Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Mon, 12 Dec 2016 10:06:13 -0800 Subject: TensorInfo related code documentation --- lib/THC/THCTensorInfo.cuh | 1 + lib/THC/THCTensorTypeUtils.cuh | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/lib/THC/THCTensorInfo.cuh b/lib/THC/THCTensorInfo.cuh index 5347116..3389e61 100644 --- a/lib/THC/THCTensorInfo.cuh +++ b/lib/THC/THCTensorInfo.cuh @@ -247,6 +247,7 @@ struct IndexToOffset { } }; +// For contiguous tensors, the offset = index template struct IndexToOffset { static inline __host__ __device__ IndexType diff --git a/lib/THC/THCTensorTypeUtils.cuh b/lib/THC/THCTensorTypeUtils.cuh index 81051f7..6889faf 100644 --- a/lib/THC/THCTensorTypeUtils.cuh +++ b/lib/THC/THCTensorTypeUtils.cuh @@ -80,6 +80,32 @@ TENSOR_UTILS(THCudaHalfTensor, half, float); #undef TENSOR_UTILS +// Utility function for constructing TensorInfo structs. In this case, the +// two template parameters are: +// +// 1. The TensorType, e.g. THCTensor in generic functions, or THCudaTensor, +// THCudaLongTensor etc. +// +// 2. The IndexType. This is always going to be an unsigned integral value, +// but depending on the size of the Tensor you may select unsigned int, +// unsigned long, unsigned long long etc. +// +// Internally we use the TensorUtils static functions to get the necessary +// dims, sizes, stride etc. +// +// For example, suppose we have a THCudaTensor t, with dim = 2, size = [3, 4], +// stride = [4, 1], offset = 8, and we set our index type to be unsigned int. +// Then we yield a TensorInfo struct templatized with float, unsigned int and +// the following fields: +// +// data is a float* to the underlying storage at position 8 +// dims is 2 +// sizes is a MAX_CUTORCH_DIMS element array with [3, 4] in its first two positions +// strides is a MAX_CUTORCH_DIMS element array with [4, 1] in its first two positions +// +// TensorInfos can then be passed to CUDA kernels, but we can use the static functions +// defined above to perform Tensor Operations that are appropriate for each +// TensorType. template TensorInfo::DataType, IndexType> getTensorInfo(THCState* state, TensorType* t) { -- cgit v1.2.3