Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua216
1 files changed, 112 insertions, 104 deletions
diff --git a/test.lua b/test.lua
index 5aadb20..15f9990 100644
--- a/test.lua
+++ b/test.lua
@@ -1255,22 +1255,24 @@ function cunntest.SpatialConvolutionLocal_forward_single()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(from,inj,ini):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- local groundtruth = sconv:forward(input)
-
- input = makeNonContiguous(input:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- local rescuda = gconv:forward(input)
-
- local error = rescuda:double() - groundtruth:double()
- mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on state (forward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(from,inj,ini):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
end
end
@@ -1290,22 +1292,24 @@ function cunntest.SpatialConvolutionLocal_forward_batch()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(bs,from,inj,ini):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- local groundtruth = sconv:forward(input)
-
- input = makeNonContiguous(input:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- local rescuda = gconv:forward(input)
-
- local error = rescuda:double() - groundtruth:double()
- mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on state (forward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(bs,from,inj,ini):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
end
end
@@ -1324,42 +1328,44 @@ function cunntest.SpatialConvolutionLocal_backward_single()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(from,inj,ini):type(typename)
- local gradOutput = torch.randn(to,outj,outi):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- gradOutput = makeNonContiguous(gradOutput:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- sconv:forward(input)
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
-
- input = makeNonContiguous(input:type(typename))
- gradOutput = makeNonContiguous(gradOutput:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
-
- local error = rescuda:double() - groundgrad:double()
- local werror = weightcuda:double() - groundweight:double()
- local berror = biascuda:double() - groundbias:double()
-
- mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
- string.format('error on state (backward) with %s', typename))
- mytester:assertlt(werror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
- string.format('error on weight (backward) with %s', typename))
- mytester:assertlt(berror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
- string.format('error on bias (backward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(from,inj,ini):type(typename)
+ local gradOutput = torch.randn(to,outj,outi):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+ local weightcuda = gconv.gradWeight
+ local biascuda = gconv.gradBias
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+ local berror = biascuda:double() - groundbias:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) with %s', typename))
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
+ string.format('error on bias (backward) with %s', typename))
+ end
end
end
@@ -1379,42 +1385,44 @@ function cunntest.SpatialConvolutionLocal_backward_batch()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(bs,from,inj,ini):type(typename)
- local gradOutput = torch.randn(bs,to,outj,outi):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- gradOutput = makeNonContiguous(gradOutput:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- sconv:forward(input)
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
-
- input = makeNonContiguous(input:type(typename))
- gradOutput = makeNonContiguous(gradOutput:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
-
- local error = rescuda:double() - groundgrad:double()
- local werror = weightcuda:double() - groundweight:double()
- local berror = biascuda:double() - groundbias:double()
-
- mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
- string.format('error on state (backward) with %s', typename))
- mytester:assertlt(werror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
- string.format('error on weight (backward) with %s', typename))
- mytester:assertlt(berror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
- string.format('error on bias (backward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(bs,from,inj,ini):type(typename)
+ local gradOutput = torch.randn(bs,to,outj,outi):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+ local weightcuda = gconv.gradWeight
+ local biascuda = gconv.gradBias
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+ local berror = biascuda:double() - groundbias:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) with %s', typename))
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
+ string.format('error on bias (backward) with %s', typename))
+ end
end
end