Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfsuzanomassa <fvsmassa@gmail.com>2015-01-20 21:54:54 +0300
committerfsuzanomassa <fvsmassa@gmail.com>2015-04-21 21:33:24 +0300
commit1750204c0e35175920de446cfaeb48e4543cfa9f (patch)
treeb7c426d41800d83a6b92e1ac7c0e0b5ebba45f49 /test.lua
parenta2db5ec31f2dd236186c376a04daa31af319e39d (diff)
Adding strides for avoiding copies of non-contiguous tensors
Fix for batched non-contigous data Removing unnecessary retain in SpatialAdaptiveMaxPooling
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua25
1 files changed, 25 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 326909e..23c7fbd 100644
--- a/test.lua
+++ b/test.lua
@@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling()
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+
+ -- non-contiguous
+
+ input = torch.rand(from,ini,inj):transpose(2,3)
+ module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+ local inputc = input:contiguous() -- contiguous
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+
+ -- non-contiguous batch
+ local nbatch = math.random(1,3)
+ input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4)
+ local inputc = input:contiguous() -- contiguous
+ module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
end