diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-05-30 19:34:50 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-07 06:13:07 +0300 |
commit | fa185dd3b5841a27726d63926b54bec37a95604b (patch) | |
tree | c47e5dafa6784b5c8b9eff7ea8befb48e4b55eef | |
parent | 014d529f20768fc324a8eaee26de4d7bfcd8e5c1 (diff) |
More performant fix for fused rnn kernels (#1532) and bugfix for #1721
-rw-r--r-- | lib/THNN/generic/FusedRNNKernel.c | 16 | ||||
-rw-r--r-- | lib/THNN/generic/THNN.h | 16 |
2 files changed, 18 insertions, 14 deletions
diff --git a/lib/THNN/generic/FusedRNNKernel.c b/lib/THNN/generic/FusedRNNKernel.c index 6126e86..30788b0 100644 --- a/lib/THNN/generic/FusedRNNKernel.c +++ b/lib/THNN/generic/FusedRNNKernel.c @@ -9,17 +9,19 @@ void THNN_(GRUFused_updateOutput)( THTensor *bias1, THTensor *bias2, THTensor *hx, - THTensor *hy) + THTensor *hy, + THTensor *storage) { THAssertMsg(false, "Not implemented for CPU"); } void THNN_(GRUFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *gradInInput, + THTensor *gradInHidden, THTensor *gradOutput, - THTensor *gradInput) + THTensor *gradInputHx, + THTensor *storage) { THAssertMsg(false, "Not implemented for CPU"); } @@ -39,13 +41,13 @@ void THNN_(LSTMFused_updateOutput)( void THNN_(LSTMFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *storage, + THTensor *gradInGates, THTensor *prevC, THTensor *cy, THTensor *gradOutput, THTensor *gradOutputCell, - THTensor *gradInput) + THTensor *gradInputCx) { THAssertMsg(false, "Not implemented for CPU"); } diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index b9fd709..d99b43b 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -177,13 +177,15 @@ TH_API void THNN_(GRUFused_updateOutput)( THTensor *bias1, // [OPTIONAL] THTensor *bias2, // [OPTIONAL] THTensor *hx, - THTensor *output); + THTensor *output, + THTensor *storage); TH_API void THNN_(GRUFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *gradInInput, + THTensor *gradInHidden, THTensor *gradOutput, - THTensor *gradInput); + THTensor *gradInputHx, + THTensor *storage); TH_API void THNN_(LSTMFused_updateOutput)( THNNState *state, @@ -196,13 +198,13 @@ TH_API void THNN_(LSTMFused_updateOutput)( THTensor *outputCell); TH_API void THNN_(LSTMFused_updateGradInput)( THNNState *state, - THTensor *input, - THTensor *hidden, + THTensor *storage, + THTensor *gradInGates, THTensor *cx, THTensor *cy, THTensor *gradOutput, THTensor *gradOutputCell, - THTensor *gradInput); + THTensor *gradInputCx); TH_API void THNN_(LogSigmoid_updateOutput)( THNNState *state, // library's state |