diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-05 22:51:26 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 74a97cf2b0eeaff5b19f79fd75493b316e340c9d (patch) | |
tree | 495cf663d72b6069c7d21f0d974ebafd678c4100 | |
parent | 89330c02a1c9e13658156bd8941b5b7b48e3b71e (diff) |
[cutorch refactor] move lerp(...) to generic
-rw-r--r-- | TensorMath.lua | 7 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 24 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 16 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 15 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 2 | ||||
-rw-r--r-- | test/test.lua | 7 |
6 files changed, 45 insertions, 26 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 0e41835..1a06121 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -835,6 +835,13 @@ for k, Tensor_ in pairs(handledTypenames) do {name="boolean", default=false}}) end + wrap("lerp", + cname("lerp"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=Tensor, method={default=1}}, + {name=Tensor}, + {name=real}}) + -- BLAS functions wrap("mv", cname("addmv"), diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 64b6af3..c913b99 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -68,30 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaCheck(cudaGetLastError()); } -struct TensorLerpOp { - TensorLerpOp(float w) : w(w) {} - - __device__ __forceinline__ void operator()(float *out, float *a, float *b) { - *out = *a + w * (*b - *a); - } - - const float w; -}; - -void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, THCudaTensor *b, float w) -{ - THAssert(THCudaTensor_checkGPU(state, 3, result, a, b)); - THArgCheck(THCudaTensor_nElement(state, a) == - THCudaTensor_nElement(state, b), 3, "sizes do not match"); - THCudaTensor_resizeAs(state, result, a); - - if (!THC_pointwiseApply3(state, result, a, b, TensorLerpOp(w))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - - THCudaCheck(cudaGetLastError()); -} - struct dist_functor { const float exponent; diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index a690c45..9560721 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -414,5 +414,21 @@ struct TensorClampOp { const T maxValue; }; +template <typename T> +struct TensorLerpOp { + TensorLerpOp(T w) : w(w) {} + + __device__ __forceinline__ void operator()(T *out, T *a, T *b) { + *out = THCNumerics<T>::add( + *a, + THCNumerics<T>::mul( + w, + THCNumerics<T>::sub(*b, *a) + ) + ); + } + + const T w; +}; #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index 707fc93..79180cd 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -137,6 +137,21 @@ void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real val THCudaCheck(cudaGetLastError()); } +THC_API void +THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w) +{ + THAssert(THCTensor_(checkGPU)(state, 3, result, a, b)); + THArgCheck(THCTensor_(nElement)(state, a) == + THCTensor_(nElement)(state, b), 3, "sizes do not match"); + THCTensor_(resizeAs)(state, result, a); + + if (!THC_pointwiseApply3(state, result, a, b, TensorLerpOp<real>(w))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + #endif THC_API void diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index 0c03045..12c420b 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -27,7 +27,7 @@ THC_API void THCTensor_(floor)(THCState *state, THCTensor *self, THCTensor *src) THC_API void THCTensor_(round)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, float w); +THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w); THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/test/test.lua b/test/test.lua index 2a88268..410696e 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1329,7 +1329,12 @@ function test.lerp() local y = torch.FloatTensor():rand(sz1, sz2) local w = math.random() local z = torch.FloatTensor() - compareFloatAndCudaTensorArgs(z, 'lerp', x, y, w) + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + local y = y:type(t2cpu[typename]) + local z = z:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, z, 'lerp', x, y, w) + end checkMultiDevice(z, 'lerp', x, y, w) end |