diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-04-21 23:53:50 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2017-04-22 11:09:46 +0300 |
commit | 93d31671597158db54e13906cba18e1b955d4562 (patch) | |
tree | 45463dd2a41c06f24969a147e1e5039df0838fe8 | |
parent | 455e488488bdaa20a50f82586975e69a0332c97a (diff) |
Indexing fix for fused GRU/LSTM kernels when all tensors are not contiguous.
-rw-r--r-- | lib/THCUNN/generic/FusedRNNKernel.cu | 84 |
1 files changed, 49 insertions, 35 deletions
diff --git a/lib/THCUNN/generic/FusedRNNKernel.cu b/lib/THCUNN/generic/FusedRNNKernel.cu index 17a6563..6aeba1e 100644 --- a/lib/THCUNN/generic/FusedRNNKernel.cu +++ b/lib/THCUNN/generic/FusedRNNKernel.cu @@ -41,16 +41,21 @@ int THNN_(minIndexType)(THCState *state, int count, ...) va_list list; va_start(list, count); - int maxDim = -2; - for (int arg=0; arg < count; ++arg){ - THCTensor* tens = va_arg(list, THCTensor*); - if(THCTensor_(isContiguous)(state, tens)) continue; - int tensdims = TensorUtils<THCTensor>::getDims(state, tens); - maxDim = (( tensdims> maxDim) ? tensdims : maxDim); + THCTensor* tens = va_arg(list, THCTensor*); + int startDim = TensorUtils<THCTensor>::getDims(state, tens); + bool canCollapse = THCTensor_(isContiguous)(state,tens); + + for (int arg=1; arg < count; ++arg){ + tens = va_arg(list, THCTensor*); + canCollapse = canCollapse && THCTensor_(isContiguous)(state, tens); + if(TensorUtils<THCTensor>::getDims(state, tens) != startDim){ + va_end(list); + return -1; + } } - va_end(list); - return maxDim; + if(canCollapse) return -2; + return startDim; } bool THNN_(canUse32BitIndexMath)(THCState *state, int count, ...) @@ -534,11 +539,13 @@ void THNN_(LSTM_forw_ind_wrap)( "Bias in pointwise operation is an incorrect size, must be 4 x feature size."); } - inputI.collapseDims(); - hiddenI.collapseDims(); - cxI.collapseDims(); - hyI.collapseDims(); - cyI.collapseDims(); + if(maxDim == -2){ + inputI.collapseDims(); + hiddenI.collapseDims(); + cxI.collapseDims(); + hyI.collapseDims(); + cyI.collapseDims(); + } INDTYPE zero[1] = {0}; TensorInfo<DATATYPE, INDTYPE> nullinfo = @@ -549,8 +556,10 @@ void THNN_(LSTM_forw_ind_wrap)( if(has_bias){ bias1I = getTensorInfo<THCTensor, INDTYPE>(state, bias1); bias2I = getTensorInfo<THCTensor, INDTYPE>(state, bias2); - bias1I.collapseDims(); - bias2I.collapseDims(); + if(maxDim == -2){ + bias1I.collapseDims(); + bias2I.collapseDims(); + } } FILL_DIM(INDTYPE, maxDim, LSTM_FORWARD); @@ -628,14 +637,15 @@ void THNN_(LSTM_back_ind_wrap)( INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1]; - inputI.collapseDims(); - hiddenI.collapseDims(); - cxI.collapseDims(); - cyI.collapseDims(); - gradoutI.collapseDims(); - gradoutcI.collapseDims(); - gradinI.collapseDims(); - + if(maxDim == -2){ + inputI.collapseDims(); + hiddenI.collapseDims(); + cxI.collapseDims(); + cyI.collapseDims(); + gradoutI.collapseDims(); + gradoutcI.collapseDims(); + gradinI.collapseDims(); + } FILL_DIM(INDTYPE, maxDim, LSTM_BACKWARD); } @@ -721,11 +731,12 @@ void THNN_(GRU_forw_ind_wrap)( "Bias in pointwise operation is an incorrect size, must be 3 x feature size."); } - inputI.collapseDims(); - hiddenI.collapseDims(); - hyI.collapseDims(); - hxI.collapseDims(); - + if(maxDim == -2){ + inputI.collapseDims(); + hiddenI.collapseDims(); + hyI.collapseDims(); + hxI.collapseDims(); + } INDTYPE zero[1] = {0}; TensorInfo<DATATYPE, INDTYPE> nullinfo = TensorInfo<DATATYPE, INDTYPE>(NULL, 1, zero, zero); @@ -735,8 +746,10 @@ void THNN_(GRU_forw_ind_wrap)( if(has_bias){ bias1I = getTensorInfo<THCTensor, INDTYPE>(state, bias1); bias2I = getTensorInfo<THCTensor, INDTYPE>(state, bias2); - bias1I.collapseDims(); - bias2I.collapseDims(); + if(maxDim == -2){ + bias1I.collapseDims(); + bias2I.collapseDims(); + } } FILL_DIM(INDTYPE, maxDim, GRU_FORWARD); @@ -804,11 +817,12 @@ void THNN_(GRU_back_ind_wrap)( INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1]; - inputI.collapseDims(); - hiddenI.collapseDims(); - gradoutI.collapseDims(); - gradinI.collapseDims(); - + if(maxDim == -2){ + inputI.collapseDims(); + hiddenI.collapseDims(); + gradoutI.collapseDims(); + gradinI.collapseDims(); + } FILL_DIM(INDTYPE, maxDim, GRU_BACKWARD); } |