Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-06-26 19:32:57 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-06-26 19:32:57 +0400
commit4725c6b639f8dfc5d0440557c65e5dbc6fec1873 (patch)
treecbce1545c19a1c674c95a09ee2d9201f1adfb178 /test
parent1310a045ebc69a9f9e8c57d07af587a6535d5ae9 (diff)
Added SpatialUpSamplingNearest module.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua64
1 files changed, 64 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 5db941a..7a23c5e 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1886,6 +1886,70 @@ function nntest.View()
"Error in minibatch nElement")
end
+-- Define a test for SpatialUpSamplingCuda
+function nntest.SpatialUpSamplingNearest()
+ local scale = torch.random(2,4)
+ for dim = 3,4 do
+ local m = nn.SpatialUpSamplingNearest(scale)
+
+ -- Create a randomly sized dimD vector
+ local shape = {}
+ for i = 1, dim do
+ table.insert(shape, torch.random(2, 2+dim-1))
+ end
+
+ -- Check that the gradient is correct by using finite elements
+ local input = torch.Tensor(unpack(shape)):zero()
+
+ local err = jac.testJacobian(m, input)
+ mytester:assertlt(err, precision, ' error on state ')
+
+ local ferr, berr = jac.testIO(m, input)
+ mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
+
+ -- Also check that the forward prop is correct.
+ input = torch.rand(unpack(shape))
+ local output = m:forward(input)
+
+ local feat
+ local nfeats
+ if input:dim() == 3 then
+ nfeats = shape[1]
+ feat = {0}
+ else
+ feat = {0, 0}
+ nfeats = shape[1] * shape[2]
+ end
+ feat[#feat+1] = 0 -- ydim
+ feat[#feat+1] = 0 -- xdim
+ local xdim = input:dim()
+ local ydim = input:dim()-1
+ local err = 0
+ for f = 1, nfeats do
+ if input:dim() == 4 then
+ feat[1] = math.floor((f-1) / shape[1]) + 1
+ feat[2] = math.mod((f-1), shape[2]) + 1
+ else
+ feat[1] = f
+ end
+ for y = 1, input:size(ydim) * scale do
+ for x = 1, input:size(xdim) * scale do
+ feat[ydim] = y
+ feat[xdim] = x
+ local oval = output[feat]
+ feat[ydim] = math.floor((y-1)/scale)+1
+ feat[xdim] = math.floor((x-1)/scale)+1
+ local ival = input[feat]
+ err = math.max(err, math.abs(oval-ival))
+ end
+ end
+ end
+
+ mytester:assertlt(err, precision, ' fprop is incorrect ')
+ end
+end
+
mytester:add(nntest)
if not nn then