Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-08-02 18:05:04 +0300
committersoumith <soumith@fb.com>2015-08-02 18:05:04 +0300
commit3e6e918dac9e94d2f104da6e36f749312e5c3951 (patch)
treec65379be66bc1dc1f09006bb2e6ce70691c94496 /test
parent54492b930cbef09853b232fae0d5eeab2bdaa42f (diff)
adding a functional interface, with the bias calculations to start with
Diffstat (limited to 'test')
-rw-r--r--test/test.lua33
1 files changed, 33 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 4629697..dc50d94 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -582,6 +582,39 @@ function cudnntest.SoftMax_batch()
precision_backward, 'error on state (backward) ')
end
+function cudnntest.functional_SpatialBias()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local to = math.random(1,64)
+ local ki = math.random(1,15)
+ local kj = math.random(1,15)
+ local si = math.random(1,ki)
+ local sj = math.random(1,kj)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local scale = torch.uniform()
+ local input = torch.zeros(bs,from,inj,ini):cuda()
+ local mod = cudnn.SpatialConvolution(from,to,ki,kj,si,sj):cuda()
+ mod.weight:zero()
+ local groundtruth = mod:forward(input)
+ local result = groundtruth:clone():zero()
+ cudnn.functional.SpatialBias_updateOutput(mod.bias, result)
+ local error = result:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_forward, 'error on forward ')
+
+ mod:zeroGradParameters()
+ local gradOutput = groundtruth:clone():normal()
+ mod:backward(input, gradOutput, scale)
+ local groundtruth = mod.gradBias
+ local result = groundtruth:clone():zero()
+ cudnn.functional.SpatialBias_accGradParameters(gradOutput, result, scale)
+ error = result:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(),
+ precision_backward, 'error on accGradParameters ')
+end
torch.setdefaulttensortype('torch.FloatTensor')