Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-06-26 22:11:24 +0300
committersoumith <soumith@fb.com>2015-08-02 20:38:44 +0300
commitf6f22a3bf2ee4b920b7a38a61d0be911377f0d47 (patch)
tree1f301fa2023b9e9a2bfbae90c93b0a89dc9e0906 /test
parent3e6e918dac9e94d2f104da6e36f749312e5c3951 (diff)
working R3 bindings for non-new modules
Diffstat (limited to 'test')
-rw-r--r--test/benchmark.lua72
-rw-r--r--test/test.lua4
2 files changed, 74 insertions, 2 deletions
diff --git a/test/benchmark.lua b/test/benchmark.lua
new file mode 100644
index 0000000..08218b9
--- /dev/null
+++ b/test/benchmark.lua
@@ -0,0 +1,72 @@
+require 'cudnn'
+require 'torch'
+
+function bench(title, nInputC, nOutputC, kH, kW, sH, sW, iH, iW, nBatch, ...)
+ local m1 = cudnn.SpatialConvolution(nInputC,nOutputC,kW,kH, sW, sH):setMode(...):fastest():cuda()
+ local i1 = torch.zeros(nBatch, nInputC, iH, iW):cuda()
+ local o1 = m1:forward(i1)
+
+ local t1 = torch.Timer()
+ local o1 = m1:forward(i1)
+ cutorch.synchronize()
+ print(title .. ': ', nInputC, nOutputC, kH, kW, iH, iW, nBatch, t1:time().real)
+end
+
+
+batchSize = 29
+from = 14
+to = 13
+kW = 9
+kH = 15
+sW = 1
+sH = 1
+outW = 10
+outH = 34
+iW = (outW-1)*sW+kW
+iH = (outH-1)*sH+kH
+
+
+print('CUDNN Version: ', tonumber(cudnn.C.cudnnGetVersion()))
+
+bench('Forward implicit gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
+ 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM',
+ 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
+ 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+
+bench('Forward implicit precomp gemm', from, to, kH, kW, sH, sW, iH, iW, batchSize,
+ 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM',
+ 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
+ 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+
+bench('Forward gemm ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
+ 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM',
+ 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
+ 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+
+-- just auto-tuned by cudnn with CUDNN_CONVOLUTION_FWD_PREFER_FASTEST mode
+bench('Forward AutoTuned ', from, to, kH, kW, sH, sW, iH, iW, batchSize)
+
+bench('Forward FFT ', from, to, kH, kW, sH, sW, iH, iW, batchSize,
+ 'CUDNN_CONVOLUTION_FWD_ALGO_FFT',
+ 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0',
+ 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
+
+
+
+-- For reference, CuDNN Convolution modes
+--[[
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0,
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1,
+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2,
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3, // Placeholder
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4
+
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0 = 0, // non-deterministic
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT = 2
+
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 = 0, // non-deterministic
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 = 1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT = 2,
+
+ ]]--
diff --git a/test/test.lua b/test/test.lua
index dc50d94..ac8a573 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -25,7 +25,7 @@ function cudnntest.SpatialConvolution_forward_batch()
local sconv = nn.SpatialConvolutionMM(from,to,ki,kj,si,sj):cuda()
local groundtruth = sconv:forward(input)
cutorch.synchronize()
- local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
+ local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest()
gconv.weight:copy(sconv.weight)
gconv.bias:copy(sconv.bias)
local rescuda = gconv:forward(input)
@@ -59,7 +59,7 @@ function cudnntest.SpatialConvolution_backward_batch()
local groundweight = sconv.gradWeight
local groundbias = sconv.gradBias
- local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
+ local gconv = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda():fastest()
gconv.weight:copy(sconv.weight)
gconv.bias:copy(sconv.bias)
gconv:forward(input)