Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-10-02 02:44:22 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-10-02 02:44:22 +0400
commitc8406078fc871a8c1e0b7b0c967053859bf9d9f7 (patch)
tree4f01faab8d65b279a07c80cabd65d743d1ed3982
parent2c6beea7dbfa04e17be5f5b7286a49c4cd137dbd (diff)
Fixed SpatialMaxSampling to accept arbitrary inputs (allowing upsampling)
-rw-r--r--generic/SpatialMaxSampling.c13
-rw-r--r--test/test-all.lua4
2 files changed, 8 insertions, 9 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c
index d53b364..5b795b3 100644
--- a/generic/SpatialMaxSampling.c
+++ b/generic/SpatialMaxSampling.c
@@ -20,7 +20,6 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L)
// check dims
luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
- luaL_argcheck(L, (input->size[1] >= oheight) && (input->size[2] >= owidth), 2, "upsampling not supported");
// dims
int ichannels = input->size[0];
@@ -58,10 +57,10 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L)
for(i = 0; i < oheight; i++) {
for(j = 0; j < owidth; j++) {
// compute nearest offsets
- long ixs = (long)(j*dW+0.5);
- long iys = (long)(i*dH+0.5);
- long ixe = (long)((j+1)*dW+0.5);
- long iye = (long)((i+1)*dH+0.5);
+ long ixs = (long)(j*dW);
+ long iys = (long)(i*dH);
+ long ixe = MAX(ixs+1, (long)((j+1)*dW));
+ long iye = MAX(iys+1, (long)((i+1)*dH));
// local pointers
real *op = output_p + i*owidth + j;
@@ -144,8 +143,8 @@ static int nn_(SpatialMaxSampling_backward)(lua_State *L)
for(i = 0; i < oheight; i++) {
for(j = 0; j < owidth; j++) {
// compute nearest offsets
- long iys = (long)(i*dH+0.5);
- long ixs = (long)(j*dW+0.5);
+ long iys = (long)(i*dH);
+ long ixs = (long)(j*dW);
// retrieve position of max
real *indyp = indy_p + i*owidth + j;
diff --git a/test/test-all.lua b/test/test-all.lua
index 2dd7b26..d0b071d 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -94,8 +94,8 @@ end
function nnxtest.SpatialMaxSampling()
local fanin = math.random(1,4)
- local sizex = math.random(8,16)
- local sizey = math.random(8,16)
+ local sizex = math.random(1,16)
+ local sizey = math.random(1,16)
local osizex = math.random(2,8)
local osizey = math.random(2,8)
local module = nn.SpatialMaxSampling(osizex,osizey)