Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-09-15 02:37:02 +0300
committersoumith <soumith@fb.com>2015-09-15 02:37:02 +0300
commit77f32a01edd42f9ca481263359e32f8a1d73f3d1 (patch)
treeea50d8acbc0b4b521916ebd163b7cb97b7da0acc /test
parentf85c8e0d178baf0dab9deb982c76b95191620418 (diff)
functional interface for R3 as well
Diffstat (limited to 'test')
-rw-r--r--test/test.lua57
1 files changed, 54 insertions, 3 deletions
diff --git a/test/test.lua b/test/test.lua
index c2938de..4062425 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -724,7 +724,7 @@ function cudnntest.LogSoftMax_batch()
precision_backward, 'error on state (backward) ')
end
-function cudnntest.functional_SpatialBias()
+function cudnntest.functional_bias2D()
local bs = math.random(1,32)
local from = math.random(1,32)
local to = math.random(1,64)
@@ -742,7 +742,7 @@ function cudnntest.functional_SpatialBias()
mod.weight:zero()
local groundtruth = mod:forward(input)
local result = groundtruth:clone():zero()
- cudnn.functional.SpatialBias_updateOutput(mod.bias, result)
+ cudnn.functional.bias2D_updateOutput(cudnn.getHandle(), mod.bias, result)
local error = result:float() - groundtruth:float()
mytester:assertlt(error:abs():max(),
precision_forward, 'error on forward ')
@@ -752,12 +752,63 @@ function cudnntest.functional_SpatialBias()
mod:backward(input, gradOutput, scale)
local groundtruth = mod.gradBias
local result = groundtruth:clone():zero()
- cudnn.functional.SpatialBias_accGradParameters(gradOutput, result, scale)
+ cudnn.functional.bias2D_accGradParameters(cudnn.getHandle(), gradOutput, result, scale)
error = result:float() - groundtruth:float()
mytester:assertlt(error:abs():max(),
precision_backward, 'error on accGradParameters ')
end
+function cudnntest.functional_convolution2d()
+ local a=cudnn.SpatialConvolution(3,16,5,5):cuda()
+ a.bias:zero();
+ local input = torch.randn(10,3,10,10):cuda()
+ a:zeroGradParameters()
+ a:forward(input);
+ local output = a.output:clone():normal()
+ local gradOutput = a.output:clone():normal()
+ local gradInput = a:backward(input, gradOutput):clone():normal()
+ local gradWeight = a.gradWeight:clone():zero()
+ cudnn.functional.Convolution2D_updateOutput(cudnn.getHandle(), input,
+ a.weight, output, a.dH,
+ a.dW, a.padH, a.padW)
+ mytester:assertlt((output - a.output):abs():max(),
+ precision_forward, 'error on forward ')
+
+ cudnn.functional.Convolution2D_updateGradInput(cudnn.getHandle(), input,
+ a.weight, output, gradOutput,
+ gradInput,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((gradInput - a.gradInput):abs():max(),
+ precision_forward, 'error on updateGradInput ')
+
+ cudnn.functional.Convolution2D_accGradParameters(cudnn.getHandle(), input,
+ gradWeight, gradOutput,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((gradWeight - a.gradWeight):abs():max(),
+ precision_forward, 'error on accGradParameters ')
+end
+
+function cudnntest.functional_maxpooling2d()
+ local a=cudnn.SpatialMaxPooling(2,2,2,2):cuda()
+ local input = torch.randn(10,3,10,10):cuda()
+ a:forward(input);
+ local output = a.output:clone():normal()
+ local gradOutput = a.output:clone():normal()
+ local gradInput = a:backward(input, gradOutput):clone():normal()
+ cudnn.functional.MaxPooling2D_updateOutput(cudnn.getHandle(), input,
+ output, a.kH, a.kW,
+ a.dH, a.dW, a.padH, a.padW)
+ mytester:assertlt((output - a.output):abs():max(),
+ precision_forward, 'error on forward ')
+
+ cudnn.functional.MaxPooling2D_updateGradInput(cudnn.getHandle(), input,
+ output, gradOutput, gradInput,
+ a.kH, a.kW, a.dH, a.dW,
+ a.padH, a.padW)
+ mytester:assertlt((gradInput - a.gradInput):abs():max(),
+ precision_forward, 'error on updateGradInput ')
+end
+
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())