diff options
author | soumith <soumith@fb.com> | 2016-11-01 05:25:51 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-11-01 05:26:08 +0300 |
commit | 3211523f59ed8182a1b57783a360162ef4b435d7 (patch) | |
tree | 2d2529679b977b4bc1e104eacf605df3283c13ad | |
parent | 398eceb46f3895c6ee7680ef2bfd1a14567801c4 (diff) |
adding multiple types for pow, trace, diag, tril, triu
-rw-r--r-- | TensorMath.lua | 35 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cu | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cuh | 26 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 5 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 32 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPairwise.cu | 146 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 15 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 44 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 5 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPairwise.cu | 66 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 18 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 2 | ||||
-rw-r--r-- | test/test.lua | 73 |
13 files changed, 280 insertions, 188 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 6163925..7f6292d 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -907,6 +907,16 @@ for k, Tensor_ in pairs(handledTypenames) do end + wrap("pow", + cname("pow"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=Tensor, method={default=1}}, + {name=real}}, + cname("tpow"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name = real}, + {name=Tensor, method={default=1}}}) + wrap("norm", cname("normall"), {{name=Tensor}, @@ -938,6 +948,31 @@ for k, Tensor_ in pairs(handledTypenames) do {name="boolean", default=false}}) end + wrap("tril", + cname("tril"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="int", default=0}}) + + wrap("triu", + cname("triu"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="int", default=0}}) + + wrap("diag", + cname("diag"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name="int", default=0}}) + + wrap("trace", + cname("trace"), + {{name=Tensor}, + {name=accreal, creturned=true}}) + + + wrap("lerp", cname("lerp"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index 8d3d95e..190b076 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -3,6 +3,7 @@ #include "THCTensorCopy.h" #include "THCApply.cuh" #include "THCNumerics.cuh" +#include "THCTensorMath.cuh" #include <thrust/copy.h> #include <thrust/count.h> diff --git a/lib/THC/THCTensorMath.cuh b/lib/THC/THCTensorMath.cuh new file mode 100644 index 0000000..1224c32 --- /dev/null +++ b/lib/THC/THCTensorMath.cuh @@ -0,0 +1,26 @@ +#ifndef THC_TENSORMATH_CUH +#define THC_TENSORMATH_CUH + +// Copy the kth diagonal of a matrix B to a vector A. +template <typename T> +__global__ void THCTensor_copyFromDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideA) { + for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < size; + linearIndex += gridDim.x * blockDim.x) { + const ptrdiff_t bOffset = start + strideSum * linearIndex; + a[strideA * linearIndex] = b[bOffset]; + } +} + +// Copy vector B to the kth diagonal of a matrix A +template <typename T> +__global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideB) { + for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < size; + linearIndex += gridDim.x * blockDim.x) { + const ptrdiff_t aOffset = start + strideSum * linearIndex; + a[aOffset] = b[strideB * linearIndex]; + } +} + +#endif diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 3d71469..2844f90 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -37,11 +37,6 @@ #include "generic/THCTensorSort.h" #include "THCGenerateAllTypes.h" -THC_API void THCudaTensor_tril(THCState *state, THCudaTensor *self, THCudaTensor *src, long k); -THC_API void THCudaTensor_triu(THCState *state, THCudaTensor *self, THCudaTensor *src, long k); -THC_API void THCudaTensor_diag(THCState *state, THCudaTensor *self, THCudaTensor *src, long k); -THC_API float THCudaTensor_trace(THCState *state, THCudaTensor *self); - THC_API void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); THC_API void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index d1fe328..2b80977 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -16,38 +16,6 @@ #include <thrust/system/cuda/execution_policy.h> #endif -struct TensorTPowOp { - TensorTPowOp(float v) : val(v) {} - - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = powf(val, *in); - } - - __device__ __forceinline__ void operator()(float* v) { - *v = powf(val, *v); - } - - const float val; -}; - -void THCudaTensor_tpow(THCState *state, THCudaTensor *self_, float value, THCudaTensor *src) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); - if (self_ == src) { - if (!THC_pointwiseApply1(state, self_, TensorTPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, TensorTPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - struct TensorATan2Op { __device__ __forceinline__ void operator()(float* out, float* a, float* b) { *out = atan2f(*a, *b); diff --git a/lib/THC/THCTensorMathPairwise.cu b/lib/THC/THCTensorMathPairwise.cu index 2695f2d..a02d19e 100644 --- a/lib/THC/THCTensorMathPairwise.cu +++ b/lib/THC/THCTensorMathPairwise.cu @@ -239,12 +239,12 @@ struct TensorDivConstantOp<half> { }; #endif // CUDA_HALF_TENSOR -template <int Upper> +template <typename T, int Upper> struct TensorTriOp { - TensorTriOp(float *start_, long stride0_, long stride1_, long k_) + TensorTriOp(T *start_, long stride0_, long stride1_, long k_) : start(start_), stride0(stride0_), stride1(stride1_), k(k_) {} - __device__ __forceinline__ int mask(float *in) { + __device__ __forceinline__ int mask(T *in) { ptrdiff_t n = in - start; long row, col; if (stride0 > stride1) @@ -261,148 +261,18 @@ struct TensorTriOp { return Upper ? (col - row >= k) : (col - row <= k); } - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = mask(in) ? *in : 0; + __device__ __forceinline__ void operator()(T* out, T* in) { + *out = mask(in) ? *in : ScalarConvert<int, T>::to(0); } - __device__ __forceinline__ void operator()(float* v) { + __device__ __forceinline__ void operator()(T* v) { if (!mask(v)) - *v = 0; + *v = ScalarConvert<int, T>::to(0); } - const float *start; + const T *start; const long stride0, stride1, k; }; -void THCudaTensor_tril(THCState *state, THCudaTensor *self_, THCudaTensor *src_, long k) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); - - THCudaTensor *src = src_; - if (self_ == src_) - src = THCudaTensor_newContiguous(state, src_); - - long stride0 = src->stride[0]; - long stride1 = src->stride[1]; - float *start = THCudaTensor_data(state, src) + src->storageOffset; - - TensorTriOp<0> op(start, stride0, stride1, k); - - if (self_ == src_) { - if (!THC_pointwiseApply1(state, src, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - if (self_ == src_) - THCudaTensor_freeCopyTo(state, src, src_); - - THCudaCheck(cudaGetLastError()); -} - -void THCudaTensor_triu(THCState *state, THCudaTensor *self_, THCudaTensor *src_, long k) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); - - THCudaTensor *src = src_; - if (self_ == src_) - src = THCudaTensor_newContiguous(state, src_); - - long stride0 = src->stride[0]; - long stride1 = src->stride[1]; - float *start = THCudaTensor_data(state, src) + src->storageOffset; - - TensorTriOp<1> op(start, stride0, stride1, k); - - if (self_ == src_) { - if (!THC_pointwiseApply1(state, src, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - if (self_ == src_) - THCudaTensor_freeCopyTo(state, src, src_); - - THCudaCheck(cudaGetLastError()); -} - #include "generic/THCTensorMathPairwise.cu" #include "THCGenerateAllTypes.h" - -// Copy the kth diagonal of a matrix B to a vector A. -__global__ void THCudaTensor_copyFromDiagonal(float* a, float* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideA) { - for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < size; - linearIndex += gridDim.x * blockDim.x) { - const ptrdiff_t bOffset = start + strideSum * linearIndex; - a[strideA * linearIndex] = b[bOffset]; - } -} - -// Copy vector B to the kth diagonal of a matrix A -__global__ void THCudaTensor_copyToDiagonal(float* a, float* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideB) { - for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < size; - linearIndex += gridDim.x * blockDim.x) { - const ptrdiff_t aOffset = start + strideSum * linearIndex; - a[aOffset] = b[strideB * linearIndex]; - } -} - -void THCudaTensor_diag(THCState *state, THCudaTensor *self_, THCudaTensor *src_, long k){ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - int nDimension = THCudaTensor_nDimension(state, src_); - THArgCheck((nDimension == 2) || (nDimension == 1), 1, "expected a matrix or a vector"); - if (nDimension == 2) { - long stride0 = THCudaTensor_stride(state, src_, 0); - long stride1 = THCudaTensor_stride(state, src_, 1); - long size0 = THCudaTensor_size(state, src_, 0); - long size1 = THCudaTensor_size(state, src_, 1); - long size = (k > 0) ? min((long long)size0, (long long)size1 - k) : min((long long)size0 + k, (long long)size1); - THCudaTensor_resize1d(state, self_, size); - long strideSelf = THCudaTensor_stride(state, self_, 0); - const dim3 threads(min((long long)THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock, (long long)size)); - dim3 grid(min((long long)1024, (long long)THCCeilDiv(size, (long)threads.x))); - long start = (k >= 0 ? k * stride1 : -k * stride0); - THCudaTensor_copyFromDiagonal<<<grid, threads, 0, THCState_getCurrentStream(state)>>> - (THCudaTensor_data(state, self_), THCudaTensor_data(state, src_), start, size, stride0 + stride1, strideSelf); - } else { - ptrdiff_t totalElements = THCudaTensor_nElement(state, src_); - ptrdiff_t size = (k > 0) ? totalElements + k : totalElements - k; - long strideSrc = THCudaTensor_stride(state, src_, 0); - THCudaTensor_resize2d(state, self_, size, size); - THCudaTensor_zero(state, self_); - long stride0 = THCudaTensor_stride(state, self_, 0); - long stride1 = THCudaTensor_stride(state, self_, 1); - const dim3 threads(min((long long)THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock, (long long)size)); - dim3 grid(min((long long)1024, (long long)THCCeilDiv(size, (ptrdiff_t)threads.x))); - ptrdiff_t start = (k >= 0 ? k * stride1 : -k * stride0); - THCudaTensor_copyToDiagonal<<<grid, threads, 0, THCState_getCurrentStream(state)>>> - (THCudaTensor_data(state, self_), THCudaTensor_data(state, src_), start, totalElements, stride0 + stride1, strideSrc); - } - THCudaCheck(cudaGetLastError()); -} - -float THCudaTensor_trace(THCState *state, THCudaTensor *src_) { - THAssert(THCudaTensor_checkGPU(state, 1, src_)); - THArgCheck((src_->nDimension == 2), 1, "expected a matrix"); - THCudaTensor *diag = THCudaTensor_new(state); - THCudaTensor_diag(state, diag, src_, 0); - float trace = THCudaTensor_sumall(state, diag); - THCudaTensor_free(state, diag); - return trace; -} diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index c52e082..40d35be 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -318,6 +318,21 @@ struct TensorPowOp<half> { }; #endif // CUDA_HALF_TENSOR +template<typename T> +struct TensorTPowOp { + TensorTPowOp(T v) : val(v) {} + + __device__ __forceinline__ void operator()(T* out, T* in) { + *out = THCNumerics<T>::pow(val, *in); + } + + __device__ __forceinline__ void operator()(T* v) { + *v = THCNumerics<T>::pow(val, *v); + } + + const T val; +}; + template <typename T> struct TensorCPowOp { __device__ __forceinline__ void operator()(T* out, T* in) { diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 67243cf..cc15039 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -202,4 +202,48 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCudaCheck(cudaGetLastError()); } +void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, long k){ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src_)); + int nDimension = THCTensor_(nDimension)(state, src_); + THArgCheck((nDimension == 2) || (nDimension == 1), 1, "expected a matrix or a vector"); + if (nDimension == 2) { + long stride0 = THCTensor_(stride)(state, src_, 0); + long stride1 = THCTensor_(stride)(state, src_, 1); + long size0 = THCTensor_(size)(state, src_, 0); + long size1 = THCTensor_(size)(state, src_, 1); + long size = (k > 0) ? min((long long)size0, (long long)size1 - k) : min((long long)size0 + k, (long long)size1); + THCTensor_(resize1d)(state, self_, size); + long strideSelf = THCTensor_(stride)(state, self_, 0); + const dim3 threads(min((long long)THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock, (long long)size)); + dim3 grid(min((long long)1024, (long long)THCCeilDiv(size, (long)threads.x))); + long start = (k >= 0 ? k * stride1 : -k * stride0); + THCTensor_copyFromDiagonal<real><<<grid, threads, 0, THCState_getCurrentStream(state)>>> + (THCTensor_(data)(state, self_), THCTensor_(data)(state, src_), start, size, stride0 + stride1, strideSelf); + } else { + ptrdiff_t totalElements = THCTensor_(nElement)(state, src_); + ptrdiff_t size = (k > 0) ? totalElements + k : totalElements - k; + long strideSrc = THCTensor_(stride)(state, src_, 0); + THCTensor_(resize2d)(state, self_, size, size); + THCTensor_(zero)(state, self_); + long stride0 = THCTensor_(stride)(state, self_, 0); + long stride1 = THCTensor_(stride)(state, self_, 1); + const dim3 threads(min((long long)THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock, (long long)size)); + dim3 grid(min((long long)1024, (long long)THCCeilDiv(size, (ptrdiff_t)threads.x))); + ptrdiff_t start = (k >= 0 ? k * stride1 : -k * stride0); + THCTensor_copyToDiagonal<real><<<grid, threads, 0, THCState_getCurrentStream(state)>>> + (THCTensor_(data)(state, self_), THCTensor_(data)(state, src_), start, totalElements, stride0 + stride1, strideSrc); + } + THCudaCheck(cudaGetLastError()); +} + +accreal THCTensor_(trace)(THCState *state, THCTensor *src_) { + THAssert(THCTensor_(checkGPU)(state, 1, src_)); + THArgCheck((src_->nDimension == 2), 1, "expected a matrix"); + THCTensor *diag = THCTensor_(new)(state); + THCTensor_(diag)(state, diag, src_, 0); + accreal trace = THCTensor_(sumall)(state, diag); + THCTensor_(free)(state, diag); + return trace; +} + #endif diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 0335a62..2b8f563 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -13,5 +13,10 @@ THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); +THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, long k); +THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, long k); +THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, long k); +THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); + #endif diff --git a/lib/THC/generic/THCTensorMathPairwise.cu b/lib/THC/generic/THCTensorMathPairwise.cu index 5ce7a38..55ad945 100644 --- a/lib/THC/generic/THCTensorMathPairwise.cu +++ b/lib/THC/generic/THCTensorMathPairwise.cu @@ -80,4 +80,70 @@ THCTensor_(div)(THCState* state, THCTensor *self_, THCTensor *src_, real value) THCudaCheck(cudaGetLastError()); } +void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, long k) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src_)); + THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); + + THCTensor *src = src_; + if (self_ == src_) + src = THCTensor_(newContiguous)(state, src_); + + long stride0 = src->stride[0]; + long stride1 = src->stride[1]; + real *start = THCTensor_(data)(state, src) + src->storageOffset; + + TensorTriOp<real, 0> op(start, stride0, stride1, k); + + if (self_ == src_) { + if (!THC_pointwiseApply1(state, src, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCTensor_(resizeAs)(state, self_, src); + + if (!THC_pointwiseApply2(state, self_, src, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + if (self_ == src_) + THCTensor_(freeCopyTo)(state, src, src_); + + THCudaCheck(cudaGetLastError()); +} + +void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, long k) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src_)); + THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); + + THCTensor *src = src_; + if (self_ == src_) + src = THCTensor_(newContiguous)(state, src_); + + long stride0 = src->stride[0]; + long stride1 = src->stride[1]; + real *start = THCTensor_(data)(state, src) + src->storageOffset; + + TensorTriOp<real, 1> op(start, stride0, stride1, k); + + if (self_ == src_) { + if (!THC_pointwiseApply1(state, src, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCTensor_(resizeAs)(state, self_, src); + + if (!THC_pointwiseApply2(state, self_, src, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + if (self_ == src_) + THCTensor_(freeCopyTo)(state, src, src_); + + THCudaCheck(cudaGetLastError()); +} + #endif diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 2638504..91c166f 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -173,6 +173,24 @@ void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real val THCudaCheck(cudaGetLastError()); } +void THCTensor_(tpow)(THCState *state, THCTensor *self_, real value, THCTensor *src) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); + if (self_ == src) { + if (!THC_pointwiseApply1(state, self_, TensorTPowOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCTensor_(resizeAs)(state, self_, src); + + if (!THC_pointwiseApply2(state, self_, src, TensorTPowOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w) { diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index 7a9d128..6e20a30 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -19,7 +19,7 @@ THC_API void THCTensor_(atan)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(atan2)(THCState *state, THCTensor *r_, THCTensor *tx, THCTensor *ty); THC_API void THCTensor_(tanh)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, real value); -THC_API void THCTensor_(tpow)(THCState *state, THCTensor *self, float value, THCTensor *src); +THC_API void THCTensor_(tpow)(THCState *state, THCTensor *self, real value, THCTensor *src); THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(rsqrt)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(ceil)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/test/test.lua b/test/test.lua index b5eef6b..8436b64 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1258,27 +1258,39 @@ function test.diag() local sz2 = chooseInt(minsize, maxsize) local k = chooseInt(-minsize, minsize) local x = torch.FloatTensor():rand(sz1, sz2) - compareFloatAndCudaTensorArgs(x, 'diag') - compareFloatAndCudaTensorArgs(x, 'diag', k) + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'diag') + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'diag', k) + end checkMultiDevice(x, 'diag') checkMultiDevice(x, 'diag', k) local y = torch.FloatTensor():rand(sz1) - compareFloatAndCudaTensorArgs(y, 'diag') - compareFloatAndCudaTensorArgs(y, 'diag', k) + for _, typename in ipairs(float_typenames) do + local y = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'diag') + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'diag', k) + end checkMultiDevice(y, 'diag') checkMultiDevice(y, 'diag', k) -- test non-contiguous cases local x1 = createTestTensorWithSizes(true, true, {sz1, sz2}); - compareFloatAndCudaTensorArgs(x1, 'diag') - compareFloatAndCudaTensorArgs(x1, 'diag', k) + for _, typename in ipairs(float_typenames) do + local x1 = x1:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x1, 'diag') + compareCPUAndCUDATypeTensorArgs(typename, nil, x1, 'diag', k) + end checkMultiDevice(x1, 'diag') checkMultiDevice(x1, 'diag', k) local y1 = createTestTensorWithSizes(true, true, {sz1}); - compareFloatAndCudaTensorArgs(y1, 'diag') - compareFloatAndCudaTensorArgs(y1, 'diag', k) + for _, typename in ipairs(float_typenames) do + local y1 = y1:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, y1, 'diag') + compareCPUAndCUDATypeTensorArgs(typename, nil, y1, 'diag', k) + end checkMultiDevice(y1, 'diag') checkMultiDevice(y1, 'diag', k) end @@ -1287,10 +1299,35 @@ function test.trace() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) local x = torch.FloatTensor():rand(sz1, sz2) - compareFloatAndCuda(x, 'trace') + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'trace') + end checkMultiDevice(x, 'trace') end +function test.tril() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local x = torch.FloatTensor():rand(sz1, sz2) + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'tril') + end + checkMultiDevice(x, 'tril') +end + +function test.triu() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local x = torch.FloatTensor():rand(sz1, sz2) + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'triu') + end + checkMultiDevice(x, 'triu') +end + -- Test element-wise unary operators with both one and two arguments. local function testUnary1(fnp, types, tensor) local fn = fnp[1] @@ -1404,7 +1441,11 @@ function test.pow1() local sz2 = chooseInt(minsize, maxsize) local x = torch.FloatTensor():rand(sz1, sz2) local pow = torch.uniform(minvalue,maxvalue) - compareFloatAndCudaTensorArgs(x, 'pow', pow) + for k, typename in ipairs(float_typenames) do + local ctype = t2cpu[typename] + local x = x:type(ctype) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'pow', pow) + end checkMultiDevice(x, 'pow', pow) end @@ -1414,7 +1455,11 @@ function test.pow2() local x = torch.FloatTensor():rand(sz1, sz2) local y = torch.FloatTensor() local pow = torch.uniform(minvalue,maxvalue) - compareFloatAndCudaTensorArgs(y, 'pow', x, pow) + for k, typename in ipairs(float_typenames) do + local ctype = t2cpu[typename] + local x, y = x:type(ctype), y:type(ctype) + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'pow', x, pow) + end checkMultiDevice(y, 'pow', x, pow) end @@ -1424,7 +1469,11 @@ function test.powExponentTensor() local pow = torch.uniform(minvalue,maxvalue) local x = torch.FloatTensor():rand(sz1, sz2) local y = torch.FloatTensor() - compareFloatAndCudaTensorArgs(y, 'pow', pow, x) + for k, typename in ipairs(float_typenames) do + local ctype = t2cpu[typename] + local x, y = x:type(ctype), y:type(ctype) + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'pow', pow, x) + end checkMultiDevice(y, 'pow', pow, x) end |