Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-05-30 19:34:50 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-07 06:13:07 +0300
commitfa185dd3b5841a27726d63926b54bec37a95604b (patch)
treec47e5dafa6784b5c8b9eff7ea8befb48e4b55eef
parent014d529f20768fc324a8eaee26de4d7bfcd8e5c1 (diff)
More performant fix for fused rnn kernels (#1532) and bugfix for #1721
-rw-r--r--lib/THNN/generic/FusedRNNKernel.c16
-rw-r--r--lib/THNN/generic/THNN.h16
2 files changed, 18 insertions, 14 deletions
diff --git a/lib/THNN/generic/FusedRNNKernel.c b/lib/THNN/generic/FusedRNNKernel.c
index 6126e86..30788b0 100644
--- a/lib/THNN/generic/FusedRNNKernel.c
+++ b/lib/THNN/generic/FusedRNNKernel.c
@@ -9,17 +9,19 @@ void THNN_(GRUFused_updateOutput)(
THTensor *bias1,
THTensor *bias2,
THTensor *hx,
- THTensor *hy)
+ THTensor *hy,
+ THTensor *storage)
{
THAssertMsg(false, "Not implemented for CPU");
}
void THNN_(GRUFused_updateGradInput)(
THNNState *state,
- THTensor *input,
- THTensor *hidden,
+ THTensor *gradInInput,
+ THTensor *gradInHidden,
THTensor *gradOutput,
- THTensor *gradInput)
+ THTensor *gradInputHx,
+ THTensor *storage)
{
THAssertMsg(false, "Not implemented for CPU");
}
@@ -39,13 +41,13 @@ void THNN_(LSTMFused_updateOutput)(
void THNN_(LSTMFused_updateGradInput)(
THNNState *state,
- THTensor *input,
- THTensor *hidden,
+ THTensor *storage,
+ THTensor *gradInGates,
THTensor *prevC,
THTensor *cy,
THTensor *gradOutput,
THTensor *gradOutputCell,
- THTensor *gradInput)
+ THTensor *gradInputCx)
{
THAssertMsg(false, "Not implemented for CPU");
}
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index b9fd709..d99b43b 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -177,13 +177,15 @@ TH_API void THNN_(GRUFused_updateOutput)(
THTensor *bias1, // [OPTIONAL]
THTensor *bias2, // [OPTIONAL]
THTensor *hx,
- THTensor *output);
+ THTensor *output,
+ THTensor *storage);
TH_API void THNN_(GRUFused_updateGradInput)(
THNNState *state,
- THTensor *input,
- THTensor *hidden,
+ THTensor *gradInInput,
+ THTensor *gradInHidden,
THTensor *gradOutput,
- THTensor *gradInput);
+ THTensor *gradInputHx,
+ THTensor *storage);
TH_API void THNN_(LSTMFused_updateOutput)(
THNNState *state,
@@ -196,13 +198,13 @@ TH_API void THNN_(LSTMFused_updateOutput)(
THTensor *outputCell);
TH_API void THNN_(LSTMFused_updateGradInput)(
THNNState *state,
- THTensor *input,
- THTensor *hidden,
+ THTensor *storage,
+ THTensor *gradInGates,
THTensor *cx,
THTensor *cy,
THTensor *gradOutput,
THTensor *gradOutputCell,
- THTensor *gradInput);
+ THTensor *gradInputCx);
TH_API void THNN_(LogSigmoid_updateOutput)(
THNNState *state, // library's state