Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-12-20 17:35:41 +0300
committerGitHub <noreply@github.com>2016-12-20 17:35:41 +0300
commitc2594f8186366b1c330f1ac2f04cbeb85b4aefb6 (patch)
tree8c4c74ba161e631dd04d4320127a622e98e42333
parent0814f81aefa0003a3584ab18c1659cb01b886012 (diff)
Revert "Add support for cremainder, cfmod"revert-641-cfuncs
-rw-r--r--TensorMath.lua4
-rw-r--r--lib/THC/THCTensorMathPointwise.cuh104
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu42
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h2
-rw-r--r--test/test.lua78
5 files changed, 6 insertions, 224 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 3072ea2..91d97dd 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -679,7 +679,7 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor},
{name="boolean", creturned=true}})
- for _, name in ipairs({"cmul", "cpow", "cdiv", "cremainder", "cfmod"}) do
+ for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
@@ -1457,7 +1457,7 @@ wrap("equal",
{name=Tensor},
{name="boolean", creturned=true}})
-for _, name in ipairs({"cmul", "cpow", "cdiv", "cremainder", "cfmod"}) do
+for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh
index a1a8304..40d35be 100644
--- a/lib/THC/THCTensorMathPointwise.cuh
+++ b/lib/THC/THCTensorMathPointwise.cuh
@@ -413,110 +413,6 @@ struct TensorDivOp<half> {
#endif // CUDA_HALF_TENSOR
template <typename T>
-struct TensorCRemainderOp {
- __device__ __forceinline__ void operator()(T* out, T* in) {
- *out = *in != 0 ? *out - *in * (*out / *in) : NAN;
- }
-
- __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
- *out = *in2 != 0 ? *in1 - *in2 * (*in1 / *in2) : NAN;
- }
-};
-
-template <>
-struct TensorCRemainderOp<float> {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = *in != 0 ? *out - *in * floorf(*out / *in) : NAN;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in2 != 0 ? *in1 - *in2 * floorf(*in1 / *in2) : NAN;
- }
-};
-
-template <>
-struct TensorCRemainderOp<double> {
- __device__ __forceinline__ void operator()(double* out, double* in) {
- *out = *in != 0 ? *out - *in * floor(*out / *in) : NAN;
- }
-
- __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
- *out = *in2 != 0 ? *in1 - *in2 * floor(*in1 / *in2) : NAN;
- }
-};
-
-#ifdef CUDA_HALF_TENSOR
-template <>
-struct TensorCRemainderOp<half> {
- __device__ __forceinline__ void operator()(half* out, half* in) {
-#ifdef CUDA_HALF_INSTRUCTIONS
- *out = __hsub(*out, __hmul(*in, hfloor(__hdiv(*out, *in))));
-#else
- float fout = __half2float(*out);
- float fin = __half2float(*in);
- *out = fin != 0 ? __float2half(fout - fin * floor(fout / fin)) : NAN;
-#endif
- }
-
- __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
-#ifdef CUDA_HALF_INSTRUCTIONS
- *out = __hsub(*in1, __hmul(*in2, hfloor(__hdiv(*in1, *in2))));
-#else
- float fin1 = __half2float(*in1);
- float fin2 = __half2float(*in2);
- *out = fin2 != 0 ? __float2half(fin1 - fin2 * floor(fin1 / fin2)) : NAN;
-#endif
- }
-};
-#endif // CUDA_HALF_TENSOR
-
-template <typename T>
-struct TensorCFmodOp {
- __device__ __forceinline__ void operator()(T* out, T* in) {
- *out = *out % *in;
- }
-
- __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
- *out = *in1 % *in2;
- }
-};
-
-template <>
-struct TensorCFmodOp<float> {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = fmodf(*out, *in);
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = fmodf(*in1, *in2);
- }
-};
-
-template <>
-struct TensorCFmodOp<double> {
- __device__ __forceinline__ void operator()(double* out, double* in) {
- *out = fmod(*out, *in);
- }
-
- __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
- *out = fmod(*in1, *in2);
- }
-};
-
-#ifdef CUDA_HALF_TENSOR
-template <>
-struct TensorCFmodOp<half> {
- __device__ __forceinline__ void operator()(half* out, double* in) {
- *out = __float2half(fmod(__half2float(*out), __half2float(*in)));
- }
-
- __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
- *out = __float2half(fmod(__half2float(*in1), __half2float(*in2)));
- }
-};
-#endif // CUDA_HALF_TENSOR
-
-template <typename T>
struct TensorClampOp {
TensorClampOp(T min, T max) : minValue(min), maxValue(max) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index b97908a..91c166f 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -344,14 +344,14 @@ THCTensor_(cdiv)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *
THCTensor_(nElement)(state, src2), 3, "sizes do not match");
if (self_ == src1) {
- // self /= src2
+ // self *= src2
if (!THC_pointwiseApply2(state, self_, src2, TensorDivOp<real>())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src1);
- // self = src1 / src2
+ // self = src1 * src2
if (!THC_pointwiseApply3(state, self_, src1, src2, TensorDivOp<real>())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
@@ -399,44 +399,6 @@ THCTensor_(cmin)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *s
}
THC_API void
-THCTensor_(cremainder)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2)
-{
- THAssert(THCTensor_(checkGPU)(state, 3, self, src1, src2));
- THArgCheck(THCTensor_(nElement)(state, src1) ==
- THCTensor_(nElement)(state, src2), 2, "sizes do not match");
-
- if (self == src1) {
- if (!THC_pointwiseApply2(state, self, src2, TensorCRemainderOp<real>())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCTensor_(resizeAs)(state, self, src1);
- if (!THC_pointwiseApply3(state, self, src1, src2, TensorCRemainderOp<real>())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
-
-THC_API void
-THCTensor_(cfmod)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2)
-{
- THAssert(THCTensor_(checkGPU)(state, 3, self, src1, src2));
- THArgCheck(THCTensor_(nElement)(state, src1) ==
- THCTensor_(nElement)(state, src2), 2, "sizes do not match");
-
- if (self == src1) {
- if (!THC_pointwiseApply2(state, self, src2, TensorCFmodOp<real>())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCTensor_(resizeAs)(state, self, src1);
- if (!THC_pointwiseApply3(state, self, src1, src2, TensorCFmodOp<real>())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
-
-THC_API void
THCTensor_(cmaxValue)(THCState *state, THCTensor *self, THCTensor *src, real value)
{
THAssert(THCTensor_(checkGPU)(state, 2, self, src));
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
index 34e594a..6e20a30 100644
--- a/lib/THC/generic/THCTensorMathPointwise.h
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -46,8 +46,6 @@ THC_API void THCTensor_(cpow)(THCState *state, THCTensor *self, THCTensor *src1,
THC_API void THCTensor_(cdiv)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(cmax)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(cmin)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
-THC_API void THCTensor_(cfmod)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
-THC_API void THCTensor_(cremainder)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(cmaxValue)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, real value);
diff --git a/test/test.lua b/test/test.lua
index 51e71ef..10152e1 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -162,48 +162,22 @@ local function createTestTensor(maxSize)
return createTestTensorMaxSize(holes, tr, maxSize)
end
-local function isEqual(x, y, tolerance, ...)
+local function isEqual(a, b, tolerance, ...)
if a == nil and b == nil then return true end
if a == nil and b ~= nil then return false end
if a ~= nil and b == nil then return false end
-
- -- clone the tensors so we can modify the contents if necessary for testing
- local a = x:clone()
- local b = y:clone()
-
if torch.type(b) ~= torch.type(a) then
b = b:typeAs(a) -- TODO: remove the need for this (a-b doesnt work for bytetensor, cudatensor pairs)
end
local diff = a-b
tolerance = tolerance or 0.000001
-
if type(a) == 'number' then
- -- NaN Check:
- if a ~= a and b ~= b then
- return true
- end
return math.abs(diff) < tolerance
else
if torch.type(diff) ~= 'torch.FloatTensor' then
diff = diff:float() -- TODO: remove the need for this (byteTensor and abs)
end
- -- NaN Check:
- local hasNaN = false
- diff:apply(function(elt) if elt ~= elt then hasNaN = true end end)
- if hasNaN then
- -- check if NaN in equal positions
- local nea = torch.ne(a, a)
- local neb = torch.ne(b, b)
- if not nea:equal(neb) then
- return false
- end
- -- check diff of all other elements less than tolerance
- local ea = a:apply(function(elt) if elt ~= elt then return 0 else return elt end end)
- local eb = b:apply(function(elt) if elt ~= elt then return 0 else return elt end end)
- return (ea-eb):abs():max() < tolerance
- else
- return diff:abs():max() < tolerance
- end
+ return diff:abs():max() < tolerance
end
end
@@ -361,7 +335,6 @@ local function compareCPUAndCUDATypeTensorArgsWithConv(cudaType, gpu2cpu_map, in
assert(baseType, 'Cannot find baseType for ' .. cudaType)
local x_cpu = x:type(baseType)
local x_cuda = cloneExactlyToGPUType(x_cpu, nil, gpu2cpu_map)
- -- print('x_cpu_initial', x_cpu, 'x_cuda_initial', x_cuda)
local rcpu = {}
local rcuda = {}
@@ -378,7 +351,6 @@ local function compareCPUAndCUDATypeTensorArgsWithConv(cudaType, gpu2cpu_map, in
end
return t
end
-
local cpu_args = {...}
local cuda_args = tranform_args({...})
if type(fn) == 'string' then
@@ -924,52 +896,6 @@ function test.cpow()
checkMultiDevice(x, 'cpow', y)
end
-function test.cremainder()
- local sz1 = chooseInt(minsize, maxsize)
- local sz2 = chooseInt(minsize, maxsize)
- local x = torch.FloatTensor(sz1, sz2):uniform(-50, 50)
- local y = torch.FloatTensor(sz1, sz2):uniform(-50, 50)
- for k, typename in ipairs(typenames) do
- local ctype = t2cpu[typename]
- local a, b = x:type(ctype), y:type(ctype)
- compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cremainder', b)
- end
- checkMultiDevice(x, 'cremainder', y)
-
- -- ensure we test divide by zero
- local x = torch.FloatTensor(1):fill(1)
- local y = torch.FloatTensor(1):zero()
- for k, typename in ipairs(typenames) do
- local ctype = t2cpu[typename]
- local a, b = x:type(ctype), y:type(ctype)
- compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cremainder', b)
- end
- checkMultiDevice(x, 'cremainder', y)
-end
-
-function test.cfmod()
- local sz1 = chooseInt(minsize, maxsize)
- local sz2 = chooseInt(minsize, maxsize)
- local x = torch.FloatTensor(sz1, sz2):uniform(-50, 50)
- local y = torch.FloatTensor(sz1, sz2):uniform(-50, 50)
- for k, typename in ipairs(typenames) do
- local ctype = t2cpu[typename]
- local a, b = x:type(ctype), y:type(ctype)
- compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cfmod', b)
- end
- checkMultiDevice(x, 'cfmod', y)
-
- -- ensure we test mod by zero
- local x = torch.FloatTensor(1):fill(1)
- local y = torch.FloatTensor(1):zero()
- for k, typename in ipairs(typenames) do
- local ctype = t2cpu[typename]
- local a, b = x:type(ctype), y:type(ctype)
- compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cfmod', b)
- end
- checkMultiDevice(x, 'cfmod', y)
-end
-
function test.nonzero()
local minsize = 10
local maxsize = 20