From cd13dead56a74a0fdb3522ba42f12c36a75ba9b0 Mon Sep 17 00:00:00 2001 From: Alykhan Tejani Date: Mon, 24 Jul 2017 08:58:43 -1000 Subject: Add numerically stable logsigmoid --- lib/THNN/generic/LogSigmoid.c | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/lib/THNN/generic/LogSigmoid.c b/lib/THNN/generic/LogSigmoid.c index 651d560..5569f1f 100644 --- a/lib/THNN/generic/LogSigmoid.c +++ b/lib/THNN/generic/LogSigmoid.c @@ -10,11 +10,12 @@ void THNN_(LogSigmoid_updateOutput)( { THTensor_(resizeAs)(output, input); THTensor_(resizeAs)(buffer, input); - + //Use the LogSumExp trick to make this stable against overflow TH_TENSOR_APPLY3(real, output, real, input, real, buffer, - real z = exp(-*input_data); + real max_elem = fmax(0, -*input_data); + real z = exp(-max_elem) + exp(-*input_data - max_elem); *buffer_data = z; - *output_data = -log(1. + z); + *output_data = -(max_elem + log(z)); ); } @@ -27,10 +28,24 @@ void THNN_(LogSigmoid_updateGradInput)( { THNN_CHECK_NELEMENT(input, gradOutput); THTensor_(resizeAs)(gradInput, buffer); - TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, buffer, +/* deriv of -max(0,-x) - log(e(0 - max(0,-x)) + e(-x - max(0,-x)) is + * -max_deriv - (-max_deriv*e(0-max(0,-x)) + (-1 - max_deriv)*e(-x - max(0,-x)))/z + * where z = e(0 - max(0,-x)) + e(-x - max(0,-x)) + * which simplifies to + * -max_deriv - (z-1)/z if x is >= 0 or + * -max_deriv + (z-1)/z if x is < 0 + */ + TH_TENSOR_APPLY3(real, input, real, gradInput, real, buffer, real z = *buffer_data; - *gradInput_data = *gradOutput_data * z / (1. + z); - ); + real max_deriv = 0.0; + real sign = -1.0; + if (*input_data < 0){ + max_deriv = -1.0; + sign = 1.0; + } + *gradInput_data = -max_deriv - sign*((z - 1.0)/ z); + ); + THTensor_(cmul)(gradInput, gradOutput, gradInput); } #endif -- cgit v1.2.3