diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 18 |
1 files changed, 18 insertions, 0 deletions
@@ -3301,6 +3301,24 @@ function nntest.SpatialBatchNormalization() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Padding() + local fanin = math.random(1,3) + local sizex = math.random(4,16) + local sizey = math.random(4,16) + local pad = math.random(-3,3) + local val = torch.randn(1):squeeze() + local module = nn.Padding(1, pad, 3, val) + local input = torch.rand(fanin,sizey,sizex) + local size = input:size():totable() + size[1] = size[1] + math.abs(pad) + + local output = module:forward(input) + mytester:assertTableEq(size, output:size():totable(), 0.00001, "Padding size error") + + local gradInput = module:backward(input, output) + mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error") +end + mytester:add(nntest) if not nn then |