diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:46:19 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:46:19 +0400 |
commit | 9bd3690b91b350c5aff9b33b32aafe1278b2ba08 (patch) | |
tree | fdb80aecb8469bbf57cbaf2c80bc95aedebfd7be /test | |
parent | 990243e328d4e235d8a47ed1089a6fabfd20f5cd (diff) |
SpatialReSampling:updateGradInput works with batches
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index e02c810..20821b5 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -168,6 +168,7 @@ function nnxtest.SpatialReSampling_1() mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + -- test batches (4D input) local batchSize = math.random(4,8) local input2 = torch.rand(batchSize,fanin,sizey,sizex) input2[2]:copy(input) @@ -175,7 +176,7 @@ function nnxtest.SpatialReSampling_1() local output = module:forward(input):clone() local output2 = module:forward(input2) mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err') - if true then return end + local gradInput = module:backward(input, output):clone() local gradInput2 = module:backward(input2, output2) mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err') |