Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-01 07:27:15 +0300
committerGitHub <noreply@github.com>2016-11-01 07:27:15 +0300
commit0b33e2cadb731c10df7cb94d87717536fa83c913 (patch)
tree9aea4821f844bd221ec255a0e99fe29cc0ae5230
parent9d26c60981991289553cff786492c4c892f9d15f (diff)
parentae0973f376218d856d5474c1a8b8ef021e9a497a (diff)
Merge pull request #576 from torch/distfix
dist, cumsum, cumprod for multiple types
-rw-r--r--TensorMath.lua16
-rw-r--r--lib/THC/THCTensorMath.h8
-rw-r--r--lib/THC/THCTensorMath2.cu30
-rw-r--r--lib/THC/THCTensorMathReduce.cuh34
-rw-r--r--lib/THC/THCTensorMathScan.cu115
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu24
-rw-r--r--lib/THC/generic/THCTensorMathReduce.h3
-rw-r--r--lib/THC/generic/THCTensorMathScan.cu89
-rw-r--r--lib/THC/generic/THCTensorMathScan.h8
-rw-r--r--test/test.lua25
10 files changed, 207 insertions, 145 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 7f6292d..802565e 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -747,6 +747,14 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor},
{name="index"}})
+ for _, name in ipairs({"cumsum", "cumprod"}) do
+ wrap(name,
+ cname(name),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor},
+ {name="index", default=1}})
+ end
+
wrap("prod",
cname("prodall"),
{{name=Tensor},
@@ -936,6 +944,14 @@ for k, Tensor_ in pairs(handledTypenames) do
{name="index"},
{name=real}})
+ wrap("dist",
+ cname("dist"),
+ {{name=Tensor},
+ {name=Tensor},
+ {name=real, default=2},
+ {name=accreal, creturned=true}})
+
+
for _,name in ipairs({"var", "std"}) do
wrap(name,
cname(name .. "all"),
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 2844f90..759c9a3 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -25,6 +25,9 @@
#include "generic/THCTensorMathCompareT.h"
#include "THCGenerateAllTypes.h"
+#include "generic/THCTensorMathScan.h"
+#include "THCGenerateAllTypes.h"
+
#include "generic/THCTensorMasked.h"
#include "THCGenerateAllTypes.h"
@@ -37,9 +40,6 @@
#include "generic/THCTensorSort.h"
#include "THCGenerateAllTypes.h"
-THC_API void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);
-THC_API void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);
-
// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
@@ -53,8 +53,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
-THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
-
THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 2b80977..9933b7e 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -8,14 +8,6 @@
#include "THCTensorMathReduce.cuh"
#include "THCTensorMathPointwise.cuh"
-#include <thrust/device_ptr.h>
-#include <thrust/transform_reduce.h>
-#include <thrust/functional.h>
-#include <thrust/inner_product.h>
-#if CUDA_VERSION >= 7000
-#include <thrust/system/cuda/execution_policy.h>
-#endif
-
struct TensorATan2Op {
__device__ __forceinline__ void operator()(float* out, float* a, float* b) {
*out = atan2f(*a, *b);
@@ -36,28 +28,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
- self = THCudaTensor_newContiguous(state, self);
- ptrdiff_t size = THCudaTensor_nElement(state, self);
- src = THCudaTensor_newContiguous(state, src);
- thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
- thrust::device_ptr<float> src_data(THCudaTensor_data(state, src));
-
- float result = thrust::inner_product(
-#if CUDA_VERSION >= 7000
- thrust::cuda::par.on(THCState_getCurrentStream(state)),
-#endif
- self_data, self_data+size, src_data, (float) 0,
- thrust::plus<float>(), TensorDistOp<float>(value));
-
- THCudaTensor_free(state, src);
- THCudaTensor_free(state, self);
-
- return pow(result, (float)1.0/value);
-}
-
void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size)
{
THAssert(THCudaTensor_checkGPU(state, 1, r_));
diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh
index db2e424..77f06ab 100644
--- a/lib/THC/THCTensorMathReduce.cuh
+++ b/lib/THC/THCTensorMathReduce.cuh
@@ -7,6 +7,12 @@
#include "THCReduce.cuh"
#include "THCReduceAll.cuh"
#include <thrust/functional.h>
+#include <thrust/device_ptr.h>
+#include <thrust/transform_reduce.h>
+#include <thrust/inner_product.h>
+#if CUDA_VERSION >= 7000
+#include <thrust/system/cuda/execution_policy.h>
+#endif
// Reduction operators that support `half`, unlike Thrust
template <typename InT, typename AccT>
@@ -239,19 +245,21 @@ struct TensorNormOp<half, StaticExp>
};
#endif
-template <typename T>
+template <typename Tacc, typename T>
struct TensorDistOp
{
- TensorDistOp(T exp) : exponent(exp) {}
+ TensorDistOp(Tacc exp) : exponent(exp) {}
- __host__ __device__ T operator()(T x, T y) const {
- return THCNumerics<T>::pow(
- THCNumerics<T>::abs(THCNumerics<T>::sub(x, y)),
+ __host__ __device__ Tacc operator()(T x, T y) const {
+ Tacc xr = ScalarConvert<T, Tacc>::to(x);
+ Tacc yr = ScalarConvert<T, Tacc>::to(y);
+ return THCNumerics<Tacc>::pow(
+ THCNumerics<Tacc>::abs(THCNumerics<Tacc>::sub(xr, yr)),
exponent
);
}
- const T exponent;
+ const Tacc exponent;
};
#include <thrust/functional.h>
@@ -664,4 +672,18 @@ struct MinValuePair {
}
};
+template <typename T>
+struct AddOp {
+ __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
+ return THCNumerics<T>::add(lhs, rhs);
+ }
+};
+
+template <typename T>
+struct MulOp {
+ __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
+ return THCNumerics<T>::mul(lhs, rhs);
+ }
+};
+
#endif // THC_TENSORMATH_REDUCE_CUH
diff --git a/lib/THC/THCTensorMathScan.cu b/lib/THC/THCTensorMathScan.cu
index a37d55d..3345e25 100644
--- a/lib/THC/THCTensorMathScan.cu
+++ b/lib/THC/THCTensorMathScan.cu
@@ -4,8 +4,8 @@
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCReduce.cuh"
-
-#include <thrust/functional.h>
+#include "THCNumerics.cuh"
+#include "THCTensorMathReduce.cuh"
/* Perform an inclusive scan along an outer dimension of a tensor.
*
@@ -18,16 +18,16 @@
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
-template<class BinaryOp>
-__global__ void THCudaTensor_kernel_scanOuterDim(float *tgt_, float *src_,
+template<typename T, class BinaryOp>
+__global__ void THCTensor_kernel_scanOuterDim(T *tgt_, T *src_,
unsigned num_orows, unsigned num_irows, unsigned row_size,
- float init, BinaryOp binary_op)
+ T init, BinaryOp binary_op)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
- float *src = src_ + orow * row_size * num_irows + irow;
- float *tgt = tgt_ + orow * row_size * num_irows + irow;
- float acc = init;
+ T *src = src_ + orow * row_size * num_irows + irow;
+ T *tgt = tgt_ + orow * row_size * num_irows + irow;
+ T acc = init;
for (unsigned col = 0; col < row_size; ++col) {
acc = binary_op(acc, *src);
@@ -40,36 +40,6 @@ __global__ void THCudaTensor_kernel_scanOuterDim(float *tgt_, float *src_,
}
}
-template<class BinaryOp>
-__host__ void THCudaTensor_scanOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension,
- float init, BinaryOp binary_op)
-{
- unsigned ndim = THCudaTensor_nDimension(state, src);
- // Treat all outer dimensions (i.e. dim < dimension) as one.
- unsigned num_orows = 1;
- for (long dim = 0; dim < dimension; dim++) {
- num_orows *= THCudaTensor_size(state, src, dim);
- }
- unsigned row_size = THCudaTensor_size(state, src, dimension);
- // Treat all inner dimensions (i.e. dim > dimension) as one.
- unsigned num_irows = 1;
- for (unsigned dim = dimension + 1; dim < ndim; dim++) {
- num_irows *= THCudaTensor_size(state, src, dim);
- }
-
- dim3 threads(min(512, num_irows));
- unsigned maxGridDim = 1024;
- dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
-
- THCudaTensor_kernel_scanOuterDim<<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
- THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size, init, binary_op);
- cudaError errcode = cudaGetLastError();
- if (errcode != cudaSuccess) {
- THError(cudaGetErrorString(errcode));
- }
-}
-
-
/* Perform an inclusive scan along the innermost dimension of a tensor.
*
* - num_rows is the size of the flattened outer dimensions;
@@ -80,23 +50,23 @@ __host__ void THCudaTensor_scanOuterDim(THCState *state, THCudaTensor *tgt, THCu
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
-template<int num_threads_x, int num_threads_y, class BinaryFunction>
-__global__ void THCudaTensor_kernel_scanInnermostDim(float *tgt_, float *src_,
+template<typename T, int num_threads_x, int num_threads_y, class BinaryFunction>
+__global__ void THCTensor_kernel_scanInnermostDim(T *tgt_, T *src_,
unsigned num_rows, unsigned row_size,
- float init, BinaryFunction binary_op)
+ T init, BinaryFunction binary_op)
{
- __shared__ float sbuf[num_threads_y][2 * num_threads_x];
+ __shared__ T sbuf[num_threads_y][2 * num_threads_x];
- float* row_buf = sbuf[threadIdx.y];
+ T* row_buf = sbuf[threadIdx.y];
for (unsigned block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
unsigned row = block_row + threadIdx.y;
- float block_total = init;
+ T block_total = init;
- float *row_src = src_ + row * row_size;
- float *row_tgt = tgt_ + row * row_size;
+ T *row_src = src_ + row * row_size;
+ T *row_tgt = tgt_ + row * row_size;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
@@ -153,54 +123,5 @@ __global__ void THCudaTensor_kernel_scanInnermostDim(float *tgt_, float *src_,
}
}
-template<class BinaryFunction>
-__host__ void THCudaTensor_scanInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, float init, BinaryFunction binary_op)
-{
- unsigned ndim = THCudaTensor_nDimension(state, src);
- // Treat all outer dimensions as a single dimension.
- unsigned num_rows = 1;
- for (unsigned dim = 0; dim < ndim - 1; dim++) {
- num_rows *= THCudaTensor_size(state, src, dim);
- }
- unsigned row_size = THCudaTensor_size(state, src, ndim - 1);
-
- dim3 threads(16, 32);
- dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y)));
-
- THCudaTensor_kernel_scanInnermostDim<16, 32><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
- THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size, init, binary_op);
- cudaError errcode = cudaGetLastError();
- if (errcode != cudaSuccess) {
- THError(cudaGetErrorString(errcode));
- }
-}
-
-template<class BinaryFunction>
-void THCudaTensor_scanDim(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, float init, BinaryFunction binary_op)
-{
- THCudaTensor_resizeAs(state, self_, src);
-
- THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
- src = THCudaTensor_newContiguous(state, src);
-
- if (dimension == THCudaTensor_nDimension(state, src) - 1) {
- THCudaTensor_scanInnermostDim(state, self, src, init, binary_op);
- } else {
- THCudaTensor_scanOuterDim(state, self, src, dimension, init, binary_op);
- }
-
- THCudaTensor_free(state, src);
- THCudaTensor_freeCopyTo(state, self, self_);
-}
-
-void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
- return THCudaTensor_scanDim(state, self, src, dimension, 0.0f, thrust::plus<float>());
-}
-
-void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
- return THCudaTensor_scanDim(state, self, src, dimension, 1.0f, thrust::multiplies<float>());
-}
+#include "generic/THCTensorMathScan.cu"
+#include "THCGenerateAllTypes.h"
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index 1e21d03..a8184b7 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -219,6 +219,30 @@ THCTensor_(normall)(THCState *state, THCTensor *self, real value)
return result;
}
+accreal THCTensor_(dist)(THCState *state, THCTensor *self,
+ THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ self = THCTensor_(newContiguous)(state, self);
+ ptrdiff_t size = THCTensor_(nElement)(state, self);
+ src = THCTensor_(newContiguous)(state, src);
+ thrust::device_ptr<real> self_data(THCTensor_(data)(state, self));
+ thrust::device_ptr<real> src_data(THCTensor_(data)(state, src));
+
+ accreal result = thrust::inner_product(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ self_data, self_data+size, src_data, ScalarConvert<int, accreal>::to(0),
+ thrust::plus<accreal>(),
+ TensorDistOp<accreal, real>(ScalarConvert<real, accreal>::to(value)));
+
+ THCTensor_(free)(state, src);
+ THCTensor_(free)(state, self);
+
+ return THCNumerics<accreal>::pow(result, 1.0 / ScalarConvert<real, accreal>::to(value));
+}
+
#endif
THC_API accreal
diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h
index 09a26fc..dc38ed6 100644
--- a/lib/THC/generic/THCTensorMathReduce.h
+++ b/lib/THC/generic/THCTensorMathReduce.h
@@ -35,4 +35,7 @@ THC_API void THCTensor_(max)(THCState *state,
THC_API real THCTensor_(minall)(THCState *state, THCTensor *self);
THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self);
+THC_API accreal THCTensor_(dist)(THCState *state, THCTensor *self, THCTensor *src,
+ real value);
+
#endif
diff --git a/lib/THC/generic/THCTensorMathScan.cu b/lib/THC/generic/THCTensorMathScan.cu
new file mode 100644
index 0000000..8a8e434
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathScan.cu
@@ -0,0 +1,89 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathScan.cu"
+#else
+
+template<class BinaryOp>
+__host__ void THCTensor_(scanOuterDim)(THCState *state, THCTensor *tgt,
+ THCTensor *src, long dimension,
+ real init, BinaryOp binary_op)
+{
+ unsigned ndim = THCTensor_(nDimension)(state, src);
+ // Treat all outer dimensions (i.e. dim < dimension) as one.
+ unsigned num_orows = 1;
+ for (long dim = 0; dim < dimension; dim++) {
+ num_orows *= THCTensor_(size)(state, src, dim);
+ }
+ unsigned row_size = THCTensor_(size)(state, src, dimension);
+ // Treat all inner dimensions (i.e. dim > dimension) as one.
+ unsigned num_irows = 1;
+ for (unsigned dim = dimension + 1; dim < ndim; dim++) {
+ num_irows *= THCTensor_(size)(state, src, dim);
+ }
+
+ dim3 threads(min(512, num_irows));
+ unsigned maxGridDim = 1024;
+ dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
+
+ THCTensor_kernel_scanOuterDim<real><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ THCTensor_(data)(state, tgt), THCTensor_(data)(state, src),
+ num_orows, num_irows, row_size, init, binary_op);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+template<class BinaryFunction>
+__host__ void THCTensor_(scanInnermostDim)(THCState *state, THCTensor *tgt,
+ THCTensor *src, real init,
+ BinaryFunction binary_op)
+{
+ unsigned ndim = THCTensor_(nDimension)(state, src);
+ // Treat all outer dimensions as a single dimension.
+ unsigned num_rows = 1;
+ for (unsigned dim = 0; dim < ndim - 1; dim++) {
+ num_rows *= THCTensor_(size)(state, src, dim);
+ }
+ unsigned row_size = THCTensor_(size)(state, src, ndim - 1);
+
+ dim3 threads(16, 32);
+ dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y)));
+
+ THCTensor_kernel_scanInnermostDim<real, 16, 32><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
+ THCTensor_(data)(state, tgt), THCTensor_(data)(state, src), num_rows, row_size, init, binary_op);
+
+ THCudaCheck(cudaGetLastError());
+}
+
+template<class BinaryFunction>
+void THCTensor_(scanDim)(THCState *state, THCTensor *self_, THCTensor *src,
+ long dimension, real init, BinaryFunction binary_op)
+{
+ THCTensor_(resizeAs)(state, self_, src);
+
+ THCTensor *self = THCTensor_(newContiguous)(state, self_);
+ src = THCTensor_(newContiguous)(state, src);
+
+ if (dimension == THCTensor_(nDimension)(state, src) - 1) {
+ THCTensor_(scanInnermostDim)(state, self, src, init, binary_op);
+ } else {
+ THCTensor_(scanOuterDim)(state, self, src, dimension, init, binary_op);
+ }
+
+ THCTensor_(free)(state, src);
+ THCTensor_(freeCopyTo)(state, self, self_);
+}
+
+void THCTensor_(cumsum)(THCState *state, THCTensor *self, THCTensor *src, long dimension)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ return THCTensor_(scanDim)(state, self, src, dimension,
+ ScalarConvert<float, real>::to(0.0), AddOp<real>());
+}
+
+void THCTensor_(cumprod)(THCState *state, THCTensor *self, THCTensor *src, long dimension)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ return THCTensor_(scanDim)(state, self, src, dimension,
+ ScalarConvert<float, real>::to(1.0), MulOp<real>());
+}
+
+#endif
diff --git a/lib/THC/generic/THCTensorMathScan.h b/lib/THC/generic/THCTensorMathScan.h
new file mode 100644
index 0000000..edd825a
--- /dev/null
+++ b/lib/THC/generic/THCTensorMathScan.h
@@ -0,0 +1,8 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/THCTensorMathScan.h"
+#else
+
+THC_API void THCTensor_(cumsum)(THCState *state, THCTensor *self, THCTensor *src, long dim);
+THC_API void THCTensor_(cumprod)(THCState *state, THCTensor *self, THCTensor *src, long dim);
+
+#endif
diff --git a/test/test.lua b/test/test.lua
index 8436b64..00cfa66 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1184,9 +1184,12 @@ function test.cumsum()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():rand(sz1, sz2)
- compareFloatAndCuda(x, 'cumsum')
- compareFloatAndCuda(x, 'cumsum', 1)
- compareFloatAndCuda(x, 'cumsum', 2)
+ for _, typename in ipairs(typenames) do
+ local x = x:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumsum');
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumsum', 1);
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumsum', 2);
+ end
checkMultiDevice(x, 'cumsum')
checkMultiDevice(x, 'cumsum', 1)
end
@@ -1210,9 +1213,12 @@ function test.cumprod()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():rand(sz1, sz2)
- compareFloatAndCuda(x, 'cumprod')
- compareFloatAndCuda(x, 'cumprod', 1)
- compareFloatAndCuda(x, 'cumprod', 2)
+ for _, typename in ipairs(typenames) do
+ local x = x:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumprod');
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumprod', 1);
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cumprod', 2);
+ end
checkMultiDevice(x, 'cumprod')
checkMultiDevice(x, 'cumprod', 1)
end
@@ -1871,7 +1877,12 @@ function test.dist()
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():rand(sz1, sz2)
local y = torch.FloatTensor():rand(sz1, sz2)
- compareFloatAndCudaTensorArgs(x, 'dist', y)
+ for _, typename in ipairs(float_typenames) do
+ local x = x:type(t2cpu[typename])
+ local y = y:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y, 3)
+ end
checkMultiDevice(x, 'dist', y)
end