Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-06-07 06:20:14 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-07 06:20:14 +0300
commit8e3364c64aa85e7ab4970c8d5f6a4f18d0c3eeae (patch)
tree758c4db20b040f91947d2b6db5c52f378ae457e4
parent8c6df2a424680dfd3fa1109a2f8e3b62ac41680d (diff)
Remove clone in fused rnn
-rw-r--r--lib/THCUNN/generic/FusedRNNKernel.cu297
-rw-r--r--lib/THCUNN/generic/THCUNN.h13
2 files changed, 164 insertions, 146 deletions
diff --git a/lib/THCUNN/generic/FusedRNNKernel.cu b/lib/THCUNN/generic/FusedRNNKernel.cu
index 6aeba1e..91f11e8 100644
--- a/lib/THCUNN/generic/FusedRNNKernel.cu
+++ b/lib/THCUNN/generic/FusedRNNKernel.cu
@@ -85,12 +85,13 @@ template <typename T, typename IndexType, int Dims>
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
- THNN_(GRUForward)(TensorInfo<T, IndexType> Input,
+THNN_(GRUForward)(TensorInfo<T, IndexType> Input,
TensorInfo<T, IndexType> Hidden,
TensorInfo<T, IndexType> Bias1,
TensorInfo<T, IndexType> Bias2,
TensorInfo<T, IndexType> _hx,
TensorInfo<T, IndexType> _hy,
+ TensorInfo<T, IndexType> storage,
IndexType hsz,
IndexType totalElements)
{
@@ -101,16 +102,14 @@ __global__ void
IndexType offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
- T* ir = &DEVICE_LINEAR_GET(Input, offset+0*hsz);
- T* ii = &DEVICE_LINEAR_GET(Input, offset+1*hsz);
- T* in = &DEVICE_LINEAR_GET(Input, offset+2*hsz);
-
- T* hr = &DEVICE_LINEAR_GET(Hidden,offset+0*hsz);
- T* hi = &DEVICE_LINEAR_GET(Hidden,offset+1*hsz);
+ T ir = DEVICE_LINEAR_GET(Input, offset+0*hsz);
+ T ii = DEVICE_LINEAR_GET(Input, offset+1*hsz);
+ T in = DEVICE_LINEAR_GET(Input, offset+2*hsz);
+ T hr = DEVICE_LINEAR_GET(Hidden,offset+0*hsz);
+ T hi = DEVICE_LINEAR_GET(Hidden,offset+1*hsz);
T hn = DEVICE_LINEAR_GET(Hidden, offset+2*hsz);
T hx = DEVICE_LINEAR_GET(_hx, linearIndex);
-
T* hy = &DEVICE_LINEAR_GET(_hy, linearIndex);
bool has_bias = (Bias1.data != NULL);
@@ -136,45 +135,46 @@ __global__ void
}
+ offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
#ifndef THC_REAL_IS_HALF
T rg, ig, ng;
- rg = *ir + *hr + b1r + b2r;
- ig = *ii + *hi + b1i + b2i;
+ rg = ir + hr + b1r + b2r;
+ ig = ii + hi + b1i + b2i;
TensorSigmoidOp<real>()(&rg, &rg);
TensorSigmoidOp<real>()(&ig, &ig);
- ng = *in + b1n + rg * (hn + b2n);
+ ng = in + b1n + rg * (hn + b2n);
ng = THCNumerics<T>::tanh(ng);
*hy = ng + ig * (hx - ng);
//SAVE FOR BACKWARDS
- *ir = rg;
- *ii = ig;
- *in = ng;
- *hr = hx;
- *hi = hn + b2n;
+ DEVICE_LINEAR_GET(storage, offset+0*hsz) = rg;
+ DEVICE_LINEAR_GET(storage, offset+1*hsz) = ig;
+ DEVICE_LINEAR_GET(storage, offset+2*hsz) = ng;
+ DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx;
+ DEVICE_LINEAR_GET(storage, offset+4*hsz) = hn + b2n;
+
#else
float rg, ig, ng;
- rg = H2F(*ir) + H2F(*hr) + H2F(b1r) + H2F(b2r);
- ig = H2F(*ii) + H2F(*hi) + H2F(b1i) + H2F(b2i);
+ rg = H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r);
+ ig = H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i);
TensorSigmoidOp<float>()(&rg, &rg);
TensorSigmoidOp<float>()(&ig, &ig);
- ng = H2F(*in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) );
+ ng = H2F(in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) );
ng = THCNumerics<float>::tanh(ng);
*hy = F2H( ng + ig * ( H2F(hx)-ng ) );
//SAVE FOR BACKWARDS
- *ir = F2H(rg);
- *ii = F2H(ig);
- *in = F2H(ng);
- *hr = hx;
- *hi = F2H( H2F(hn) + H2F(b2n) );
-
+ DEVICE_LINEAR_GET(storage, offset+0*hsz) = F2H(rg);
+ DEVICE_LINEAR_GET(storage, offset+1*hsz) = F2H(ig);
+ DEVICE_LINEAR_GET(storage, offset+2*hsz) = F2H(ng);
+ DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx;
+ DEVICE_LINEAR_GET(storage, offset+4*hsz) = F2H(H2F(hn) + H2F(b2n));
#endif
}
}
@@ -184,63 +184,61 @@ template <typename T, typename IndexType, int Dims>
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
-THNN_(GRUBackward)(TensorInfo<T, IndexType> input,
- TensorInfo<T, IndexType> hidden,
- TensorInfo<T, IndexType> gradoutput,
- TensorInfo<T, IndexType> gradinput,
+THNN_(GRUBackward)(TensorInfo<T, IndexType> gradInInput,
+ TensorInfo<T, IndexType> gradInHidden,
+ TensorInfo<T, IndexType> gradOutput,
+ TensorInfo<T, IndexType> gradInputHx,
+ TensorInfo<T, IndexType> storage,
IndexType hsz,
IndexType totalElements)
{
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
- IndexType offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;;
+ IndexType offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
- //will return input grads here
- T* rg = &DEVICE_LINEAR_GET(input, offset+0*hsz);
- T* ig = &DEVICE_LINEAR_GET(input, offset+1*hsz);
- T* ng = &DEVICE_LINEAR_GET(input, offset+2*hsz);
- //will return hidden grads here
- T* hx = &DEVICE_LINEAR_GET(hidden, offset+0*hsz);
- T* hn = &DEVICE_LINEAR_GET(hidden, offset+1*hsz);
- T* oghn=&DEVICE_LINEAR_GET(hidden, offset+2*hsz);
+ T rg = DEVICE_LINEAR_GET(storage, offset+0*hsz);
+ T ig = DEVICE_LINEAR_GET(storage, offset+1*hsz);
+ T ng = DEVICE_LINEAR_GET(storage, offset+2*hsz);
+ T hx = DEVICE_LINEAR_GET(storage, offset+3*hsz);
+ T hn = DEVICE_LINEAR_GET(storage, offset+4*hsz);
- T* gi = &DEVICE_LINEAR_GET(gradinput, linearIndex);
+ T go = DEVICE_LINEAR_GET(gradOutput, linearIndex);
- T* go = &DEVICE_LINEAR_GET(gradoutput, linearIndex);
+ offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
#ifndef THC_REAL_IS_HALF
- T gig = (*go)*(*hx-*ng)*( 1-(*ig) )*(*ig);
- T ghx = (*go)*(*ig);
- T gin = (*go)*(1-*ig)*( 1-(*ng)*(*ng) );
- T ghn = (gin) * (*rg);
- T grg = (gin)*(*hn)*( 1-(*rg) )*(*rg);
+ T gig = go*(hx-ng)*(1-ig)*(ig);
+ T ghx = go*(ig);
+ T gin = go*(1-ig)*(1-ng*ng);
+ T ghn = gin *rg;
+ T grg = gin*hn*(1-rg)*rg;
- *gi = ghx;
+ DEVICE_LINEAR_GET(gradInputHx, linearIndex) = ghx;
- *rg = grg;
- *ig = gig;
- *ng = gin;
+ DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = grg;
+ DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = gig;
+ DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = gin;
- *hx = grg;
- *hn = gig;
- *oghn = ghn;
+ DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = grg;
+ DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = gig;
+ DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = ghn;
#else
- float gig = H2F(*go)*( H2F(*hx)-H2F(*ng) )*( 1-H2F(*ig) )*H2F(*ig);
- float ghx = H2F(*go)*H2F(*ig);
- float gin = H2F(*go)*( 1-H2F(*ig) )*( 1-H2F(*ng)*H2F(*ng) );
- float ghn = H2F(gin) * H2F(*rg);
- float grg = H2F(gin)*H2F(*hn)*( 1-H2F(*rg) )*H2F(*rg);
+ float gig = H2F(go)*( H2F(hx)-H2F(ng) )*( 1-H2F(ig) )*H2F(ig);
+ float ghx = H2F(go)*H2F(ig);
+ float gin = H2F(go)*( 1-H2F(ig) )*( 1-H2F(ng)*H2F(ng) );
+ float ghn = H2F(gin) * H2F(rg);
+ float grg = H2F(gin)*H2F(hn)*( 1-H2F(rg) )*H2F(rg);
- *gi = F2H(ghx);
+ DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = F2H(grg);
+ DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = F2H(gig);
+ DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = F2H(gin);
- *rg = F2H(grg);
- *ig = F2H(gig);
- *ng = F2H(gin);
+ DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = F2H(grg);
+ DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = F2H(gig);
+ DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = F2H(ghn);
+ DEVICE_LINEAR_GET(gradInputHx, linearIndex) = F2H(ghx);
- *hx = F2H(grg);
- *hn = F2H(gig);
- *oghn = F2H(ghn);
#endif
}
}
@@ -364,13 +362,13 @@ template <typename T, typename IndexType, int Dims>
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
- THNN_(LSTMBackward)(TensorInfo<T, IndexType> input,
- TensorInfo<T, IndexType> hidden,
+ THNN_(LSTMBackward)(TensorInfo<T, IndexType> storage,
+ TensorInfo<T, IndexType> gradInGates,
TensorInfo<T, IndexType> _cx,
TensorInfo<T, IndexType> _cy,
TensorInfo<T, IndexType> gradoutput,
TensorInfo<T, IndexType> gradoutputcell,
- TensorInfo<T, IndexType> gradinput,
+ TensorInfo<T, IndexType> gradInputCx,
IndexType hsz,
IndexType totalElements)
{
@@ -379,21 +377,21 @@ __global__ void
linearIndex += gridDim.x * blockDim.x) {
IndexType offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz;
- T ig = DEVICE_LINEAR_GET(input, offset+0*hsz);
- T fg = DEVICE_LINEAR_GET(input, offset+1*hsz);
- T cg = DEVICE_LINEAR_GET(input, offset+2*hsz);
- T og = DEVICE_LINEAR_GET(input, offset+3*hsz);
+ T ig = DEVICE_LINEAR_GET(storage, offset+0*hsz);
+ T fg = DEVICE_LINEAR_GET(storage, offset+1*hsz);
+ T cg = DEVICE_LINEAR_GET(storage, offset+2*hsz);
+ T og = DEVICE_LINEAR_GET(storage, offset+3*hsz);
- T* ih = &DEVICE_LINEAR_GET(hidden, offset+0*hsz);
- T* fh = &DEVICE_LINEAR_GET(hidden, offset+1*hsz);
- T* ch = &DEVICE_LINEAR_GET(hidden, offset+2*hsz);
- T* oh = &DEVICE_LINEAR_GET(hidden, offset+3*hsz);
+ T* ih = &DEVICE_LINEAR_GET(gradInGates, offset+0*hsz);
+ T* fh = &DEVICE_LINEAR_GET(gradInGates, offset+1*hsz);
+ T* ch = &DEVICE_LINEAR_GET(gradInGates, offset+2*hsz);
+ T* oh = &DEVICE_LINEAR_GET(gradInGates, offset+3*hsz);
//will return hidden grads here
T cx = DEVICE_LINEAR_GET(_cx, linearIndex);
T cy = DEVICE_LINEAR_GET(_cy, linearIndex);
- T* gi = &DEVICE_LINEAR_GET(gradinput, linearIndex);
+ T* gi = &DEVICE_LINEAR_GET(gradInputCx, linearIndex);
T go = DEVICE_LINEAR_GET(gradoutput, linearIndex);
T goc= DEVICE_LINEAR_GET(gradoutputcell, linearIndex);
@@ -474,19 +472,20 @@ __global__ void
#define LSTM_BACKWARD(ITYPE, DIM) THNN_(LSTMBackward) \
<DATATYPE, ITYPE, DIM> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>> \
- (inputI, hiddenI, cxI, cyI, \
- gradoutI, gradoutcI, gradinI, \
+ (storageI, gradingatesI, cxI, cyI, \
+ gradoutI, gradoutcI, gradincxI, \
hid_size, totalElements);
#define GRU_FORWARD(ITYPE, DIM) THNN_(GRUForward)<DATATYPE, ITYPE, DIM> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>> \
- (inputI, hiddenI, bias1I, bias2I, hxI, hyI, \
+ (inputI, hiddenI, bias1I, bias2I, hxI, hyI, storageI, \
hid_size, totalElements);
#define GRU_BACKWARD(ITYPE, DIM) THNN_(GRUBackward) \
<DATATYPE, ITYPE, DIM> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>> \
- (inputI, hiddenI, gradoutI, gradinI, hid_size, totalElements);
+ (gradininputI, gradinhiddenI, gradoutI, gradinhxI, storageI, \
+ hid_size, totalElements);
// ************ END Create actual function calls ************ //
@@ -602,17 +601,17 @@ void THNN_(LSTMFused_updateOutput)(
template<typename INDTYPE>
void THNN_(LSTM_back_ind_wrap)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *storage,
+ THCTensor *gradInGates,
THCTensor *cx,
THCTensor *cy,
THCTensor *gradOutput,
THCTensor *gradOutputCell,
- THCTensor *gradInput)
+ THCTensor *gradInputCx)
{
int maxDim = THNN_(minIndexType)
- (state, 7, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ (state, 7, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, gradOutput);
const dim3 block = getApplyBlock();
@@ -620,10 +619,10 @@ void THNN_(LSTM_back_ind_wrap)(
THAssertMsg(getApplyGrid(state, totalElements, grid),
"Could not get grid size for pointwise apply");
- TensorInfo<DATATYPE, INDTYPE> inputI =
- getTensorInfo<THCTensor, INDTYPE>(state, input);
- TensorInfo<DATATYPE, INDTYPE> hiddenI =
- getTensorInfo<THCTensor, INDTYPE>(state, hidden);
+ TensorInfo<DATATYPE, INDTYPE> storageI =
+ getTensorInfo<THCTensor, INDTYPE>(state, storage);
+ TensorInfo<DATATYPE, INDTYPE> gradingatesI =
+ getTensorInfo<THCTensor, INDTYPE>(state, gradInGates);
TensorInfo<DATATYPE, INDTYPE> cxI =
getTensorInfo<THCTensor, INDTYPE>(state, cx);
TensorInfo<DATATYPE, INDTYPE> cyI =
@@ -632,19 +631,19 @@ void THNN_(LSTM_back_ind_wrap)(
getTensorInfo<THCTensor, INDTYPE>(state, gradOutput);
TensorInfo<DATATYPE, INDTYPE> gradoutcI =
getTensorInfo<THCTensor, INDTYPE>(state, gradOutputCell);
- TensorInfo<DATATYPE, INDTYPE> gradinI =
- getTensorInfo<THCTensor, INDTYPE>(state, gradInput);
+ TensorInfo<DATATYPE, INDTYPE> gradincxI =
+ getTensorInfo<THCTensor, INDTYPE>(state, gradInputCx);
INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1];
if(maxDim == -2){
- inputI.collapseDims();
- hiddenI.collapseDims();
+ storageI.collapseDims();
+ gradingatesI.collapseDims();
cxI.collapseDims();
cyI.collapseDims();
gradoutI.collapseDims();
gradoutcI.collapseDims();
- gradinI.collapseDims();
+ gradincxI.collapseDims();
}
FILL_DIM(INDTYPE, maxDim, LSTM_BACKWARD);
@@ -652,33 +651,33 @@ void THNN_(LSTM_back_ind_wrap)(
void THNN_(LSTMFused_updateGradInput)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *storage,
+ THCTensor *gradInGates,
THCTensor *cx,
THCTensor *cy,
THCTensor *gradOutput,
THCTensor *gradOutputCell,
- THCTensor *gradInput)
+ THCTensor *gradInputCx)
{
- THCTensor_(resizeAs)(state, gradInput, gradOutput);
- THCUNN_assertSameGPU(state, 7, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ THCTensor_(resizeAs)(state, gradInputCx, gradOutput);
+ THCUNN_assertSameGPU(state, 7, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
THNN_(FusedRNNAssertSizes)
- (state, 4, 7, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ (state, 4, 7, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
bool canUse32bi = THNN_(canUse32BitIndexMath)
- (state, 7, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ (state, 7, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
if(canUse32bi){
THNN_(LSTM_back_ind_wrap)<unsigned int>
- (state, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ (state, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
}else{
THNN_(LSTM_back_ind_wrap)<unsigned long>
- (state, input, hidden, cx, cy,
- gradOutput, gradOutputCell, gradInput);
+ (state, storage, gradInGates, cx, cy,
+ gradOutput, gradOutputCell, gradInputCx);
}
THCudaCheck(cudaGetLastError());
}
@@ -691,21 +690,22 @@ void THNN_(GRU_forw_ind_wrap)(
THCTensor *bias1,
THCTensor *bias2,
THCTensor *hx,
- THCTensor *hy)
+ THCTensor *hy,
+ THCTensor *storage)
{
bool has_bias = (bias1!=NULL);
int maxDim;
if(has_bias){
THCUNN_assertSameGPU
- (state, 6, input, hidden, hx, hy, bias1, bias2);
+ (state, 7, input, hidden, hx, hy, bias1, bias2, storage);
maxDim = THNN_(minIndexType)
- (state, 6, input, hidden, hx, hy, bias1, bias2);
+ (state, 7, input, hidden, hx, hy, bias1, bias2, storage);
}else{
THCUNN_assertSameGPU
- (state, 4, input, hidden, hx, hy);
+ (state, 5, input, hidden, hx, hy, storage);
maxDim = THNN_(minIndexType)
- (state, 4, input, hidden, hx, hy);
+ (state, 5, input, hidden, hx, hy, storage);
}
ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, hx);
@@ -723,6 +723,8 @@ void THNN_(GRU_forw_ind_wrap)(
getTensorInfo<THCTensor, INDTYPE>(state, hx);
TensorInfo<DATATYPE, INDTYPE> hyI =
getTensorInfo<THCTensor, INDTYPE>(state, hy);
+ TensorInfo<DATATYPE, INDTYPE> storageI =
+ getTensorInfo<THCTensor, INDTYPE>(state, storage);
INDTYPE hid_size = hxI.sizes[hxI.dims-1];
if(has_bias){
@@ -736,7 +738,9 @@ void THNN_(GRU_forw_ind_wrap)(
hiddenI.collapseDims();
hyI.collapseDims();
hxI.collapseDims();
+ storageI.collapseDims();
}
+
INDTYPE zero[1] = {0};
TensorInfo<DATATYPE, INDTYPE> nullinfo =
TensorInfo<DATATYPE, INDTYPE>(NULL, 1, zero, zero);
@@ -763,28 +767,33 @@ void THNN_(GRUFused_updateOutput)(
THCTensor *bias1,
THCTensor *bias2,
THCTensor *hx,
- THCTensor *hy)
+ THCTensor *hy,
+ THCTensor *storage)
{
THCTensor_(resizeAs)(state, hy, hx);
THNN_(FusedRNNAssertSizes)(state, 3, 4, input, hidden, hx, hy);
+ THArgCheck(THCTensor_(nElement)(state, storage) ==
+ THCTensor_(nElement)(state, hx)*5,
+ 3, "Storage tensor for fused kernel was not sized correctly.");
+
bool has_bias = (bias1!=NULL);
bool canUse32bi;
if(has_bias){
canUse32bi = THNN_(canUse32BitIndexMath)
- (state, 6, input, hidden, hx, hy, bias1, bias2);
+ (state, 7, input, hidden, hx, hy, bias1, bias2, storage);
}else{
canUse32bi = THNN_(canUse32BitIndexMath)
- (state, 4, input, hidden, hx, hy);
+ (state, 5, input, hidden, hx, hy, storage);
}
if(canUse32bi){
THNN_(GRU_forw_ind_wrap)<unsigned int>
- (state, input, hidden, bias1, bias2, hx, hy);
+ (state, input, hidden, bias1, bias2, hx, hy, storage);
}else{
THNN_(GRU_forw_ind_wrap)<unsigned long>
- (state, input, hidden, bias1, bias2, hx, hy);
+ (state, input, hidden, bias1, bias2, hx, hy, storage);
}
THCudaCheck(cudaGetLastError());
@@ -793,12 +802,15 @@ void THNN_(GRUFused_updateOutput)(
template<typename INDTYPE>
void THNN_(GRU_back_ind_wrap)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *gradInInput,
+ THCTensor *gradInHidden,
THCTensor *gradOutput,
- THCTensor *gradInput)
+ THCTensor *gradInputHx,
+ THCTensor *storage)
{
- int maxDim = THNN_(minIndexType)(state, 4, input, hidden, gradOutput, gradInput);
+
+ int maxDim = THNN_(minIndexType)(state, 5, gradInInput, gradInHidden, gradOutput,
+ gradInputHx, storage);
ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, gradOutput);
const dim3 block = getApplyBlock();
@@ -806,43 +818,48 @@ void THNN_(GRU_back_ind_wrap)(
THAssertMsg(getApplyGrid(state, totalElements, grid),
"Could not get grid size for pointwise apply");
- TensorInfo<DATATYPE, INDTYPE> inputI =
- getTensorInfo<THCTensor, INDTYPE>(state, input);
- TensorInfo<DATATYPE, INDTYPE> hiddenI =
- getTensorInfo<THCTensor, INDTYPE>(state, hidden);
+ TensorInfo<DATATYPE, INDTYPE> gradininputI =
+ getTensorInfo<THCTensor, INDTYPE>(state, gradInInput);
+ TensorInfo<DATATYPE, INDTYPE> gradinhiddenI =
+ getTensorInfo<THCTensor, INDTYPE>(state, gradInHidden);
TensorInfo<DATATYPE, INDTYPE> gradoutI =
getTensorInfo<THCTensor, INDTYPE>(state, gradOutput);
- TensorInfo<DATATYPE, INDTYPE> gradinI =
- getTensorInfo<THCTensor, INDTYPE>(state, gradInput);
+ TensorInfo<DATATYPE, INDTYPE> gradinhxI =
+ getTensorInfo<THCTensor, INDTYPE>(state, gradInputHx);
+ TensorInfo<DATATYPE, INDTYPE> storageI =
+ getTensorInfo<THCTensor, INDTYPE>(state, storage);
INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1];
if(maxDim == -2){
- inputI.collapseDims();
- hiddenI.collapseDims();
+ gradininputI.collapseDims();
+ gradinhiddenI.collapseDims();
gradoutI.collapseDims();
- gradinI.collapseDims();
+ gradinhxI.collapseDims();
+ storageI.collapseDims();
}
FILL_DIM(INDTYPE, maxDim, GRU_BACKWARD);
}
void THNN_(GRUFused_updateGradInput)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *gradInInput,
+ THCTensor *gradInHidden,
THCTensor *gradOutput,
- THCTensor *gradInput)
+ THCTensor *gradInputHx,
+ THCTensor *storage)
{
- THCTensor_(resizeAs)(state, gradInput, gradOutput);
- THCUNN_assertSameGPU(state, 4, input, hidden, gradOutput, gradInput);
- THNN_(FusedRNNAssertSizes)(state, 3, 4, input, hidden, gradOutput, gradInput);
- bool canUse32bi = THNN_(canUse32BitIndexMath)(state, 4, input, hidden, gradOutput, gradInput);
+ THCTensor_(resizeAs)(state, gradInputHx, gradOutput);
+ THCUNN_assertSameGPU(state, 5, gradInInput, gradInHidden, gradOutput, gradInputHx, storage);
+ THNN_(FusedRNNAssertSizes)(state, 3, 4, gradInInput, gradInHidden, gradOutput, gradInputHx);
+ bool canUse32bi = THNN_(canUse32BitIndexMath)(state, 5, gradInInput, gradInHidden,
+ gradOutput, gradInputHx, storage);
if(canUse32bi){
THNN_(GRU_back_ind_wrap)<unsigned int>
- (state, input, hidden, gradOutput, gradInput);
+ (state, gradInInput, gradInHidden, gradOutput, gradInputHx, storage);
}else{
THNN_(GRU_back_ind_wrap)<unsigned long>
- (state, input, hidden, gradOutput, gradInput);
+ (state, gradInInput, gradInHidden, gradOutput, gradInputHx, storage);
}
THCudaCheck(cudaGetLastError());
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index 72ea749..6fc545c 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -175,12 +175,13 @@ TH_API void THNN_(GRUFused_updateOutput)(
THCTensor *bias1, // [OPTIONAL]
THCTensor *bias2, // [OPTIONAL]
THCTensor *hx,
- THCTensor *hy);
+ THCTensor *hy,
+ THCTensor *storage);
TH_API void THNN_(GRUFused_updateGradInput)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *gradInInput,
+ THCTensor *gradInHidden,
THCTensor *gradOutput,
THCTensor *gradInput);
@@ -196,13 +197,13 @@ TH_API void THNN_(LSTMFused_updateOutput)(
TH_API void THNN_(LSTMFused_updateGradInput)(
THCState *state,
- THCTensor *input,
- THCTensor *hidden,
+ THCTensor *storage,
+ THCTensor *gradInGates,
THCTensor *prevC,
THCTensor *cy,
THCTensor *gradOutput,
THCTensor *gradOutputCell,
- THCTensor *gradInput);
+ THCTensor *gradInputCx);
TH_API void THNN_(LogSigmoid_updateOutput)(
THCState *state,