Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavan Yalamanchili <pyalamanchili@twitter.com>2017-02-17 04:25:33 +0300
committerPavan Yalamanchili <pyalamanchili@twitter.com>2017-02-17 04:33:03 +0300
commit3996dbb87ec79d087c37bc6f4fe8f23a3767c88c (patch)
treed79dced267c31b5250930087ebd0e9fee21be340
parent618f847d94ad65baef1c1614ed241d6e4bea7151 (diff)
Convert real to accreal in libTHCUNN
- This reverts commit 0d85922d116879448485ef88ae21e83a9255a0b0. - Includes fixes for TemporalRowConvolution
-rw-r--r--THCUNN.lua40
-rw-r--r--lib/THCUNN/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/BatchNormalization.cu2
-rw-r--r--lib/THCUNN/generic/ELU.cu6
-rw-r--r--lib/THCUNN/generic/HardTanh.cu14
-rw-r--r--lib/THCUNN/generic/LeakyReLU.cu8
-rw-r--r--lib/THCUNN/generic/LookupTable.cu9
-rw-r--r--lib/THCUNN/generic/MarginCriterion.cu7
-rw-r--r--lib/THCUNN/generic/MultiMarginCriterion.cu6
-rw-r--r--lib/THCUNN/generic/PReLU.cu3
-rw-r--r--lib/THCUNN/generic/SoftPlus.cu12
-rw-r--r--lib/THCUNN/generic/SoftShrink.cu6
-rw-r--r--lib/THCUNN/generic/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu3
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionMM.cu3
-rw-r--r--lib/THCUNN/generic/SpatialCrossMapLRN.cu24
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialSubSampling.cu2
-rw-r--r--lib/THCUNN/generic/Sqrt.cu3
-rw-r--r--lib/THCUNN/generic/THCUNN.h98
-rw-r--r--lib/THCUNN/generic/TemporalConvolution.cu3
-rw-r--r--lib/THCUNN/generic/TemporalRowConvolution.cu3
-rw-r--r--lib/THCUNN/generic/Threshold.cu12
-rw-r--r--lib/THCUNN/generic/VolumetricConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu23
-rw-r--r--test.lua111
28 files changed, 230 insertions, 200 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index 573690b..d5bf1c2 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -45,7 +45,7 @@ local replacements_generic =
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
- ['real'] = 'float'
+ ['real'] = 'float',
},
{
['THCTensor'] = 'THCudaDoubleTensor',
@@ -55,6 +55,13 @@ local replacements_generic =
}
}
+-- gsub(s, 'real', 'float') changes accreal to accfloat.
+-- typedef accfloat ahead of time.
+ffi.cdef("typedef float accfloat;")
+-- gsub(s, 'real', 'double') changes accreal to accfloat.
+-- typedef accdouble ahead of time
+ffi.cdef("typedef double accdouble;")
+
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
@@ -63,9 +70,12 @@ if cutorch.hasHalf then
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
- ['real'] = 'half'
+ ['real'] = 'half',
}
table.insert(replacements_generic, half_replacement)
+ -- gsub(s, 'real', 'double') changes accreal to accfloat.
+ -- typedef acchalf ahead of time
+ ffi.cdef("typedef float acchalf;")
end
for i=1,#replacements_generic do
@@ -133,29 +143,9 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
--- in order to call 'half' functions from lua, convert real arguments from
--- to half since there is no other defined conversion
-local transform_reals_to_half = function(func_name, real_args, ...)
- local t = {}
- -- this select logic is necessary to deal with nil arguments
- for i = 1, select('#', ...) do
- t[i] = select(i, ...)
- end
- for k,v in ipairs(real_args[func_name]) do
- -- first argument (THCState) is added implicitly by bind
- t[v-1] = THC.THC_float2half(t[v-1])
- end
- return t
-end
-
-local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
-for k,v in pairs(raw_half_functions) do
- -- select required in case there are trailing nils
- raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...)))
-end
-end
-THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
-torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
+ local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
+ THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
+ torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)
diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu
index a7ffa1e..f36206f 100644
--- a/lib/THCUNN/SparseLinear.cu
+++ b/lib/THCUNN/SparseLinear.cu
@@ -34,8 +34,8 @@ void THNN_CudaHalfSparseLinear_accGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- double weightDecay,
- double scale) {
+ float weightDecay,
+ float scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -56,8 +56,8 @@ void THNN_CudaHalfSparseLinear_legacyAccGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- double weightDecay,
- double scale) {
+ float weightDecay,
+ float scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -76,7 +76,7 @@ void THNN_CudaHalfSparseLinear_updateParameters(
THCudaHalfTensor *gradWeight,
THCudaHalfTensor *gradBias,
THCudaHalfTensor *lastInput,
- double learningRate) {
+ float learningRate) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
#endif
diff --git a/lib/THCUNN/generic/BatchNormalization.cu b/lib/THCUNN/generic/BatchNormalization.cu
index cbe99f3..d42f18e 100644
--- a/lib/THCUNN/generic/BatchNormalization.cu
+++ b/lib/THCUNN/generic/BatchNormalization.cu
@@ -69,7 +69,7 @@ void THNN_(BatchNormalization_backward)(
THCState *state, THCTensor *input_, THCTensor *gradOutput_,
THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_,
THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_,
- THCTensor *saveMean_, THCTensor *saveStd_, bool train, float scale, double eps) {
+ THCTensor *saveMean_, THCTensor *saveStd_, bool train, double scale, double eps) {
THCUNN_check_shape(state, input_, gradOutput_);
DeviceTensor3 input = devicetensor<3>(state, input_);
diff --git a/lib/THCUNN/generic/ELU.cu b/lib/THCUNN/generic/ELU.cu
index 0beb5a1..4b8da27 100644
--- a/lib/THCUNN/generic/ELU.cu
+++ b/lib/THCUNN/generic/ELU.cu
@@ -9,9 +9,10 @@ void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real alpha,
+ accreal alpha_,
bool inplace)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -33,9 +34,10 @@ void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real alpha,
+ accreal alpha_,
bool inplace)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/HardTanh.cu b/lib/THCUNN/generic/HardTanh.cu
index 0651431..47835f0 100644
--- a/lib/THCUNN/generic/HardTanh.cu
+++ b/lib/THCUNN/generic/HardTanh.cu
@@ -8,10 +8,13 @@ void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real min_val,
- real max_val,
+ accreal min_val_,
+ accreal max_val_,
bool inplace)
{
+ real min_val = ScalarConvert<accreal, real>::to(min_val_);
+ real max_val = ScalarConvert<accreal, real>::to(max_val_);
+
THCUNN_assertSameGPU(state, 2, input, output);
if(inplace)
{
@@ -31,10 +34,13 @@ void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real min_val,
- real max_val,
+ accreal min_val_,
+ accreal max_val_,
bool inplace)
{
+ real min_val = ScalarConvert<accreal, real>::to(min_val_);
+ real max_val = ScalarConvert<accreal, real>::to(max_val_);
+
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/LeakyReLU.cu b/lib/THCUNN/generic/LeakyReLU.cu
index 23cf59a..179819d 100644
--- a/lib/THCUNN/generic/LeakyReLU.cu
+++ b/lib/THCUNN/generic/LeakyReLU.cu
@@ -8,9 +8,11 @@ void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real negval,
+ accreal negval_,
bool inplace)
{
+ real negval = ScalarConvert<accreal, real>::to(negval_);
+
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -32,9 +34,11 @@ void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real negval,
+ accreal negval_,
bool inplace)
{
+ real negval = ScalarConvert<accreal, real>::to(negval_);
+
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu
index bd59a04..fa7c5ac 100644
--- a/lib/THCUNN/generic/LookupTable.cu
+++ b/lib/THCUNN/generic/LookupTable.cu
@@ -12,8 +12,9 @@ void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices,
bool scaleGradByFreq,
int paddingValue,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (!(THCIndexTensor_(isContiguous)(state, input) &&
@@ -119,9 +120,11 @@ void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- real maxNorm,
- real normType)
+ accreal maxNorm_,
+ accreal normType_)
{
+ real maxNorm = ScalarConvert<accreal, real>::to(maxNorm_);
+ real normType = ScalarConvert<accreal, real>::to(normType_);
THCUNN_assertSameGPU(state, 2, idx, weight);
if (!(THCIndexTensor_(isContiguous)(state, idx) &&
THCTensor_(isContiguous)(state, weight)))
diff --git a/lib/THCUNN/generic/MarginCriterion.cu b/lib/THCUNN/generic/MarginCriterion.cu
index d5678ec..221f9d9 100644
--- a/lib/THCUNN/generic/MarginCriterion.cu
+++ b/lib/THCUNN/generic/MarginCriterion.cu
@@ -8,8 +8,9 @@ void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_check_nElement(state, input, target);
THCUNN_check_dim_size(state, output, 1, 0, 1);
THCUNN_assertSameGPU(state, 2, input, target);
@@ -40,8 +41,10 @@ void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
+
THCUNN_check_nElement(state, input, target);
THCUNN_assertSameGPU(state, 3, input, target, gradInput);
diff --git a/lib/THCUNN/generic/MultiMarginCriterion.cu b/lib/THCUNN/generic/MultiMarginCriterion.cu
index 8026331..c3ff2d6 100644
--- a/lib/THCUNN/generic/MultiMarginCriterion.cu
+++ b/lib/THCUNN/generic/MultiMarginCriterion.cu
@@ -11,8 +11,9 @@ void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 2, input, target);
input = THCTensor_(newContiguous)(state, input);
if(weights)
@@ -102,8 +103,9 @@ void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 3, input, gradInput, target);
input = THCTensor_(newContiguous)(state, input);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/PReLU.cu b/lib/THCUNN/generic/PReLU.cu
index 89087fb..db9b0d2 100644
--- a/lib/THCUNN/generic/PReLU.cu
+++ b/lib/THCUNN/generic/PReLU.cu
@@ -92,8 +92,9 @@ void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_check_nElement(state, input, gradOutput);
// use grad input for temporary storage, then call updateGradInput again
diff --git a/lib/THCUNN/generic/SoftPlus.cu b/lib/THCUNN/generic/SoftPlus.cu
index e72038e..17cde70 100644
--- a/lib/THCUNN/generic/SoftPlus.cu
+++ b/lib/THCUNN/generic/SoftPlus.cu
@@ -8,9 +8,11 @@ void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real beta,
- real threshold)
+ accreal beta_,
+ accreal threshold_)
{
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor<real>(threshold, beta));
@@ -22,9 +24,11 @@ void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real beta,
- real threshold)
+ accreal beta_,
+ accreal threshold_)
{
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
diff --git a/lib/THCUNN/generic/SoftShrink.cu b/lib/THCUNN/generic/SoftShrink.cu
index 261593f..9e47695 100644
--- a/lib/THCUNN/generic/SoftShrink.cu
+++ b/lib/THCUNN/generic/SoftShrink.cu
@@ -8,8 +8,9 @@ void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real lambda)
+ accreal lambda_)
{
+ real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput<real>(lambda));
@@ -21,8 +22,9 @@ void THNN_(SoftShrink_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real lambda)
+ accreal lambda_)
{
+ real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu
index f22b233..6838cac 100644
--- a/lib/THCUNN/generic/SparseLinear.cu
+++ b/lib/THCUNN/generic/SparseLinear.cu
@@ -127,8 +127,8 @@ void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale)
+ accreal weightDecay,
+ accreal scale)
{
long outDim = THCTensor_(size)(state, weight, 0);
long inDim = THCTensor_(size)(state, weight, 1);
@@ -237,8 +237,8 @@ void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale) {
+ accreal weightDecay,
+ accreal scale) {
THError("CUDA does not support legacy input format, please use a table of nnz x 2 vectors");
}
@@ -259,7 +259,7 @@ void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- double learningRate) {
+ accreal learningRate) {
THCTensor_(cadd)(state, weight, weight, -learningRate, gradWeight);
THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
}
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index afbc24d..0d4b9ad 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -309,8 +309,9 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight,
gradBias, finput);
diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu
index e7aeacb..b4ae8e5 100644
--- a/lib/THCUNN/generic/SpatialConvolutionMM.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu
@@ -335,8 +335,9 @@ void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialCrossMapLRN.cu b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
index a09ea0b..6b79c15 100644
--- a/lib/THCUNN/generic/SpatialCrossMapLRN.cu
+++ b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
@@ -3,8 +3,12 @@
#else
void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
- THCTensor* scale, int local_size, real alpha, real beta, real k)
+ THCTensor* scale, int local_size, accreal alpha_, accreal beta_, accreal k_)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real k = ScalarConvert<accreal, real>::to(k_);
+
THCTensor_(resizeAs)(state, output, input);
THCTensor_(resizeAs)(state, scale, input);
@@ -45,8 +49,12 @@ void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
void LRNbackward(THCState* state, THCTensor* input, THCTensor* output,
THCTensor* gradOutput, THCTensor* gradInput, THCTensor* scale,
- int local_size, real alpha, real beta, real k)
+ int local_size, accreal alpha_, accreal beta_, accreal k_)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real k = ScalarConvert<accreal, real>::to(k_);
+
THCTensor_(resizeAs)(state, gradInput, input);
int batchSize;
@@ -89,9 +97,9 @@ void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- real alpha,
- real beta,
- real k)
+ accreal alpha,
+ accreal beta,
+ accreal k)
{
LRNforward(state, input, output, scale, size, alpha, beta, k);
}
@@ -104,9 +112,9 @@ void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- real alpha,
- real beta,
- real k)
+ accreal alpha,
+ accreal beta,
+ accreal k)
{
LRNbackward(state, input, output, gradOutput, gradInput, scale, size, alpha, beta, k);
}
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index 7b656d3..02a640b 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -322,8 +322,9 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index ec7eb2f..9e8d30f 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -314,8 +314,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
diff --git a/lib/THCUNN/generic/SpatialSubSampling.cu b/lib/THCUNN/generic/SpatialSubSampling.cu
index b918962..ef3c508 100644
--- a/lib/THCUNN/generic/SpatialSubSampling.cu
+++ b/lib/THCUNN/generic/SpatialSubSampling.cu
@@ -191,7 +191,7 @@ void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- float scale)
+ accreal scale)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradWeight, gradBias);
THNN_(SpatialSubSampling_shapeCheck)(state, input, gradOutput, gradWeight, kW, kH);
diff --git a/lib/THCUNN/generic/Sqrt.cu b/lib/THCUNN/generic/Sqrt.cu
index 3602cbe..b6a68f8 100644
--- a/lib/THCUNN/generic/Sqrt.cu
+++ b/lib/THCUNN/generic/Sqrt.cu
@@ -8,8 +8,9 @@ void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real eps)
+ accreal eps_)
{
+ real eps = ScalarConvert<accreal, real>::to(eps_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, sqrtupdateOutput_functor<real>(eps));
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index 3cfbd84..930f4de 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -54,7 +54,7 @@ TH_API void THNN_(BatchNormalization_backward)(
THCTensor *saveMean_,
THCTensor *saveStd_,
bool train,
- float scale,
+ double scale,
double eps);
TH_API void THNN_(BCECriterion_updateOutput)(
@@ -109,7 +109,7 @@ TH_API void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real alpha,
+ accreal alpha,
bool inplace);
TH_API void THNN_(ELU_updateGradInput)(
@@ -118,15 +118,15 @@ TH_API void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real alpha,
+ accreal alpha,
bool inplace);
TH_API void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real min_val,
- real max_val,
+ accreal min_val,
+ accreal max_val,
bool inplace);
TH_API void THNN_(HardTanh_updateGradInput)(
@@ -134,8 +134,8 @@ TH_API void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real min_val,
- real max_val,
+ accreal min_val,
+ accreal max_val,
bool inplace);
TH_API void THNN_(GatedLinear_updateOutput)(
@@ -155,7 +155,7 @@ TH_API void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real negval,
+ accreal negval,
bool inplace);
TH_API void THNN_(LeakyReLU_updateGradInput)(
@@ -163,7 +163,7 @@ TH_API void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real negval,
+ accreal negval,
bool inplace);
TH_API void THNN_(LogSigmoid_updateOutput)(
@@ -201,14 +201,14 @@ TH_API void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices, // [OPTIONAL]
bool scaleGradByFreq,
int paddingValue,
- real scale);
+ accreal scale);
TH_API void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- real maxNorm,
- real normType);
+ accreal maxNorm,
+ accreal normType);
TH_API void THNN_(L1Cost_updateOutput)(
THCState *state,
@@ -227,7 +227,7 @@ TH_API void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- real margin);
+ accreal margin);
TH_API void THNN_(MarginCriterion_updateGradInput)(
THCState *state,
@@ -235,7 +235,7 @@ TH_API void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- real margin);
+ accreal margin);
TH_API void THNN_(MSECriterion_updateOutput)(
THCState *state,
@@ -275,7 +275,7 @@ TH_API void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- real margin);
+ accreal margin);
TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
THCState *state,
@@ -285,7 +285,7 @@ TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- real margin);
+ accreal margin);
TH_API void THNN_(PReLU_updateOutput)(
THCState *state,
@@ -312,7 +312,7 @@ TH_API void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- real scale);
+ accreal scale);
TH_API void THNN_(SmoothL1Criterion_updateOutput)(
THCState *state,
@@ -343,8 +343,8 @@ TH_API void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale);
+ accreal weightDecay,
+ accreal scale);
TH_API void THNN_(SparseLinear_legacyUpdateOutput)(
THCState *state,
@@ -361,8 +361,8 @@ TH_API void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale);
+ accreal weightDecay,
+ accreal scale);
TH_API void THNN_(SparseLinear_zeroGradParameters)(
THCState *state,
@@ -377,7 +377,7 @@ TH_API void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- double learningRate);
+ accreal learningRate);
TH_API void THNN_(SpatialAdaptiveMaxPooling_updateOutput)(
THCState *state,
@@ -474,7 +474,7 @@ TH_API void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
THCState *state,
@@ -511,7 +511,7 @@ TH_API void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCState *state,
@@ -519,9 +519,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- real alpha,
- real beta,
- real k);
+ accreal alpha,
+ accreal beta,
+ accreal k);
TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCState *state,
@@ -531,9 +531,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- real alpha,
- real beta,
- real k);
+ accreal alpha,
+ accreal beta,
+ accreal k);
TH_API void THNN_(SpatialDilatedConvolution_updateOutput)(
THCState *state,
@@ -572,7 +572,7 @@ TH_API void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -652,7 +652,7 @@ TH_API void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialMaxPooling_updateOutput)(
THCState *state,
@@ -746,7 +746,7 @@ TH_API void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- float scale);
+ accreal scale);
TH_API void THNN_(SpatialUpSamplingBilinear_updateOutput)(
THCState *state,
@@ -843,8 +843,8 @@ TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real beta,
- real threshold);
+ accreal beta,
+ accreal threshold);
TH_API void THNN_(SoftPlus_updateGradInput)(
THCState *state,
@@ -852,21 +852,21 @@ TH_API void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real beta,
- real threshold);
+ accreal beta,
+ accreal threshold);
TH_API void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real lambda);
+ accreal lambda);
TH_API void THNN_(SoftShrink_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real lambda);
+ accreal lambda);
TH_API void THNN_(Square_updateOutput)(
THCState *state,
@@ -883,7 +883,7 @@ TH_API void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real eps);
+ accreal eps);
TH_API void THNN_(Sqrt_updateGradInput)(
THCState *state,
@@ -929,7 +929,7 @@ TH_API void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- real scale);
+ accreal scale);
TH_API void THNN_(TemporalMaxPooling_updateOutput)(
THCState *state,
@@ -984,14 +984,14 @@ TH_API void THNN_(TemporalRowConvolution_accGradParameters)(
int dW,
int padW,
bool featFirst,
- real scale);
+ accreal scale);
TH_API void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real threshold,
- real val,
+ accreal threshold,
+ accreal val,
bool inplace);
TH_API void THNN_(Threshold_updateGradInput)(
@@ -999,8 +999,8 @@ TH_API void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real threshold,
- real val,
+ accreal threshold,
+ accreal val,
bool inplace);
TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
@@ -1049,7 +1049,7 @@ TH_API void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
THCState *state,
@@ -1088,7 +1088,7 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -1148,7 +1148,7 @@ TH_API void THNN_(VolumetricFullConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int adjT, int adjW, int adjH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricMaxPooling_updateOutput)(
THCState *state,
diff --git a/lib/THCUNN/generic/TemporalConvolution.cu b/lib/THCUNN/generic/TemporalConvolution.cu
index a51894d..5658527 100644
--- a/lib/THCUNN/generic/TemporalConvolution.cu
+++ b/lib/THCUNN/generic/TemporalConvolution.cu
@@ -273,8 +273,9 @@ void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
long nInputFrame;
long nOutputFrame;
diff --git a/lib/THCUNN/generic/TemporalRowConvolution.cu b/lib/THCUNN/generic/TemporalRowConvolution.cu
index 365599d..a0835a9 100644
--- a/lib/THCUNN/generic/TemporalRowConvolution.cu
+++ b/lib/THCUNN/generic/TemporalRowConvolution.cu
@@ -291,8 +291,9 @@ void THNN_(TemporalRowConvolution_accGradParameters)(
THCState *state, THCTensor *input, THCTensor *gradOutput,
THCTensor *gradWeight, THCTensor *gradBias, THCTensor *finput,
THCTensor *fgradInput, int kW, int dW, int padW, bool featFirst,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
// Aliases
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
diff --git a/lib/THCUNN/generic/Threshold.cu b/lib/THCUNN/generic/Threshold.cu
index 4f9f622..0b7b79e 100644
--- a/lib/THCUNN/generic/Threshold.cu
+++ b/lib/THCUNN/generic/Threshold.cu
@@ -8,10 +8,12 @@ void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real threshold,
- real val,
+ accreal threshold_,
+ accreal val_,
bool inplace)
{
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
+ real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -37,10 +39,12 @@ void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real threshold,
- real val,
+ accreal threshold_,
+ accreal val_,
bool inplace)
{
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
+ real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/VolumetricConvolution.cu b/lib/THCUNN/generic/VolumetricConvolution.cu
index 3343f27..8227246 100644
--- a/lib/THCUNN/generic/VolumetricConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricConvolution.cu
@@ -366,8 +366,9 @@ void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, gradBias, columns, ones);
diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
index b0145a5..ffeea7f 100644
--- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
@@ -336,8 +336,9 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index 883874a..eb8e9e2 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -335,18 +335,19 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
void THNN_(VolumetricFullConvolution_accGradParameters)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradWeight,
- THCTensor *gradBias,
- THCTensor *finput,
- THCTensor *fgradInput,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- int adjT, int adjW, int adjH,
- real scale)
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradWeight,
+ THCTensor *gradBias,
+ THCTensor *finput,
+ THCTensor *fgradInput,
+ int dT, int dW, int dH,
+ int padT, int padW, int padH,
+ int adjT, int adjW, int adjH,
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
diff --git a/test.lua b/test.lua
index 1fb1205..8c89276 100644
--- a/test.lua
+++ b/test.lua
@@ -376,17 +376,17 @@ function cunntest.Square_transposed()
end
function cunntest.SoftShrink_forward()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_forward(nn.SoftShrink(r), 'SoftShrink', precision_forward)
end
function cunntest.SoftShrink_backward()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_backward(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
function cunntest.SoftShrink_transposed()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_transposed(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
@@ -2056,8 +2056,8 @@ function cunntest.SpatialMaxPooling_forward()
local sj = math.random(1,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2094,8 +2094,8 @@ function cunntest.SpatialMaxPooling_forward_batch()
local sj = math.random(2,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2129,8 +2129,8 @@ function cunntest.SpatialMaxUnpooling_forward_batch()
local sj = kj
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ceil_mode = math.random(0,1) == 1
local fun = ceil_mode and torch.ceil or torch.floor
local ini = fun((outi + padi*2 - ki)/si) +1
@@ -2170,8 +2170,8 @@ function cunntest.SpatialMaxPooling_backward()
local sj = math.random(1,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = true--math.random(0,1) == 1
@@ -2214,8 +2214,8 @@ function cunntest.SpatialMaxPooling_backward_batch()
local sj = math.random(2,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2257,8 +2257,8 @@ function cunntest.SpatialMaxUnpooling_backward_batch()
local sj = kj
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ceil_mode = math.random(0,1) == 1
local fun = ceil_mode and torch.ceil or torch.floor
local ini = fun((outi + padi*2 - ki)/si) +1
@@ -2307,8 +2307,8 @@ function cunntest.SpatialDilatedMaxPooling_forward()
local sj = math.random(1,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2347,8 +2347,8 @@ function cunntest.SpatialDilatedMaxPooling_forward_batch()
local sj = math.random(2,4)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2383,8 +2383,8 @@ function cunntest.SpatialDilatedMaxPooling_backward()
local sj = math.random(1,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2428,8 +2428,8 @@ function cunntest.SpatialDilatedMaxPooling_backward_batch()
local sj = math.random(2,4)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
local ini = (outi-1)*si+(dilationi*(ki-1)+1)-2*padi
@@ -2622,8 +2622,8 @@ function cunntest.SpatialAveragePooling_forward()
local sj = math.random(1,kj)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2661,8 +2661,8 @@ function cunntest.SpatialAveragePooling_forward_batch()
local sj = math.random(1,kj)
local outi = math.random(32,256)
local outj = math.random(32,256)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2699,8 +2699,8 @@ function cunntest.SpatialAveragePooling_backward()
local sj = math.random(1,kj)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -2746,8 +2746,8 @@ function cunntest.SpatialAveragePooling_backward_batch()
local sj = math.random(1,kj)
local outi = math.random(32,64)
local outj = math.random(32,64)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local ini = (outi-1)*si+ki - padi*2
local inj = (outj-1)*sj+kj - padj*2
local ceil_mode = math.random(0,1) == 1
@@ -3412,9 +3412,6 @@ function cunntest.mse()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -3446,9 +3443,6 @@ function cunntest.SmoothL1()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
@@ -4280,9 +4274,6 @@ function cunntest.l1cost()
local cout = cmod:forward(cinput)
local cgin = cmod:backward(cinput)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -4684,9 +4675,9 @@ function cunntest.VolumetricMaxPooling_forward()
local iT = math.random(kT*2, 60)
local iH = math.random(kH*2, 60)
local iW = math.random(kW*2, 60)
- local padT = math.random(0,kT/2-1)
- local padH = math.random(0,kH/2-1)
- local padW = math.random(0,kW/2-1)
+ local padT = math.random(0,math.floor(kT/2)-1)
+ local padH = math.random(0,math.floor(kH/2)-1)
+ local padW = math.random(0,math.floor(kW/2)-1)
local iF = math.random(1, 16) -- features
local oT = math.floor((iT - kT + 2*padT) / dT + 1)
local oH = math.floor((iH - kH + 2*padH) / dH + 1)
@@ -4720,9 +4711,9 @@ function cunntest.VolumetricMaxPooling_backward()
local iT = math.random(kT*2, 60)
local iH = math.random(kH*2, 60)
local iW = math.random(kW*2, 60)
- local padT = math.random(0,kT/2-1)
- local padH = math.random(0,kH/2-1)
- local padW = math.random(0,kW/2-1)
+ local padT = math.random(0,math.floor(kT/2)-1)
+ local padH = math.random(0,math.floor(kH/2)-1)
+ local padW = math.random(0,math.floor(kW/2)-1)
local iF = math.random(1, 16) -- features
local oT = math.floor((iT - kT + 2*padT) / dT + 1)
local oH = math.floor((iH - kH + 2*padH) / dH + 1)
@@ -4764,9 +4755,9 @@ function cunntest.VolumetricDilatedMaxPooling_forward_batch()
local outt = math.random(1,10)
local outi = math.random(1,33)
local outj = math.random(1,33)
- local padt = math.random(0,kt/2-1)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padt = math.random(0,math.floor(kt/2)-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationt = math.random(1,10)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
@@ -4808,9 +4799,9 @@ function cunntest.VolumetricDilatedMaxPooling_backward_batch()
local outt = math.random(8,16)
local outi = math.random(8,16)
local outj = math.random(8,16)
- local padt = math.random(0,kt/2-1)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padt = math.random(0,math.floor(kt/2)-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local dilationt = math.random(1,10)
local dilationi = math.random(1,10)
local dilationj = math.random(1,10)
@@ -4858,9 +4849,9 @@ function cunntest.VolumetricMaxUnpooling_forward_batch()
local outt = math.random(32,128)
local outi = math.random(32,128)
local outj = math.random(32,128)
- local padt = math.random(0,kt/2-1)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padt = math.random(0,math.floor(kt/2)-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local it = math.max(((outt + padt*2 - kt)/st) +1, kt)
local ii = math.max(((outi + padi*2 - ki)/si) +1, ki)
local ij = math.max(((outj + padj*2 - kj)/sj) +1, kj)
@@ -4899,9 +4890,9 @@ function cunntest.VolumetricMaxUnpooling_backward_batch()
local outt = math.random(32,128)
local outi = math.random(32,128)
local outj = math.random(32,128)
- local padt = math.random(0,kt/2-1)
- local padi = math.random(0,ki/2-1)
- local padj = math.random(0,kj/2-1)
+ local padt = math.random(0,math.floor(kt/2)-1)
+ local padi = math.random(0,math.floor(ki/2)-1)
+ local padj = math.random(0,math.floor(kj/2)-1)
local it = math.max(((outt + padt*2 - kt)/st) +1, kt)
local ii = math.max(((outi + padi*2 - ki)/si) +1, ki)
local ij = math.max(((outj + padj*2 - kj)/sj) +1, kj)
@@ -5226,8 +5217,8 @@ function cunntest.VolumetricFullConvolution_pair_test()
local dT = math.random(1,3)
local dH = math.random(1,3)
local dW = dH
- local pT = (kT-1)/2
- local pH = (kH-1)/2
+ local pT = math.floor((kT-1)/2)
+ local pH = math.floor((kH-1)/2)
local pW = pH
local inChan = math.random(1,32)