Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-03-12 02:00:31 +0300
committerNicholas Leonard <nick@nikopia.org>2015-03-12 02:00:31 +0300
commit6e8d328550a6b9c2d9082884a1cc59d738788336 (patch)
treebac6dfa3d774bab91ed0e245b1c599c0c7167ca8 /test
parentb8165b784bda1877a7900b1dfffa4dcb421729e3 (diff)
nn.Padding
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua18
1 files changed, 18 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 58b72a2..5abd4e2 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -1041,6 +1041,24 @@ function nnxtest.PushPullTable()
mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull multi-backward error")
end
+function nnxtest.Padding()
+ local fanin = math.random(1,3)
+ local sizex = math.random(4,16)
+ local sizey = math.random(4,16)
+ local pad = math.random(-3,3)
+ local val = torch.randn(1):squeeze()
+ local module = nn.Padding(1, pad, 3, val)
+ local input = torch.rand(fanin,sizey,sizex)
+ local size = input:size():totable()
+ size[1] = size[1] + math.abs(pad)
+
+ local output = module:forward(input)
+ mytester:assertTableEq(size, output:size():totable(), 0.00001, "Padding size error")
+
+ local gradInput = module:backward(input, output)
+ mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
+end
+
function nnx.test(tests)
xlua.require('image',true)