Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-01-28 00:52:01 +0300
committersoumith <soumith@fb.com>2015-01-28 00:52:01 +0300
commit516c818d2b5e501b11165c4514d5260b42e9b289 (patch)
tree792aec8ff57a1068e54b583d60b664d4c489fd3e /test
parent7b169478b14f9f8740ba454435e92b46f10af590 (diff)
added avg pooling tests
Diffstat (limited to 'test')
-rw-r--r--test/test.lua73
1 files changed, 73 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index b27fd11..d3a7e35 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -318,6 +318,79 @@ function cudnntest.SpatialMaxPooling_single()
'error on state (backward) ')
end
+function cudnntest.SpatialAveragePooling_batch()
+ local bs = math.random(1,32)
+ local from = math.random(1,32)
+ local ki = math.random(2,4)
+ local kj = math.random(2,4)
+ local si = math.random(2,4)
+ local sj = math.random(2,4)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local input = torch.randn(bs,from,inj,ini):cuda()
+ local gradOutput = torch.randn(bs,from,outj,outi):cuda()
+
+ local sconv = nn.SpatialAveragePooling(ki,kj,si,sj):cuda()
+ local groundtruth = sconv:forward(input):clone()
+ groundtruth:mul(1/(ki*kj)) -- difference between nn and cudnn
+ local groundgrad = sconv:backward(input, gradOutput)
+ groundgrad:mul(1/(ki*kj)) -- difference between nn and cudnn
+ cutorch.synchronize()
+ local gconv = cudnn.SpatialAveragePooling(ki,kj,si,sj):cuda()
+ local rescuda = gconv:forward(input)
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+ local rescuda = gconv:forward(input)
+ local resgrad = gconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ mytester:asserteq(rescuda:dim(), 4, 'error in dimension')
+ mytester:asserteq(resgrad:dim(), 4, 'error in dimension')
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
+end
+
+function cudnntest.SpatialAveragePooling_single()
+ local from = math.random(1,32)
+ local ki = math.random(2,4)
+ local kj = math.random(2,4)
+ local si = math.random(2,4)
+ local sj = math.random(2,4)
+ local outi = math.random(1,64)
+ local outj = math.random(1,64)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local input = torch.randn(from,inj,ini):cuda()
+ local gradOutput = torch.randn(from,outj,outi):cuda()
+
+ local sconv = nn.SpatialAveragePooling(ki,kj,si,sj):cuda()
+ local groundtruth = sconv:forward(input):clone()
+ groundtruth:mul(1/(ki*kj)) -- difference between nn and cudnn
+ local groundgrad = sconv:backward(input, gradOutput)
+ groundgrad:mul(1/(ki*kj)) -- difference between nn and cudnn
+ cutorch.synchronize()
+ local gconv = cudnn.SpatialAveragePooling(ki,kj,si,sj):cuda()
+ local _ = gconv:forward(input)
+ -- serialize and deserialize
+ torch.save('modelTemp.t7', gconv)
+ gconv = torch.load('modelTemp.t7')
+ local rescuda = gconv:forward(input)
+ local resgrad = gconv:backward(input, gradOutput)
+ cutorch.synchronize()
+ mytester:asserteq(rescuda:dim(), 3, 'error in dimension')
+ mytester:asserteq(resgrad:dim(), 3, 'error in dimension')
+ local error = rescuda:float() - groundtruth:float()
+ mytester:assertlt(error:abs():max(), precision_forward,
+ 'error on state (forward) ')
+ error = resgrad:float() - groundgrad:float()
+ mytester:assertlt(error:abs():max(), precision_backward,
+ 'error on state (backward) ')
+end
+
local function nonlinSingle(nonlin)
local from = math.random(1,32)
local outi = math.random(1,64)