Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-11 17:59:47 +0300
committerGitHub <noreply@github.com>2016-06-11 17:59:47 +0300
commit53123134291b21a55215996fd4317e6d31b248cb (patch)
treea341040e8a1c5d09c02d6a7b32d8173d2badeae6
parentbc1080ebd10ed8319df0de0a0a23b0ff62994c0a (diff)
parentb01894b3a9fb02723fbf3584dff01464c58e73de (diff)
Merge pull request #296 from torch/fixes
fixes for cutorch API changes
-rw-r--r--lib/THCUNN/Abs.cu4
-rw-r--r--lib/THCUNN/ELU.cu8
-rw-r--r--lib/THCUNN/HardTanh.cu4
-rw-r--r--lib/THCUNN/LeakyReLU.cu8
-rw-r--r--lib/THCUNN/LogSigmoid.cu4
-rw-r--r--lib/THCUNN/PReLU.cu10
-rw-r--r--lib/THCUNN/RReLU.cu8
-rw-r--r--lib/THCUNN/Sigmoid.cu4
-rw-r--r--lib/THCUNN/SoftPlus.cu4
-rw-r--r--lib/THCUNN/SoftShrink.cu4
-rw-r--r--lib/THCUNN/SpatialReflectionPadding.cu6
-rw-r--r--lib/THCUNN/SpatialReplicationPadding.cu6
-rw-r--r--lib/THCUNN/Sqrt.cu4
-rw-r--r--lib/THCUNN/Square.cu4
-rw-r--r--lib/THCUNN/Tanh.cu4
-rw-r--r--lib/THCUNN/Threshold.cu8
16 files changed, 45 insertions, 45 deletions
diff --git a/lib/THCUNN/Abs.cu b/lib/THCUNN/Abs.cu
index 3983251..81b3297 100644
--- a/lib/THCUNN/Abs.cu
+++ b/lib/THCUNN/Abs.cu
@@ -13,7 +13,7 @@ void THNN_CudaAbs_updateOutput(THCState *state, THCudaTensor *input, THCudaTenso
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, absupdateOutput_functor());
+ THC_pointwiseApply2(state, output, input, absupdateOutput_functor());
}
struct absupdateGradInput_functor
@@ -28,5 +28,5 @@ void THNN_CudaAbs_updateGradInput(THCState *state, THCudaTensor *input, THCudaTe
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, absupdateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, absupdateGradInput_functor());
}
diff --git a/lib/THCUNN/ELU.cu b/lib/THCUNN/ELU.cu
index f58262e..e4a05bc 100644
--- a/lib/THCUNN/ELU.cu
+++ b/lib/THCUNN/ELU.cu
@@ -37,13 +37,13 @@ void THNN_CudaELU_updateOutput(THCState *state, THCudaTensor *input, THCudaTenso
if (inplace)
{
- THCudaTensor_pointwiseApply1(state, input, ELUupdateOutputIP_functor(alpha));
+ THC_pointwiseApply1(state, input, ELUupdateOutputIP_functor(alpha));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, ELUupdateOutput_functor(alpha));
+ THC_pointwiseApply2(state, output, input, ELUupdateOutput_functor(alpha));
}
}
@@ -82,12 +82,12 @@ void THNN_CudaELU_updateGradInput(THCState *state, THCudaTensor *input, THCudaTe
if (inplace)
{
- THCudaTensor_pointwiseApply2(state, gradOutput, output, ELUupdateGradInputIP_functor(alpha));
+ THC_pointwiseApply2(state, gradOutput, output, ELUupdateGradInputIP_functor(alpha));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, output);
- THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, ELUupdateGradInput_functor(alpha));
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, ELUupdateGradInput_functor(alpha));
}
}
diff --git a/lib/THCUNN/HardTanh.cu b/lib/THCUNN/HardTanh.cu
index e97faef..764a3c0 100644
--- a/lib/THCUNN/HardTanh.cu
+++ b/lib/THCUNN/HardTanh.cu
@@ -26,7 +26,7 @@ void THNN_CudaHardTanh_updateOutput(THCState *state, THCudaTensor *input, THCuda
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input,
+ THC_pointwiseApply2(state, output, input,
hardtanhupdateOutput_functor(min_val, max_val));
}
@@ -54,6 +54,6 @@ void THNN_CudaHardTanh_updateGradInput(THCState *state, THCudaTensor *input, THC
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
+ THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
}
diff --git a/lib/THCUNN/LeakyReLU.cu b/lib/THCUNN/LeakyReLU.cu
index 3d3fd92..a641821 100644
--- a/lib/THCUNN/LeakyReLU.cu
+++ b/lib/THCUNN/LeakyReLU.cu
@@ -38,13 +38,13 @@ void THNN_CudaLeakyReLU_updateOutput(THCState *state, THCudaTensor *input, THCud
if (inplace)
{
- THCudaTensor_pointwiseApply1(state, input, LeakyReLUUpdateOutputIP(negval));
+ THC_pointwiseApply1(state, input, LeakyReLUUpdateOutputIP(negval));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, LeakyReLUUpdateOutput(negval));
+ THC_pointwiseApply2(state, output, input, LeakyReLUUpdateOutput(negval));
}
THCudaCheck(cudaGetLastError());
@@ -90,13 +90,13 @@ void THNN_CudaLeakyReLU_updateGradInput(THCState *state, THCudaTensor *input, TH
if (inplace)
{
- THCudaTensor_pointwiseApply2(state, gradOutput, input, LeakyReLUUpdateGradInputIP(negval));
+ THC_pointwiseApply2(state, gradOutput, input, LeakyReLUUpdateGradInputIP(negval));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, LeakyReLUUpdateGradInput(negval));
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, LeakyReLUUpdateGradInput(negval));
}
THCudaCheck(cudaGetLastError());
diff --git a/lib/THCUNN/LogSigmoid.cu b/lib/THCUNN/LogSigmoid.cu
index b6ee6f2..2f56081 100644
--- a/lib/THCUNN/LogSigmoid.cu
+++ b/lib/THCUNN/LogSigmoid.cu
@@ -14,7 +14,7 @@ void THNN_CudaLogSigmoid_updateOutput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor());
+ THC_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor());
}
struct logSigmoid_updateGradInput_functor
@@ -31,5 +31,5 @@ void THNN_CudaLogSigmoid_updateGradInput(THCState *state, THCudaTensor *input, T
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor());
}
diff --git a/lib/THCUNN/PReLU.cu b/lib/THCUNN/PReLU.cu
index b9f2eed..048c0b4 100644
--- a/lib/THCUNN/PReLU.cu
+++ b/lib/THCUNN/PReLU.cu
@@ -42,7 +42,7 @@ void THNN_CudaPReLU_updateOutput(
if (nOutputPlane == 0)
{
- THCudaTensor_pointwiseApply2(state, output, input, PReLUUpdateOutput(w));
+ THC_pointwiseApply2(state, output, input, PReLUUpdateOutput(w));
}
else
{
@@ -109,7 +109,7 @@ void THNN_CudaPReLU_updateGradInput(
float *w = THCudaTensor_data(state, weight);
if (nOutputPlane == 0)
{
- THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, PReLUUpdateGradInput(w));
+ THC_pointwiseApply3(state, gradInput, gradOutput, input, PReLUUpdateGradInput(w));
}
else
{
@@ -189,7 +189,7 @@ void THNN_CudaPReLU_accGradParameters(
if (nOutputPlane == 0)
{
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParametersShared());
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParametersShared());
// introduces a sync point
float sum = THCudaTensor_sumall(state, gradInput);
@@ -205,11 +205,11 @@ void THNN_CudaPReLU_accGradParameters(
if (ndim == 1)
{
- THCudaTensor_pointwiseApply3(state, gradWeight, input, gradOutput, PReLUAccGradParameters1to1(scale));
+ THC_pointwiseApply3(state, gradWeight, input, gradOutput, PReLUAccGradParameters1to1(scale));
}
else
{
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParameters(scale));
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParameters(scale));
THCudaTensor *sumbuf = gradWeightBuf2;
THCudaTensor_resizeAs(state, gradWeightBuf, gradWeight);
diff --git a/lib/THCUNN/RReLU.cu b/lib/THCUNN/RReLU.cu
index 8e35ef9..86d962b 100644
--- a/lib/THCUNN/RReLU.cu
+++ b/lib/THCUNN/RReLU.cu
@@ -100,13 +100,13 @@ void THNN_CudaRReLU_updateOutput(THCState *state, THCudaTensor *input, THCudaTen
const double negSlope = (lower + upper) / 2;
if (inplace)
{
- THCudaTensor_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
+ THC_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
+ THC_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
}
}
}
@@ -169,13 +169,13 @@ void THNN_CudaRReLU_updateGradInput(THCState *state, THCudaTensor *input, THCuda
const double negSlope = (lower + upper) / 2;
if (inplace)
{
- THCudaTensor_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
+ THC_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
+ THC_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
}
}
diff --git a/lib/THCUNN/Sigmoid.cu b/lib/THCUNN/Sigmoid.cu
index 9414961..f2a3675 100644
--- a/lib/THCUNN/Sigmoid.cu
+++ b/lib/THCUNN/Sigmoid.cu
@@ -13,7 +13,7 @@ void THNN_CudaSigmoid_updateOutput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
+ THC_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
}
struct sigmoidupdateGradInput_functor
@@ -28,5 +28,5 @@ void THNN_CudaSigmoid_updateGradInput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
- THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
}
diff --git a/lib/THCUNN/SoftPlus.cu b/lib/THCUNN/SoftPlus.cu
index 36301d8..0d1609a 100644
--- a/lib/THCUNN/SoftPlus.cu
+++ b/lib/THCUNN/SoftPlus.cu
@@ -22,7 +22,7 @@ void THNN_CudaSoftPlus_updateOutput(THCState *state, THCudaTensor *input, THCuda
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
+ THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
}
struct softPlusupdateGradInput_functor
@@ -48,5 +48,5 @@ void THNN_CudaSoftPlus_updateGradInput(THCState *state, THCudaTensor *input, THC
{
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
- THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
}
diff --git a/lib/THCUNN/SoftShrink.cu b/lib/THCUNN/SoftShrink.cu
index 503a74a..2a08570 100644
--- a/lib/THCUNN/SoftShrink.cu
+++ b/lib/THCUNN/SoftShrink.cu
@@ -22,7 +22,7 @@ void THNN_CudaSoftShrink_updateOutput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
+ THC_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
THCudaCheck(cudaGetLastError());
}
@@ -49,6 +49,6 @@ void THNN_CudaSoftShrink_updateGradInput(THCState *state, THCudaTensor *input, T
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
THCudaCheck(cudaGetLastError());
}
diff --git a/lib/THCUNN/SpatialReflectionPadding.cu b/lib/THCUNN/SpatialReflectionPadding.cu
index 05bc691..c3ae14e 100644
--- a/lib/THCUNN/SpatialReflectionPadding.cu
+++ b/lib/THCUNN/SpatialReflectionPadding.cu
@@ -46,7 +46,7 @@ void THNN_CudaSpatialReflectionPadding_updateOutput(THCState *state,
int padL, int padR,
int padT, int padB
) {
- THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
int planeDim = 0;
@@ -139,9 +139,9 @@ void THNN_CudaSpatialReflectionPadding_updateGradInput(THCState *state,
int padL, int padR,
int padT, int padB) {
- THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
- THArgCheck(THC_canUse32BitIndexMath(state, gradOutput), 3,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, gradOutput), 3,
"output gradient tensor must fit into 32-bit index math");
int planeDim = 0;
diff --git a/lib/THCUNN/SpatialReplicationPadding.cu b/lib/THCUNN/SpatialReplicationPadding.cu
index 9c5b8be..fc09291 100644
--- a/lib/THCUNN/SpatialReplicationPadding.cu
+++ b/lib/THCUNN/SpatialReplicationPadding.cu
@@ -37,7 +37,7 @@ void THNN_CudaSpatialReplicationPadding_updateOutput(THCState *state,
int padL, int padR,
int padT, int padB
) {
- THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
int planeDim = 0;
@@ -121,9 +121,9 @@ void THNN_CudaSpatialReplicationPadding_updateGradInput(THCState *state,
int padL, int padR,
int padT, int padB) {
- THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
- THArgCheck(THC_canUse32BitIndexMath(state, gradOutput), 3,
+ THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, gradOutput), 3,
"output gradient tensor must fit into 32-bit index math");
int planeDim = 0;
diff --git a/lib/THCUNN/Sqrt.cu b/lib/THCUNN/Sqrt.cu
index 7a156f8..e1a4fe3 100644
--- a/lib/THCUNN/Sqrt.cu
+++ b/lib/THCUNN/Sqrt.cu
@@ -19,7 +19,7 @@ void THNN_CudaSqrt_updateOutput(THCState *state, THCudaTensor *input, THCudaTens
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(eps));
+ THC_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(eps));
}
struct sqrtupdateGradInput_functor
@@ -36,5 +36,5 @@ void THNN_CudaSqrt_updateGradInput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
- THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
}
diff --git a/lib/THCUNN/Square.cu b/lib/THCUNN/Square.cu
index aaa556f..a6d147c 100644
--- a/lib/THCUNN/Square.cu
+++ b/lib/THCUNN/Square.cu
@@ -13,7 +13,7 @@ void THNN_CudaSquare_updateOutput(THCState *state, THCudaTensor *input, THCudaTe
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, squareupdateOutput_functor());
+ THC_pointwiseApply2(state, output, input, squareupdateOutput_functor());
}
struct squareupdateGradInput_functor
@@ -28,5 +28,5 @@ void THNN_CudaSquare_updateGradInput(THCState *state, THCudaTensor *input, THCud
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
}
diff --git a/lib/THCUNN/Tanh.cu b/lib/THCUNN/Tanh.cu
index 5a09b5c..726169a 100644
--- a/lib/THCUNN/Tanh.cu
+++ b/lib/THCUNN/Tanh.cu
@@ -13,7 +13,7 @@ void THNN_CudaTanh_updateOutput(THCState *state, THCudaTensor *input, THCudaTens
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
+ THC_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
}
struct tanhupdateGradInput_functor
@@ -28,5 +28,5 @@ void THNN_CudaTanh_updateGradInput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
- THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
}
diff --git a/lib/THCUNN/Threshold.cu b/lib/THCUNN/Threshold.cu
index 7ba3c8e..d00a8f9 100644
--- a/lib/THCUNN/Threshold.cu
+++ b/lib/THCUNN/Threshold.cu
@@ -42,7 +42,7 @@ void THNN_CudaThreshold_updateOutput(THCState *state, THCudaTensor *input, THCud
if (inplace)
{
- THCudaTensor_pointwiseApply1(state, input,
+ THC_pointwiseApply1(state, input,
ThresholdUpdateOutputIP(threshold, val)
);
THCudaTensor_set(state, output, input);
@@ -50,7 +50,7 @@ void THNN_CudaThreshold_updateOutput(THCState *state, THCudaTensor *input, THCud
else
{
THCudaTensor_resizeAs(state, output, input);
- THCudaTensor_pointwiseApply2(state, output, input,
+ THC_pointwiseApply2(state, output, input,
ThresholdUpdateOutput(threshold, val)
);
}
@@ -95,7 +95,7 @@ void THNN_CudaThreshold_updateGradInput(THCState *state, THCudaTensor *input, TH
if (inplace)
{
- THCudaTensor_pointwiseApply2(state, gradOutput, input,
+ THC_pointwiseApply2(state, gradOutput, input,
ThresholdUpdateGradInputIP(threshold)
);
THCudaTensor_set(state, gradInput, gradOutput);
@@ -103,7 +103,7 @@ void THNN_CudaThreshold_updateGradInput(THCState *state, THCudaTensor *input, TH
else
{
THCudaTensor_resizeAs(state, gradInput, input);
- THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
+ THC_pointwiseApply3(state, gradInput, input, gradOutput,
ThresholdUpdateGradInput(threshold)
);
}