diff options
-rw-r--r-- | generic/SpatialAdaptiveMaxPooling.c | 33 | ||||
-rw-r--r-- | test.lua | 25 |
2 files changed, 47 insertions, 11 deletions
diff --git a/generic/SpatialAdaptiveMaxPooling.c b/generic/SpatialAdaptiveMaxPooling.c index 4c46f28..85f728b 100644 --- a/generic/SpatialAdaptiveMaxPooling.c +++ b/generic/SpatialAdaptiveMaxPooling.c @@ -6,7 +6,9 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real real *indx_p, real *indy_p, long nslices, long iwidth, long iheight, - long owidth, long oheight) + long owidth, long oheight, + long stridew,long strideh, + long strided) { long k; #pragma omp parallel for private(k) @@ -28,7 +30,7 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real int kW = x_end-x_start; /* local pointers */ - real *ip = input_p + k*iwidth*iheight + y_start*iwidth + x_start; + real *ip = input_p + k*strided + y_start*strideh + x_start*stridew; real *op = output_p + k*owidth*oheight + i*owidth + j; real *indyp = indy_p + k*owidth*oheight + i*owidth + j; real *indxp = indx_p + k*owidth*oheight + i*owidth + j; @@ -42,7 +44,7 @@ static void nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(real *input_p,real { for(x = 0; x < kW; x++) { - real val = *(ip + y*iwidth + x); + real val = *(ip + y*strideh + x*stridew); if (val > maxval) { maxval = val; @@ -76,6 +78,11 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L) long nslices; long iheight; long iwidth; + + long istride_d; + long istride_h; + long istride_w; + long istride_b; real *input_data; real *output_data; @@ -86,6 +93,7 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L) if (input->nDimension == 4) { + istride_b = input->stride[0]; nbatch = input->size[0]; dimw++; dimh++; @@ -95,9 +103,10 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L) nslices = input->size[dimh-1]; iheight = input->size[dimh]; iwidth = input->size[dimw]; - - /* get contiguous input */ - input = THTensor_(newContiguous)(input); + /* strides */ + istride_d = input->stride[dimh-1]; + istride_h = input->stride[dimh]; + istride_w = input->stride[dimw]; /* resize output */ if (input->nDimension == 3) @@ -114,7 +123,9 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L) indices_data+nslices*owidth*oheight, indices_data, nslices, iwidth, iheight, - owidth, oheight); + owidth, oheight, + istride_w,istride_h, + istride_d); } else { @@ -131,16 +142,16 @@ static int nn_(SpatialAdaptiveMaxPooling_updateOutput)(lua_State *L) #pragma omp parallel for private(p) for (p = 0; p < nbatch; p++) { - nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight, + nn_(SpatialAdaptiveMaxPooling_updateOutput_frame)(input_data+p*istride_b, output_data+p*nslices*owidth*oheight, indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight, nslices, iwidth, iheight, - owidth, oheight); + owidth, oheight, + istride_w,istride_h, + istride_d); } } - /* cleanup */ - THTensor_(free)(input); return 1; } @@ -1832,6 +1832,31 @@ function nntest.SpatialAdaptiveMaxPooling() local ferr, berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + + -- non-contiguous + + input = torch.rand(from,ini,inj):transpose(2,3) + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + local inputc = input:contiguous() -- contiguous + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ') + + -- non-contiguous batch + local nbatch = math.random(1,3) + input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4) + local inputc = input:contiguous() -- contiguous + module = nn.SpatialAdaptiveMaxPooling(ki,kj) + + local output = module:forward(input):clone() + local outputc = module:forward(inputc):clone() + mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') + local gradInput = module:backward(input, output):clone() + local gradInputc = module:backward(inputc, outputc):clone() + mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ') end |