Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-03-20 03:04:59 +0300
committersoumith <soumith@fb.com>2015-04-07 23:51:42 +0300
commitab09f77b32119e0c2de49572c8c856c81363c2a0 (patch)
treef444078b84beca2db77c90fa16c83bf45f1a12c9 /generic
parenta7b5dcd4ec6e0b38f2c88c84e757eb42081e2aca (diff)
adds in-place ReLU and fixes a potential divide-by-zero in nn.Sqrt
Diffstat (limited to 'generic')
-rw-r--r--generic/Sqrt.c9
-rw-r--r--generic/Threshold.c37
2 files changed, 35 insertions, 11 deletions
diff --git a/generic/Sqrt.c b/generic/Sqrt.c
index 0e7cbd7..f3b6d98 100644
--- a/generic/Sqrt.c
+++ b/generic/Sqrt.c
@@ -42,7 +42,8 @@ static int nn_(Sqrt_updateGradInput)(lua_State *L)
!THTensor_(isContiguous)(gradInput))
{
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
- *gradInput_data = 0.5 * (*gradOutput_data / *output_data););
+ *gradInput_data = ((*output_data == 0.0) ? 0.0 : \
+ (0.5 * (*gradOutput_data / *output_data))););
}
else
{
@@ -52,7 +53,11 @@ static int nn_(Sqrt_updateGradInput)(lua_State *L)
long i;
#pragma omp parallel for private(i)
for(i = 0; i < THTensor_(nElement)(output); i++)
- gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]);
+ if (output_data[i] == 0.0) {
+ gradInput_data[i] = 0.0;
+ } else {
+ gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]);
+ }
}
return 1;
}
diff --git a/generic/Threshold.c b/generic/Threshold.c
index f21b615..a309f78 100644
--- a/generic/Threshold.c
+++ b/generic/Threshold.c
@@ -8,10 +8,20 @@ static int nn_(Threshold_updateOutput)(lua_State *L)
real val = luaT_getfieldchecknumber(L, 1, "val");
real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
-
- THTensor_(resizeAs)(output, input);
- TH_TENSOR_APPLY2(real, output, real, input, \
- *output_data = (*input_data > threshold) ? *input_data : val;);
+ int inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ if (inPlace) {
+ TH_TENSOR_APPLY(real, input, \
+ if (*input_data <= threshold) { \
+ *input_data = val; \
+ });
+ THTensor_(set)(output, input);
+ } else {
+ THTensor_(resizeAs)(output, input);
+ TH_TENSOR_APPLY2(real, output, real, input, \
+ *output_data = (*input_data > threshold) ? *input_data : val;);
+
+ }
return 1;
}
@@ -22,12 +32,21 @@ static int nn_(Threshold_updateGradInput)(lua_State *L)
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
real threshold = luaT_getfieldchecknumber(L, 1, "threshold");
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+ int inPlace = luaT_getfieldcheckboolean(L, 1, "inplace");
+
+ if (inPlace) {
+ TH_TENSOR_APPLY2(real, gradOutput, real, input, \
+ if ((*input_data) <= threshold) { \
+ *gradOutput_data = 0; \
+ });
+ THTensor_(set)(gradInput, gradOutput);
+ } else {
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
+ if ((*input_data) > threshold) *gradInput_data = *gradOutput_data; \
+ else *gradInput_data = 0;); \
+ }
- THTensor_(resizeAs)(gradInput, input);
- TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
- if ((*input_data) > threshold) *gradInput_data = *gradOutput_data; \
- else *gradInput_data = 0;); \
-
return 1;
}