From fa185dd3b5841a27726d63926b54bec37a95604b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 30 May 2017 09:34:50 -0700 Subject: More performant fix for fused rnn kernels (#1532) and bugfix for #1721 --- lib/THNN/generic/FusedRNNKernel.c | 16 +++++++++------- lib/THNN/generic/THNN.h | 16 +++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/lib/THNN/generic/FusedRNNKernel.c b/lib/THNN/generic/FusedRNNKernel.c index 6126e86..30788b0 100644 --- a/lib/THNN/generic/FusedRNNKernel.c +++ b/lib/THNN/generic/FusedRNNKernel.c @@ -9,17 +9,19 @@ void THNN_(GRUFused_updateOutput)( THTensor *bias1, THTensor *bias2, THTensor *hx, - THTensor *hy) + THTensor *hy, + THTensor *storage) { THAssertMsg(false, "Not implemented for CPU"); } void THNN_(GRUFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *gradInInput, + THTensor *gradInHidden, THTensor *gradOutput, - THTensor *gradInput) + THTensor *gradInputHx, + THTensor *storage) { THAssertMsg(false, "Not implemented for CPU"); } @@ -39,13 +41,13 @@ void THNN_(LSTMFused_updateOutput)( void THNN_(LSTMFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *storage, + THTensor *gradInGates, THTensor *prevC, THTensor *cy, THTensor *gradOutput, THTensor *gradOutputCell, - THTensor *gradInput) + THTensor *gradInputCx) { THAssertMsg(false, "Not implemented for CPU"); } diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index b9fd709..d99b43b 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -177,13 +177,15 @@ TH_API void THNN_(GRUFused_updateOutput)( THTensor *bias1, // [OPTIONAL] THTensor *bias2, // [OPTIONAL] THTensor *hx, - THTensor *output); + THTensor *output, + THTensor *storage); TH_API void THNN_(GRUFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *gradInInput, + THTensor *gradInHidden, THTensor *gradOutput, - THTensor *gradInput); + THTensor *gradInputHx, + THTensor *storage); TH_API void THNN_(LSTMFused_updateOutput)( THNNState *state, @@ -196,13 +198,13 @@ TH_API void THNN_(LSTMFused_updateOutput)( THTensor *outputCell); TH_API void THNN_(LSTMFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *storage, + THTensor *gradInGates, THTensor *cx, THTensor *cy, THTensor *gradOutput, THTensor *gradOutputCell, - THTensor *gradInput); + THTensor *gradInputCx); TH_API void THNN_(LogSigmoid_updateOutput)( THNNState *state, // library's state -- cgit v1.2.3