Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-01-26 00:13:20 +0300
committerGitHub <noreply@github.com>2017-01-26 00:13:20 +0300
commit0d85922d116879448485ef88ae21e83a9255a0b0 (patch)
tree87449f65566e7c1b6d68e6b6671bb3d18083c600
parent87223032d716826207c97bdac72ccc269225790d (diff)
Revert "Convert real to accreal in libTHCUNN"revert-416-half-fixes
-rw-r--r--THCUNN.lua40
-rw-r--r--lib/THCUNN/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/BatchNormalization.cu2
-rw-r--r--lib/THCUNN/generic/ELU.cu6
-rw-r--r--lib/THCUNN/generic/HardTanh.cu14
-rw-r--r--lib/THCUNN/generic/LeakyReLU.cu8
-rw-r--r--lib/THCUNN/generic/LookupTable.cu9
-rw-r--r--lib/THCUNN/generic/MarginCriterion.cu7
-rw-r--r--lib/THCUNN/generic/MultiMarginCriterion.cu6
-rw-r--r--lib/THCUNN/generic/PReLU.cu3
-rw-r--r--lib/THCUNN/generic/SoftPlus.cu12
-rw-r--r--lib/THCUNN/generic/SoftShrink.cu6
-rw-r--r--lib/THCUNN/generic/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu3
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionMM.cu3
-rw-r--r--lib/THCUNN/generic/SpatialCrossMapLRN.cu24
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialSubSampling.cu2
-rw-r--r--lib/THCUNN/generic/Sqrt.cu3
-rw-r--r--lib/THCUNN/generic/THCUNN.h96
-rw-r--r--lib/THCUNN/generic/TemporalConvolution.cu3
-rw-r--r--lib/THCUNN/generic/Threshold.cu12
-rw-r--r--lib/THCUNN/generic/VolumetricConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu3
-rw-r--r--test.lua111
27 files changed, 188 insertions, 217 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index d5bf1c2..6776a23 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -45,7 +45,7 @@ local replacements_generic =
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
- ['real'] = 'float',
+ ['real'] = 'float'
},
{
['THCTensor'] = 'THCudaDoubleTensor',
@@ -55,13 +55,6 @@ local replacements_generic =
}
}
--- gsub(s, 'real', 'float') changes accreal to accfloat.
--- typedef accfloat ahead of time.
-ffi.cdef("typedef float accfloat;")
--- gsub(s, 'real', 'double') changes accreal to accfloat.
--- typedef accdouble ahead of time
-ffi.cdef("typedef double accdouble;")
-
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
@@ -70,12 +63,9 @@ if cutorch.hasHalf then
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
- ['real'] = 'half',
+ ['real'] = 'half'
}
table.insert(replacements_generic, half_replacement)
- -- gsub(s, 'real', 'double') changes accreal to accfloat.
- -- typedef acchalf ahead of time
- ffi.cdef("typedef float acchalf;")
end
for i=1,#replacements_generic do
@@ -143,9 +133,29 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
- local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
- THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
- torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
+-- in order to call 'half' functions from lua, convert real arguments from
+-- to half since there is no other defined conversion
+local transform_reals_to_half = function(func_name, real_args, ...)
+ t = {}
+ -- this select logic is necessary to deal with nil arguments
+ for i = 1, select('#', ...) do
+ t[i] = select(i, ...)
+ end
+ for k,v in ipairs(real_args[func_name]) do
+ -- first argument (THCState) is added implicitly by bind
+ t[v-1] = THC.THC_float2half(t[v-1])
+ end
+ return t
+end
+
+local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
+for k,v in pairs(raw_half_functions) do
+ -- select required in case there are trailing nils
+ raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...)))
+end
+end
+THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
+torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)
diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu
index f36206f..a7ffa1e 100644
--- a/lib/THCUNN/SparseLinear.cu
+++ b/lib/THCUNN/SparseLinear.cu
@@ -34,8 +34,8 @@ void THNN_CudaHalfSparseLinear_accGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- float weightDecay,
- float scale) {
+ double weightDecay,
+ double scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -56,8 +56,8 @@ void THNN_CudaHalfSparseLinear_legacyAccGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- float weightDecay,
- float scale) {
+ double weightDecay,
+ double scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -76,7 +76,7 @@ void THNN_CudaHalfSparseLinear_updateParameters(
THCudaHalfTensor *gradWeight,
THCudaHalfTensor *gradBias,
THCudaHalfTensor *lastInput,
- float learningRate) {
+ double learningRate) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
#endif
diff --git a/lib/THCUNN/generic/BatchNormalization.cu b/lib/THCUNN/generic/BatchNormalization.cu
index d42f18e..cbe99f3 100644
--- a/lib/THCUNN/generic/BatchNormalization.cu
+++ b/lib/THCUNN/generic/BatchNormalization.cu
@@ -69,7 +69,7 @@ void THNN_(BatchNormalization_backward)(
THCState *state, THCTensor *input_, THCTensor *gradOutput_,
THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_,
THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_,
- THCTensor *saveMean_, THCTensor *saveStd_, bool train, double scale, double eps) {
+ THCTensor *saveMean_, THCTensor *saveStd_, bool train, float scale, double eps) {
THCUNN_check_shape(state, input_, gradOutput_);
DeviceTensor3 input = devicetensor<3>(state, input_);
diff --git a/lib/THCUNN/generic/ELU.cu b/lib/THCUNN/generic/ELU.cu
index 4b8da27..0beb5a1 100644
--- a/lib/THCUNN/generic/ELU.cu
+++ b/lib/THCUNN/generic/ELU.cu
@@ -9,10 +9,9 @@ void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal alpha_,
+ real alpha,
bool inplace)
{
- real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -34,10 +33,9 @@ void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- accreal alpha_,
+ real alpha,
bool inplace)
{
- real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/HardTanh.cu b/lib/THCUNN/generic/HardTanh.cu
index 47835f0..0651431 100644
--- a/lib/THCUNN/generic/HardTanh.cu
+++ b/lib/THCUNN/generic/HardTanh.cu
@@ -8,13 +8,10 @@ void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal min_val_,
- accreal max_val_,
+ real min_val,
+ real max_val,
bool inplace)
{
- real min_val = ScalarConvert<accreal, real>::to(min_val_);
- real max_val = ScalarConvert<accreal, real>::to(max_val_);
-
THCUNN_assertSameGPU(state, 2, input, output);
if(inplace)
{
@@ -34,13 +31,10 @@ void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal min_val_,
- accreal max_val_,
+ real min_val,
+ real max_val,
bool inplace)
{
- real min_val = ScalarConvert<accreal, real>::to(min_val_);
- real max_val = ScalarConvert<accreal, real>::to(max_val_);
-
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/LeakyReLU.cu b/lib/THCUNN/generic/LeakyReLU.cu
index 179819d..23cf59a 100644
--- a/lib/THCUNN/generic/LeakyReLU.cu
+++ b/lib/THCUNN/generic/LeakyReLU.cu
@@ -8,11 +8,9 @@ void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal negval_,
+ real negval,
bool inplace)
{
- real negval = ScalarConvert<accreal, real>::to(negval_);
-
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -34,11 +32,9 @@ void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal negval_,
+ real negval,
bool inplace)
{
- real negval = ScalarConvert<accreal, real>::to(negval_);
-
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu
index fa7c5ac..bd59a04 100644
--- a/lib/THCUNN/generic/LookupTable.cu
+++ b/lib/THCUNN/generic/LookupTable.cu
@@ -12,9 +12,8 @@ void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices,
bool scaleGradByFreq,
int paddingValue,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (!(THCIndexTensor_(isContiguous)(state, input) &&
@@ -120,11 +119,9 @@ void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- accreal maxNorm_,
- accreal normType_)
+ real maxNorm,
+ real normType)
{
- real maxNorm = ScalarConvert<accreal, real>::to(maxNorm_);
- real normType = ScalarConvert<accreal, real>::to(normType_);
THCUNN_assertSameGPU(state, 2, idx, weight);
if (!(THCIndexTensor_(isContiguous)(state, idx) &&
THCTensor_(isContiguous)(state, weight)))
diff --git a/lib/THCUNN/generic/MarginCriterion.cu b/lib/THCUNN/generic/MarginCriterion.cu
index 221f9d9..d5678ec 100644
--- a/lib/THCUNN/generic/MarginCriterion.cu
+++ b/lib/THCUNN/generic/MarginCriterion.cu
@@ -8,9 +8,8 @@ void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- accreal margin_)
+ real margin)
{
- real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_check_nElement(state, input, target);
THCUNN_check_dim_size(state, output, 1, 0, 1);
THCUNN_assertSameGPU(state, 2, input, target);
@@ -41,10 +40,8 @@ void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- accreal margin_)
+ real margin)
{
- real margin = ScalarConvert<accreal, real>::to(margin_);
-
THCUNN_check_nElement(state, input, target);
THCUNN_assertSameGPU(state, 3, input, target, gradInput);
diff --git a/lib/THCUNN/generic/MultiMarginCriterion.cu b/lib/THCUNN/generic/MultiMarginCriterion.cu
index c3ff2d6..8026331 100644
--- a/lib/THCUNN/generic/MultiMarginCriterion.cu
+++ b/lib/THCUNN/generic/MultiMarginCriterion.cu
@@ -11,9 +11,8 @@ void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights,
- accreal margin_)
+ real margin)
{
- real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 2, input, target);
input = THCTensor_(newContiguous)(state, input);
if(weights)
@@ -103,9 +102,8 @@ void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights,
- accreal margin_)
+ real margin)
{
- real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 3, input, gradInput, target);
input = THCTensor_(newContiguous)(state, input);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/PReLU.cu b/lib/THCUNN/generic/PReLU.cu
index db9b0d2..89087fb 100644
--- a/lib/THCUNN/generic/PReLU.cu
+++ b/lib/THCUNN/generic/PReLU.cu
@@ -92,9 +92,8 @@ void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_check_nElement(state, input, gradOutput);
// use grad input for temporary storage, then call updateGradInput again
diff --git a/lib/THCUNN/generic/SoftPlus.cu b/lib/THCUNN/generic/SoftPlus.cu
index 17cde70..e72038e 100644
--- a/lib/THCUNN/generic/SoftPlus.cu
+++ b/lib/THCUNN/generic/SoftPlus.cu
@@ -8,11 +8,9 @@ void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal beta_,
- accreal threshold_)
+ real beta,
+ real threshold)
{
- real beta = ScalarConvert<accreal, real>::to(beta_);
- real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor<real>(threshold, beta));
@@ -24,11 +22,9 @@ void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- accreal beta_,
- accreal threshold_)
+ real beta,
+ real threshold)
{
- real beta = ScalarConvert<accreal, real>::to(beta_);
- real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
diff --git a/lib/THCUNN/generic/SoftShrink.cu b/lib/THCUNN/generic/SoftShrink.cu
index 9e47695..261593f 100644
--- a/lib/THCUNN/generic/SoftShrink.cu
+++ b/lib/THCUNN/generic/SoftShrink.cu
@@ -8,9 +8,8 @@ void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal lambda_)
+ real lambda)
{
- real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput<real>(lambda));
@@ -22,9 +21,8 @@ void THNN_(SoftShrink_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal lambda_)
+ real lambda)
{
- real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu
index 6838cac..f22b233 100644
--- a/lib/THCUNN/generic/SparseLinear.cu
+++ b/lib/THCUNN/generic/SparseLinear.cu
@@ -127,8 +127,8 @@ void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- accreal weightDecay,
- accreal scale)
+ double weightDecay,
+ double scale)
{
long outDim = THCTensor_(size)(state, weight, 0);
long inDim = THCTensor_(size)(state, weight, 1);
@@ -237,8 +237,8 @@ void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- accreal weightDecay,
- accreal scale) {
+ double weightDecay,
+ double scale) {
THError("CUDA does not support legacy input format, please use a table of nnz x 2 vectors");
}
@@ -259,7 +259,7 @@ void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- accreal learningRate) {
+ double learningRate) {
THCTensor_(cadd)(state, weight, weight, -learningRate, gradWeight);
THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
}
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index 0d4b9ad..afbc24d 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -309,9 +309,8 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight,
gradBias, finput);
diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu
index b4ae8e5..e7aeacb 100644
--- a/lib/THCUNN/generic/SpatialConvolutionMM.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu
@@ -335,9 +335,8 @@ void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- accreal scale_) {
+ real scale) {
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialCrossMapLRN.cu b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
index 6b79c15..a09ea0b 100644
--- a/lib/THCUNN/generic/SpatialCrossMapLRN.cu
+++ b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
@@ -3,12 +3,8 @@
#else
void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
- THCTensor* scale, int local_size, accreal alpha_, accreal beta_, accreal k_)
+ THCTensor* scale, int local_size, real alpha, real beta, real k)
{
- real alpha = ScalarConvert<accreal, real>::to(alpha_);
- real beta = ScalarConvert<accreal, real>::to(beta_);
- real k = ScalarConvert<accreal, real>::to(k_);
-
THCTensor_(resizeAs)(state, output, input);
THCTensor_(resizeAs)(state, scale, input);
@@ -49,12 +45,8 @@ void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
void LRNbackward(THCState* state, THCTensor* input, THCTensor* output,
THCTensor* gradOutput, THCTensor* gradInput, THCTensor* scale,
- int local_size, accreal alpha_, accreal beta_, accreal k_)
+ int local_size, real alpha, real beta, real k)
{
- real alpha = ScalarConvert<accreal, real>::to(alpha_);
- real beta = ScalarConvert<accreal, real>::to(beta_);
- real k = ScalarConvert<accreal, real>::to(k_);
-
THCTensor_(resizeAs)(state, gradInput, input);
int batchSize;
@@ -97,9 +89,9 @@ void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- accreal alpha,
- accreal beta,
- accreal k)
+ real alpha,
+ real beta,
+ real k)
{
LRNforward(state, input, output, scale, size, alpha, beta, k);
}
@@ -112,9 +104,9 @@ void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- accreal alpha,
- accreal beta,
- accreal k)
+ real alpha,
+ real beta,
+ real k)
{
LRNbackward(state, input, output, gradOutput, gradInput, scale, size, alpha, beta, k);
}
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index 02a640b..7b656d3 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -322,9 +322,8 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- accreal scale_) {
+ real scale) {
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index 54fda23..7a5d7ea 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -315,9 +315,8 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
diff --git a/lib/THCUNN/generic/SpatialSubSampling.cu b/lib/THCUNN/generic/SpatialSubSampling.cu
index ef3c508..b918962 100644
--- a/lib/THCUNN/generic/SpatialSubSampling.cu
+++ b/lib/THCUNN/generic/SpatialSubSampling.cu
@@ -191,7 +191,7 @@ void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- accreal scale)
+ float scale)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradWeight, gradBias);
THNN_(SpatialSubSampling_shapeCheck)(state, input, gradOutput, gradWeight, kW, kH);
diff --git a/lib/THCUNN/generic/Sqrt.cu b/lib/THCUNN/generic/Sqrt.cu
index b6a68f8..3602cbe 100644
--- a/lib/THCUNN/generic/Sqrt.cu
+++ b/lib/THCUNN/generic/Sqrt.cu
@@ -8,9 +8,8 @@ void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal eps_)
+ real eps)
{
- real eps = ScalarConvert<accreal, real>::to(eps_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, sqrtupdateOutput_functor<real>(eps));
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index c9d7e2c..bf903b9 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -54,7 +54,7 @@ TH_API void THNN_(BatchNormalization_backward)(
THCTensor *saveMean_,
THCTensor *saveStd_,
bool train,
- double scale,
+ float scale,
double eps);
TH_API void THNN_(BCECriterion_updateOutput)(
@@ -109,7 +109,7 @@ TH_API void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal alpha,
+ real alpha,
bool inplace);
TH_API void THNN_(ELU_updateGradInput)(
@@ -118,15 +118,15 @@ TH_API void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- accreal alpha,
+ real alpha,
bool inplace);
TH_API void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal min_val,
- accreal max_val,
+ real min_val,
+ real max_val,
bool inplace);
TH_API void THNN_(HardTanh_updateGradInput)(
@@ -134,15 +134,15 @@ TH_API void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal min_val,
- accreal max_val,
+ real min_val,
+ real max_val,
bool inplace);
TH_API void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal negval,
+ real negval,
bool inplace);
TH_API void THNN_(LeakyReLU_updateGradInput)(
@@ -150,7 +150,7 @@ TH_API void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal negval,
+ real negval,
bool inplace);
TH_API void THNN_(LogSigmoid_updateOutput)(
@@ -188,14 +188,14 @@ TH_API void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices, // [OPTIONAL]
bool scaleGradByFreq,
int paddingValue,
- accreal scale);
+ real scale);
TH_API void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- accreal maxNorm,
- accreal normType);
+ real maxNorm,
+ real normType);
TH_API void THNN_(L1Cost_updateOutput)(
THCState *state,
@@ -214,7 +214,7 @@ TH_API void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- accreal margin);
+ real margin);
TH_API void THNN_(MarginCriterion_updateGradInput)(
THCState *state,
@@ -222,7 +222,7 @@ TH_API void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- accreal margin);
+ real margin);
TH_API void THNN_(MSECriterion_updateOutput)(
THCState *state,
@@ -262,7 +262,7 @@ TH_API void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- accreal margin);
+ real margin);
TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
THCState *state,
@@ -272,7 +272,7 @@ TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- accreal margin);
+ real margin);
TH_API void THNN_(PReLU_updateOutput)(
THCState *state,
@@ -299,7 +299,7 @@ TH_API void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- accreal scale);
+ real scale);
TH_API void THNN_(SmoothL1Criterion_updateOutput)(
THCState *state,
@@ -330,8 +330,8 @@ TH_API void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- accreal weightDecay,
- accreal scale);
+ double weightDecay,
+ double scale);
TH_API void THNN_(SparseLinear_legacyUpdateOutput)(
THCState *state,
@@ -348,8 +348,8 @@ TH_API void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- accreal weightDecay,
- accreal scale);
+ double weightDecay,
+ double scale);
TH_API void THNN_(SparseLinear_zeroGradParameters)(
THCState *state,
@@ -364,7 +364,7 @@ TH_API void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- accreal learningRate);
+ double learningRate);
TH_API void THNN_(SpatialAdaptiveMaxPooling_updateOutput)(
THCState *state,
@@ -461,7 +461,7 @@ TH_API void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- accreal scale);
+ real scale);
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
THCState *state,
@@ -498,7 +498,7 @@ TH_API void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- accreal scale);
+ real scale);
TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCState *state,
@@ -506,9 +506,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- accreal alpha,
- accreal beta,
- accreal k);
+ real alpha,
+ real beta,
+ real k);
TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCState *state,
@@ -518,9 +518,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- accreal alpha,
- accreal beta,
- accreal k);
+ real alpha,
+ real beta,
+ real k);
TH_API void THNN_(SpatialDilatedConvolution_updateOutput)(
THCState *state,
@@ -559,7 +559,7 @@ TH_API void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- accreal scale);
+ real scale);
TH_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -639,7 +639,7 @@ TH_API void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- accreal scale);
+ real scale);
TH_API void THNN_(SpatialMaxPooling_updateOutput)(
THCState *state,
@@ -733,7 +733,7 @@ TH_API void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- accreal scale);
+ float scale);
TH_API void THNN_(SpatialUpSamplingBilinear_updateOutput)(
THCState *state,
@@ -830,8 +830,8 @@ TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal beta,
- accreal threshold);
+ real beta,
+ real threshold);
TH_API void THNN_(SoftPlus_updateGradInput)(
THCState *state,
@@ -839,21 +839,21 @@ TH_API void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- accreal beta,
- accreal threshold);
+ real beta,
+ real threshold);
TH_API void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal lambda);
+ real lambda);
TH_API void THNN_(SoftShrink_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal lambda);
+ real lambda);
TH_API void THNN_(Square_updateOutput)(
THCState *state,
@@ -870,7 +870,7 @@ TH_API void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal eps);
+ real eps);
TH_API void THNN_(Sqrt_updateGradInput)(
THCState *state,
@@ -916,7 +916,7 @@ TH_API void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- accreal scale);
+ real scale);
TH_API void THNN_(TemporalMaxPooling_updateOutput)(
THCState *state,
@@ -937,8 +937,8 @@ TH_API void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal threshold,
- accreal val,
+ real threshold,
+ real val,
bool inplace);
TH_API void THNN_(Threshold_updateGradInput)(
@@ -946,8 +946,8 @@ TH_API void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal threshold,
- accreal val,
+ real threshold,
+ real val,
bool inplace);
TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
@@ -996,7 +996,7 @@ TH_API void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- accreal scale);
+ real scale);
TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
THCState *state,
@@ -1035,7 +1035,7 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- accreal scale);
+ real scale);
TH_API void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -1095,7 +1095,7 @@ TH_API void THNN_(VolumetricFullConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int adjT, int adjW, int adjH,
- accreal scale);
+ real scale);
TH_API void THNN_(VolumetricMaxPooling_updateOutput)(
THCState *state,
diff --git a/lib/THCUNN/generic/TemporalConvolution.cu b/lib/THCUNN/generic/TemporalConvolution.cu
index 5658527..a51894d 100644
--- a/lib/THCUNN/generic/TemporalConvolution.cu
+++ b/lib/THCUNN/generic/TemporalConvolution.cu
@@ -273,9 +273,8 @@ void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- accreal scale_) {
+ real scale) {
- real scale = ScalarConvert<accreal, real>::to(scale_);
long nInputFrame;
long nOutputFrame;
diff --git a/lib/THCUNN/generic/Threshold.cu b/lib/THCUNN/generic/Threshold.cu
index 0b7b79e..4f9f622 100644
--- a/lib/THCUNN/generic/Threshold.cu
+++ b/lib/THCUNN/generic/Threshold.cu
@@ -8,12 +8,10 @@ void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- accreal threshold_,
- accreal val_,
+ real threshold,
+ real val,
bool inplace)
{
- real threshold = ScalarConvert<accreal, real>::to(threshold_);
- real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -39,12 +37,10 @@ void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- accreal threshold_,
- accreal val_,
+ real threshold,
+ real val,
bool inplace)
{
- real threshold = ScalarConvert<accreal, real>::to(threshold_);
- real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/VolumetricConvolution.cu b/lib/THCUNN/generic/VolumetricConvolution.cu
index 5b982c9..d6da545 100644
--- a/lib/THCUNN/generic/VolumetricConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricConvolution.cu
@@ -362,9 +362,8 @@ void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, gradBias, columns, ones);
diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
index ffeea7f..b0145a5 100644
--- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
@@ -336,9 +336,8 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- accreal scale_) {
+ real scale) {
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index 127babc..334c7da 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -344,9 +344,8 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int adjT, int adjW, int adjH,
- accreal scale_)
+ real scale)
{
- real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
diff --git a/test.lua b/test.lua
index 5ab07bf..c3ed9bb 100644
--- a/test.lua
+++ b/test.lua
@@ -365,17 +365,17 @@ function cunntest.Square_transposed()
end
function cunntest.SoftShrink_forward()
- local r = math.random()
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_forward(nn.SoftShrink(r), 'SoftShrink', precision_forward)
end
function cunntest.SoftShrink_backward()
- local r = math.random()
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_backward(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
function cunntest.SoftShrink_transposed()
- local r = math.random()
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_transposed(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
@@ -2045,8 +2045,8 @@ function cunntest.SpatialMaxPooling_forward()
local sj = math.random(1,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2083,8 +2083,8 @@ function cunntest.SpatialMaxPooling_forward_batch()
local sj = math.random(2,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2118,8 +2118,8 @@ function cunntest.SpatialMaxUnpooling_forward_batch()
local sj = kj
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ceil_mode = math.random(0,1) == 1
local fun = ceil_mode and torch.ceil or torch.floor
local ini = fun((outi + padi*2 - ki)/si) +1
@@ -2159,8 +2159,8 @@ function cunntest.SpatialMaxPooling_backward()
local sj = math.random(1,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = true--math.random(0,1) == 1
@@ -2203,8 +2203,8 @@ function cunntest.SpatialMaxPooling_backward_batch()
local sj = math.random(2,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2246,8 +2246,8 @@ function cunntest.SpatialMaxUnpooling_backward_batch()
local sj = kj
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ceil_mode = math.random(0,1) == 1
local fun = ceil_mode and torch.ceil or torch.floor
local ini = fun((outi + padi*2 - ki)/si) +1
@@ -2296,8 +2296,8 @@ function cunntest.SpatialDilatedMaxPooling_forward()
local sj = math.random(1,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2336,8 +2336,8 @@ function cunntest.SpatialDilatedMaxPooling_forward_batch()
local sj = math.random(2,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2372,8 +2372,8 @@ function cunntest.SpatialDilatedMaxPooling_backward()
local sj = math.random(1,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2417,8 +2417,8 @@ function cunntest.SpatialDilatedMaxPooling_backward_batch()
local sj = math.random(2,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2611,8 +2611,8 @@ function cunntest.SpatialAveragePooling_forward()
local sj = math.random(1,kj)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2650,8 +2650,8 @@ function cunntest.SpatialAveragePooling_forward_batch()
local sj = math.random(1,kj)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2688,8 +2688,8 @@ function cunntest.SpatialAveragePooling_backward()
local sj = math.random(1,kj)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2735,8 +2735,8 @@ function cunntest.SpatialAveragePooling_backward_batch()
local sj = math.random(1,kj)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -3401,6 +3401,9 @@ function cunntest.mse()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
+ if (typename == 'torch.CudaHalfTensor') then
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
+ end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -3432,6 +3435,9 @@ function cunntest.SmoothL1()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
+ if (typename == 'torch.CudaHalfTensor') then
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
+ end
mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
@@ -3994,6 +4000,9 @@ function cunntest.l1cost()
local cout = cmod:forward(cinput)
local cgin = cmod:backward(cinput)
+ if (typename == 'torch.CudaHalfTensor') then
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
+ end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -4395,9 +4404,9 @@ function cunntest.VolumetricMaxPooling_forward()
local iT = math.random(kT*2, 60)
local iH = math.random(kH*2, 60)
local iW = math.random(kW*2, 60)
- local padT = math.random(0,math.floor(kT/2)-1)
- local padH = math.random(0,math.floor(kH/2)-1)
- local padW = math.random(0,math.floor(kW/2)-1)
+ local padT = math.random(0,kT/2-1)
+ local padH = math.random(0,kH/2-1)
+ local padW = math.random(0,kW/2-1)
local iF = math.random(1, 16) -- features
local oT = math.floor((iT - kT + 2*padT) / dT + 1)
local oH = math.floor((iH - kH + 2*padH) / dH + 1)
@@ -4431,9 +4440,9 @@ function cunntest.VolumetricMaxPooling_backward()
local iT = math.random(kT*2, 60)
local iH = math.random(kH*2, 60)
local iW = math.random(kW*2, 60)
- local padT = math.random(0,math.floor(kT/2)-1)
- local padH = math.random(0,math.floor(kH/2)-1)
- local padW = math.random(0,math.floor(kW/2)-1)
+ local padT = math.random(0,kT/2-1)
+ local padH = math.random(0,kH/2-1)
+ local padW = math.random(0,kW/2-1)
local iF = math.random(1, 16) -- features
local oT = math.floor((iT - kT + 2*padT) / dT + 1)
local oH = math.floor((iH - kH + 2*padH) / dH + 1)
@@ -4475,9 +4484,9 @@ function cunntest.VolumetricDilatedMaxPooling_forward_batch()
local outt = math.random(1,10)
local outi = math.random(1,33)
local outj = math.random(1,33)
- local padt = math.random(0,math.floor(kt/2)-1)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padt = math.random(0,kt/2-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationt = math.random(1,10)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
@@ -4519,9 +4528,9 @@ function cunntest.VolumetricDilatedMaxPooling_backward_batch()
local outt = math.random(8,16)
local outi = math.random(8,16)
local outj = math.random(8,16)
- local padt = math.random(0,math.floor(kt/2)-1)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padt = math.random(0,kt/2-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local dilationt = math.random(1,10)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
@@ -4569,9 +4578,9 @@ function cunntest.VolumetricMaxUnpooling_forward_batch()
local outt = math.random(32,128)
local outi = math.random(32,128)
local outj = math.random(32,128)
- local padt = math.random(0,math.floor(kt/2)-1)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padt = math.random(0,kt/2-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local it = math.max(((outt + padt*2 - kt)/st) +1, kt)
local ii = math.max(((outi + padi*2 - ki)/si) +1, ki)
local ij = math.max(((outj + padj*2 - kj)/sj) +1, kj)
@@ -4610,9 +4619,9 @@ function cunntest.VolumetricMaxUnpooling_backward_batch()
local outt = math.random(32,128)
local outi = math.random(32,128)
local outj = math.random(32,128)
- local padt = math.random(0,math.floor(kt/2)-1)
- local padi = math.random(0,math.floor(ki/2)-1)
- local padj = math.random(0,math.floor(kj/2)-1)
+ local padt = math.random(0,kt/2-1)
+ local padi = math.random(0,ki/2-1)
+ local padj = math.random(0,kj/2-1)
local it = math.max(((outt + padt*2 - kt)/st) +1, kt)
local ii = math.max(((outi + padi*2 - ki)/si) +1, ki)
local ij = math.max(((outj + padj*2 - kj)/sj) +1, kj)
@@ -4937,8 +4946,8 @@ function cunntest.VolumetricFullConvolution_pair_test()
local dT = math.random(1,3)
local dH = math.random(1,3)
local dW = dH
- local pT = math.floor((kT-1)/2)
- local pH = math.floor((kH-1)/2)
+ local pT = (kT-1)/2
+ local pH = (kH-1)/2
local pW = pH
local inChan = math.random(1,32)