Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-10-08 01:46:19 +0400
committerNicholas Leonard <nick@nikopia.org>2014-10-08 01:46:19 +0400
commit9bd3690b91b350c5aff9b33b32aafe1278b2ba08 (patch)
treefdb80aecb8469bbf57cbaf2c80bc95aedebfd7be /test
parent990243e328d4e235d8a47ed1089a6fabfd20f5cd (diff)
SpatialReSampling:updateGradInput works with batches
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua3
1 files changed, 2 insertions, 1 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index e02c810..20821b5 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -168,6 +168,7 @@ function nnxtest.SpatialReSampling_1()
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+ -- test batches (4D input)
local batchSize = math.random(4,8)
local input2 = torch.rand(batchSize,fanin,sizey,sizex)
input2[2]:copy(input)
@@ -175,7 +176,7 @@ function nnxtest.SpatialReSampling_1()
local output = module:forward(input):clone()
local output2 = module:forward(input2)
mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err')
- if true then return end
+
local gradInput = module:backward(input, output):clone()
local gradInput2 = module:backward(input2, output2)
mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err')