diff options
author | Michael Pound <michael.pound@nottingham.ac.uk> | 2017-07-24 16:07:08 +0300 |
---|---|---|
committer | Michael Pound <michael.pound@nottingham.ac.uk> | 2017-07-24 16:07:08 +0300 |
commit | 528a871d06cb8afe5417fd8c7e075781bcdd10de (patch) | |
tree | 2f5dc22a4a0e792cc6ff30405bdc8d0c0094e649 | |
parent | e9d54e162b3628d5f8b4e4a0f6c2191f559ebc12 (diff) |
Added cunn tests for UpSampling module.
-rw-r--r-- | test.lua | 78 |
1 files changed, 78 insertions, 0 deletions
@@ -4437,6 +4437,84 @@ function cunntest.SpatialUpSamplingBilinear_backward_batch() end end +function cunntest.UpSampling_forward_batch() + local minibatch = torch.random(1, 10) + local f = torch.random(3, 10) + local d = torch.random(3, 10) + local h = torch.random(3, 10) + local w = torch.random(3, 10) + local scale = torch.random(2,5) + + for k, typename in ipairs(typenames) do + for _,mode in pairs({'nearest','linear'}) do + for dim = 4,5 do + local input + if (dim == 4) then + input = torch.randn(minibatch, f, h, w):type(typename) + else + input = torch.randn(minibatch, f, d, h, w):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + local sconv = nn.UpSampling(scale, mode):type(ctype) + local groundtruth = sconv:forward(input) + + input = makeNonContiguous(input:type(typename)) + local gconv = sconv:clone():type(typename) + local rescuda = gconv:forward(input) + + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename), + string.format('error on state (forward) with %s', typename)) + end + end + end +end + +function cunntest.UpSampling_backward_batch() + local minibatch = torch.random(1, 10) + local f = torch.random(3, 10) + local d = torch.random(3, 10) + local h = torch.random(3, 10) + local w = torch.random(3, 10) + local scale = torch.random(2,4) + + for k, typename in ipairs(typenames) do + for _,mode in pairs({'nearest','linear'}) do + for dim = 4,5 do + local input, gradOutput + if (dim == 4) then + input = torch.randn(minibatch, f, h, w):type(typename) + gradOutput = torch.randn(minibatch, f, h*scale, w*scale):type(typename) + else + input = torch.randn(minibatch, f, d, h, w):type(typename) + gradOutput = torch.randn(minibatch, f, d*scale, h*scale, w*scale):type(typename) + end + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + gradOutput = makeNonContiguous(gradOutput:type(ctype)) + local sconv = nn.UpSampling(scale, mode):type(ctype) + sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput) + + input = makeNonContiguous(input:type(typename)) + gradOutput = makeNonContiguous(gradOutput:type(typename)) + local gconv = sconv:clone():type(typename) + gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput) + + local error = rescuda:double() - groundgrad:double() + mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename), + string.format('error on state (backward) with %s', typename)) + end + end + end +end + function cunntest.l1cost() local size = math.random(300,500) |