From 1750204c0e35175920de446cfaeb48e4543cfa9f Mon Sep 17 00:00:00 2001 From: fsuzanomassa Date: Tue, 20 Jan 2015 19:54:54 +0100 Subject: Adding strides for avoiding copies of non-contiguous tensors Fix for batched non-contigous data Removing unnecessary retain in SpatialAdaptiveMaxPooling --- test.lua | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 326909e..23c7fbd 100644 --- a/test.lua +++ b/test.lua @@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + + -- non-contiguous + + input = torch.rand(from,ini,inj):transpose(2,3) + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + local inputc = input:contiguous() -- contiguous + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + + -- non-contiguous batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4) + local inputc = input:contiguous() -- contiguous + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') end -- cgit v1.2.3