From a30f4bf6e035e5241a20a9c9d83251f8304de9f2 Mon Sep 17 00:00:00 2001 From: Sergio Gomez Date: Fri, 19 Feb 2016 14:28:01 +0000 Subject: Extra checks and cleanup in C code This fixes a memory leak in SpatialLinear (added a missing THTensor_free()) --- SpatialDownSampling.lua | 5 +++++ SpatialPadding.lua | 2 ++ SpatialReSampling.lua | 3 +++ generic/SoftMaxTree.c | 7 +++---- generic/SpatialDownSampling.c | 2 ++ generic/SpatialLinear.c | 1 + test/test-all.lua | 4 ++-- 7 files changed, 18 insertions(+), 6 deletions(-) diff --git a/SpatialDownSampling.lua b/SpatialDownSampling.lua index b18849f..2aa4216 100644 --- a/SpatialDownSampling.lua +++ b/SpatialDownSampling.lua @@ -26,6 +26,11 @@ function SpatialDownSampling:__init(...) end function SpatialDownSampling:updateOutput(input) + if (input:size(2) / self.rH) < 1 then + error('input too small in dimension 2') + elseif (input:size(3) / self.rW) < 1 then + error('input too small in dimension 3') + end self.output:resize(input:size(1), math.floor(input:size(2) / self.rH), math.floor(input:size(3) / self.rW)) input.nn.SpatialDownSampling_updateOutput(self, input) diff --git a/SpatialPadding.lua b/SpatialPadding.lua index 07bacf8..9c59e9a 100644 --- a/SpatialPadding.lua +++ b/SpatialPadding.lua @@ -22,6 +22,8 @@ function SpatialPadding:__init(pad_l, pad_r, pad_t, pad_b, y_dim, x_dim, val) self.pad_b = pad_b or self.pad_l self.x_dim = x_dim or 3 self.y_dim = y_dim or 2 + if (self.x_dim % 1) ~= 0 then error('x_dim must be integer') end + if (self.y_dim % 1) ~= 0 then error('y_dim must be integer') end self.val = val or 0 end diff --git a/SpatialReSampling.lua b/SpatialReSampling.lua index 7324098..e83f8f2 100644 --- a/SpatialReSampling.lua +++ b/SpatialReSampling.lua @@ -31,6 +31,9 @@ function SpatialReSampling:__init(...) end function SpatialReSampling:updateOutput(input) + assert(input:dim() == 3 or input:dim() == 4, + 'input to SpatialReSampling must be 3D or 4D, received: [' .. + table.concat(input:size():totable(), ', ') .. ']') local hDim, wDim = 2, 3 if input:dim() == 4 then hDim, wDim = 3, 4 diff --git a/generic/SoftMaxTree.c b/generic/SoftMaxTree.c index 9e66a58..074c74e 100644 --- a/generic/SoftMaxTree.c +++ b/generic/SoftMaxTree.c @@ -126,13 +126,12 @@ static int nn_(SoftMaxTree_updateGradInput)(lua_State *L) THTensor *logsoftOutput = luaT_getfieldcheckudata(L, 1, "_multiBuffer", torch_Tensor); THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor); - THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "_gradInput", torch_Tensor); THIntTensor *node; THTensor *nodeWeight, *nodeOutput; - THTensor *nodeGradInput, *nodeGradOutput, *weightTranspose; - real *gradInput_data, *output_data; + THTensor *nodeGradInput, *weightTranspose; + real *output_data; long i, d; @@ -231,7 +230,7 @@ static int nn_(SoftMaxTree_accGradParameters)(lua_State *L) THIntTensor *node; THTensor *nodeGradWeight, *nodeGradBias, *nodeInput, *nodeGradOutput; - long i, d; + long i; luaL_argcheck(L, input->nDimension == 2, 2, "2D(batch mode) tensor expected"); luaL_argcheck(L, input->size[1] == inputSize, 2, "invalid input size"); diff --git a/generic/SpatialDownSampling.c b/generic/SpatialDownSampling.c index e5c3c47..4e1fc5a 100644 --- a/generic/SpatialDownSampling.c +++ b/generic/SpatialDownSampling.c @@ -48,6 +48,8 @@ static int nn_(SpatialDownSampling_updateGradInput)(lua_State *L) { int rW = luaT_getfieldcheckint(L, 1, "rW"); int rH = luaT_getfieldcheckint(L, 1, "rH"); + THArgCheck(gradOutput->nDimension == 3, 2, "gradOutput must be 3D Tensor"); + // dims int owidth = gradOutput->size[2]; int oheight = gradOutput->size[1]; diff --git a/generic/SpatialLinear.c b/generic/SpatialLinear.c index 8a8b756..113eb85 100644 --- a/generic/SpatialLinear.c +++ b/generic/SpatialLinear.c @@ -105,6 +105,7 @@ static int nn_(SpatialLinear_updateGradInput)(lua_State *L) THTensor_(free)(gradOutput_y); THTensor_(free)(input_xy); THTensor_(free)(input_y); + THTensor_(free)(weight_t); return 1; } diff --git a/test/test-all.lua b/test/test-all.lua index 760d452..27efc37 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -137,8 +137,8 @@ end function nnxtest.SpatialDownSampling() local fanin = math.random(1,4) - local sizex = math.random(1,4) - local sizey = math.random(1,4) + local sizex = math.random(11,4) + local sizey = math.random(11,4) local mx = math.random(2,6) local my = math.random(2,6) local module = nn.SpatialDownSampling(mx,my) -- cgit v1.2.3