diff options
author | soumith <soumith@fb.com> | 2015-03-20 03:04:59 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-04-07 23:51:42 +0300 |
commit | ab09f77b32119e0c2de49572c8c856c81363c2a0 (patch) | |
tree | f444078b84beca2db77c90fa16c83bf45f1a12c9 /generic | |
parent | a7b5dcd4ec6e0b38f2c88c84e757eb42081e2aca (diff) |
adds in-place ReLU and fixes a potential divide-by-zero in nn.Sqrt
Diffstat (limited to 'generic')
-rw-r--r-- | generic/Sqrt.c | 9 | ||||
-rw-r--r-- | generic/Threshold.c | 37 |
2 files changed, 35 insertions, 11 deletions
diff --git a/generic/Sqrt.c b/generic/Sqrt.c index 0e7cbd7..f3b6d98 100644 --- a/generic/Sqrt.c +++ b/generic/Sqrt.c @@ -42,7 +42,8 @@ static int nn_(Sqrt_updateGradInput)(lua_State *L) !THTensor_(isContiguous)(gradInput)) { TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \ - *gradInput_data = 0.5 * (*gradOutput_data / *output_data);); + *gradInput_data = ((*output_data == 0.0) ? 0.0 : \ + (0.5 * (*gradOutput_data / *output_data)));); } else { @@ -52,7 +53,11 @@ static int nn_(Sqrt_updateGradInput)(lua_State *L) long i; #pragma omp parallel for private(i) for(i = 0; i < THTensor_(nElement)(output); i++) - gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]); + if (output_data[i] == 0.0) { + gradInput_data[i] = 0.0; + } else { + gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]); + } } return 1; } diff --git a/generic/Threshold.c b/generic/Threshold.c index f21b615..a309f78 100644 --- a/generic/Threshold.c +++ b/generic/Threshold.c @@ -8,10 +8,20 @@ static int nn_(Threshold_updateOutput)(lua_State *L) real val = luaT_getfieldchecknumber(L, 1, "val"); real threshold = luaT_getfieldchecknumber(L, 1, "threshold"); THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); - - THTensor_(resizeAs)(output, input); - TH_TENSOR_APPLY2(real, output, real, input, \ - *output_data = (*input_data > threshold) ? *input_data : val;); + int inPlace = luaT_getfieldcheckboolean(L, 1, "inplace"); + + if (inPlace) { + TH_TENSOR_APPLY(real, input, \ + if (*input_data <= threshold) { \ + *input_data = val; \ + }); + THTensor_(set)(output, input); + } else { + THTensor_(resizeAs)(output, input); + TH_TENSOR_APPLY2(real, output, real, input, \ + *output_data = (*input_data > threshold) ? *input_data : val;); + + } return 1; } @@ -22,12 +32,21 @@ static int nn_(Threshold_updateGradInput)(lua_State *L) THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor); real threshold = luaT_getfieldchecknumber(L, 1, "threshold"); THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor); + int inPlace = luaT_getfieldcheckboolean(L, 1, "inplace"); + + if (inPlace) { + TH_TENSOR_APPLY2(real, gradOutput, real, input, \ + if ((*input_data) <= threshold) { \ + *gradOutput_data = 0; \ + }); + THTensor_(set)(gradInput, gradOutput); + } else { + THTensor_(resizeAs)(gradInput, input); + TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \ + if ((*input_data) > threshold) *gradInput_data = *gradOutput_data; \ + else *gradInput_data = 0;); \ + } - THTensor_(resizeAs)(gradInput, input); - TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \ - if ((*input_data) > threshold) *gradInput_data = *gradOutput_data; \ - else *gradInput_data = 0;); \ - return 1; } |