diff options
author | Shenxiu Liu <shenxiu@devvm22329.prn1.facebook.com> | 2017-03-12 04:48:47 +0300 |
---|---|---|
committer | Shenxiu Liu <shenxiu@fb.com> | 2017-03-16 06:50:30 +0300 |
commit | 8f6fcb3019e10e02c448e2316e326af797ba2e5b (patch) | |
tree | 0ea13807875b71e07b2698ca7fd14c922e00b58f | |
parent | c31cc583a33f5dbf052257c72b51d3341a089bb8 (diff) |
implement linspace, logspace and range in CUDA
-rw-r--r-- | TensorMath.lua | 44 | ||||
-rw-r--r-- | lib/THC/THCNumerics.cuh | 15 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cu | 26 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 63 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 8 | ||||
-rw-r--r-- | test/test.lua | 146 |
6 files changed, 300 insertions, 2 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index c35a366..df44b7d 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -965,6 +965,13 @@ for k, Tensor_ in pairs(handledTypenames) do {{name="CudaLongTensor", default=true, returned=true}, {name=Tensor}}) + wrap("range", + cname("range"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=accreal}, + {name=accreal}, + {name=accreal, default=1}}) + if real == 'float' or real == 'double' or real == 'half' then for _,name in ipairs({"log", "log1p", "exp", "cos", "acos", "cosh", @@ -981,6 +988,20 @@ for k, Tensor_ in pairs(handledTypenames) do end + wrap("linspace", + cname("linspace"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=real}, + {name=real}, + {name="long", default=100}}) + + wrap("logspace", + cname("logspace"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=real}, + {name=real}, + {name="long", default=100}}) + wrap("pow", cname("pow"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, @@ -1087,8 +1108,6 @@ for k, Tensor_ in pairs(handledTypenames) do {{name=Tensor}, {name=accreal, creturned=true}}) - - wrap("lerp", cname("lerp"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, @@ -1425,6 +1444,20 @@ wrap("zeros", {{name=Tensor, default=true, returned=true, method={default='nil'}}, {name="LongArg"}}) +wrap("linspace", + cname("linspace"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=real}, + {name=real}, + {name="long", default=100}}) + +wrap("logspace", + cname("logspace"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=real}, + {name=real}, + {name="long", default=100}}) + wrap("reshape", cname("reshape"), {{name=Tensor, default=true, returned=true}, @@ -1909,6 +1942,13 @@ wrap("nonzero", {{name="CudaLongTensor", default=true, returned=true}, {name=Tensor}}) +wrap("range", + cname("range"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=real}, + {name=real}, + {name=real, default=1}}) + wrap("geometric", cname("geometric"), {{name=Tensor, returned=true}, diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index 0944360..ab9d1d9 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -211,6 +211,19 @@ struct THCNumerics<half> { #endif } + static inline __host__ __device__ half exp10(half a) { +#ifdef __CUDA_ARCH__ +#ifdef CUDA_HALF_INSTRUCTIONS + return hexp10(a); +#else + float fa = __half2float(a); + return __float2half(exp10f(fa)); +#endif +#else // __CUDA_ARCH__ + return THC_float2half(exp10f(THC_half2float(a))); +#endif + } + static inline __host__ __device__ half log(half a) { #ifdef __CUDA_ARCH__ #ifdef CUDA_HALF_INSTRUCTIONS @@ -515,6 +528,7 @@ struct THCNumerics<float> { static inline __host__ __device__ bool ne(float a, float b) { return a != b; } static inline __host__ __device__ float exp (float a) { return expf(a); } + static inline __host__ __device__ float exp10(float a) { return exp10f(a); } static inline __host__ __device__ float log (float a) { return logf(a); } static inline __host__ __device__ float log1p(float a) { return log1pf(a); } static inline __host__ __device__ float cos (float a) { return cosf(a); } @@ -558,6 +572,7 @@ struct THCNumerics<double> { static inline __host__ __device__ bool ne(double a, double b) { return a != b; } static inline __host__ __device__ double exp (double a) { return ::exp(a); } + static inline __host__ __device__ double exp10(double a) { return ::exp10(a); } static inline __host__ __device__ double log (double a) { return ::log(a); } static inline __host__ __device__ double log1p(double a) { return ::log1p(a); } static inline __host__ __device__ double cos (double a) { return ::cos(a); } diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index 41e6466..b9225fe 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -107,6 +107,32 @@ struct NonZeroOp } }; +template<typename T, typename accT = T> +struct LinspaceOp { + __host__ __device__ LinspaceOp(accT start, accT step): + start_(start), step_(step) { } + __device__ __forceinline__ T operator()(ptrdiff_t index) { + accT increment = THCNumerics<accT>::mul(step_, ScalarConvert<ptrdiff_t,accT>::to(index)); + accT value = THCNumerics<accT>::add(start_, increment); + return ScalarConvert<accT,T>::to(value); + } + + const accT start_, step_; +}; + +template<typename T, typename accT = T> +struct LogspaceOp { + __host__ __device__ LogspaceOp(accT start, accT step): + start_(start), step_(step) { } + __device__ __forceinline__ T operator()(ptrdiff_t index) { + accT increment = THCNumerics<accT>::mul(step_, ScalarConvert<ptrdiff_t,accT>::to(index)); + accT value = THCNumerics<accT>::exp10(THCNumerics<accT>::add(start_, increment)); + return ScalarConvert<accT,T>::to(value); + } + + const accT start_, step_; +}; + #include "generic/THCTensorMath.cu" #include "THCGenerateAllTypes.h" diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index fbcd422..e9f697d 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -391,4 +391,67 @@ accreal THCTensor_(trace)(THCState *state, THCTensor *src_) { return trace; } +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n) { + THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); + THArgCheck(n > 1 || (n == 1 && (a == b)), 3, "invalid number of points"); + if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); + if (n == 1) THCTensor_(fill)(state, r_, a); + else { + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ // if r_ is contiguous we can direct work on it + : THCTensor_(newContiguous)(state, r_); + real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), + ScalarConvert<long,real>::to(n - 1)); + LinspaceOp<real> linspace_method(a, step); + thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); + thrust::tabulate(data_, data_ + n, linspace_method); + if (!THCTensor_(isContiguous)(state, r_)) { // We need to move data back to r_ + THCTensor_(freeCopyTo)(state, r, r_); + } + } + THCudaCheck(cudaGetLastError()); +} + +void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n) { + THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); + THArgCheck(n > 1 || (n == 1 && (a == b)), 3, "invalid number of points"); + if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); + if (n == 1) THCTensor_(fill)(state, r_, THCNumerics<real>::exp10(a)); + else { + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ + : THCTensor_(newContiguous)(state, r_); + real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), + ScalarConvert<long,real>::to(n - 1)); + LogspaceOp<real> logspace_method(a, step); + thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); + thrust::tabulate(data_, data_ + n, logspace_method); + if (!THCTensor_(isContiguous)(state, r_)) { + THCTensor_(freeCopyTo)(state, r, r_); + } + } + THCudaCheck(cudaGetLastError()); +} + +#endif + +void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xmax, accreal step) { + THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, r_)); + THArgCheck(step > 0 || step < 0, 3, "step must be a non-null number"); + THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin)) + , 2, "upper bound and larger bound incoherent with step sign"); + ptrdiff_t size = (ptrdiff_t) (((xmax - xmin) / step) + 1); + if (THCTensor_(nElement)(state, r_) != size) THCTensor_(resize1d)(state, r_, size); + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ + : THCTensor_(newContiguous)(state, r_); + LinspaceOp<real,accreal> linspace_method(xmin, step); + thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); + thrust::tabulate(data_, data_ + size, linspace_method); + if (!THCTensor_(isContiguous)(state, r_)) THCTensor_(freeCopyTo)(state, r, r_); + THCudaCheck(cudaGetLastError()); +} + #endif diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 2b8f563..aae6775 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -18,5 +18,13 @@ THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, long k); THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +THC_API void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n); +THC_API void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n); + +#endif + +THC_API void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xmax, accreal step); #endif diff --git a/test/test.lua b/test/test.lua index 57c61f6..3aacdbe 100644 --- a/test/test.lua +++ b/test/test.lua @@ -882,6 +882,98 @@ function test.ones() torch.setdefaulttensortype(t) end +function test.linspace() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local n = sz1 * sz2 + local a = torch.uniform() + local b = torch.uniform() + local x = torch.FloatTensor():rand(sz1, sz2) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'linspace', a, b, n) + end + checkMultiDevice(x, 'linspace', a, b, n) + + -- Check range for non-contiguous tensors. + local x = createTestTensorWithSizes(true, true, {sz1, sz2}) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'linspace', a, b, n) + end + checkMultiDevice(x, 'linspace', a, b, n) + + -- Ckeck new tensor creation + local x = torch.FloatTensor() + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'linspace', a, b, n) + end + checkMultiDevice(x, 'linspace', a, b, n) + + -- Ckeck n = 1 case + local x = torch.rand(1) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'linspace', a, a, 1) + end + checkMultiDevice(x, 'linspace', a, a, 1) + + -- Ckeck default parameter case + local x = createTestTensorWithSizes(true, true, {100}) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'linspace', a, b) + end + checkMultiDevice(x, 'linspace', a, b) +end + +function test.logspace() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local n = sz1 * sz2 + local a = torch.uniform() + local b = torch.uniform() + local x = torch.FloatTensor():rand(sz1, sz2) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'logspace', a, b, n) + end + checkMultiDevice(x, 'logspace', a, b, n) + + -- Check range for non-contiguous tensors. + local x = createTestTensorWithSizes(true, true, {sz1, sz2}) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'logspace', a, b, n) + end + checkMultiDevice(x, 'logspace', a, b, n) + + -- Ckeck new tensor creation + local x = torch.FloatTensor() + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'logspace', a, b, n) + end + checkMultiDevice(x, 'logspace', a, b, n) + + -- Ckeck n = 1 case + local x = torch.rand(1) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'logspace', a, a, 1) + end + checkMultiDevice(x, 'logspace', a, a, 1) + + -- Ckeck default parameter case + local x = createTestTensorWithSizes(true, true, {100}) + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'logspace', a, b) + end + checkMultiDevice(x, 'logspace3', a, b) +end + function test.add() local sz1 = chooseInt(minsize, maxsize) @@ -1510,6 +1602,60 @@ function test.diag() checkMultiDevice(y1, 'diag', k) end +function test.range() + local xmin = chooseInt(minsize, maxsize) + local xmax = chooseInt(xmin, maxsize) + local step = 3 + local size = math.floor((xmax - xmin) / step + 1) + -- Base case + local x = torch.FloatTensor():rand(size) + for k, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmin, xmax, step) + end + checkMultiDevice(x, 'range', xmin, xmax, step) + + -- Check range for non-contiguous tensors. + local x = createTestTensorWithSizes(true, true, {size}) + for k, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmin, xmax, step) + end + checkMultiDevice(x, 'range', xmin, xmax, step) + + -- Ckeck new tensor creation + local x = torch.Tensor() + for k, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmin, xmax, step) + end + checkMultiDevice(x, 'range', xmin, xmax, step) + + -- Ckeck negative step case + for k, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmax, xmin, -step) + end + checkMultiDevice(x, 'range', xmax, xmin, -step) + + -- Ckeck default parameter case + local x = createTestTensorWithSizes(true, true, {size}) + for k, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmin, xmax) + end + checkMultiDevice(x, 'range', xmin, xmax, step) + + -- Ckeck floating step case + local step = 1.3 + local x = torch.Tensor() + for k, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'range', xmin, xmax) + end + checkMultiDevice(x, 'range', xmin, xmax, step) +end + function test.trace() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) |