Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-04-21 23:53:50 +0300
committersoumith <soumith@fb.com>2017-04-22 11:09:46 +0300
commit93d31671597158db54e13906cba18e1b955d4562 (patch)
tree45463dd2a41c06f24969a147e1e5039df0838fe8
parent455e488488bdaa20a50f82586975e69a0332c97a (diff)
Indexing fix for fused GRU/LSTM kernels when all tensors are not contiguous.
-rw-r--r--lib/THCUNN/generic/FusedRNNKernel.cu84
1 files changed, 49 insertions, 35 deletions
diff --git a/lib/THCUNN/generic/FusedRNNKernel.cu b/lib/THCUNN/generic/FusedRNNKernel.cu
index 17a6563..6aeba1e 100644
--- a/lib/THCUNN/generic/FusedRNNKernel.cu
+++ b/lib/THCUNN/generic/FusedRNNKernel.cu
@@ -41,16 +41,21 @@ int THNN_(minIndexType)(THCState *state, int count, ...)
va_list list;
va_start(list, count);
- int maxDim = -2;
- for (int arg=0; arg < count; ++arg){
- THCTensor* tens = va_arg(list, THCTensor*);
- if(THCTensor_(isContiguous)(state, tens)) continue;
- int tensdims = TensorUtils<THCTensor>::getDims(state, tens);
- maxDim = (( tensdims> maxDim) ? tensdims : maxDim);
+ THCTensor* tens = va_arg(list, THCTensor*);
+ int startDim = TensorUtils<THCTensor>::getDims(state, tens);
+ bool canCollapse = THCTensor_(isContiguous)(state,tens);
+
+ for (int arg=1; arg < count; ++arg){
+ tens = va_arg(list, THCTensor*);
+ canCollapse = canCollapse && THCTensor_(isContiguous)(state, tens);
+ if(TensorUtils<THCTensor>::getDims(state, tens) != startDim){
+ va_end(list);
+ return -1;
+ }
}
-
va_end(list);
- return maxDim;
+ if(canCollapse) return -2;
+ return startDim;
}
bool THNN_(canUse32BitIndexMath)(THCState *state, int count, ...)
@@ -534,11 +539,13 @@ void THNN_(LSTM_forw_ind_wrap)(
"Bias in pointwise operation is an incorrect size, must be 4 x feature size.");
}
- inputI.collapseDims();
- hiddenI.collapseDims();
- cxI.collapseDims();
- hyI.collapseDims();
- cyI.collapseDims();
+ if(maxDim == -2){
+ inputI.collapseDims();
+ hiddenI.collapseDims();
+ cxI.collapseDims();
+ hyI.collapseDims();
+ cyI.collapseDims();
+ }
INDTYPE zero[1] = {0};
TensorInfo<DATATYPE, INDTYPE> nullinfo =
@@ -549,8 +556,10 @@ void THNN_(LSTM_forw_ind_wrap)(
if(has_bias){
bias1I = getTensorInfo<THCTensor, INDTYPE>(state, bias1);
bias2I = getTensorInfo<THCTensor, INDTYPE>(state, bias2);
- bias1I.collapseDims();
- bias2I.collapseDims();
+ if(maxDim == -2){
+ bias1I.collapseDims();
+ bias2I.collapseDims();
+ }
}
FILL_DIM(INDTYPE, maxDim, LSTM_FORWARD);
@@ -628,14 +637,15 @@ void THNN_(LSTM_back_ind_wrap)(
INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1];
- inputI.collapseDims();
- hiddenI.collapseDims();
- cxI.collapseDims();
- cyI.collapseDims();
- gradoutI.collapseDims();
- gradoutcI.collapseDims();
- gradinI.collapseDims();
-
+ if(maxDim == -2){
+ inputI.collapseDims();
+ hiddenI.collapseDims();
+ cxI.collapseDims();
+ cyI.collapseDims();
+ gradoutI.collapseDims();
+ gradoutcI.collapseDims();
+ gradinI.collapseDims();
+ }
FILL_DIM(INDTYPE, maxDim, LSTM_BACKWARD);
}
@@ -721,11 +731,12 @@ void THNN_(GRU_forw_ind_wrap)(
"Bias in pointwise operation is an incorrect size, must be 3 x feature size.");
}
- inputI.collapseDims();
- hiddenI.collapseDims();
- hyI.collapseDims();
- hxI.collapseDims();
-
+ if(maxDim == -2){
+ inputI.collapseDims();
+ hiddenI.collapseDims();
+ hyI.collapseDims();
+ hxI.collapseDims();
+ }
INDTYPE zero[1] = {0};
TensorInfo<DATATYPE, INDTYPE> nullinfo =
TensorInfo<DATATYPE, INDTYPE>(NULL, 1, zero, zero);
@@ -735,8 +746,10 @@ void THNN_(GRU_forw_ind_wrap)(
if(has_bias){
bias1I = getTensorInfo<THCTensor, INDTYPE>(state, bias1);
bias2I = getTensorInfo<THCTensor, INDTYPE>(state, bias2);
- bias1I.collapseDims();
- bias2I.collapseDims();
+ if(maxDim == -2){
+ bias1I.collapseDims();
+ bias2I.collapseDims();
+ }
}
FILL_DIM(INDTYPE, maxDim, GRU_FORWARD);
@@ -804,11 +817,12 @@ void THNN_(GRU_back_ind_wrap)(
INDTYPE hid_size = gradoutI.sizes[gradoutI.dims-1];
- inputI.collapseDims();
- hiddenI.collapseDims();
- gradoutI.collapseDims();
- gradinI.collapseDims();
-
+ if(maxDim == -2){
+ inputI.collapseDims();
+ hiddenI.collapseDims();
+ gradoutI.collapseDims();
+ gradinI.collapseDims();
+ }
FILL_DIM(INDTYPE, maxDim, GRU_BACKWARD);
}