Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-10-08 01:33:20 +0400
committerNicholas Leonard <nick@nikopia.org>2014-10-08 01:33:20 +0400
commit990243e328d4e235d8a47ed1089a6fabfd20f5cd (patch)
treef2bf4cd79075b049e7f817f0b7f2b85b4542bcb0 /generic
parent12349b4f2a341099a9c7f7d8bfcc31efffc36c62 (diff)
SpatialReSampling:updateOutput works with batches
Diffstat (limited to 'generic')
-rw-r--r--generic/SpatialReSampling.c118
1 files changed, 71 insertions, 47 deletions
diff --git a/generic/SpatialReSampling.c b/generic/SpatialReSampling.c
index c8d4bdd..81be1d2 100644
--- a/generic/SpatialReSampling.c
+++ b/generic/SpatialReSampling.c
@@ -12,22 +12,36 @@
static int nn_(SpatialReSampling_updateOutput)(lua_State *L)
{
// get all params
- THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *input_ = luaT_checkudata(L, 2, torch_Tensor);
int owidth = luaT_getfieldcheckint(L, 1, "owidth");
int oheight = luaT_getfieldcheckint(L, 1, "oheight");
- THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
+ THTensor *output_ = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
// check dims
- luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
+ luaL_argcheck(L, (input_->nDimension == 3) || (input_->nDimension == 4), 2, "3D or 4D tensor expected");
// dims
- int iwidth = input->size[2];
- int iheight = input->size[1];
- int ochannels = input->size[0];
+ int channelDim = 0;
+ int batchSize = 1;
+ if (input_->nDimension == 4){
+ channelDim = 1;
+ batchSize = input_->size[0];
+ }
+
+ int iwidth = input_->size[channelDim + 2];
+ int iheight = input_->size[channelDim + 1];
+ int ochannels = input_->size[channelDim + 0];
// resize output
- THTensor_(resize3d)(output, ochannels, oheight, owidth);
-
+ if (input_->nDimension == 3)
+ THTensor_(resize3d)(output_, ochannels, oheight, owidth);
+ else
+ THTensor_(resize4d)(output_, batchSize, ochannels, oheight, owidth);
+
+ // select example
+ THTensor *output = THTensor_(newWithTensor)(output_);
+ THTensor *input = THTensor_(newWithTensor)(input_);
+
// select planes
THTensor *outputPlane = THTensor_(new)();
THTensor *inputPlane = THTensor_(new)();
@@ -36,45 +50,53 @@ static int nn_(SpatialReSampling_updateOutput)(lua_State *L)
float wratio = (float)(iwidth-1) / (owidth-1);
float hratio = (float)(iheight-1) / (oheight-1);
- // resample each plane
- int k;
- for (k=0; k<ochannels; k++) {
- // get planes
- THTensor_(select)(inputPlane, input, 0, k);
- THTensor_(select)(outputPlane, output, 0, k);
-
- // for each plane, resample
- int x,y;
- for (y=0; y<oheight; y++) {
- for (x=0; x<owidth; x++) {
- // subpixel position:
- float ix = wratio*x;
- float iy = hratio*y;
-
- // 4 nearest neighbors:
- float ix_nw = floor(ix);
- float iy_nw = floor(iy);
- float ix_ne = ix_nw + 1;
- float iy_ne = iy_nw;
- float ix_sw = ix_nw;
- float iy_sw = iy_nw + 1;
- float ix_se = ix_nw + 1;
- float iy_se = iy_nw + 1;
-
- // get surfaces to each neighbor:
- float se = (ix-ix_nw)*(iy-iy_nw);
- float sw = (ix_ne-ix)*(iy-iy_ne);
- float ne = (ix-ix_sw)*(iy_sw-iy);
- float nw = (ix_se-ix)*(iy_se-iy);
-
- // weighted sum of neighbors:
- double sum = THTensor_(get2d)(inputPlane, iy_nw, ix_nw) * nw
- + THTensor_(get2d)(inputPlane, iy_ne, MIN(ix_ne,iwidth-1)) * ne
- + THTensor_(get2d)(inputPlane, MIN(iy_sw,iheight-1), ix_sw) * sw
- + THTensor_(get2d)(inputPlane, MIN(iy_se,iheight-1), MIN(ix_se,iwidth-1)) * se;
-
- // set output
- THTensor_(set2d)(outputPlane, y, x, sum);
+ int b;
+ for (b=0; b<batchSize; b++) {
+ if (input_->nDimension == 4)
+ {
+ THTensor_(select)(input, input_, 0, b);
+ THTensor_(select)(output, output_, 0, b);
+ }
+ // resample each plane
+ int k;
+ for (k=0; k<ochannels; k++) {
+ // get planes
+ THTensor_(select)(inputPlane, input, 0, k);
+ THTensor_(select)(outputPlane, output, 0, k);
+
+ // for each plane, resample
+ int x,y;
+ for (y=0; y<oheight; y++) {
+ for (x=0; x<owidth; x++) {
+ // subpixel position:
+ float ix = wratio*x;
+ float iy = hratio*y;
+
+ // 4 nearest neighbors:
+ float ix_nw = floor(ix);
+ float iy_nw = floor(iy);
+ float ix_ne = ix_nw + 1;
+ float iy_ne = iy_nw;
+ float ix_sw = ix_nw;
+ float iy_sw = iy_nw + 1;
+ float ix_se = ix_nw + 1;
+ float iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ float se = (ix-ix_nw)*(iy-iy_nw);
+ float sw = (ix_ne-ix)*(iy-iy_ne);
+ float ne = (ix-ix_sw)*(iy_sw-iy);
+ float nw = (ix_se-ix)*(iy_se-iy);
+
+ // weighted sum of neighbors:
+ double sum = THTensor_(get2d)(inputPlane, iy_nw, ix_nw) * nw
+ + THTensor_(get2d)(inputPlane, iy_ne, MIN(ix_ne,iwidth-1)) * ne
+ + THTensor_(get2d)(inputPlane, MIN(iy_sw,iheight-1), ix_sw) * sw
+ + THTensor_(get2d)(inputPlane, MIN(iy_se,iheight-1), MIN(ix_se,iwidth-1)) * se;
+
+ // set output
+ THTensor_(set2d)(outputPlane, y, x, sum);
+ }
}
}
}
@@ -82,6 +104,8 @@ static int nn_(SpatialReSampling_updateOutput)(lua_State *L)
// cleanup
THTensor_(free)(inputPlane);
THTensor_(free)(outputPlane);
+ THTensor_(free)(input);
+ THTensor_(free)(output);
return 1;
}