diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-02 02:44:22 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-02 02:44:22 +0400 |
commit | c8406078fc871a8c1e0b7b0c967053859bf9d9f7 (patch) | |
tree | 4f01faab8d65b279a07c80cabd65d743d1ed3982 | |
parent | 2c6beea7dbfa04e17be5f5b7286a49c4cd137dbd (diff) |
Fixed SpatialMaxSampling to accept arbitrary inputs (allowing upsampling)
-rw-r--r-- | generic/SpatialMaxSampling.c | 13 | ||||
-rw-r--r-- | test/test-all.lua | 4 |
2 files changed, 8 insertions, 9 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c index d53b364..5b795b3 100644 --- a/generic/SpatialMaxSampling.c +++ b/generic/SpatialMaxSampling.c @@ -20,7 +20,6 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L) // check dims luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected"); - luaL_argcheck(L, (input->size[1] >= oheight) && (input->size[2] >= owidth), 2, "upsampling not supported"); // dims int ichannels = input->size[0]; @@ -58,10 +57,10 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L) for(i = 0; i < oheight; i++) { for(j = 0; j < owidth; j++) { // compute nearest offsets - long ixs = (long)(j*dW+0.5); - long iys = (long)(i*dH+0.5); - long ixe = (long)((j+1)*dW+0.5); - long iye = (long)((i+1)*dH+0.5); + long ixs = (long)(j*dW); + long iys = (long)(i*dH); + long ixe = MAX(ixs+1, (long)((j+1)*dW)); + long iye = MAX(iys+1, (long)((i+1)*dH)); // local pointers real *op = output_p + i*owidth + j; @@ -144,8 +143,8 @@ static int nn_(SpatialMaxSampling_backward)(lua_State *L) for(i = 0; i < oheight; i++) { for(j = 0; j < owidth; j++) { // compute nearest offsets - long iys = (long)(i*dH+0.5); - long ixs = (long)(j*dW+0.5); + long iys = (long)(i*dH); + long ixs = (long)(j*dW); // retrieve position of max real *indyp = indy_p + i*owidth + j; diff --git a/test/test-all.lua b/test/test-all.lua index 2dd7b26..d0b071d 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -94,8 +94,8 @@ end function nnxtest.SpatialMaxSampling() local fanin = math.random(1,4) - local sizex = math.random(8,16) - local sizey = math.random(8,16) + local sizex = math.random(1,16) + local sizey = math.random(1,16) local osizex = math.random(2,8) local osizey = math.random(2,8) local module = nn.SpatialMaxSampling(osizex,osizey) |