Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-23 21:57:27 +0300
committerGitHub <noreply@github.com>2016-06-23 21:57:27 +0300
commit0e7f4387c6034cdc6af05e4183aeab5702fdb95a (patch)
treeefb2bc59f2142ed8e576e58d6c1b36e2ecf895c4
parent67c87efa98fa9eb79f71e2c3b792a16077eef1fd (diff)
parentb80facc48401a02da024df6e9d831882bfa4ea73 (diff)
Merge pull request #301 from PraveerSINGH/SpatialFullConvolution-noBias
Add noBias for nn.SpatialFullConvolution
-rw-r--r--lib/THCUNN/SpatialFullConvolution.cu46
-rw-r--r--test.lua279
2 files changed, 195 insertions, 130 deletions
diff --git a/lib/THCUNN/SpatialFullConvolution.cu b/lib/THCUNN/SpatialFullConvolution.cu
index b826600..e6714b0 100644
--- a/lib/THCUNN/SpatialFullConvolution.cu
+++ b/lib/THCUNN/SpatialFullConvolution.cu
@@ -21,6 +21,7 @@ void THNN_CudaSpatialFullConvolution_updateOutput(
THCUNN_assertSameGPU(state, 6, input, output, weight,
bias, columns, ones);
+
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
int batch = 1;
@@ -100,16 +101,18 @@ void THNN_CudaSpatialFullConvolution_updateOutput(
long k_ = 1;
// Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices)
- THCudaBlas_gemm(
- state,
- 't', 'n',
- n_, m_, k_,
- 1,
- THCudaTensor_data(state, ones), k_,
- THCudaTensor_data(state, bias), k_,
- 1,
- THCudaTensor_data(state, output_n), n_
- );
+ if (bias) {
+ THCudaBlas_gemm(
+ state,
+ 't', 'n',
+ n_, m_, k_,
+ 1,
+ THCudaTensor_data(state, ones), k_,
+ THCudaTensor_data(state, bias), k_,
+ 1,
+ THCudaTensor_data(state, output_n), n_
+ );
+ }
}
@@ -236,6 +239,7 @@ void THNN_CudaSpatialFullConvolution_accGradParameters(
THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight,
gradBias, columns, ones);
+
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor is expected");
int batch = 1;
@@ -307,16 +311,18 @@ void THNN_CudaSpatialFullConvolution_accGradParameters(
long k_ = outputHeight * outputWidth;
// Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices)
- THCudaBlas_gemv(
- state,
- 't',
- k_, m_,
- scale,
- THCudaTensor_data(state, gradOutput_n), k_,
- THCudaTensor_data(state, ones), 1,
- 1,
- THCudaTensor_data(state, gradBias), 1
- );
+ if (gradBias) {
+ THCudaBlas_gemv(
+ state,
+ 't',
+ k_, m_,
+ scale,
+ THCudaTensor_data(state, gradOutput_n), k_,
+ THCudaTensor_data(state, ones), 1,
+ 1,
+ THCudaTensor_data(state, gradBias), 1
+ );
+ }
}
// Free
diff --git a/test.lua b/test.lua
index 81b8ab3..49c9ce3 100644
--- a/test.lua
+++ b/test.lua
@@ -1361,29 +1361,43 @@ function cunntest.SpatialFullConvolution_forward_single()
from, inj, ini, kj, ki, to, outj, outi, sj, si, padH, padW, adjH, adjW)
times[title] = tm
- local input = torch.randn(from,inj,ini)
- local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
- local groundtruth = sconv:forward(input)
- local a = torch.Timer()
- for i = 1,nloop do
- groundtruth = sconv:forward(input)
- end
- tm.cpu = a:time().real
+ local function jacTests(noBias)
+ noBias = noBias or false
+ local input = torch.randn(from,inj,ini)
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
+ if noBias then
+ sconv:noBias()
+ end
+ local groundtruth = sconv:forward(input)
+ local a = torch.Timer()
+ for i = 1,nloop do
+ groundtruth = sconv:forward(input)
+ end
+ tm.cpu = a:time().real
- input = input:cuda()
- local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
- gconv.weight = sconv.weight:cuda()
- gconv.bias = sconv.bias:cuda()
- local rescuda = gconv:forward(input)
- a:reset()
- for i = 1,nloop do
- rescuda = gconv:forward(input)
+ input = input:cuda()
+ local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
+ if noBias then
+ gconv:noBias()
+ end
+ gconv.weight = sconv.weight:cuda()
+ if gconv.bias then
+ gconv.bias = sconv.bias:cuda()
+ end
+ local rescuda = gconv:forward(input)
+ a:reset()
+ for i = 1,nloop do
+ rescuda = gconv:forward(input)
+ end
+ cutorch.synchronize()
+ tm.gpu = a:time().real
+
+ local error = rescuda:float() - groundtruth
+ mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
end
- cutorch.synchronize()
- tm.gpu = a:time().real
- local error = rescuda:float() - groundtruth
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
+ jacTests(false)
+ jacTests(true)
end
function cunntest.SpatialFullConvolution_forward_batch()
@@ -1408,29 +1422,43 @@ function cunntest.SpatialFullConvolution_forward_batch()
bs, from, inj, ini, kj, ki, bs, to, outj, outi, sj, si, padH, padW, adjH, adjW)
times[title] = tm
- local input = torch.randn(bs,from,inj,ini)
- local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
- local groundtruth = sconv:forward(input)
- local a = torch.Timer()
- for i = 1,nloop do
- groundtruth = sconv:forward(input)
- end
- tm.cpu = a:time().real
+ local function jacTests(noBias)
+ noBias = noBias or false
+ local input = torch.randn(bs,from,inj,ini)
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
+ if noBias then
+ sconv:noBias()
+ end
+ local groundtruth = sconv:forward(input)
+ local a = torch.Timer()
+ for i = 1,nloop do
+ groundtruth = sconv:forward(input)
+ end
+ tm.cpu = a:time().real
- input = input:cuda()
- local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
- gconv.weight = sconv.weight:cuda()
- gconv.bias = sconv.bias:cuda()
- local rescuda = gconv:forward(input)
- a:reset()
- for i = 1,nloop do
- rescuda = gconv:forward(input)
+ input = input:cuda()
+ local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
+ if noBias then
+ gconv:noBias()
+ end
+ gconv.weight = sconv.weight:cuda()
+ if gconv.bias then
+ gconv.bias = sconv.bias:cuda()
+ end
+ local rescuda = gconv:forward(input)
+ a:reset()
+ for i = 1,nloop do
+ rescuda = gconv:forward(input)
+ end
+ cutorch.synchronize()
+ tm.gpu = a:time().real
+
+ local error = rescuda:float() - groundtruth
+ mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
end
- cutorch.synchronize()
- tm.gpu = a:time().real
- local error = rescuda:float() - groundtruth
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ')
+ jacTests(false)
+ jacTests(true)
end
function cunntest.SpatialFullConvolution_backward_single()
@@ -1454,46 +1482,62 @@ function cunntest.SpatialFullConvolution_backward_single()
from, inj, ini, kj, ki, to, outj, outi, sj, si, padH, padW, adjH, adjW)
times[title] = tm
- local input = torch.randn(from,inj,ini)
- local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
- local output = sconv:forward(input)
- local gradOutput = output:clone():normal()
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local a = torch.Timer()
- for i = 1,nloop do
+ local function jacTests(noBias)
+ noBias = noBias or false
+ local input = torch.randn(from,inj,ini)
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
+ if noBias then
+ sconv:noBias()
+ end
+ local output = sconv:forward(input)
+ local gradOutput = output:clone():normal()
sconv:zeroGradParameters()
- groundgrad = sconv:backward(input, gradOutput)
- end
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
- tm.cpu = a:time().real
+ local groundgrad = sconv:backward(input, gradOutput)
+ local a = torch.Timer()
+ for i = 1,nloop do
+ sconv:zeroGradParameters()
+ groundgrad = sconv:backward(input, gradOutput)
+ end
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+ tm.cpu = a:time().real
- input = input:cuda()
- gradOutput = gradOutput:cuda()
- local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
- gconv.weight = sconv.weight:cuda()
- gconv.bias = sconv.bias:cuda()
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- a:reset()
- for i = 1,nloop do
+ input = input:cuda()
+ gradOutput = gradOutput:cuda()
+ local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
+ if noBias then
+ gconv:noBias()
+ end
+ gconv.weight = sconv.weight:cuda()
+ if gconv.bias then
+ gconv.bias = sconv.bias:cuda()
+ end
+ gconv:forward(input)
gconv:zeroGradParameters()
- rescuda = gconv:backward(input, gradOutput)
- end
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
- cutorch.synchronize()
- tm.gpu = a:time().real
+ local rescuda = gconv:backward(input, gradOutput)
+ a:reset()
+ for i = 1,nloop do
+ gconv:zeroGradParameters()
+ rescuda = gconv:backward(input, gradOutput)
+ end
+ local weightcuda = gconv.gradWeight
+ cutorch.synchronize()
+ tm.gpu = a:time().real
- local error = rescuda:float() - groundgrad
- local werror = weightcuda:float() - groundweight
- local berror = biascuda:float() - groundbias
+ local error = rescuda:float() - groundgrad
+ local werror = weightcuda:float() - groundweight
- mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
- mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
- mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
+ mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
+
+ if gconv.bias then
+ local berror = gconv.gradBias:float() - groundbias
+ mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ end
+ end
+
+ jacTests(false)
+ jacTests(true)
end
function cunntest.SpatialFullConvolution_backward_batch()
@@ -1520,46 +1564,61 @@ function cunntest.SpatialFullConvolution_backward_batch()
bs, to, outj, outi, sj, si, padH, padW, adjH, adjW)
times[title] = tm
- local input = torch.randn(bs,from,inj,ini)
- local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
- local output = sconv:forward(input)
- local gradOutput = output:clone():normal()
- sconv:zeroGradParameters()
- local groundgrad = sconv:backward(input, gradOutput)
- local a = torch.Timer()
- for i = 1,nloop do
+ local function jacTests(noBias)
+ noBias = noBias or false
+ local input = torch.randn(bs,from,inj,ini)
+ local sconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH)
+ if noBias then
+ sconv:noBias()
+ end
+ local output = sconv:forward(input)
+ local gradOutput = output:clone():normal()
sconv:zeroGradParameters()
- groundgrad = sconv:backward(input, gradOutput)
- end
- local groundweight = sconv.gradWeight
- local groundbias = sconv.gradBias
- tm.cpu = a:time().real
+ local groundgrad = sconv:backward(input, gradOutput)
+ local a = torch.Timer()
+ for i = 1,nloop do
+ sconv:zeroGradParameters()
+ groundgrad = sconv:backward(input, gradOutput)
+ end
+ local groundweight = sconv.gradWeight
+ local groundbias = sconv.gradBias
+ tm.cpu = a:time().real
- input = input:cuda()
- gradOutput = gradOutput:cuda()
- local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
- gconv.weight = sconv.weight:cuda()
- gconv.bias = sconv.bias:cuda()
- gconv:forward(input)
- gconv:zeroGradParameters()
- local rescuda = gconv:backward(input, gradOutput)
- a:reset()
- for i = 1,nloop do
+ input = input:cuda()
+ gradOutput = gradOutput:cuda()
+ local gconv = nn.SpatialFullConvolution(from,to,ki,kj,si,sj,padW,padH,adjW,adjH):cuda()
+ if noBias then
+ gconv:noBias()
+ end
+ gconv.weight = sconv.weight:cuda()
+ if gconv.bias then
+ gconv.bias = sconv.bias:cuda()
+ end
+ gconv:forward(input)
gconv:zeroGradParameters()
- rescuda = gconv:backward(input, gradOutput)
- end
- local weightcuda = gconv.gradWeight
- local biascuda = gconv.gradBias
- cutorch.synchronize()
- tm.gpu = a:time().real
+ local rescuda = gconv:backward(input, gradOutput)
+ a:reset()
+ for i = 1,nloop do
+ gconv:zeroGradParameters()
+ rescuda = gconv:backward(input, gradOutput)
+ end
+ local weightcuda = gconv.gradWeight
+ cutorch.synchronize()
+ tm.gpu = a:time().real
- local error = rescuda:float() - groundgrad
- local werror = weightcuda:float() - groundweight
- local berror = biascuda:float() - groundbias
+ local error = rescuda:float() - groundgrad
+ local werror = weightcuda:float() - groundweight
- mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
- mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
- mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
+ mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
+ if gconv.bias then
+ local berror = gconv.gradBias:float() - groundbias
+ mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ end
+ end
+
+ jacTests(false)
+ jacTests(true)
end
function cunntest.SpatialDilatedConvolution_forward_single()