diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:33:20 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:33:20 +0400 |
commit | 990243e328d4e235d8a47ed1089a6fabfd20f5cd (patch) | |
tree | f2bf4cd79075b049e7f817f0b7f2b85b4542bcb0 /test | |
parent | 12349b4f2a341099a9c7f7d8bfcc31efffc36c62 (diff) |
SpatialReSampling:updateOutput works with batches
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 58c305c..e02c810 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -167,6 +167,18 @@ function nnxtest.SpatialReSampling_1() local ferr, berr = nn.Jacobian.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + local batchSize = math.random(4,8) + local input2 = torch.rand(batchSize,fanin,sizey,sizex) + input2[2]:copy(input) + + local output = module:forward(input):clone() + local output2 = module:forward(input2) + mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err') + if true then return end + local gradInput = module:backward(input, output):clone() + local gradInput2 = module:backward(input2, output2) + mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err') end function nnxtest.SpatialReSampling_2() |