diff options
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 25 |
1 files changed, 25 insertions, 0 deletions
@@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + + -- non-contiguous + + input = torch.rand(from,ini,inj):transpose(2,3) + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + local inputc = input:contiguous() -- contiguous + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + + -- non-contiguous batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4) + local inputc = input:contiguous() -- contiguous + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') end |