diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-20 01:19:13 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-20 01:19:13 +0300 |
commit | 21ad069fd0174acbcf956c82895dc8c4f1fa39ce (patch) | |
tree | 94d79f50c2c535236c15bf0d2b36abd0b55f79c8 /lib/THC | |
parent | 8c3f15a9497bb23519ee3d60a765d223e001ab13 (diff) | |
parent | 0d6ae18ec6f34b38c6bbd8517389c0a7950f2964 (diff) |
Merge pull request #558 from gchanan/genericDeviceTensorUtils
Add generic type support for toDeviceTensor.
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/CMakeLists.txt | 1 | ||||
-rw-r--r-- | lib/THC/THCDeviceTensorUtils-inl.cuh | 34 | ||||
-rw-r--r-- | lib/THC/THCDeviceTensorUtils.cuh | 23 | ||||
-rw-r--r-- | lib/THC/generic/THCDeviceTensorUtils.cu | 55 |
4 files changed, 60 insertions, 53 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index ddafefe..edc0af0 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -275,4 +275,5 @@ INSTALL(FILES generic/THCTensorIndex.cu generic/THCTensorSort.h generic/THCTensorSort.cu + generic/THCDeviceTensorUtils.cu DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic") diff --git a/lib/THC/THCDeviceTensorUtils-inl.cuh b/lib/THC/THCDeviceTensorUtils-inl.cuh index 26c1bb8..42aab34 100644 --- a/lib/THC/THCDeviceTensorUtils-inl.cuh +++ b/lib/THC/THCDeviceTensorUtils-inl.cuh @@ -1,37 +1,3 @@ -#include <limits> - -template <typename T, int Dim, - typename IndexT, template <typename U> class PtrTraits> -THCDeviceTensor<T, Dim, IndexT, PtrTraits> -toDeviceTensor(THCState* state, THCudaTensor* t) { - if (Dim != THCudaTensor_nDimension(state, t)) { - THError("THCudaTensor dimension mismatch"); - } - - // Determine the maximum offset into the tensor achievable; `IndexT` - // must be smaller than this type in order to use it. - ptrdiff_t maxOffset = 0; - IndexT sizes[Dim]; - IndexT strides[Dim]; - - for (int i = 0; i < Dim; ++i) { - long size = THCudaTensor_size(state, t, i); - long stride = THCudaTensor_stride(state, t, i); - - maxOffset += (size - 1) * stride; - - sizes[i] = (IndexT) size; - strides[i] = (IndexT) stride; - } - - if (maxOffset > std::numeric_limits<IndexT>::max()) { - THError("THCudaTensor sizes too large for THCDeviceTensor conversion"); - } - - return THCDeviceTensor<T, Dim, IndexT, PtrTraits>( - THCudaTensor_data(state, t), sizes, strides); -} - namespace detail { // Add a layer of SFINAE to support static_assert diff --git a/lib/THC/THCDeviceTensorUtils.cuh b/lib/THC/THCDeviceTensorUtils.cuh index 0e4ae82..472c6e1 100644 --- a/lib/THC/THCDeviceTensorUtils.cuh +++ b/lib/THC/THCDeviceTensorUtils.cuh @@ -3,25 +3,7 @@ #include "THCDeviceTensor.cuh" #include "THCTensor.h" - -/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will -/// error if the dimensionality does not match exactly. -template <typename T, int Dim, - typename IndexT, template <typename U> class PtrTraits> -THCDeviceTensor<T, Dim, IndexT, PtrTraits> -toDeviceTensor(THCState* state, THCudaTensor* t); - -template <typename T, int Dim, typename IndexT> -THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits> -toDeviceTensor(THCState* state, THCudaTensor* t) { - return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t); -} - -template <typename T, int Dim> -THCDeviceTensor<T, Dim, int, DefaultPtrTraits> -toDeviceTensor(THCState* state, THCudaTensor* t) { - return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t); -} +#include <limits> /// Constructs a DeviceTensor initialized from a THCudaTensor by /// upcasting or downcasting the tensor to that of a different @@ -43,6 +25,9 @@ toDeviceTensorCast(THCState* state, THCudaTensor* t) { return toDeviceTensorCast<T, Dim, int, DefaultPtrTraits>(state, t); } +#include "generic/THCDeviceTensorUtils.cu" +#include "THCGenerateAllTypes.h" + #include "THCDeviceTensorUtils-inl.cuh" #endif // THC_DEVICE_TENSOR_UTILS_INC diff --git a/lib/THC/generic/THCDeviceTensorUtils.cu b/lib/THC/generic/THCDeviceTensorUtils.cu new file mode 100644 index 0000000..db6b5e7 --- /dev/null +++ b/lib/THC/generic/THCDeviceTensorUtils.cu @@ -0,0 +1,55 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/THCDeviceTensorUtils.cu" +#else + +/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will +/// error if the dimensionality does not match exactly. +template <typename T, int Dim, + typename IndexT, template <typename U> class PtrTraits> +THCDeviceTensor<T, Dim, IndexT, PtrTraits> +toDeviceTensor(THCState* state, THCTensor* t); + +template <typename T, int Dim, typename IndexT> +THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits> +toDeviceTensor(THCState* state, THCTensor* t) { + return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t); +} + +template <typename T, int Dim> +THCDeviceTensor<T, Dim, int, DefaultPtrTraits> +toDeviceTensor(THCState* state, THCTensor* t) { + return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t); +} + +template <typename T, int Dim, + typename IndexT, template <typename U> class PtrTraits> +THCDeviceTensor<T, Dim, IndexT, PtrTraits> +toDeviceTensor(THCState* state, THCTensor* t) { + if (Dim != THCTensor_(nDimension)(state, t)) { + THError("THCudaTensor dimension mismatch"); + } + // Determine the maximum offset into the tensor achievable; `IndexT` + // must be smaller than this type in order to use it. + ptrdiff_t maxOffset = 0; + IndexT sizes[Dim]; + IndexT strides[Dim]; + + for (int i = 0; i < Dim; ++i) { + long size = THCTensor_(size)(state, t, i); + long stride = THCTensor_(stride)(state, t, i); + + maxOffset += (size - 1) * stride; + + sizes[i] = (IndexT) size; + strides[i] = (IndexT) stride; + } + + if (maxOffset > std::numeric_limits<IndexT>::max()) { + THError("THCudaTensor sizes too large for THCDeviceTensor conversion"); + } + + return THCDeviceTensor<T, Dim, IndexT, PtrTraits>( + THCTensor_(data)(state, t), sizes, strides); +} + +#endif |