Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-07-24 18:22:18 +0300
committerGitHub <noreply@github.com>2017-07-24 18:22:18 +0300
commitb336dc940c513d0b42d2ef2940bec9199b4377cf (patch)
tree2f5dc22a4a0e792cc6ff30405bdc8d0c0094e649
parente9d54e162b3628d5f8b4e4a0f6c2191f559ebc12 (diff)
parent528a871d06cb8afe5417fd8c7e075781bcdd10de (diff)
Merge pull request #479 from mikepound/upsampling
Added cunn tests for UpSampling module.
-rw-r--r--test.lua78
1 files changed, 78 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 7903aa6..670c70e 100644
--- a/test.lua
+++ b/test.lua
@@ -4437,6 +4437,84 @@ function cunntest.SpatialUpSamplingBilinear_backward_batch()
end
end
+function cunntest.UpSampling_forward_batch()
+ local minibatch = torch.random(1, 10)
+ local f = torch.random(3, 10)
+ local d = torch.random(3, 10)
+ local h = torch.random(3, 10)
+ local w = torch.random(3, 10)
+ local scale = torch.random(2,5)
+
+ for k, typename in ipairs(typenames) do
+ for _,mode in pairs({'nearest','linear'}) do
+ for dim = 4,5 do
+ local input
+ if (dim == 4) then
+ input = torch.randn(minibatch, f, h, w):type(typename)
+ else
+ input = torch.randn(minibatch, f, d, h, w):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.UpSampling(scale, mode):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = sconv:clone():type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
+ end
+ end
+end
+
+function cunntest.UpSampling_backward_batch()
+ local minibatch = torch.random(1, 10)
+ local f = torch.random(3, 10)
+ local d = torch.random(3, 10)
+ local h = torch.random(3, 10)
+ local w = torch.random(3, 10)
+ local scale = torch.random(2,4)
+
+ for k, typename in ipairs(typenames) do
+ for _,mode in pairs({'nearest','linear'}) do
+ for dim = 4,5 do
+ local input, gradOutput
+ if (dim == 4) then
+ input = torch.randn(minibatch, f, h, w):type(typename)
+ gradOutput = torch.randn(minibatch, f, h*scale, w*scale):type(typename)
+ else
+ input = torch.randn(minibatch, f, d, h, w):type(typename)
+ gradOutput = torch.randn(minibatch, f, d*scale, h*scale, w*scale):type(typename)
+ end
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local sconv = nn.UpSampling(scale, mode):type(ctype)
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = sconv:clone():type(typename)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+
+ local error = rescuda:double() - groundgrad:double()
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ end
+ end
+ end
+end
+
function cunntest.l1cost()
local size = math.random(300,500)