diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-05-03 21:03:58 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-09 21:54:58 +0300 |
commit | c5e5487075fa7c9bafcbdc9d52e114aedb2bdac3 (patch) | |
tree | 53757aa517d6e1484411cc617ad7ddc23809d7d5 | |
parent | 8e7898166d1844620b1bde58ca05a7ca9291fa39 (diff) |
Add a keepdim parameter for reduction functions over a single dimension.
By default, this parameter is False -- a backwards incompatible change, but
one that follows numpy semantics, e.g. numpy.sum (numpy names the parameter
"keepdims" since you can pass multiple dims to reduction functions).
The old behavior seems desired for normalization type operations
where the tensor will immediately be expanded out again, e.g.:
probs.sum(1).expand_as(probs)
which no longer works because the dimension to expand is missing.
This can be fixed by simply passing True as "keepdim" argument
to the reduction operation, e.g:
probs.sum(1, keepdim=True).expand_as(probs)
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 64 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.h | 22 |
2 files changed, 62 insertions, 24 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index e12d2ac..74d5cc5 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1485,7 +1485,7 @@ ptrdiff_t THTensor_(numel)(THTensor *t) return THTensor_(nElement)(t); } -void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) { THLongStorage *dim; @@ -1554,9 +1554,14 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THTensor_(free)(tempValues_); THLongTensor_free(tempIndices_); } + + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } } -void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) { THLongStorage *dim; @@ -1622,10 +1627,15 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int *tempIndices__data = *tempIndices__dimOffset; }); } + + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } } -void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension) +void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim) { THLongStorage *dim; @@ -1655,9 +1665,13 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension) TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;); THTensor_(free)(temp_); } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } } -void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension) +void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim) { THLongStorage *dim; @@ -1687,6 +1701,10 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension) TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;); THTensor_(free)(temp_); } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } } void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension) @@ -2255,7 +2273,7 @@ static void THTensor_(quickselect)(real *arr, long *idx, long k, long elements, #undef REAL_SWAP #undef BOTH_SWAP -void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) { THLongStorage *dim; THTensor *temp_; @@ -2313,9 +2331,13 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int THTensor_(free)(temp_); THLongTensor_free(tempi_); + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } } -void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension) +void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension, int keepdim) { THLongStorage *dim; THTensor *temp_; @@ -2355,9 +2377,13 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, THTensor_(free)(temp_); THLongTensor_free(tempi_); + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } } -void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension) +void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) { long t_size_dim, k; @@ -2366,7 +2392,7 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i t_size_dim = THTensor_(size)(t, dimension); k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */ - THTensor_(kthvalue)(values_, indices_, t, k+1, dimension); + THTensor_(kthvalue)(values_, indices_, t, k+1, dimension, keepdim); } void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, long k, int dim, int dir, int sorted) @@ -2759,16 +2785,16 @@ void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight) TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_lerp(*a_data, *b_data, weight);); } -void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension) +void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim) { THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d", dimension + TH_INDEX_BASE); - THTensor_(sum)(r_, t, dimension); + THTensor_(sum)(r_, t, dimension, keepdim); THTensor_(div)(r_, r_, t->size[dimension]); } -void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag) +void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim) { THLongStorage *dim; @@ -2807,9 +2833,13 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag) sum2 = (sum2 < 0 ? 0 : sum2); *r__data = (real)sqrt(sum2); }); + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } } -void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag) +void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim) { THLongStorage *dim; @@ -2848,9 +2878,13 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag) sum2 = (sum2 < 0 ? 0 : sum2); *r__data = (real)sum2; }); + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } } -void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension) +void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim) { THLongStorage *dim; @@ -2877,6 +2911,10 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension) sum += pow(fabs(t_data[i*t_stride]), value); *r__data = pow(sum, 1.0/value);) } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } } accreal THTensor_(normall)(THTensor *tensor, real value) diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index 86d36e6..a3cf410 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -69,13 +69,13 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha, TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain); TH_API ptrdiff_t THTensor_(numel)(THTensor *t); -TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); -TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); -TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension); -TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); -TH_API void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension); -TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension); -TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension); +TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension, int keepdim); +TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim); TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension); TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension); TH_API void THTensor_(sign)(THTensor *r_, THTensor *t); @@ -165,10 +165,10 @@ TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t); TH_API void THTensor_(frac)(THTensor *r_, THTensor *t); TH_API void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight); -TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension); -TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag); -TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag); -TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension); +TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim); +TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim); +TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim); +TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim); TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, real value, int dimension, real maxnorm); TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value); TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalue, real maxvalue); |