Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-11-21 20:10:26 +0300
committersoumith <soumith@fb.com>2016-11-21 20:10:26 +0300
commitf94eb5989371e1f17d04190f7a06c62267d6a124 (patch)
treea8283899d0a9cc2fbedc7c44324736fb6ce40be7
parent50224b77d55a0e2cdb4c8e6b1716a1cdc381a12f (diff)
SpatialConvolutionLocal uses baddbmm
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu42
-rw-r--r--test.lua216
2 files changed, 121 insertions, 137 deletions
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index 0469029..6fe52a5 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -149,20 +149,11 @@ void THNN_(SpatialConvolutionLocal_updateOutput)(
THCTensor_(copy)(state, output_n, bias);
- for (int i = 0; i < outputHeight; i++) {
- for(int j = 0; j < outputWidth; j++) {
- int sliceidx = i * outputWidth + j;
- THCTensor_(select)(state, wslice, weight, 0, sliceidx);
- THCTensor_(select)(state, islice, finput3d, 0, sliceidx);
- THCTensor_(select)(state, oslice, output3d, 0, sliceidx);
- THCTensor_(addmm)(state, oslice, ScalarConvert<int, real>::to(1), oslice, ScalarConvert<int, real>::to(1), wslice, islice);
- }
- }
-
-
// weight: oH*oW x nOutputPlane x nInputPlane*kH*kW
// finput3d: oH*oW x nInputPlane*kH*kW x 1
- // THCTensor_(baddbmm)(state, output3d, 1.0, output3d, 1.0, weight, finput3d);
+ THCTensor_(baddbmm)(state, output3d, ScalarConvert<int, real>::to(1),
+ output3d, ScalarConvert<int, real>::to(1),
+ weight, finput3d);
// output3d: oH*oW x nOutputPlane x 1
THCTensor_(free)(state, output3d);
@@ -261,19 +252,12 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
kW*kH*nInputPlane, outputHeight*outputWidth,
1, kW*kH*nInputPlane*outputHeight*outputWidth);
- for (int i = 0; i < outputHeight; i++) {
- for(int j = 0; j < outputWidth; j++) {
- int sliceidx = i * outputWidth + j;
- THCTensor_(select)(state, wslice, weight, 0, sliceidx);
- THCTensor_(select)(state, gislice, fgradInput3d, 0, sliceidx);
- THCTensor_(select)(state, goslice, gradOutput3d, 0, sliceidx);
- THCTensor_(addmm)(state, gislice, ScalarConvert<int, real>::to(0), gislice, ScalarConvert<int, real>::to(1), wslice, goslice);
- }
- }
-
// weight: oH*oW x nInputPlane*kH*kW x nOutputPlane
// gradOutput3d: oH*oW x nOutputPlane x 1
- //THCTensor_(baddbmm)(state, fgradInput3d, 0.0, fgradInput3d, 1.0, weight, gradOutput3d);
+ THCTensor_(baddbmm)(state, fgradInput3d,
+ ScalarConvert<int, real>::to(0),
+ fgradInput3d, ScalarConvert<int, real>::to(1),
+ weight, gradOutput3d);
// fgradInput3d: oH*oW x nInputPlane*kH*kW x 1
// Unpack columns back into input:
@@ -385,18 +369,10 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
1, 1, THCTensor_(data)(state, finput_n)
);
- for (int i = 0; i < outputHeight; i++) {
- for(int j = 0; j < outputWidth; j++) {
- int sliceidx = i * outputWidth + j;
- THCTensor_(select)(state, gwslice, gradWeight, 0, sliceidx);
- THCTensor_(select)(state, goslice, gradOutput3d, 0, sliceidx);
- THCTensor_(select)(state, islice, finput3d, 0, sliceidx);
- THCTensor_(addmm)(state, gwslice, ScalarConvert<int, real>::to(1), gwslice, scale, goslice, islice);
- }
- }
// gradOutput3d: oH*oW x nOutputPlane x 1
// finput3d: oH*oW x 1 x kW*kH*nInputPlane
- //THCTensor_(baddbmm)(state, gradWeight, 1.0, gradWeight, scale, gradOutput3d, finput3d);
+ THCTensor_(baddbmm)(state, gradWeight, ScalarConvert<int, real>::to(1),
+ gradWeight, scale, gradOutput3d, finput3d);
// gradWeight: oH*oW x nOutputPlane x kW*kH*nInputPlane
THCTensor_(cadd)(state, gradBias, gradBias, scale, gradOutput_n);
diff --git a/test.lua b/test.lua
index 5aadb20..15f9990 100644
--- a/test.lua
+++ b/test.lua
@@ -1255,22 +1255,24 @@ function cunntest.SpatialConvolutionLocal_forward_single()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(from,inj,ini):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- local groundtruth = sconv:forward(input)
-
- input = makeNonContiguous(input:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- local rescuda = gconv:forward(input)
-
- local error = rescuda:double() - groundtruth:double()
- mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on state (forward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(from,inj,ini):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
end
end
@@ -1290,22 +1292,24 @@ function cunntest.SpatialConvolutionLocal_forward_batch()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(bs,from,inj,ini):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- local groundtruth = sconv:forward(input)
-
- input = makeNonContiguous(input:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- local rescuda = gconv:forward(input)
-
- local error = rescuda:double() - groundtruth:double()
- mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
- string.format('error on state (forward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(bs,from,inj,ini):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
end
end
@@ -1324,42 +1328,44 @@ function cunntest.SpatialConvolutionLocal_backward_single()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(from,inj,ini):type(typename)
- local gradOutput = torch.randn(to,outj,outi):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- gradOutput = makeNonContiguous(gradOutput:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- sconv:forward(input)
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
-
- input = makeNonContiguous(input:type(typename))
- gradOutput = makeNonContiguous(gradOutput:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
-
- local error = rescuda:double() - groundgrad:double()
- local werror = weightcuda:double() - groundweight:double()
- local berror = biascuda:double() - groundbias:double()
-
- mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
- string.format('error on state (backward) with %s', typename))
- mytester:assertlt(werror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
- string.format('error on weight (backward) with %s', typename))
- mytester:assertlt(berror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
- string.format('error on bias (backward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(from,inj,ini):type(typename)
+ local gradOutput = torch.randn(to,outj,outi):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+ local weightcuda = gconv.gradWeight
+ local biascuda = gconv.gradBias
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+ local berror = biascuda:double() - groundbias:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) with %s', typename))
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
+ string.format('error on bias (backward) with %s', typename))
+ end
end
end
@@ -1379,42 +1385,44 @@ function cunntest.SpatialConvolutionLocal_backward_batch()
local inj = (outj-1)*sj+kj-padH*2
for k, typename in ipairs(typenames) do
- local input = torch.randn(bs,from,inj,ini):type(typename)
- local gradOutput = torch.randn(bs,to,outj,outi):type(typename)
-
- local ctype = t2cpu[typename]
- input = makeNonContiguous(input:type(ctype))
- gradOutput = makeNonContiguous(gradOutput:type(ctype))
- local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
- sconv:forward(input)
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
-
- input = makeNonContiguous(input:type(typename))
- gradOutput = makeNonContiguous(gradOutput:type(typename))
- local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
- gconv.weight = sconv.weight:type(typename)
- gconv.bias = sconv.bias:type(typename)
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
-
- local error = rescuda:double() - groundgrad:double()
- local werror = weightcuda:double() - groundweight:double()
- local berror = biascuda:double() - groundbias:double()
-
- mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
- string.format('error on state (backward) with %s', typename))
- mytester:assertlt(werror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
- string.format('error on weight (backward) with %s', typename))
- mytester:assertlt(berror:abs():max(),
- precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
- string.format('error on bias (backward) with %s', typename))
+ if typename ~= "torch.CudaHalfTensor" then
+ local input = torch.randn(bs,from,inj,ini):type(typename)
+ local gradOutput = torch.randn(bs,to,outj,outi):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+ local sconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(ctype)
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = nn.SpatialConvolutionLocal(from,to,ini,inj,ki,kj,si,sj,padW,padH):type(typename)
+ gconv.weight = sconv.weight:type(typename)
+ gconv.bias = sconv.bias:type(typename)
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+ local weightcuda = gconv.gradWeight
+ local biascuda = gconv.gradBias
+
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+ local berror = biascuda:double() - groundbias:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, weightcuda:abs():max()),
+ string.format('error on weight (backward) with %s', typename))
+ mytester:assertlt(berror:abs():max(),
+ precision_backward_conv_weightbias(precision_backward, typename, biascuda:abs():max()),
+ string.format('error on bias (backward) with %s', typename))
+ end
end
end