Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-07-31 15:24:52 +0300
committersoumith <soumith@fb.com>2015-07-31 15:24:52 +0300
commit9f8b17b088088b0978ed6fcac52ffd95f646acec (patch)
tree7f7e7ba9e1ba17b64e12f3814b0c29349ecfa25d
parentb6636d9014e36a58926962200067f1c2abab0ef9 (diff)
fixing non-contiguous bug in SpatialAveragePooling
-rw-r--r--generic/SpatialAveragePooling.c8
1 files changed, 8 insertions, 0 deletions
diff --git a/generic/SpatialAveragePooling.c b/generic/SpatialAveragePooling.c
index 681938e..0a3f782 100644
--- a/generic/SpatialAveragePooling.c
+++ b/generic/SpatialAveragePooling.c
@@ -51,6 +51,7 @@ static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L)
THTensor_(resize4d)(output, input->size[0], nInputPlane, outputHeight, outputWidth);
input = THTensor_(newContiguous)(input);
+ luaL_argcheck(L, THTensor_(isContiguous)(output), 1, "");
input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
@@ -135,6 +136,11 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L)
input_data = THTensor_(data)(input);
THTensor_(resizeAs)(gradInput, input);
+
+ input = THTensor_(newContiguous)(input);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ luaL_argcheck(L, THTensor_(isContiguous)(gradInput), 1, "");
+
gradInput_data = THTensor_(data)(gradInput);
gradOutput_data = THTensor_(data)(gradOutput);
@@ -171,6 +177,8 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L)
}
}
+ THTensor_(free)(input);
+ THTensor_(free)(gradOutput);
return 1;
}