diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-26 19:39:28 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-26 19:39:28 +0400 |
commit | f6842e9432955daae9d003deb4194c1ba467d8f4 (patch) | |
tree | cbce1545c19a1c674c95a09ee2d9201f1adfb178 /test | |
parent | 1310a045ebc69a9f9e8c57d07af587a6535d5ae9 (diff) | |
parent | 4725c6b639f8dfc5d0440557c65e5dbc6fec1873 (diff) |
Merge pull request #21 from jonathantompson/upsamplingnearest
Added SpatialUpSamplingNearest module.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 5db941a..7a23c5e 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1886,6 +1886,70 @@ function nntest.View() "Error in minibatch nElement") end +-- Define a test for SpatialUpSamplingCuda +function nntest.SpatialUpSamplingNearest() + local scale = torch.random(2,4) + for dim = 3,4 do + local m = nn.SpatialUpSamplingNearest(scale) + + -- Create a randomly sized dimD vector + local shape = {} + for i = 1, dim do + table.insert(shape, torch.random(2, 2+dim-1)) + end + + -- Check that the gradient is correct by using finite elements + local input = torch.Tensor(unpack(shape)):zero() + + local err = jac.testJacobian(m, input) + mytester:assertlt(err, precision, ' error on state ') + + local ferr, berr = jac.testIO(m, input) + mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') + + -- Also check that the forward prop is correct. + input = torch.rand(unpack(shape)) + local output = m:forward(input) + + local feat + local nfeats + if input:dim() == 3 then + nfeats = shape[1] + feat = {0} + else + feat = {0, 0} + nfeats = shape[1] * shape[2] + end + feat[#feat+1] = 0 -- ydim + feat[#feat+1] = 0 -- xdim + local xdim = input:dim() + local ydim = input:dim()-1 + local err = 0 + for f = 1, nfeats do + if input:dim() == 4 then + feat[1] = math.floor((f-1) / shape[1]) + 1 + feat[2] = math.mod((f-1), shape[2]) + 1 + else + feat[1] = f + end + for y = 1, input:size(ydim) * scale do + for x = 1, input:size(xdim) * scale do + feat[ydim] = y + feat[xdim] = x + local oval = output[feat] + feat[ydim] = math.floor((y-1)/scale)+1 + feat[xdim] = math.floor((x-1)/scale)+1 + local ival = input[feat] + err = math.max(err, math.abs(oval-ival)) + end + end + end + + mytester:assertlt(err, precision, ' fprop is incorrect ') + end +end + mytester:add(nntest) if not nn then |