Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-05 22:51:26 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:28 +0300
commit74a97cf2b0eeaff5b19f79fd75493b316e340c9d (patch)
tree495cf663d72b6069c7d21f0d974ebafd678c4100
parent89330c02a1c9e13658156bd8941b5b7b48e3b71e (diff)
[cutorch refactor] move lerp(...) to generic
-rw-r--r--TensorMath.lua7
-rw-r--r--lib/THC/THCTensorMath2.cu24
-rw-r--r--lib/THC/THCTensorMathPointwise.cuh16
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu15
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h2
-rw-r--r--test/test.lua7
6 files changed, 45 insertions, 26 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 0e41835..1a06121 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -835,6 +835,13 @@ for k, Tensor_ in pairs(handledTypenames) do
{name="boolean", default=false}})
end
+ wrap("lerp",
+ cname("lerp"),
+ {{name=Tensor, default=true, returned=true, method={default='nil'}},
+ {name=Tensor, method={default=1}},
+ {name=Tensor},
+ {name=real}})
+
-- BLAS functions
wrap("mv",
cname("addmv"),
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 64b6af3..c913b99 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -68,30 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-struct TensorLerpOp {
- TensorLerpOp(float w) : w(w) {}
-
- __device__ __forceinline__ void operator()(float *out, float *a, float *b) {
- *out = *a + w * (*b - *a);
- }
-
- const float w;
-};
-
-void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, THCudaTensor *b, float w)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, result, a, b));
- THArgCheck(THCudaTensor_nElement(state, a) ==
- THCudaTensor_nElement(state, b), 3, "sizes do not match");
- THCudaTensor_resizeAs(state, result, a);
-
- if (!THC_pointwiseApply3(state, result, a, b, TensorLerpOp(w))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
struct dist_functor
{
const float exponent;
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh
index a690c45..9560721 100644
--- a/lib/THC/THCTensorMathPointwise.cuh
+++ b/lib/THC/THCTensorMathPointwise.cuh
@@ -414,5 +414,21 @@ struct TensorClampOp {
const T maxValue;
};
+template <typename T>
+struct TensorLerpOp {
+ TensorLerpOp(T w) : w(w) {}
+
+ __device__ __forceinline__ void operator()(T *out, T *a, T *b) {
+ *out = THCNumerics<T>::add(
+ *a,
+ THCNumerics<T>::mul(
+ w,
+ THCNumerics<T>::sub(*b, *a)
+ )
+ );
+ }
+
+ const T w;
+};
#endif // THC_TENSORMATH_POINTWISE_CUH
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index 707fc93..79180cd 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -137,6 +137,21 @@ void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, real val
THCudaCheck(cudaGetLastError());
}
+THC_API void
+THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, result, a, b));
+ THArgCheck(THCTensor_(nElement)(state, a) ==
+ THCTensor_(nElement)(state, b), 3, "sizes do not match");
+ THCTensor_(resizeAs)(state, result, a);
+
+ if (!THC_pointwiseApply3(state, result, a, b, TensorLerpOp<real>(w))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
#endif
THC_API void
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
index 0c03045..12c420b 100644
--- a/lib/THC/generic/THCTensorMathPointwise.h
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -27,7 +27,7 @@ THC_API void THCTensor_(floor)(THCState *state, THCTensor *self, THCTensor *src)
THC_API void THCTensor_(round)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src);
-THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, float w);
+THC_API void THCTensor_(lerp)(THCState *state, THCTensor *result, THCTensor *a, THCTensor *b, real w);
THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);
diff --git a/test/test.lua b/test/test.lua
index 2a88268..410696e 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1329,7 +1329,12 @@ function test.lerp()
local y = torch.FloatTensor():rand(sz1, sz2)
local w = math.random()
local z = torch.FloatTensor()
- compareFloatAndCudaTensorArgs(z, 'lerp', x, y, w)
+ for _, typename in ipairs(float_typenames) do
+ local x = x:type(t2cpu[typename])
+ local y = y:type(t2cpu[typename])
+ local z = z:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, z, 'lerp', x, y, w)
+ end
checkMultiDevice(z, 'lerp', x, y, w)
end