Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuca Antiga <luca.antiga@orobix.com>2017-06-07 03:21:20 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-04 21:55:19 +0300
commite78e9aacd4894e65e3ce3524c517d2a39d12f338 (patch)
treeea6d50891010af1147619e9adae45cb29b23686f
parent6b32a6149564762a6028703f7655c8ef00dd4852 (diff)
Have median reduce over all dims and return just the value when dim is not provided
-rw-r--r--lib/TH/generic/THTensorMath.c78
-rw-r--r--lib/TH/generic/THTensorMath.h1
2 files changed, 79 insertions, 0 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index b955319..215c04c 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -586,6 +586,34 @@ real THTensor_(maxall)(THTensor *tensor)
return theMax;
}
+static void THTensor_(quickselectnoidx)(real *arr, long k, long elements, long stride);
+
+real THTensor_(medianall)(THTensor *tensor)
+{
+ THArgCheck(tensor->nDimension > 0, 1, "tensor must have one dimension");
+ THArgCheck(THTensor_(isContiguous)(tensor), 1, "input is not contiguous");
+
+ real theMedian;
+ ptrdiff_t numel;
+ long k;
+ THTensor *temp_;
+ real *temp__data;
+
+ numel = THTensor_(nElement)(tensor);
+ k = (numel-1) >> 1;
+
+ temp_ = THTensor_(newClone)(tensor);
+ temp__data = THTensor_(data)(temp_);
+
+ THTensor_(quickselectnoidx)(temp__data, k, numel, 1);
+
+ theMedian = temp__data[k];
+
+ THTensor_(free)(temp_);
+
+ return theMedian;
+}
+
accreal THTensor_(sumall)(THTensor *tensor)
{
accreal sum = 0;
@@ -2044,6 +2072,9 @@ void THTensor_(reshape)(THTensor *r_, THTensor *t, THLongStorage *size)
#define LONG_SWAP(AAA, BBB) swap = AAA; AAA = BBB; BBB = swap
#define REAL_SWAP(AAA, BBB) rswap = AAA; AAA = BBB; BBB = rswap
+#define ARR_SWAP(III, JJJ) \
+ REAL_SWAP(ARR(III), ARR(JJJ));
+
#define BOTH_SWAP(III, JJJ) \
REAL_SWAP(ARR(III), ARR(JJJ)); \
LONG_SWAP(IDX(III), IDX(JJJ))
@@ -2263,6 +2294,53 @@ void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimensio
/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's
public domain implementation at http://ndevilla.free.fr/median/median/
+Adapted similarly to the above Quicksort algorithm.
+This version does not produce indices along with values. */
+static void THTensor_(quickselectnoidx)(real *arr, long k, long elements, long stride)
+{
+ long P, L, R, i, j, swap;
+ real rswap, piv;
+ L = 0;
+ R = elements-1;
+
+ do {
+ if (R <= L) /* One element only */
+ return;
+
+ if (R == L+1) { /* Two elements only */
+ if (ARR(L) > ARR(R)) {
+ ARR_SWAP(L, R);
+ }
+ return;
+ }
+
+ /* Use median of three for pivot choice */
+ P=(L+R)>>1;
+ ARR_SWAP(P, L+1);
+ if (ARR(L+1) > ARR(R)) { ARR_SWAP(L+1, R); }
+ if (ARR(L) > ARR(R)) { ARR_SWAP(L, R); }
+ if (ARR(L+1) > ARR(L)) { ARR_SWAP(L+1, L); }
+
+ i = L+1;
+ j = R;
+ piv = ARR(L);
+ do {
+ do i++; while(ARR(i) < piv);
+ do j--; while(ARR(j) > piv);
+ if (j < i)
+ break;
+ ARR_SWAP(i, j);
+ } while(1);
+ ARR_SWAP(L, j);
+
+ /* Re-set active partition */
+ if (j <= k) L=i;
+ if (j >= k) R=j-1;
+ } while(1);
+}
+
+/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's
+public domain implementation at http://ndevilla.free.fr/median/median/
Adapted similarly to the above Quicksort algorithm. */
static void THTensor_(quickselect)(real *arr, long *idx, long k, long elements, long stride)
{
diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h
index 6337533..17e54cc 100644
--- a/lib/TH/generic/THTensorMath.h
+++ b/lib/TH/generic/THTensorMath.h
@@ -25,6 +25,7 @@ TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
TH_API real THTensor_(minall)(THTensor *t);
TH_API real THTensor_(maxall)(THTensor *t);
+TH_API real THTensor_(medianall)(THTensor *t);
TH_API accreal THTensor_(sumall)(THTensor *t);
TH_API accreal THTensor_(prodall)(THTensor *t);