Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-07-12 13:04:20 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-07-12 13:04:20 +0300
commitc5bf237f79ebe7f77ea295197e3804a2937774e2 (patch)
tree9f9a4ff6a97012fbadf5cb7a0c5ae1fd7ca470c7
parent155d6f59211504e6dc72840133df1eb93a16cfb3 (diff)
fix functional for 6.0
-rw-r--r--functional.lua12
1 files changed, 6 insertions, 6 deletions
diff --git a/functional.lua b/functional.lua
index 4564fb7..cea9df9 100644
--- a/functional.lua
+++ b/functional.lua
@@ -60,7 +60,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou
local nOutputPlane, nInputPlane, kH, kW
= weight:size(1), weight:size(2), weight:size(3), weight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -76,7 +76,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou
errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- 'CUDNN_DATA_FLOAT');
+ cudnn.configmap(torch.type(weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
@@ -139,7 +139,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
local nOutputPlane, nInputPlane, kH, kW
= weight:size(1), weight:size(2), weight:size(3), weight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -155,7 +155,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight,
errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- 'CUDNN_DATA_FLOAT');
+ cudnn.configmap(torch.type(weight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end
@@ -204,7 +204,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
local nOutputPlane, nInputPlane, kH, kW
= gradWeight:size(1), gradWeight:size(2), gradWeight:size(3), gradWeight:size(4)
local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW})
- errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4,
+ errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4,
desc:data());
local function destroyWDesc(d)
errcheck('cudnnDestroyFilterDescriptor', d[0]);
@@ -220,7 +220,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW
errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0],
2, pad:data(),
stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION',
- 'CUDNN_DATA_FLOAT');
+ cudnn.configmap(torch.type(gradWeight)));
local function destroyConvDesc(d)
errcheck('cudnnDestroyConvolutionDescriptor', d[0]);
end