Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfsuzanomassa <fvsmassa@gmail.com>2015-01-20 21:54:54 +0300
committerfsuzanomassa <fvsmassa@gmail.com>2015-04-21 21:33:24 +0300
commit1750204c0e35175920de446cfaeb48e4543cfa9f (patch)
treeb7c426d41800d83a6b92e1ac7c0e0b5ebba45f49
parenta2db5ec31f2dd236186c376a04daa31af319e39d (diff)
Adding strides for avoiding copies of non-contiguous tensors
Fix for batched non-contigous data Removing unnecessary retain in SpatialAdaptiveMaxPooling
-rw-r--r--generic/SpatialAdaptiveMaxPooling.c33
-rw-r--r--test.lua25
2 files changed, 47 insertions, 11 deletions
diff --git a/generic/SpatialAdaptiveMaxPooling.c b/generic/SpatialAdaptiveMaxPooling.c
index 4c46f28..85f728b 100644
--- a/generic/SpatialAdaptiveMaxPooling.c
+++ b/generic/SpatialAdaptiveMaxPooling.c
@@ -6,7 +6,9 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real
real *indx_p, real *indy_p,
long nslices,
long iwidth, long iheight,
- long owidth, long oheight)
+ long owidth, long oheight,
+ long stridew,long strideh,
+ long strided)
{
long k;
#pragma omp parallel for private(k)
@@ -28,7 +30,7 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real
int kW = x_end-x_start;
/* local pointers */
- real *ip = input_p + k*iwidth*iheight + y_start*iwidth + x_start;
+ real *ip = input_p + k*strided + y_start*strideh + x_start*stridew;
real *op = output_p + k*owidth*oheight + i*owidth + j;
real *indyp = indy_p + k*owidth*oheight + i*owidth + j;
real *indxp = indx_p + k*owidth*oheight + i*owidth + j;
@@ -42,7 +44,7 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real
{
for(x = 0; x < kW; x++)
{
- real val = *(ip + y*iwidth + x);
+ real val = *(ip + y*strideh + x*stridew);
if (val > maxval)
{
maxval = val;
@@ -76,6 +78,11 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L)
long nslices;
long iheight;
long iwidth;
+
+ long istride_d;
+ long istride_h;
+ long istride_w;
+ long istride_b;
real *input_data;
real *output_data;
@@ -86,6 +93,7 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L)
if (input->nDimension == 4)
{
+ istride_b = input->stride[0];
nbatch = input->size[0];
dimw++;
dimh++;
@@ -95,9 +103,10 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L)
nslices = input->size[dimh-1];
iheight = input->size[dimh];
iwidth = input->size[dimw];
-
- /* get contiguous input */
- input = THTensor_(newContiguous)(input);
+ /* strides */
+ istride_d = input->stride[dimh-1];
+ istride_h = input->stride[dimh];
+ istride_w = input->stride[dimw];
/* resize output */
if (input->nDimension == 3)
@@ -114,7 +123,9 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L)
indices_data+nslices*owidth*oheight, indices_data,
nslices,
iwidth, iheight,
- owidth, oheight);
+ owidth, oheight,
+ istride_w,istride_h,
+ istride_d);
}
else
{
@@ -131,16 +142,16 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L)
#pragma omp parallel for private(p)
for (p = 0; p < nbatch; p++)
{
- nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight,
+ nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(input_data+p*istride_b, output_data+p*nslices*owidth*oheight,
indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight,
nslices,
iwidth, iheight,
- owidth, oheight);
+ owidth, oheight,
+ istride_w,istride_h,
+ istride_d);
}
}
- /* cleanup */
- THTensor_(free)(input);
return 1;
}
diff --git a/test.lua b/test.lua
index 326909e..23c7fbd 100644
--- a/test.lua
+++ b/test.lua
@@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling()
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+
+ -- non-contiguous
+
+ input = torch.rand(from,ini,inj):transpose(2,3)
+ module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+ local inputc = input:contiguous() -- contiguous
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+
+ -- non-contiguous batch
+ local nbatch = math.random(1,3)
+ input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4)
+ local inputc = input:contiguous() -- contiguous
+ module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
end