Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-10-08 01:33:20 +0400
committerNicholas Leonard <nick@nikopia.org>2014-10-08 01:33:20 +0400
commit990243e328d4e235d8a47ed1089a6fabfd20f5cd (patch)
treef2bf4cd79075b049e7f817f0b7f2b85b4542bcb0 /test
parent12349b4f2a341099a9c7f7d8bfcc31efffc36c62 (diff)
SpatialReSampling:updateOutput works with batches
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua12
1 files changed, 12 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 58c305c..e02c810 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -167,6 +167,18 @@ function nnxtest.SpatialReSampling_1()
local ferr, berr = nn.Jacobian.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ local batchSize = math.random(4,8)
+ local input2 = torch.rand(batchSize,fanin,sizey,sizex)
+ input2[2]:copy(input)
+
+ local output = module:forward(input):clone()
+ local output2 = module:forward(input2)
+ mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err')
+ if true then return end
+ local gradInput = module:backward(input, output):clone()
+ local gradInput2 = module:backward(input2, output2)
+ mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err')
end
function nnxtest.SpatialReSampling_2()