diff options
author | soumith <soumith@fb.com> | 2015-07-31 15:24:52 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-07-31 15:24:52 +0300 |
commit | 9f8b17b088088b0978ed6fcac52ffd95f646acec (patch) | |
tree | 7f7e7ba9e1ba17b64e12f3814b0c29349ecfa25d | |
parent | b6636d9014e36a58926962200067f1c2abab0ef9 (diff) |
fixing non-contiguous bug in SpatialAveragePooling
-rw-r--r-- | generic/SpatialAveragePooling.c | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/generic/SpatialAveragePooling.c b/generic/SpatialAveragePooling.c index 681938e..0a3f782 100644 --- a/generic/SpatialAveragePooling.c +++ b/generic/SpatialAveragePooling.c @@ -51,6 +51,7 @@ static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L) THTensor_(resize4d)(output, input->size[0], nInputPlane, outputHeight, outputWidth); input = THTensor_(newContiguous)(input); + luaL_argcheck(L, THTensor_(isContiguous)(output), 1, ""); input_data = THTensor_(data)(input); output_data = THTensor_(data)(output); @@ -135,6 +136,11 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L) input_data = THTensor_(data)(input); THTensor_(resizeAs)(gradInput, input); + + input = THTensor_(newContiguous)(input); + gradOutput = THTensor_(newContiguous)(gradOutput); + luaL_argcheck(L, THTensor_(isContiguous)(gradInput), 1, ""); + gradInput_data = THTensor_(data)(gradInput); gradOutput_data = THTensor_(data)(gradOutput); @@ -171,6 +177,8 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L) } } + THTensor_(free)(input); + THTensor_(free)(gradOutput); return 1; } |