diff options
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 2ba394c..9d2a7b4 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1595,6 +1595,10 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THLongTensor_zero(indices_); if(t->size[dimension] == 1) { + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } return; } @@ -1671,6 +1675,10 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THLongTensor_zero(indices_); if(t->size[dimension] == 1) { + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } return; } |