diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-06 17:26:25 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 29e5059c3980bb2a905a7f7eacc4add1123b1b93 (patch) | |
tree | bc34996e730af321e1be6039cb44bcbe3dca0afc | |
parent | 74a97cf2b0eeaff5b19f79fd75493b316e340c9d (diff) |
[cutorch refactor] move cross(...) to generic
-rw-r--r-- | TensorMath.lua | 7 | ||||
-rw-r--r-- | lib/THC/THCNumerics.cuh | 12 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 2 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 45 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 25 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 36 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 1 | ||||
-rw-r--r-- | test/test.lua | 8 |
8 files changed, 86 insertions, 50 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 1a06121..53a112a 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -596,6 +596,13 @@ for k, Tensor_ in pairs(handledTypenames) do {name=real}, {name=real}}) + wrap("cross", + cname("cross"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}, + {name="index", default=0}}) + wrap("div", cname("div"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index 81ed0d1..af51809 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -26,6 +26,8 @@ struct THCNumerics<unsigned char> { static inline __host__ __device__ bool ne(unsigned char a, unsigned char b) { return a != b; } static inline __host__ __device__ unsigned char add(unsigned char a, unsigned char b) { return a + b; } + static inline __host__ __device__ unsigned char mul(unsigned char a, unsigned char b) { return a * b; } + static inline __host__ __device__ unsigned char sub(unsigned char a, unsigned char b) { return a - b; } static inline __host__ __device__ unsigned char abs(unsigned char a) { return abs(a); } }; @@ -42,6 +44,8 @@ struct THCNumerics<char> { static inline __host__ __device__ bool ne(char a, char b) { return a != b; } static inline __host__ __device__ char add(char a, char b) { return a + b; } + static inline __host__ __device__ char mul(char a, char b) { return a * b; } + static inline __host__ __device__ char sub(char a, char b) { return a - b; } static inline __host__ __device__ char abs(char a) { return abs(a); } }; @@ -58,6 +62,8 @@ struct THCNumerics<short> { static inline __host__ __device__ bool ne(short a, short b) { return a != b; } static inline __host__ __device__ short add(short a, short b) { return a + b; } + static inline __host__ __device__ short mul(short a, short b) { return a * b; } + static inline __host__ __device__ short sub(short a, short b) { return a - b; } static inline __host__ __device__ short abs(short a) { return abs(a); } }; @@ -74,6 +80,8 @@ struct THCNumerics<int> { static inline __host__ __device__ bool ne(int a, int b) { return a != b; } static inline __host__ __device__ int add(int a, int b) { return a + b; } + static inline __host__ __device__ int mul(int a, int b) { return a * b; } + static inline __host__ __device__ int sub(int a, int b) { return a - b; } static inline __host__ __device__ int abs(int a) { return ::abs(a); } }; @@ -90,8 +98,10 @@ struct THCNumerics<long> { static inline __host__ __device__ bool ne(long a, long b) { return a != b; } static inline __host__ __device__ long add(long a, long b) { return a + b; } - static inline __host__ __device__ long abs(long a) { return labs(a); } + static inline __host__ __device__ long mul(long a, long b) { return a * b; } + static inline __host__ __device__ long sub(long a, long b) { return a - b; } static inline __host__ __device__ long div(long a, long b) { return a / b; }; + static inline __host__ __device__ long abs(long a) { return labs(a); } }; #ifdef CUDA_HALF_TENSOR diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 0c850ab..f54f026 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -53,8 +53,6 @@ THC_API void THCudaTensor_cmax(THCState *state, THCudaTensor *self, THCudaTensor THC_API void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); -THC_API void THCudaTensor_cross(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2, int dimension); - // MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_); THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index c913b99..e0e3255 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -116,48 +116,3 @@ void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size) THCudaTensor_normal(state, r_, 0, 1); } -struct TensorCrossOp { - TensorCrossOp(long sx, long sy, long so) : sx(sx), sy(sy), so(so) {} - - __device__ __forceinline__ void operator()(float* out, float* x, float*y) { - out[0 * so] = x[1 * sx] * y[2 * sy] - x[2 * sx] * y[1 * sy]; - out[1 * so] = x[2 * sx] * y[0 * sy] - x[0 * sx] * y[2 * sy]; - out[2 * so] = x[0 * sx] * y[1 * sy] - x[1 * sx] * y[0 * sy]; - } - - const long sx, sy, so; -}; - -THC_API void THCudaTensor_cross(THCState *state, THCudaTensor *self, THCudaTensor *x, THCudaTensor *y, int dimension) -{ - THAssert(THCudaTensor_checkGPU(state, 3, self, x, y)); - - int i; - long nd = THCudaTensor_nDimension(state, x); - long nelem = THCudaTensor_nElement(state, x); - THArgCheck(nd == THCudaTensor_nDimension(state, y), 1, "tensors must have same number of dimensions"); - for (i = 0; i < nd; i++) { - THArgCheck(THCudaTensor_size(state, x, i) == THCudaTensor_size(state, y, i), 1, "dimension %i of x and y does not match", i); - if (dimension < 0 && THCudaTensor_size(state, x, i) == 3) { - dimension = i; - } - } - - THArgCheck(dimension >= 0 && dimension < nd, 3, "dimension %d out of range", dimension+1); - THArgCheck(THCudaTensor_size(state, x, dimension) == 3, 3, - "dimension %d does not have size 3", dimension+1); - THCudaTensor_resizeAs(state, self, x); - - long sx = THCudaTensor_stride(state, x, dimension); - long sy = THCudaTensor_stride(state, y, dimension); - long so = THCudaTensor_stride(state, self, dimension); - THCudaTensor *nx = THCudaTensor_newNarrow(state, x, dimension, 0, 1); - THCudaTensor *ny = THCudaTensor_newNarrow(state, y, dimension, 0, 1); - THCudaTensor *nself = THCudaTensor_newNarrow(state, self, dimension, 0, 1); - if (!THC_pointwiseApply3(state, nself, nx, ny, TensorCrossOp(sx, sy, so))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - THCudaTensor_free(state, nx); - THCudaTensor_free(state, ny); - THCudaTensor_free(state, nself); -} diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index 9560721..3f99279 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -431,4 +431,29 @@ struct TensorLerpOp { const T w; }; +template <typename T> +struct TensorCrossOp { + TensorCrossOp(long sx, long sy, long so) : sx(sx), sy(sy), so(so) {} + + __device__ __forceinline__ void operator()(T* out, T* x, T*y) { + out[0 * so] = THCNumerics<T>::sub( + THCNumerics<T>::mul(x[1 * sx], y[2 * sy]), + THCNumerics<T>::mul(x[2 * sx], y[1 * sy]) + ); + + out[1 * so] = THCNumerics<T>::sub( + THCNumerics<T>::mul(x[2 * sx], y[0 * sy]), + THCNumerics<T>::mul(x[0 * sx], y[2 * sy]) + ); + + out[2 * so] = THCNumerics<T>::sub( + THCNumerics<T>::mul(x[0 * sx], y[1 * sy]), + THCNumerics<T>::mul(x[1 * sx], y[0 * sy]) + ); + } + + const long sx, sy, so; +}; + + #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 79180cd..86cea2a 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -101,6 +101,42 @@ void THCTensor_(clamp)(THCState *state, THCTensor *self_, THCTensor *src, real m THCudaCheck(cudaGetLastError()); } +THC_API void +THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension) +{ + THAssert(THCTensor_(checkGPU)(state, 3, self, x, y)); + + int i; + long nd = THCTensor_(nDimension)(state, x); + long nelem = THCTensor_(nElement)(state, x); + THArgCheck(nd == THCTensor_(nDimension)(state, y), 1, "tensors must have same number of dimensions"); + for (i = 0; i < nd; i++) { + THArgCheck(THCTensor_(size)(state, x, i) == THCTensor_(size)(state, y, i), 1, "dimension %i of x and y does not match", i); + if (dimension < 0 && THCTensor_(size)(state, x, i) == 3) { + dimension = i; + } + } + + THArgCheck(dimension >= 0 && dimension < nd, 3, "dimension %d out of range", dimension+1); + THArgCheck(THCTensor_(size)(state, x, dimension) == 3, 3, + "dimension %d does not have size 3", dimension+1); + THCTensor_(resizeAs)(state, self, x); + + long sx = THCTensor_(stride)(state, x, dimension); + long sy = THCTensor_(stride)(state, y, dimension); + long so = THCTensor_(stride)(state, self, dimension); + THCTensor *nx = THCTensor_(newNarrow)(state, x, dimension, 0, 1); + THCTensor *ny = THCTensor_(newNarrow)(state, y, dimension, 0, 1); + THCTensor *nself = THCTensor_(newNarrow)(state, self, dimension, 0, 1); + if (!THC_pointwiseApply3(state, nself, nx, ny, TensorCrossOp<real>(sx, sy, so))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + THCTensor_(free)(state, nx); + THCTensor_(free)(state, ny); + THCTensor_(free)(state, nself); +} + + #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index 12c420b..03ab32e 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -37,6 +37,7 @@ THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value); +THC_API void THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension); THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2); THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2); diff --git a/test/test.lua b/test/test.lua index 410696e..f26effd 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1817,8 +1817,12 @@ function test.cross() sizes[crossdim] = 3 local x = torch.FloatTensor():randn(unpack(sizes)) local y = torch.FloatTensor():randn(unpack(sizes)) - compareFloatAndCudaTensorArgs(x, 'cross', y, crossdim) - checkMultiDevice(x, 'cross', y, crossdim) + for _, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + local y = y:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'cross', y, crossdim) + checkMultiDevice(x, 'cross', y, crossdim) + end end end |