diff options
author | fsuzanomassa <fvsmassa@gmail.com> | 2015-01-20 21:54:54 +0300 |
---|---|---|
committer | fsuzanomassa <fvsmassa@gmail.com> | 2015-04-21 21:33:24 +0300 |
commit | 1750204c0e35175920de446cfaeb48e4543cfa9f (patch) | |
tree | b7c426d41800d83a6b92e1ac7c0e0b5ebba45f49 /test.lua | |
parent | a2db5ec31f2dd236186c376a04daa31af319e39d (diff) |
Adding strides for avoiding copies of non-contiguous tensors
Fix for batched non-contigous data
Removing unnecessary retain in SpatialAdaptiveMaxPooling
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 25 |
1 files changed, 25 insertions, 0 deletions
@@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + + -- non-contiguous + + input = torch.rand(from,ini,inj):transpose(2,3) + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + local inputc = input:contiguous() -- contiguous + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + + -- non-contiguous batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4) + local inputc = input:contiguous() -- contiguous + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') end |