diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-06-07 06:20:14 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-07 06:20:14 +0300 |
commit | 8e3364c64aa85e7ab4970c8d5f6a4f18d0c3eeae (patch) | |
tree | 758c4db20b040f91947d2b6db5c52f378ae457e4 | |
parent | 8c6df2a424680dfd3fa1109a2f8e3b62ac41680d (diff) |
Remove clone in fused rnn
-rw-r--r-- | lib/THCUNN/generic/FusedRNNKernel.cu | 297 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 13 |
2 files changed, 164 insertions, 146 deletions
diff --git a/lib/THCUNN/generic/FusedRNNKernel.cu b/lib/THCUNN/generic/FusedRNNKernel.cu index 6aeba1e..91f11e8 100644 --- a/lib/THCUNN/generic/FusedRNNKernel.cu +++ b/lib/THCUNN/generic/FusedRNNKernel.cu @@ -85,12 +85,13 @@ template <typename T, typename IndexType, int Dims> __launch_bounds__(32 * 16, 4) #endif __global__ void - THNN_(GRUForward)(TensorInfo<T, IndexType> Input, +THNN_(GRUForward)(TensorInfo<T, IndexType> Input, TensorInfo<T, IndexType> Hidden, TensorInfo<T, IndexType> Bias1, TensorInfo<T, IndexType> Bias2, TensorInfo<T, IndexType> _hx, TensorInfo<T, IndexType> _hy, + TensorInfo<T, IndexType> storage, IndexType hsz, IndexType totalElements) { @@ -101,16 +102,14 @@ __global__ void IndexType offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz; - T* ir = &DEVICE_LINEAR_GET(Input, offset+0*hsz); - T* ii = &DEVICE_LINEAR_GET(Input, offset+1*hsz); - T* in = &DEVICE_LINEAR_GET(Input, offset+2*hsz); - - T* hr = &DEVICE_LINEAR_GET(Hidden,offset+0*hsz); - T* hi = &DEVICE_LINEAR_GET(Hidden,offset+1*hsz); + T ir = DEVICE_LINEAR_GET(Input, offset+0*hsz); + T ii = DEVICE_LINEAR_GET(Input, offset+1*hsz); + T in = DEVICE_LINEAR_GET(Input, offset+2*hsz); + T hr = DEVICE_LINEAR_GET(Hidden,offset+0*hsz); + T hi = DEVICE_LINEAR_GET(Hidden,offset+1*hsz); T hn = DEVICE_LINEAR_GET(Hidden, offset+2*hsz); T hx = DEVICE_LINEAR_GET(_hx, linearIndex); - T* hy = &DEVICE_LINEAR_GET(_hy, linearIndex); bool has_bias = (Bias1.data != NULL); @@ -136,45 +135,46 @@ __global__ void } + offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz; #ifndef THC_REAL_IS_HALF T rg, ig, ng; - rg = *ir + *hr + b1r + b2r; - ig = *ii + *hi + b1i + b2i; + rg = ir + hr + b1r + b2r; + ig = ii + hi + b1i + b2i; TensorSigmoidOp<real>()(&rg, &rg); TensorSigmoidOp<real>()(&ig, &ig); - ng = *in + b1n + rg * (hn + b2n); + ng = in + b1n + rg * (hn + b2n); ng = THCNumerics<T>::tanh(ng); *hy = ng + ig * (hx - ng); //SAVE FOR BACKWARDS - *ir = rg; - *ii = ig; - *in = ng; - *hr = hx; - *hi = hn + b2n; + DEVICE_LINEAR_GET(storage, offset+0*hsz) = rg; + DEVICE_LINEAR_GET(storage, offset+1*hsz) = ig; + DEVICE_LINEAR_GET(storage, offset+2*hsz) = ng; + DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx; + DEVICE_LINEAR_GET(storage, offset+4*hsz) = hn + b2n; + #else float rg, ig, ng; - rg = H2F(*ir) + H2F(*hr) + H2F(b1r) + H2F(b2r); - ig = H2F(*ii) + H2F(*hi) + H2F(b1i) + H2F(b2i); + rg = H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r); + ig = H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i); TensorSigmoidOp<float>()(&rg, &rg); TensorSigmoidOp<float>()(&ig, &ig); - ng = H2F(*in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) ); + ng = H2F(in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) ); ng = THCNumerics<float>::tanh(ng); *hy = F2H( ng + ig * ( H2F(hx)-ng ) ); //SAVE FOR BACKWARDS - *ir = F2H(rg); - *ii = F2H(ig); - *in = F2H(ng); - *hr = hx; - *hi = F2H( H2F(hn) + H2F(b2n) ); - + DEVICE_LINEAR_GET(storage, offset+0*hsz) = F2H(rg); + DEVICE_LINEAR_GET(storage, offset+1*hsz) = F2H(ig); + DEVICE_LINEAR_GET(storage, offset+2*hsz) = F2H(ng); + DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx; + DEVICE_LINEAR_GET(storage, offset+4*hsz) = F2H(H2F(hn) + H2F(b2n)); #endif } } @@ -184,63 +184,61 @@ template <typename T, typename IndexType, int Dims> __launch_bounds__(32 * 16, 4) #endif __global__ void -THNN_(GRUBackward)(TensorInfo<T, IndexType> input, - TensorInfo<T, IndexType> hidden, - TensorInfo<T, IndexType> gradoutput, - TensorInfo<T, IndexType> gradinput, +THNN_(GRUBackward)(TensorInfo<T, IndexType> gradInInput, + TensorInfo<T, IndexType> gradInHidden, + TensorInfo<T, IndexType> gradOutput, + TensorInfo<T, IndexType> gradInputHx, + TensorInfo<T, IndexType> storage, IndexType hsz, IndexType totalElements) { for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < totalElements; linearIndex += gridDim.x * blockDim.x) { - IndexType offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;; + IndexType offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz; - //will return input grads here - T* rg = &DEVICE_LINEAR_GET(input, offset+0*hsz); - T* ig = &DEVICE_LINEAR_GET(input, offset+1*hsz); - T* ng = &DEVICE_LINEAR_GET(input, offset+2*hsz); - //will return hidden grads here - T* hx = &DEVICE_LINEAR_GET(hidden, offset+0*hsz); - T* hn = &DEVICE_LINEAR_GET(hidden, offset+1*hsz); - T* oghn=&DEVICE_LINEAR_GET(hidden, offset+2*hsz); + T rg = DEVICE_LINEAR_GET(storage, offset+0*hsz); + T ig = DEVICE_LINEAR_GET(storage, offset+1*hsz); + T ng = DEVICE_LINEAR_GET(storage, offset+2*hsz); + T hx = DEVICE_LINEAR_GET(storage, offset+3*hsz); + T hn = DEVICE_LINEAR_GET(storage, offset+4*hsz); - T* gi = &DEVICE_LINEAR_GET(gradinput, linearIndex); + T go = DEVICE_LINEAR_GET(gradOutput, linearIndex); - T* go = &DEVICE_LINEAR_GET(gradoutput, linearIndex); + offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz; #ifndef THC_REAL_IS_HALF - T gig = (*go)*(*hx-*ng)*( 1-(*ig) )*(*ig); - T ghx = (*go)*(*ig); - T gin = (*go)*(1-*ig)*( 1-(*ng)*(*ng) ); - T ghn = (gin) * (*rg); - T grg = (gin)*(*hn)*( 1-(*rg) )*(*rg); + T gig = go*(hx-ng)*(1-ig)*(ig); + T ghx = go*(ig); + T gin = go*(1-ig)*(1-ng*ng); + T ghn = gin *rg; + T grg = gin*hn*(1-rg)*rg; - *gi = ghx; + DEVICE_LINEAR_GET(gradInputHx, linearIndex) = ghx; - *rg = grg; - *ig = gig; - *ng = gin; + DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = grg; + DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = gig; + DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = gin; - *hx = grg; - *hn = gig; - *oghn = ghn; + DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = grg; + DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = gig; + DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = ghn; #else - float gig = H2F(*go)*( H2F(*hx)-H2F(*ng) )*( 1-H2F(*ig) )*H2F(*ig); - float ghx = H2F(*go)*H2F(*ig); - float gin = H2F(*go)*( 1-H2F(*ig) )*( 1-H2F(*ng)*H2F(*ng) ); - float ghn = H2F(gin) * H2F(*rg); - float grg = H2F(gin)*H2F(*hn)*( 1-H2F(*rg) )*H2F(*rg); + float gig = H2F(go)*( H2F(hx)-H2F(ng) )*( 1-H2F(ig) )*H2F(ig); + float ghx = H2F(go)*H2F(ig); + float gin = H2F(go)*( 1-H2F(ig) )*( 1-H2F(ng)*H2F(ng) ); + float ghn = H2F(gin) * H2F(rg); + float grg = H2F(gin)*H2F(hn)*( 1-H2F(rg) )*H2F(rg); - *gi = F2H(ghx); + DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = F2H(grg); + DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = F2H(gig); + DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = F2H(gin); - *rg = F2H(grg); - *ig = F2H(gig); - *ng = F2H(gin); + DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = F2H(grg); + DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = F2H(gig); + DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = F2H(ghn); + DEVICE_LINEAR_GET(gradInputHx, linearIndex) = F2H(ghx); - *hx = F2H(grg); - *hn = F2H(gig); - *oghn = F2H(ghn); #endif } } @@ -364,13 +362,13 @@ template <typename T, typename IndexType, int Dims> __launch_bounds__(32 * 16, 4) #endif __global__ void - THNN_(LSTMBackward)(TensorInfo<T, IndexType> input, - TensorInfo<T, IndexType> hidden, + THNN_(LSTMBackward)(TensorInfo<T, IndexType> storage, + TensorInfo<T, IndexType> gradInGates, TensorInfo<T, IndexType> _cx, TensorInfo<T, IndexType> _cy, TensorInfo<T, IndexType> gradoutput, TensorInfo<T, IndexType> gradoutputcell, - TensorInfo<T, IndexType> gradinput, + TensorInfo<T, IndexType> gradInputCx, IndexType hsz, IndexType totalElements) { @@ -379,21 +377,21 @@ __global__ void linearIndex += gridDim.x * blockDim.x) { IndexType offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz; - T ig = DEVICE_LINEAR_GET(input, offset+0*hsz); - T fg = DEVICE_LINEAR_GET(input, offset+1*hsz); - T cg = DEVICE_LINEAR_GET(input, offset+2*hsz); - T og = DEVICE_LINEAR_GET(input, offset+3*hsz); + T ig = DEVICE_LINEAR_GET(storage, offset+0*hsz); + T fg = DEVICE_LINEAR_GET(storage, offset+1*hsz); + T cg = DEVICE_LINEAR_GET(storage, offset+2*hsz); + T og = DEVICE_LINEAR_GET(storage, offset+3*hsz); - T* ih = &DEVICE_LINEAR_GET(hidden, offset+0*hsz); - T* fh = &DEVICE_LINEAR_GET(hidden, offset+1*hsz); - T* ch = &DEVICE_LINEAR_GET(hidden, offset+2*hsz); - T* oh = &DEVICE_LINEAR_GET(hidden, offset+3*hsz); + T* ih = &DEVICE_LINEAR_GET(gradInGates, offset+0*hsz); + T* fh = &DEVICE_LINEAR_GET(gradInGates, offset+1*hsz); + T* ch = &DEVICE_LINEAR_GET(gradInGates, offset+2*hsz); + T* oh = &DEVICE_LINEAR_GET(gradInGates, offset+3*hsz); //will return hidden grads here T cx = DEVICE_LINEAR_GET(_cx, linearIndex); T cy = DEVICE_LINEAR_GET(_cy, linearIndex); - T* gi = &DEVICE_LINEAR_GET(gradinput, linearIndex); + T* gi = &DEVICE_LINEAR_GET(gradInputCx, linearIndex); T go = DEVICE_LINEAR_GET(gradoutput, linearIndex); T goc= DEVICE_LINEAR_GET(gradoutputcell, linearIndex); @@ -474,19 +472,20 @@ __global__ void #define LSTM_BACKWARD(ITYPE, DIM) THNN_(LSTMBackward) \ <DATATYPE, ITYPE, DIM> \ <<<grid, block, 0, THCState_getCurrentStream(state)>>> \ - (inputI, hiddenI, cxI, cyI, \ - gradoutI, gradoutcI, gradinI, \ + (storageI, gradingatesI, cxI, cyI, \ + gradoutI, gradoutcI, gradincxI, \ hid_size, totalElements); #define GRU_FORWARD(ITYPE, DIM) THNN_(GRUForward)<DATATYPE, ITYPE, DIM> \ <<<grid, block, 0, THCState_getCurrentStream(state)>>> \ - (inputI, hiddenI, bias1I, bias2I, hxI, hyI, \ + (inputI, hiddenI, bias1I, bias2I, hxI, hyI, storageI, \ hid_size, totalElements); #define GRU_BACKWARD(ITYPE, DIM) THNN_(GRUBackward) \ <DATATYPE, ITYPE, DIM> \ <<<grid, block, 0, THCState_getCurrentStream(state)>>> \ - (inputI, hiddenI, gradoutI, gradinI, hid_size, totalElements); + (gradininputI, gradinhiddenI, gradoutI, gradinhxI, storageI, \ + hid_size, totalElements); // ************ END Create actual function calls ************ // @@ -602,17 +601,17 @@ void THNN_(LSTMFused_updateOutput)( template<typename INDTYPE> void THNN_(LSTM_back_ind_wrap)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *storage, + THCTensor *gradInGates, THCTensor *cx, THCTensor *cy, THCTensor *gradOutput, THCTensor *gradOutputCell, - THCTensor *gradInput) + THCTensor *gradInputCx) { int maxDim = THNN_(minIndexType) - (state, 7, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + (state, 7, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, gradOutput); const dim3 block = getApplyBlock(); @@ -620,10 +619,10 @@ void THNN_(LSTM_back_ind_wrap)( THAssertMsg(getApplyGrid(state, totalElements, grid), "Could not get grid size for pointwise apply"); - TensorInfo<DATATYPE, INDTYPE> inputI = - getTensorInfo<THCTensor, INDTYPE>(state, input); - TensorInfo<DATATYPE, INDTYPE> hiddenI = - getTensorInfo<THCTensor, INDTYPE>(state, hidden); + TensorInfo<DATATYPE, INDTYPE> storageI = + getTensorInfo<THCTensor, INDTYPE>(state, storage); + TensorInfo<DATATYPE, INDTYPE> gradingatesI = + getTensorInfo<THCTensor, INDTYPE>(state, gradInGates); TensorInfo<DATATYPE, INDTYPE> cxI = getTensorInfo<THCTensor, INDTYPE>(state, cx); TensorInfo<DATATYPE, INDTYPE> cyI = @@ -632,19 +631,19 @@ void THNN_(LSTM_back_ind_wrap)( getTensorInfo<THCTensor, INDTYPE>(state, gradOutput); TensorInfo<DATATYPE, INDTYPE> gradoutcI = getTensorInfo<THCTensor, INDTYPE>(state, gradOutputCell); - TensorInfo<DATATYPE, INDTYPE> gradinI = - getTensorInfo<THCTensor, INDTYPE>(state, gradInput); + TensorInfo<DATATYPE, INDTYPE> gradincxI = + getTensorInfo<THCTensor, INDTYPE>(state, gradInputCx); INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1]; if(maxDim == -2){ - inputI.collapseDims(); - hiddenI.collapseDims(); + storageI.collapseDims(); + gradingatesI.collapseDims(); cxI.collapseDims(); cyI.collapseDims(); gradoutI.collapseDims(); gradoutcI.collapseDims(); - gradinI.collapseDims(); + gradincxI.collapseDims(); } FILL_DIM(INDTYPE, maxDim, LSTM_BACKWARD); @@ -652,33 +651,33 @@ void THNN_(LSTM_back_ind_wrap)( void THNN_(LSTMFused_updateGradInput)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *storage, + THCTensor *gradInGates, THCTensor *cx, THCTensor *cy, THCTensor *gradOutput, THCTensor *gradOutputCell, - THCTensor *gradInput) + THCTensor *gradInputCx) { - THCTensor_(resizeAs)(state, gradInput, gradOutput); - THCUNN_assertSameGPU(state, 7, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + THCTensor_(resizeAs)(state, gradInputCx, gradOutput); + THCUNN_assertSameGPU(state, 7, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); THNN_(FusedRNNAssertSizes) - (state, 4, 7, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + (state, 4, 7, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); bool canUse32bi = THNN_(canUse32BitIndexMath) - (state, 7, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + (state, 7, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); if(canUse32bi){ THNN_(LSTM_back_ind_wrap)<unsigned int> - (state, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + (state, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); }else{ THNN_(LSTM_back_ind_wrap)<unsigned long> - (state, input, hidden, cx, cy, - gradOutput, gradOutputCell, gradInput); + (state, storage, gradInGates, cx, cy, + gradOutput, gradOutputCell, gradInputCx); } THCudaCheck(cudaGetLastError()); } @@ -691,21 +690,22 @@ void THNN_(GRU_forw_ind_wrap)( THCTensor *bias1, THCTensor *bias2, THCTensor *hx, - THCTensor *hy) + THCTensor *hy, + THCTensor *storage) { bool has_bias = (bias1!=NULL); int maxDim; if(has_bias){ THCUNN_assertSameGPU - (state, 6, input, hidden, hx, hy, bias1, bias2); + (state, 7, input, hidden, hx, hy, bias1, bias2, storage); maxDim = THNN_(minIndexType) - (state, 6, input, hidden, hx, hy, bias1, bias2); + (state, 7, input, hidden, hx, hy, bias1, bias2, storage); }else{ THCUNN_assertSameGPU - (state, 4, input, hidden, hx, hy); + (state, 5, input, hidden, hx, hy, storage); maxDim = THNN_(minIndexType) - (state, 4, input, hidden, hx, hy); + (state, 5, input, hidden, hx, hy, storage); } ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, hx); @@ -723,6 +723,8 @@ void THNN_(GRU_forw_ind_wrap)( getTensorInfo<THCTensor, INDTYPE>(state, hx); TensorInfo<DATATYPE, INDTYPE> hyI = getTensorInfo<THCTensor, INDTYPE>(state, hy); + TensorInfo<DATATYPE, INDTYPE> storageI = + getTensorInfo<THCTensor, INDTYPE>(state, storage); INDTYPE hid_size = hxI.sizes[hxI.dims-1]; if(has_bias){ @@ -736,7 +738,9 @@ void THNN_(GRU_forw_ind_wrap)( hiddenI.collapseDims(); hyI.collapseDims(); hxI.collapseDims(); + storageI.collapseDims(); } + INDTYPE zero[1] = {0}; TensorInfo<DATATYPE, INDTYPE> nullinfo = TensorInfo<DATATYPE, INDTYPE>(NULL, 1, zero, zero); @@ -763,28 +767,33 @@ void THNN_(GRUFused_updateOutput)( THCTensor *bias1, THCTensor *bias2, THCTensor *hx, - THCTensor *hy) + THCTensor *hy, + THCTensor *storage) { THCTensor_(resizeAs)(state, hy, hx); THNN_(FusedRNNAssertSizes)(state, 3, 4, input, hidden, hx, hy); + THArgCheck(THCTensor_(nElement)(state, storage) == + THCTensor_(nElement)(state, hx)*5, + 3, "Storage tensor for fused kernel was not sized correctly."); + bool has_bias = (bias1!=NULL); bool canUse32bi; if(has_bias){ canUse32bi = THNN_(canUse32BitIndexMath) - (state, 6, input, hidden, hx, hy, bias1, bias2); + (state, 7, input, hidden, hx, hy, bias1, bias2, storage); }else{ canUse32bi = THNN_(canUse32BitIndexMath) - (state, 4, input, hidden, hx, hy); + (state, 5, input, hidden, hx, hy, storage); } if(canUse32bi){ THNN_(GRU_forw_ind_wrap)<unsigned int> - (state, input, hidden, bias1, bias2, hx, hy); + (state, input, hidden, bias1, bias2, hx, hy, storage); }else{ THNN_(GRU_forw_ind_wrap)<unsigned long> - (state, input, hidden, bias1, bias2, hx, hy); + (state, input, hidden, bias1, bias2, hx, hy, storage); } THCudaCheck(cudaGetLastError()); @@ -793,12 +802,15 @@ void THNN_(GRUFused_updateOutput)( template<typename INDTYPE> void THNN_(GRU_back_ind_wrap)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *gradInInput, + THCTensor *gradInHidden, THCTensor *gradOutput, - THCTensor *gradInput) + THCTensor *gradInputHx, + THCTensor *storage) { - int maxDim = THNN_(minIndexType)(state, 4, input, hidden, gradOutput, gradInput); + + int maxDim = THNN_(minIndexType)(state, 5, gradInInput, gradInHidden, gradOutput, + gradInputHx, storage); ptrdiff_t totalElements = TensorUtils<THCTensor>::getNumElements(state, gradOutput); const dim3 block = getApplyBlock(); @@ -806,43 +818,48 @@ void THNN_(GRU_back_ind_wrap)( THAssertMsg(getApplyGrid(state, totalElements, grid), "Could not get grid size for pointwise apply"); - TensorInfo<DATATYPE, INDTYPE> inputI = - getTensorInfo<THCTensor, INDTYPE>(state, input); - TensorInfo<DATATYPE, INDTYPE> hiddenI = - getTensorInfo<THCTensor, INDTYPE>(state, hidden); + TensorInfo<DATATYPE, INDTYPE> gradininputI = + getTensorInfo<THCTensor, INDTYPE>(state, gradInInput); + TensorInfo<DATATYPE, INDTYPE> gradinhiddenI = + getTensorInfo<THCTensor, INDTYPE>(state, gradInHidden); TensorInfo<DATATYPE, INDTYPE> gradoutI = getTensorInfo<THCTensor, INDTYPE>(state, gradOutput); - TensorInfo<DATATYPE, INDTYPE> gradinI = - getTensorInfo<THCTensor, INDTYPE>(state, gradInput); + TensorInfo<DATATYPE, INDTYPE> gradinhxI = + getTensorInfo<THCTensor, INDTYPE>(state, gradInputHx); + TensorInfo<DATATYPE, INDTYPE> storageI = + getTensorInfo<THCTensor, INDTYPE>(state, storage); INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1]; if(maxDim == -2){ - inputI.collapseDims(); - hiddenI.collapseDims(); + gradininputI.collapseDims(); + gradinhiddenI.collapseDims(); gradoutI.collapseDims(); - gradinI.collapseDims(); + gradinhxI.collapseDims(); + storageI.collapseDims(); } FILL_DIM(INDTYPE, maxDim, GRU_BACKWARD); } void THNN_(GRUFused_updateGradInput)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *gradInInput, + THCTensor *gradInHidden, THCTensor *gradOutput, - THCTensor *gradInput) + THCTensor *gradInputHx, + THCTensor *storage) { - THCTensor_(resizeAs)(state, gradInput, gradOutput); - THCUNN_assertSameGPU(state, 4, input, hidden, gradOutput, gradInput); - THNN_(FusedRNNAssertSizes)(state, 3, 4, input, hidden, gradOutput, gradInput); - bool canUse32bi = THNN_(canUse32BitIndexMath)(state, 4, input, hidden, gradOutput, gradInput); + THCTensor_(resizeAs)(state, gradInputHx, gradOutput); + THCUNN_assertSameGPU(state, 5, gradInInput, gradInHidden, gradOutput, gradInputHx, storage); + THNN_(FusedRNNAssertSizes)(state, 3, 4, gradInInput, gradInHidden, gradOutput, gradInputHx); + bool canUse32bi = THNN_(canUse32BitIndexMath)(state, 5, gradInInput, gradInHidden, + gradOutput, gradInputHx, storage); if(canUse32bi){ THNN_(GRU_back_ind_wrap)<unsigned int> - (state, input, hidden, gradOutput, gradInput); + (state, gradInInput, gradInHidden, gradOutput, gradInputHx, storage); }else{ THNN_(GRU_back_ind_wrap)<unsigned long> - (state, input, hidden, gradOutput, gradInput); + (state, gradInInput, gradInHidden, gradOutput, gradInputHx, storage); } THCudaCheck(cudaGetLastError()); diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index 72ea749..6fc545c 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -175,12 +175,13 @@ TH_API void THNN_(GRUFused_updateOutput)( THCTensor *bias1, // [OPTIONAL] THCTensor *bias2, // [OPTIONAL] THCTensor *hx, - THCTensor *hy); + THCTensor *hy, + THCTensor *storage); TH_API void THNN_(GRUFused_updateGradInput)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *gradInInput, + THCTensor *gradInHidden, THCTensor *gradOutput, THCTensor *gradInput); @@ -196,13 +197,13 @@ TH_API void THNN_(LSTMFused_updateOutput)( TH_API void THNN_(LSTMFused_updateGradInput)( THCState *state, - THCTensor *input, - THCTensor *hidden, + THCTensor *storage, + THCTensor *gradInGates, THCTensor *prevC, THCTensor *cy, THCTensor *gradOutput, THCTensor *gradOutputCell, - THCTensor *gradInput); + THCTensor *gradInputCx); TH_API void THNN_(LogSigmoid_updateOutput)( THCState *state, |