Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2017-05-03 21:03:58 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-09 21:54:58 +0300
commitc5e5487075fa7c9bafcbdc9d52e114aedb2bdac3 (patch)
tree53757aa517d6e1484411cc617ad7ddc23809d7d5
parent8e7898166d1844620b1bde58ca05a7ca9291fa39 (diff)
Add a keepdim parameter for reduction functions over a single dimension.
By default, this parameter is False -- a backwards incompatible change, but one that follows numpy semantics, e.g. numpy.sum (numpy names the parameter "keepdims" since you can pass multiple dims to reduction functions). The old behavior seems desired for normalization type operations where the tensor will immediately be expanded out again, e.g.: probs.sum(1).expand_as(probs) which no longer works because the dimension to expand is missing. This can be fixed by simply passing True as "keepdim" argument to the reduction operation, e.g: probs.sum(1, keepdim=True).expand_as(probs)
-rw-r--r--lib/TH/generic/THTensorMath.c64
-rw-r--r--lib/TH/generic/THTensorMath.h22
2 files changed, 62 insertions, 24 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index e12d2ac..74d5cc5 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -1485,7 +1485,7 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
return THTensor_(nElement)(t);
}
-void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension)
+void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;
@@ -1554,9 +1554,14 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
THTensor_(free)(tempValues_);
THLongTensor_free(tempIndices_);
}
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(values_, values_, dimension);
+ THLongTensor_squeeze1d(indices_, indices_, dimension);
+ }
}
-void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension)
+void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;
@@ -1622,10 +1627,15 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
*tempIndices__data = *tempIndices__dimOffset;
});
}
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(values_, values_, dimension);
+ THLongTensor_squeeze1d(indices_, indices_, dimension);
+ }
}
-void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension)
+void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;
@@ -1655,9 +1665,13 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension)
TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;);
THTensor_(free)(temp_);
}
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(r_, r_, dimension);
+ }
}
-void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension)
+void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;
@@ -1687,6 +1701,10 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension)
TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;);
THTensor_(free)(temp_);
}
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(r_, r_, dimension);
+ }
}
void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
@@ -2255,7 +2273,7 @@ static void THTensor_(quickselect)(real *arr, long *idx, long k, long elements,
#undef REAL_SWAP
#undef BOTH_SWAP
-void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension)
+void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;
THTensor *temp_;
@@ -2313,9 +2331,13 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
THTensor_(free)(temp_);
THLongTensor_free(tempi_);
+ if (!keepdim) {
+ THTensor_(squeeze1d)(values_, values_, dimension);
+ THLongTensor_squeeze1d(indices_, indices_, dimension);
+ }
}
-void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension)
+void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension, int keepdim)
{
THLongStorage *dim;
THTensor *temp_;
@@ -2355,9 +2377,13 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
THTensor_(free)(temp_);
THLongTensor_free(tempi_);
+ if (!keepdim) {
+ THTensor_(squeeze1d)(values_, values_, dimension);
+ THLongTensor_squeeze1d(indices_, indices_, dimension);
+ }
}
-void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension)
+void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
long t_size_dim, k;
@@ -2366,7 +2392,7 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i
t_size_dim = THTensor_(size)(t, dimension);
k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */
- THTensor_(kthvalue)(values_, indices_, t, k+1, dimension);
+ THTensor_(kthvalue)(values_, indices_, t, k+1, dimension, keepdim);
}
void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, long k, int dim, int dir, int sorted)
@@ -2759,16 +2785,16 @@ void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight)
TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_lerp(*a_data, *b_data, weight););
}
-void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension)
+void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim)
{
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d",
dimension + TH_INDEX_BASE);
- THTensor_(sum)(r_, t, dimension);
+ THTensor_(sum)(r_, t, dimension, keepdim);
THTensor_(div)(r_, r_, t->size[dimension]);
}
-void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag)
+void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim)
{
THLongStorage *dim;
@@ -2807,9 +2833,13 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag)
sum2 = (sum2 < 0 ? 0 : sum2);
*r__data = (real)sqrt(sum2);
});
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(r_, r_, dimension);
+ }
}
-void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag)
+void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim)
{
THLongStorage *dim;
@@ -2848,9 +2878,13 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag)
sum2 = (sum2 < 0 ? 0 : sum2);
*r__data = (real)sum2;
});
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(r_, r_, dimension);
+ }
}
-void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension)
+void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim)
{
THLongStorage *dim;
@@ -2877,6 +2911,10 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension)
sum += pow(fabs(t_data[i*t_stride]), value);
*r__data = pow(sum, 1.0/value);)
}
+
+ if (!keepdim) {
+ THTensor_(squeeze1d)(r_, r_, dimension);
+ }
}
accreal THTensor_(normall)(THTensor *tensor, real value)
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index 86d36e6..a3cf410 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -69,13 +69,13 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha,
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);
TH_API ptrdiff_t THTensor_(numel)(THTensor *t);
-TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
-TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
-TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension);
-TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
-TH_API void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension);
-TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension);
-TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension);
+TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long k, int dimension, int keepdim);
+TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim);
TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);
TH_API void THTensor_(sign)(THTensor *r_, THTensor *t);
@@ -165,10 +165,10 @@ TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t);
TH_API void THTensor_(frac)(THTensor *r_, THTensor *t);
TH_API void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight);
-TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension);
-TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag);
-TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag);
-TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension);
+TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim);
+TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim);
+TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag, int keepdim);
+TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim);
TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, real value, int dimension, real maxnorm);
TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, real value);
TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalue, real maxvalue);