Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-16 19:35:17 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 19:35:17 +0300
commit5652762d1138ee848a54bce6da67ed139576c876 (patch)
tree85ae643ed0582b363c46cbdf7dbf0f191da0ea14
parentc65c1da9e61e9004b121b15da4aeb54f1b4513ce (diff)
add support for fmod in cutorch
-rw-r--r--TensorMath.lua12
-rw-r--r--lib/THC/THCTensorMathPairwise.cu45
-rw-r--r--lib/THC/generic/THCTensorMathPairwise.cu19
-rw-r--r--lib/THC/generic/THCTensorMathPairwise.h1
-rw-r--r--test/test.lua18
5 files changed, 94 insertions, 1 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index d4471c4..f47c6cc 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -661,6 +661,12 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor, method={default=1}},
{name=real}})
+ wrap("fmod",
+ cname("fmod"),
+ {{name=Tensor, default=true, returned=true, method={default='nil'}},
+ {name=Tensor, method={default=1}},
+ {name=real}})
+
wrap("remainder",
cname("remainder"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
@@ -1312,6 +1318,12 @@ wrap("div",
{name=Tensor, method={default=1}},
{name=real}})
+wrap("fmod",
+ cname("fmod"),
+ {{name=Tensor, default=true, returned=true, method={default='nil'}},
+ {name=Tensor, method={default=1}},
+ {name=real}})
+
wrap("remainder",
cname("remainder"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
diff --git a/lib/THC/THCTensorMathPairwise.cu b/lib/THC/THCTensorMathPairwise.cu
index 097533f..337e59f 100644
--- a/lib/THC/THCTensorMathPairwise.cu
+++ b/lib/THC/THCTensorMathPairwise.cu
@@ -318,6 +318,51 @@ struct TensorRemainderOp<half> {
};
#endif // CUDA_HALF_TENSOR
+template <typename T>
+struct TensorFmodOp {
+ TensorFmodOp(T v) : val((float)v) {}
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = (T) fmodf((float) *in, val);
+ }
+
+ __device__ __forceinline__ void operator()(T* v) {
+ *v = (T) fmodf((float) *v, val);
+ }
+
+ const float val;
+};
+
+template <>
+struct TensorFmodOp<double> {
+ TensorFmodOp(double v) : val(v) {}
+ __device__ __forceinline__ void operator()(double* out, double* in) {
+ *out = fmod(*in, val);
+ }
+
+ __device__ __forceinline__ void operator()(double* v) {
+ *v = fmod(*v, val);
+ }
+
+ const double val;
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorFmodOp<half> {
+ TensorFmodOp(half v): fval(THC_half2float(v)) {}
+
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+ *out = __float2half(fmodf(__half2float(*in), fval));
+ }
+
+ __device__ __forceinline__ void operator()(half* v) {
+ *v = __float2half(fmodf(__half2float(*v), fval));
+ }
+
+ const float fval;
+};
+#endif // CUDA_HALF_TENSOR
+
template <typename T, int Upper>
struct TensorTriOp {
TensorTriOp(T *start_, long stride0_, long stride1_, long k_)
diff --git a/lib/THC/generic/THCTensorMathPairwise.cu b/lib/THC/generic/THCTensorMathPairwise.cu
index b28bfe6..119d333 100644
--- a/lib/THC/generic/THCTensorMathPairwise.cu
+++ b/lib/THC/generic/THCTensorMathPairwise.cu
@@ -81,6 +81,25 @@ THCTensor_(div)(THCState* state, THCTensor *self_, THCTensor *src_, real value)
}
THC_API void
+THCTensor_(fmod)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
+ if (self_ == src_) {
+ if (!THC_pointwiseApply1(state, self_, TensorFmodOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src_);
+
+ if (!THC_pointwiseApply2(state, self_, src_, TensorFmodOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
+THC_API void
THCTensor_(remainder)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
{
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
diff --git a/lib/THC/generic/THCTensorMathPairwise.h b/lib/THC/generic/THCTensorMathPairwise.h
index ced2315..75df6cf 100644
--- a/lib/THC/generic/THCTensorMathPairwise.h
+++ b/lib/THC/generic/THCTensorMathPairwise.h
@@ -6,6 +6,7 @@ THC_API void THCTensor_(add)(THCState *state, THCTensor *self, THCTensor *src, r
THC_API void THCTensor_(sub)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(mul)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(div)(THCState *state, THCTensor *self, THCTensor *src, real value);
+THC_API void THCTensor_(fmod)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(remainder)(THCState *state, THCTensor *self, THCTensor *src, real value);
#endif
diff --git a/test/test.lua b/test/test.lua
index 8d36097..1ecd138 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -383,7 +383,6 @@ local function compareCPUAndCUDATypeTensorArgsWithConv(cudaType, gpu2cpu_map, in
string.format("number of return arguments for CPU and CUDA "
.. "are different for function '%s'", tostring(fn)))
for k, _ in ipairs(rcpu) do
- print(rcpu[k], rcuda[k])
tester:assert(isEqual(rcpu[k], rcuda[k], tolerance),
string.format(errstrval, k, divval(rcpu[k], rcuda[k])))
end
@@ -1002,6 +1001,23 @@ function test.addcdiv()
checkMultiDevice(r, 'addcdiv', x, torch.uniform(), y, z)
end
+function test.fmod()
+ local sz1 = chooseInt(minsize, maxsize)
+ local sz2 = chooseInt(minsize, maxsize)
+ local x = torch.FloatTensor():randn(sz1, sz2)
+ x:apply(function(x)
+ x = x * torch.random(1, 100)
+ return x
+ end)
+ local r = torch.normal(0, 25)
+ print(x, r)
+
+ for _, typename in ipairs(typenames) do
+ local x = x:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'fmod', r)
+ end
+end
+
function test.remainder()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)