diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:33:20 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 01:33:20 +0400 |
commit | 990243e328d4e235d8a47ed1089a6fabfd20f5cd (patch) | |
tree | f2bf4cd79075b049e7f817f0b7f2b85b4542bcb0 /generic | |
parent | 12349b4f2a341099a9c7f7d8bfcc31efffc36c62 (diff) |
SpatialReSampling:updateOutput works with batches
Diffstat (limited to 'generic')
-rw-r--r-- | generic/SpatialReSampling.c | 118 |
1 files changed, 71 insertions, 47 deletions
diff --git a/generic/SpatialReSampling.c b/generic/SpatialReSampling.c index c8d4bdd..81be1d2 100644 --- a/generic/SpatialReSampling.c +++ b/generic/SpatialReSampling.c @@ -12,22 +12,36 @@ static int nn_(SpatialReSampling_updateOutput)(lua_State *L) { // get all params - THTensor *input = luaT_checkudata(L, 2, torch_Tensor); + THTensor *input_ = luaT_checkudata(L, 2, torch_Tensor); int owidth = luaT_getfieldcheckint(L, 1, "owidth"); int oheight = luaT_getfieldcheckint(L, 1, "oheight"); - THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); + THTensor *output_ = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); // check dims - luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected"); + luaL_argcheck(L, (input_->nDimension == 3) || (input_->nDimension == 4), 2, "3D or 4D tensor expected"); // dims - int iwidth = input->size[2]; - int iheight = input->size[1]; - int ochannels = input->size[0]; + int channelDim = 0; + int batchSize = 1; + if (input_->nDimension == 4){ + channelDim = 1; + batchSize = input_->size[0]; + } + + int iwidth = input_->size[channelDim + 2]; + int iheight = input_->size[channelDim + 1]; + int ochannels = input_->size[channelDim + 0]; // resize output - THTensor_(resize3d)(output, ochannels, oheight, owidth); - + if (input_->nDimension == 3) + THTensor_(resize3d)(output_, ochannels, oheight, owidth); + else + THTensor_(resize4d)(output_, batchSize, ochannels, oheight, owidth); + + // select example + THTensor *output = THTensor_(newWithTensor)(output_); + THTensor *input = THTensor_(newWithTensor)(input_); + // select planes THTensor *outputPlane = THTensor_(new)(); THTensor *inputPlane = THTensor_(new)(); @@ -36,45 +50,53 @@ static int nn_(SpatialReSampling_updateOutput)(lua_State *L) float wratio = (float)(iwidth-1) / (owidth-1); float hratio = (float)(iheight-1) / (oheight-1); - // resample each plane - int k; - for (k=0; k<ochannels; k++) { - // get planes - THTensor_(select)(inputPlane, input, 0, k); - THTensor_(select)(outputPlane, output, 0, k); - - // for each plane, resample - int x,y; - for (y=0; y<oheight; y++) { - for (x=0; x<owidth; x++) { - // subpixel position: - float ix = wratio*x; - float iy = hratio*y; - - // 4 nearest neighbors: - float ix_nw = floor(ix); - float iy_nw = floor(iy); - float ix_ne = ix_nw + 1; - float iy_ne = iy_nw; - float ix_sw = ix_nw; - float iy_sw = iy_nw + 1; - float ix_se = ix_nw + 1; - float iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - float se = (ix-ix_nw)*(iy-iy_nw); - float sw = (ix_ne-ix)*(iy-iy_ne); - float ne = (ix-ix_sw)*(iy_sw-iy); - float nw = (ix_se-ix)*(iy_se-iy); - - // weighted sum of neighbors: - double sum = THTensor_(get2d)(inputPlane, iy_nw, ix_nw) * nw - + THTensor_(get2d)(inputPlane, iy_ne, MIN(ix_ne,iwidth-1)) * ne - + THTensor_(get2d)(inputPlane, MIN(iy_sw,iheight-1), ix_sw) * sw - + THTensor_(get2d)(inputPlane, MIN(iy_se,iheight-1), MIN(ix_se,iwidth-1)) * se; - - // set output - THTensor_(set2d)(outputPlane, y, x, sum); + int b; + for (b=0; b<batchSize; b++) { + if (input_->nDimension == 4) + { + THTensor_(select)(input, input_, 0, b); + THTensor_(select)(output, output_, 0, b); + } + // resample each plane + int k; + for (k=0; k<ochannels; k++) { + // get planes + THTensor_(select)(inputPlane, input, 0, k); + THTensor_(select)(outputPlane, output, 0, k); + + // for each plane, resample + int x,y; + for (y=0; y<oheight; y++) { + for (x=0; x<owidth; x++) { + // subpixel position: + float ix = wratio*x; + float iy = hratio*y; + + // 4 nearest neighbors: + float ix_nw = floor(ix); + float iy_nw = floor(iy); + float ix_ne = ix_nw + 1; + float iy_ne = iy_nw; + float ix_sw = ix_nw; + float iy_sw = iy_nw + 1; + float ix_se = ix_nw + 1; + float iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + float se = (ix-ix_nw)*(iy-iy_nw); + float sw = (ix_ne-ix)*(iy-iy_ne); + float ne = (ix-ix_sw)*(iy_sw-iy); + float nw = (ix_se-ix)*(iy_se-iy); + + // weighted sum of neighbors: + double sum = THTensor_(get2d)(inputPlane, iy_nw, ix_nw) * nw + + THTensor_(get2d)(inputPlane, iy_ne, MIN(ix_ne,iwidth-1)) * ne + + THTensor_(get2d)(inputPlane, MIN(iy_sw,iheight-1), ix_sw) * sw + + THTensor_(get2d)(inputPlane, MIN(iy_se,iheight-1), MIN(ix_se,iwidth-1)) * se; + + // set output + THTensor_(set2d)(outputPlane, y, x, sum); + } } } } @@ -82,6 +104,8 @@ static int nn_(SpatialReSampling_updateOutput)(lua_State *L) // cleanup THTensor_(free)(inputPlane); THTensor_(free)(outputPlane); + THTensor_(free)(input); + THTensor_(free)(output); return 1; } |