Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavan Yalamanchili <pyalamanchili@twitter.com>2017-01-12 03:16:45 +0300
committerPavan Yalamanchili <pyalamanchili@twitter.com>2017-01-18 03:06:42 +0300
commite5cc080153679dc7beed75bc53267f7d94be34d1 (patch)
tree5c895642c8fd48c79f6319e56bbb6850253c7c8f
parentc77be795fe0e58d213ea4b2dc42a4364c6d965bc (diff)
Converting all instances of real to accreal in libTHCUNN
This is because the current version of luaffifb fails to pass custom structs (i.e. half) as arguments or accept them as return values. The accreal parameters are immediately converted to real internally. This is done to ensure none of the internal code needs to be changed. This change also removes transform_reals_to_half which is no longer necessary. Change-Id: I978151d001de5492576fb0eddfa0608cd4e99149
-rw-r--r--.gitignore1
-rw-r--r--THCUNN.lua40
-rw-r--r--lib/THCUNN/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/BatchNormalization.cu2
-rw-r--r--lib/THCUNN/generic/ELU.cu6
-rw-r--r--lib/THCUNN/generic/HardTanh.cu14
-rw-r--r--lib/THCUNN/generic/LeakyReLU.cu8
-rw-r--r--lib/THCUNN/generic/LookupTable.cu9
-rw-r--r--lib/THCUNN/generic/MarginCriterion.cu7
-rw-r--r--lib/THCUNN/generic/MultiMarginCriterion.cu6
-rw-r--r--lib/THCUNN/generic/PReLU.cu3
-rw-r--r--lib/THCUNN/generic/SoftPlus.cu12
-rw-r--r--lib/THCUNN/generic/SoftShrink.cu6
-rw-r--r--lib/THCUNN/generic/SparseLinear.cu10
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu3
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionMM.cu3
-rw-r--r--lib/THCUNN/generic/SpatialCrossMapLRN.cu24
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu3
-rw-r--r--lib/THCUNN/generic/SpatialSubSampling.cu2
-rw-r--r--lib/THCUNN/generic/Sqrt.cu3
-rw-r--r--lib/THCUNN/generic/THCUNN.h96
-rw-r--r--lib/THCUNN/generic/TemporalConvolution.cu3
-rw-r--r--lib/THCUNN/generic/Threshold.cu12
-rw-r--r--lib/THCUNN/generic/VolumetricConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricDilatedConvolution.cu3
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu3
-rw-r--r--test.lua15
28 files changed, 170 insertions, 140 deletions
diff --git a/.gitignore b/.gitignore
index 9d93b79..f63d6a2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
build
THCUNN_h.lua
+THCUNN_generic_h.lua
diff --git a/THCUNN.lua b/THCUNN.lua
index 6776a23..d5bf1c2 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -45,7 +45,7 @@ local replacements_generic =
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
- ['real'] = 'float'
+ ['real'] = 'float',
},
{
['THCTensor'] = 'THCudaDoubleTensor',
@@ -55,6 +55,13 @@ local replacements_generic =
}
}
+-- gsub(s, 'real', 'float') changes accreal to accfloat.
+-- typedef accfloat ahead of time.
+ffi.cdef("typedef float accfloat;")
+-- gsub(s, 'real', 'double') changes accreal to accfloat.
+-- typedef accdouble ahead of time
+ffi.cdef("typedef double accdouble;")
+
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
@@ -63,9 +70,12 @@ if cutorch.hasHalf then
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
- ['real'] = 'half'
+ ['real'] = 'half',
}
table.insert(replacements_generic, half_replacement)
+ -- gsub(s, 'real', 'double') changes accreal to accfloat.
+ -- typedef acchalf ahead of time
+ ffi.cdef("typedef float acchalf;")
end
for i=1,#replacements_generic do
@@ -133,29 +143,9 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
--- in order to call 'half' functions from lua, convert real arguments from
--- to half since there is no other defined conversion
-local transform_reals_to_half = function(func_name, real_args, ...)
- t = {}
- -- this select logic is necessary to deal with nil arguments
- for i = 1, select('#', ...) do
- t[i] = select(i, ...)
- end
- for k,v in ipairs(real_args[func_name]) do
- -- first argument (THCState) is added implicitly by bind
- t[v-1] = THC.THC_float2half(t[v-1])
- end
- return t
-end
-
-local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
-for k,v in pairs(raw_half_functions) do
- -- select required in case there are trailing nils
- raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...)))
-end
-end
-THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
-torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
+ local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
+ THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
+ torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)
diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu
index a7ffa1e..f36206f 100644
--- a/lib/THCUNN/SparseLinear.cu
+++ b/lib/THCUNN/SparseLinear.cu
@@ -34,8 +34,8 @@ void THNN_CudaHalfSparseLinear_accGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- double weightDecay,
- double scale) {
+ float weightDecay,
+ float scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -56,8 +56,8 @@ void THNN_CudaHalfSparseLinear_legacyAccGradParameters(
THCudaHalfTensor *gradBias,
THCudaHalfTensor *weight,
THCudaHalfTensor *bias,
- double weightDecay,
- double scale) {
+ float weightDecay,
+ float scale) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
@@ -76,7 +76,7 @@ void THNN_CudaHalfSparseLinear_updateParameters(
THCudaHalfTensor *gradWeight,
THCudaHalfTensor *gradBias,
THCudaHalfTensor *lastInput,
- double learningRate) {
+ float learningRate) {
THError("THCudaHalfTensor not supported with SparseLinear");
}
#endif
diff --git a/lib/THCUNN/generic/BatchNormalization.cu b/lib/THCUNN/generic/BatchNormalization.cu
index cbe99f3..d42f18e 100644
--- a/lib/THCUNN/generic/BatchNormalization.cu
+++ b/lib/THCUNN/generic/BatchNormalization.cu
@@ -69,7 +69,7 @@ void THNN_(BatchNormalization_backward)(
THCState *state, THCTensor *input_, THCTensor *gradOutput_,
THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_,
THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_,
- THCTensor *saveMean_, THCTensor *saveStd_, bool train, float scale, double eps) {
+ THCTensor *saveMean_, THCTensor *saveStd_, bool train, double scale, double eps) {
THCUNN_check_shape(state, input_, gradOutput_);
DeviceTensor3 input = devicetensor<3>(state, input_);
diff --git a/lib/THCUNN/generic/ELU.cu b/lib/THCUNN/generic/ELU.cu
index 0beb5a1..4b8da27 100644
--- a/lib/THCUNN/generic/ELU.cu
+++ b/lib/THCUNN/generic/ELU.cu
@@ -9,9 +9,10 @@ void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real alpha,
+ accreal alpha_,
bool inplace)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -33,9 +34,10 @@ void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real alpha,
+ accreal alpha_,
bool inplace)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/HardTanh.cu b/lib/THCUNN/generic/HardTanh.cu
index 0651431..47835f0 100644
--- a/lib/THCUNN/generic/HardTanh.cu
+++ b/lib/THCUNN/generic/HardTanh.cu
@@ -8,10 +8,13 @@ void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real min_val,
- real max_val,
+ accreal min_val_,
+ accreal max_val_,
bool inplace)
{
+ real min_val = ScalarConvert<accreal, real>::to(min_val_);
+ real max_val = ScalarConvert<accreal, real>::to(max_val_);
+
THCUNN_assertSameGPU(state, 2, input, output);
if(inplace)
{
@@ -31,10 +34,13 @@ void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real min_val,
- real max_val,
+ accreal min_val_,
+ accreal max_val_,
bool inplace)
{
+ real min_val = ScalarConvert<accreal, real>::to(min_val_);
+ real max_val = ScalarConvert<accreal, real>::to(max_val_);
+
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
diff --git a/lib/THCUNN/generic/LeakyReLU.cu b/lib/THCUNN/generic/LeakyReLU.cu
index 23cf59a..179819d 100644
--- a/lib/THCUNN/generic/LeakyReLU.cu
+++ b/lib/THCUNN/generic/LeakyReLU.cu
@@ -8,9 +8,11 @@ void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real negval,
+ accreal negval_,
bool inplace)
{
+ real negval = ScalarConvert<accreal, real>::to(negval_);
+
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -32,9 +34,11 @@ void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real negval,
+ accreal negval_,
bool inplace)
{
+ real negval = ScalarConvert<accreal, real>::to(negval_);
+
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu
index bd59a04..fa7c5ac 100644
--- a/lib/THCUNN/generic/LookupTable.cu
+++ b/lib/THCUNN/generic/LookupTable.cu
@@ -12,8 +12,9 @@ void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices,
bool scaleGradByFreq,
int paddingValue,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
if (!(THCIndexTensor_(isContiguous)(state, input) &&
@@ -119,9 +120,11 @@ void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- real maxNorm,
- real normType)
+ accreal maxNorm_,
+ accreal normType_)
{
+ real maxNorm = ScalarConvert<accreal, real>::to(maxNorm_);
+ real normType = ScalarConvert<accreal, real>::to(normType_);
THCUNN_assertSameGPU(state, 2, idx, weight);
if (!(THCIndexTensor_(isContiguous)(state, idx) &&
THCTensor_(isContiguous)(state, weight)))
diff --git a/lib/THCUNN/generic/MarginCriterion.cu b/lib/THCUNN/generic/MarginCriterion.cu
index d5678ec..221f9d9 100644
--- a/lib/THCUNN/generic/MarginCriterion.cu
+++ b/lib/THCUNN/generic/MarginCriterion.cu
@@ -8,8 +8,9 @@ void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_check_nElement(state, input, target);
THCUNN_check_dim_size(state, output, 1, 0, 1);
THCUNN_assertSameGPU(state, 2, input, target);
@@ -40,8 +41,10 @@ void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
+
THCUNN_check_nElement(state, input, target);
THCUNN_assertSameGPU(state, 3, input, target, gradInput);
diff --git a/lib/THCUNN/generic/MultiMarginCriterion.cu b/lib/THCUNN/generic/MultiMarginCriterion.cu
index 8026331..c3ff2d6 100644
--- a/lib/THCUNN/generic/MultiMarginCriterion.cu
+++ b/lib/THCUNN/generic/MultiMarginCriterion.cu
@@ -11,8 +11,9 @@ void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 2, input, target);
input = THCTensor_(newContiguous)(state, input);
if(weights)
@@ -102,8 +103,9 @@ void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights,
- real margin)
+ accreal margin_)
{
+ real margin = ScalarConvert<accreal, real>::to(margin_);
THCUNN_assertSameGPU(state, 3, input, gradInput, target);
input = THCTensor_(newContiguous)(state, input);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/PReLU.cu b/lib/THCUNN/generic/PReLU.cu
index 89087fb..db9b0d2 100644
--- a/lib/THCUNN/generic/PReLU.cu
+++ b/lib/THCUNN/generic/PReLU.cu
@@ -92,8 +92,9 @@ void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_check_nElement(state, input, gradOutput);
// use grad input for temporary storage, then call updateGradInput again
diff --git a/lib/THCUNN/generic/SoftPlus.cu b/lib/THCUNN/generic/SoftPlus.cu
index e72038e..17cde70 100644
--- a/lib/THCUNN/generic/SoftPlus.cu
+++ b/lib/THCUNN/generic/SoftPlus.cu
@@ -8,9 +8,11 @@ void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real beta,
- real threshold)
+ accreal beta_,
+ accreal threshold_)
{
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor<real>(threshold, beta));
@@ -22,9 +24,11 @@ void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real beta,
- real threshold)
+ accreal beta_,
+ accreal threshold_)
{
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
diff --git a/lib/THCUNN/generic/SoftShrink.cu b/lib/THCUNN/generic/SoftShrink.cu
index 261593f..9e47695 100644
--- a/lib/THCUNN/generic/SoftShrink.cu
+++ b/lib/THCUNN/generic/SoftShrink.cu
@@ -8,8 +8,9 @@ void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real lambda)
+ accreal lambda_)
{
+ real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput<real>(lambda));
@@ -21,8 +22,9 @@ void THNN_(SoftShrink_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real lambda)
+ accreal lambda_)
{
+ real lambda = ScalarConvert<accreal, real>::to(lambda_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, input);
diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu
index f22b233..6838cac 100644
--- a/lib/THCUNN/generic/SparseLinear.cu
+++ b/lib/THCUNN/generic/SparseLinear.cu
@@ -127,8 +127,8 @@ void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale)
+ accreal weightDecay,
+ accreal scale)
{
long outDim = THCTensor_(size)(state, weight, 0);
long inDim = THCTensor_(size)(state, weight, 1);
@@ -237,8 +237,8 @@ void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale) {
+ accreal weightDecay,
+ accreal scale) {
THError("CUDA does not support legacy input format, please use a table of nnz x 2 vectors");
}
@@ -259,7 +259,7 @@ void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- double learningRate) {
+ accreal learningRate) {
THCTensor_(cadd)(state, weight, weight, -learningRate, gradWeight);
THCTensor_(cadd)(state, bias, bias, -learningRate, gradBias);
}
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index afbc24d..0d4b9ad 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -309,8 +309,9 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight,
gradBias, finput);
diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu
index 01848f4..b874a1a 100644
--- a/lib/THCUNN/generic/SpatialConvolutionMM.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu
@@ -328,8 +328,9 @@ void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialCrossMapLRN.cu b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
index a09ea0b..6b79c15 100644
--- a/lib/THCUNN/generic/SpatialCrossMapLRN.cu
+++ b/lib/THCUNN/generic/SpatialCrossMapLRN.cu
@@ -3,8 +3,12 @@
#else
void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
- THCTensor* scale, int local_size, real alpha, real beta, real k)
+ THCTensor* scale, int local_size, accreal alpha_, accreal beta_, accreal k_)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real k = ScalarConvert<accreal, real>::to(k_);
+
THCTensor_(resizeAs)(state, output, input);
THCTensor_(resizeAs)(state, scale, input);
@@ -45,8 +49,12 @@ void LRNforward(THCState* state, THCTensor* input, THCTensor* output,
void LRNbackward(THCState* state, THCTensor* input, THCTensor* output,
THCTensor* gradOutput, THCTensor* gradInput, THCTensor* scale,
- int local_size, real alpha, real beta, real k)
+ int local_size, accreal alpha_, accreal beta_, accreal k_)
{
+ real alpha = ScalarConvert<accreal, real>::to(alpha_);
+ real beta = ScalarConvert<accreal, real>::to(beta_);
+ real k = ScalarConvert<accreal, real>::to(k_);
+
THCTensor_(resizeAs)(state, gradInput, input);
int batchSize;
@@ -89,9 +97,9 @@ void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- real alpha,
- real beta,
- real k)
+ accreal alpha,
+ accreal beta,
+ accreal k)
{
LRNforward(state, input, output, scale, size, alpha, beta, k);
}
@@ -104,9 +112,9 @@ void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- real alpha,
- real beta,
- real k)
+ accreal alpha,
+ accreal beta,
+ accreal k)
{
LRNbackward(state, input, output, gradOutput, gradInput, scale, size, alpha, beta, k);
}
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index c790ab4..4dae8d2 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -318,8 +318,9 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index 12995d2..4552f5f 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -311,8 +311,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
int nInputPlane = THCTensor_(size)(state, gradWeight, 0);
int nOutputPlane = THCTensor_(size)(state, gradWeight, 1);
diff --git a/lib/THCUNN/generic/SpatialSubSampling.cu b/lib/THCUNN/generic/SpatialSubSampling.cu
index b918962..ef3c508 100644
--- a/lib/THCUNN/generic/SpatialSubSampling.cu
+++ b/lib/THCUNN/generic/SpatialSubSampling.cu
@@ -191,7 +191,7 @@ void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- float scale)
+ accreal scale)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradWeight, gradBias);
THNN_(SpatialSubSampling_shapeCheck)(state, input, gradOutput, gradWeight, kW, kH);
diff --git a/lib/THCUNN/generic/Sqrt.cu b/lib/THCUNN/generic/Sqrt.cu
index 3602cbe..b6a68f8 100644
--- a/lib/THCUNN/generic/Sqrt.cu
+++ b/lib/THCUNN/generic/Sqrt.cu
@@ -8,8 +8,9 @@ void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real eps)
+ accreal eps_)
{
+ real eps = ScalarConvert<accreal, real>::to(eps_);
THCUNN_assertSameGPU(state, 2, input, output);
THCTensor_(resizeAs)(state, output, input);
THC_pointwiseApply2(state, output, input, sqrtupdateOutput_functor<real>(eps));
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index bf903b9..c9d7e2c 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -54,7 +54,7 @@ TH_API void THNN_(BatchNormalization_backward)(
THCTensor *saveMean_,
THCTensor *saveStd_,
bool train,
- float scale,
+ double scale,
double eps);
TH_API void THNN_(BCECriterion_updateOutput)(
@@ -109,7 +109,7 @@ TH_API void THNN_(ELU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real alpha,
+ accreal alpha,
bool inplace);
TH_API void THNN_(ELU_updateGradInput)(
@@ -118,15 +118,15 @@ TH_API void THNN_(ELU_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real alpha,
+ accreal alpha,
bool inplace);
TH_API void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real min_val,
- real max_val,
+ accreal min_val,
+ accreal max_val,
bool inplace);
TH_API void THNN_(HardTanh_updateGradInput)(
@@ -134,15 +134,15 @@ TH_API void THNN_(HardTanh_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real min_val,
- real max_val,
+ accreal min_val,
+ accreal max_val,
bool inplace);
TH_API void THNN_(LeakyReLU_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real negval,
+ accreal negval,
bool inplace);
TH_API void THNN_(LeakyReLU_updateGradInput)(
@@ -150,7 +150,7 @@ TH_API void THNN_(LeakyReLU_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real negval,
+ accreal negval,
bool inplace);
TH_API void THNN_(LogSigmoid_updateOutput)(
@@ -188,14 +188,14 @@ TH_API void THNN_(LookupTable_accGradParameters)(
THCIndexTensor *indices, // [OPTIONAL]
bool scaleGradByFreq,
int paddingValue,
- real scale);
+ accreal scale);
TH_API void THNN_(LookupTable_renorm)(
THCState *state,
THCIndexTensor *idx,
THCTensor *weight,
- real maxNorm,
- real normType);
+ accreal maxNorm,
+ accreal normType);
TH_API void THNN_(L1Cost_updateOutput)(
THCState *state,
@@ -214,7 +214,7 @@ TH_API void THNN_(MarginCriterion_updateOutput)(
THCTensor *target,
THCTensor *output,
bool sizeAverage,
- real margin);
+ accreal margin);
TH_API void THNN_(MarginCriterion_updateGradInput)(
THCState *state,
@@ -222,7 +222,7 @@ TH_API void THNN_(MarginCriterion_updateGradInput)(
THCTensor *target,
THCTensor *gradInput,
bool sizeAverage,
- real margin);
+ accreal margin);
TH_API void THNN_(MSECriterion_updateOutput)(
THCState *state,
@@ -262,7 +262,7 @@ TH_API void THNN_(MultiMarginCriterion_updateOutput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- real margin);
+ accreal margin);
TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
THCState *state,
@@ -272,7 +272,7 @@ TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
bool sizeAverage,
int p,
THCTensor *weights, // [OPTIONAL]
- real margin);
+ accreal margin);
TH_API void THNN_(PReLU_updateOutput)(
THCState *state,
@@ -299,7 +299,7 @@ TH_API void THNN_(PReLU_accGradParameters)(
THCTensor *gradWeightBuf,
THCTensor *gradWeightBuf2,
long nOutputPlane,
- real scale);
+ accreal scale);
TH_API void THNN_(SmoothL1Criterion_updateOutput)(
THCState *state,
@@ -330,8 +330,8 @@ TH_API void THNN_(SparseLinear_accGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale);
+ accreal weightDecay,
+ accreal scale);
TH_API void THNN_(SparseLinear_legacyUpdateOutput)(
THCState *state,
@@ -348,8 +348,8 @@ TH_API void THNN_(SparseLinear_legacyAccGradParameters)(
THCTensor *gradBias,
THCTensor *weight,
THCTensor *bias,
- double weightDecay,
- double scale);
+ accreal weightDecay,
+ accreal scale);
TH_API void THNN_(SparseLinear_zeroGradParameters)(
THCState *state,
@@ -364,7 +364,7 @@ TH_API void THNN_(SparseLinear_updateParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
THCTensor *lastInput,
- double learningRate);
+ accreal learningRate);
TH_API void THNN_(SpatialAdaptiveMaxPooling_updateOutput)(
THCState *state,
@@ -461,7 +461,7 @@ TH_API void THNN_(SpatialConvolutionLocal_accGradParameters)(
int padW, int padH,
long inputWidth, long inputHeight,
long outputWidth, long outputHeight,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
THCState *state,
@@ -498,7 +498,7 @@ TH_API void THNN_(SpatialConvolutionMM_accGradParameters)(
int kW, int kH,
int dW, int dH,
int padW, int padH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCState *state,
@@ -506,9 +506,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateOutput)(
THCTensor *output,
THCTensor *scale,
int size,
- real alpha,
- real beta,
- real k);
+ accreal alpha,
+ accreal beta,
+ accreal k);
TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCState *state,
@@ -518,9 +518,9 @@ TH_API void THNN_(SpatialCrossMapLRN_updateGradInput)(
THCTensor *scale,
THCTensor *output,
int size,
- real alpha,
- real beta,
- real k);
+ accreal alpha,
+ accreal beta,
+ accreal k);
TH_API void THNN_(SpatialDilatedConvolution_updateOutput)(
THCState *state,
@@ -559,7 +559,7 @@ TH_API void THNN_(SpatialDilatedConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int dilationW, int dilationH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -639,7 +639,7 @@ TH_API void THNN_(SpatialFullConvolution_accGradParameters)(
int dW, int dH,
int padW, int padH,
int adjW, int adjH,
- real scale);
+ accreal scale);
TH_API void THNN_(SpatialMaxPooling_updateOutput)(
THCState *state,
@@ -733,7 +733,7 @@ TH_API void THNN_(SpatialSubSampling_accGradParameters)(
THCTensor *gradBias,
int kW, int kH,
int dW, int dH,
- float scale);
+ accreal scale);
TH_API void THNN_(SpatialUpSamplingBilinear_updateOutput)(
THCState *state,
@@ -830,8 +830,8 @@ TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real beta,
- real threshold);
+ accreal beta,
+ accreal threshold);
TH_API void THNN_(SoftPlus_updateGradInput)(
THCState *state,
@@ -839,21 +839,21 @@ TH_API void THNN_(SoftPlus_updateGradInput)(
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output,
- real beta,
- real threshold);
+ accreal beta,
+ accreal threshold);
TH_API void THNN_(SoftShrink_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real lambda);
+ accreal lambda);
TH_API void THNN_(SoftShrink_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real lambda);
+ accreal lambda);
TH_API void THNN_(Square_updateOutput)(
THCState *state,
@@ -870,7 +870,7 @@ TH_API void THNN_(Sqrt_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real eps);
+ accreal eps);
TH_API void THNN_(Sqrt_updateGradInput)(
THCState *state,
@@ -916,7 +916,7 @@ TH_API void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- real scale);
+ accreal scale);
TH_API void THNN_(TemporalMaxPooling_updateOutput)(
THCState *state,
@@ -937,8 +937,8 @@ TH_API void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real threshold,
- real val,
+ accreal threshold,
+ accreal val,
bool inplace);
TH_API void THNN_(Threshold_updateGradInput)(
@@ -946,8 +946,8 @@ TH_API void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real threshold,
- real val,
+ accreal threshold,
+ accreal val,
bool inplace);
TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
@@ -996,7 +996,7 @@ TH_API void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
THCState *state,
@@ -1035,7 +1035,7 @@ TH_API void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
THCState *state,
@@ -1095,7 +1095,7 @@ TH_API void THNN_(VolumetricFullConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int adjT, int adjW, int adjH,
- real scale);
+ accreal scale);
TH_API void THNN_(VolumetricMaxPooling_updateOutput)(
THCState *state,
diff --git a/lib/THCUNN/generic/TemporalConvolution.cu b/lib/THCUNN/generic/TemporalConvolution.cu
index a51894d..5658527 100644
--- a/lib/THCUNN/generic/TemporalConvolution.cu
+++ b/lib/THCUNN/generic/TemporalConvolution.cu
@@ -273,8 +273,9 @@ void THNN_(TemporalConvolution_accGradParameters)(
THCTensor *gradWeight,
THCTensor *gradBias,
int kW, int dW,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
long nInputFrame;
long nOutputFrame;
diff --git a/lib/THCUNN/generic/Threshold.cu b/lib/THCUNN/generic/Threshold.cu
index 4f9f622..0b7b79e 100644
--- a/lib/THCUNN/generic/Threshold.cu
+++ b/lib/THCUNN/generic/Threshold.cu
@@ -8,10 +8,12 @@ void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
- real threshold,
- real val,
+ accreal threshold_,
+ accreal val_,
bool inplace)
{
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
+ real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
@@ -37,10 +39,12 @@ void THNN_(Threshold_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- real threshold,
- real val,
+ accreal threshold_,
+ accreal val_,
bool inplace)
{
+ real threshold = ScalarConvert<accreal, real>::to(threshold_);
+ real val = ScalarConvert<accreal, real>::to(val_);
THCUNN_check_nElement(state, input, gradOutput);
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
diff --git a/lib/THCUNN/generic/VolumetricConvolution.cu b/lib/THCUNN/generic/VolumetricConvolution.cu
index a371ac8..ca71cdb 100644
--- a/lib/THCUNN/generic/VolumetricConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricConvolution.cu
@@ -356,8 +356,9 @@ void THNN_(VolumetricConvolution_accGradParameters)(
THCTensor *fgradInput,
int dT, int dW, int dH,
int padT, int padW, int padH,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, gradBias, columns, ones);
diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
index 422cdc7..d9653ab 100644
--- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
@@ -332,8 +332,9 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int dilationT, int dilationW, int dilationH,
- real scale) {
+ accreal scale_) {
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, columns, ones);
if (gradBias) {
THCUNN_assertSameGPU(state, 2, gradWeight, gradBias);
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index 47f4943..1ac6f25 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -340,8 +340,9 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
int dT, int dW, int dH,
int padT, int padW, int padH,
int adjT, int adjW, int adjH,
- real scale)
+ accreal scale_)
{
+ real scale = ScalarConvert<accreal, real>::to(scale_);
THCTensor *columns = finput;
THCTensor *ones = fgradInput;
diff --git a/test.lua b/test.lua
index 4f9d844..5ab07bf 100644
--- a/test.lua
+++ b/test.lua
@@ -365,17 +365,17 @@ function cunntest.Square_transposed()
end
function cunntest.SoftShrink_forward()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_forward(nn.SoftShrink(r), 'SoftShrink', precision_forward)
end
function cunntest.SoftShrink_backward()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_backward(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
function cunntest.SoftShrink_transposed()
- local r = THC.THC_half2float(THC.THC_float2half(math.random()))
+ local r = math.random()
pointwise_transposed(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
@@ -3401,9 +3401,6 @@ function cunntest.mse()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -3435,9 +3432,6 @@ function cunntest.SmoothL1()
local cout = cmod:forward(cinput,ctarget)
local cgin = cmod:backward(cinput,ctarget)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
mytester:assertlt(gerr:abs():max(), precision_forward_type(precision_forward, typename),
@@ -4000,9 +3994,6 @@ function cunntest.l1cost()
local cout = cmod:forward(cinput)
local cgin = cmod:backward(cinput)
- if (typename == 'torch.CudaHalfTensor') then
- fout = THC.THC_half2float(THC.THC_float2half(fout))
- end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()