diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-07-12 13:04:20 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-07-12 13:04:20 +0300 |
commit | c5bf237f79ebe7f77ea295197e3804a2937774e2 (patch) | |
tree | 9f9a4ff6a97012fbadf5cb7a0c5ae1fd7ca470c7 | |
parent | 155d6f59211504e6dc72840133df1eb93a16cfb3 (diff) |
fix functional for 6.0
-rw-r--r-- | functional.lua | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/functional.lua b/functional.lua index 4564fb7..cea9df9 100644 --- a/functional.lua +++ b/functional.lua @@ -60,7 +60,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou local nOutputPlane, nInputPlane, kH, kW = weight:size(1), weight:size(2), weight:size(3), weight:size(4) local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4, + errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); @@ -76,7 +76,7 @@ cudnn.functional.Convolution2D_updateOutput = function(handle, input, weight, ou errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], 2, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - 'CUDNN_DATA_FLOAT'); + cudnn.configmap(torch.type(weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end @@ -139,7 +139,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, local nOutputPlane, nInputPlane, kH, kW = weight:size(1), weight:size(2), weight:size(3), weight:size(4) local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4, + errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); @@ -155,7 +155,7 @@ cudnn.functional.Convolution2D_updateGradInput = function(handle, input, weight, errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], 2, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - 'CUDNN_DATA_FLOAT'); + cudnn.configmap(torch.type(weight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end @@ -204,7 +204,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW local nOutputPlane, nInputPlane, kH, kW = gradWeight:size(1), gradWeight:size(2), gradWeight:size(3), gradWeight:size(4) local desc = torch.IntTensor({nOutputPlane, nInputPlane, kH, kW}) - errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], 'CUDNN_DATA_FLOAT', 'CUDNN_TENSOR_NCHW', 4, + errcheck('cudnnSetFilterNdDescriptor', weightDesc[0], cudnn.typemap[torch.type(input)], 'CUDNN_TENSOR_NCHW', 4, desc:data()); local function destroyWDesc(d) errcheck('cudnnDestroyFilterDescriptor', d[0]); @@ -220,7 +220,7 @@ cudnn.functional.Convolution2D_accGradParameters = function(handle, input, gradW errcheck('cudnnSetConvolutionNdDescriptor', convDesc[0], 2, pad:data(), stride:data(), upscale:data(), 'CUDNN_CROSS_CORRELATION', - 'CUDNN_DATA_FLOAT'); + cudnn.configmap(torch.type(gradWeight))); local function destroyConvDesc(d) errcheck('cudnnDestroyConvolutionDescriptor', d[0]); end |